什么是MDN
对于一输入对一的输出的任务,传统的神经网络可以很好地进行预测/回归/分类。然而,对于一对多的问题,传统的网络无法胜任,此时MDN派上用场。
MDN的工作机制
与传统的NN不一样的是,MDN的预测输出是特定分布的参数值,例如指定三个正态分布进行叠加模拟,需要6个\mu,\sigma参数,那么网络输出为6个参数,输入还是常见的输入,得到正常的6个参数之后,采用f(x)=\sum_{i=0}^{3}\frac{1}{\sqrt{2\pi}\sigma_{i}}e^{\frac{(x-\mu_{i})^2}{2\sigma_{i}^2}}进行叠加。最后网络训练的loss采用最大似然损失来衡量,也即要使得我们训练数据的标签y_train在得到的组合分布下出现的概率最大。那么网络在输出的时候需要输出各个分布的概率值才能在后期计算最大似然。
MDN的应用场景
目前在光学领域有论文采用MDN对器件的逆向设计以及优化。此外有使用MDN进行舞蹈图片和手写图片的生成。
A Deep Convolutional Mixture Density Network for Inverse Design of Layered Photonic Structures