博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
深度学习----Xavier初始化方法
阅读量:5092 次
发布时间:2019-06-13

本文共 1788 字,大约阅读时间需要 5 分钟。

“Xavier”初始化方法是一种很有效的神经网络初始化方法,方法来源于2010年的一篇论文,可惜直到近两年,这个方法才逐渐得到更多人的应用和认可。

为了使得网络中信息更好的流动,每一层输出的方差应该尽量相等。

基于这个目标,现在我们就去推导一下:每一层的权重应该满足哪种条件。

文章先假设的是线性激活函数,而且满足0点处导数为1,即 

这里写图片描述

现在我们先来分析一层卷积: 

这里写图片描述 
其中ni表示输入个数。

根据概率统计知识我们有下面的方差公式: 

这里写图片描述

特别的,当我们假设输入和权重都是0均值时(目前有了BN之后,这一点也较容易满足),上式可以简化为: 

这里写图片描述

进一步假设输入x和权重w独立同分布,则有: 

这里写图片描述

于是,为了保证输入与输出方差一致,则应该有: 

这里写图片描述

对于一个多层的网络,某一层的方差可以用累积的形式表达: 

这里写图片描述

特别的,反向传播计算梯度时同样具有类似的形式: 

这里写图片描述

综上,为了保证前向传播和反向传播时每一层的方差一致,应满足:

这里写图片描述

但是,实际当中输入与输出的个数往往不相等,于是为了均衡考量,最终我们的权重方差应满足

——————————————————————————————————————— 

这里写图片描述 
———————————————————————————————————————

学过概率统计的都知道 [a,b] 间的均匀分布的方差为: 

这里写图片描述

因此,Xavier初始化的实现就是下面的均匀分布:

—————————————————————————————————————————— 

这里写图片描述 
———————————————————————————————————————————

下面,我们来看一下caffe中具体是怎样实现的,代码位于include/caffe/filler.hpp文件中。

template 
class XavierFiller : public Filler
{ public: explicit XavierFiller(const FillerParameter& param) : Filler
(param) {} virtual void Fill(Blob
* blob) { CHECK(blob->count()); int fan_in = blob->count() / blob->num(); int fan_out = blob->count() / blob->channels(); Dtype n = fan_in; // default to fan_in if (this->filler_param_.variance_norm() == FillerParameter_VarianceNorm_AVERAGE) { n = (fan_in + fan_out) / Dtype(2); } else if (this->filler_param_.variance_norm() == FillerParameter_VarianceNorm_FAN_OUT) { n = fan_out; } Dtype scale = sqrt(Dtype(3) / n); caffe_rng_uniform
(blob->count(), -scale, scale, blob->mutable_cpu_data()); CHECK_EQ(this->filler_param_.sparse(), -1) << "Sparsity not supported by this Filler."; } };

 

由上面可以看出,caffe的Xavier实现有三种选择

(1) 默认情况,方差只考虑输入个数: 

这里写图片描述

(2) FillerParameter_VarianceNorm_FAN_OUT,方差只考虑输出个数: 

这里写图片描述

(3) FillerParameter_VarianceNorm_AVERAGE,方差同时考虑输入和输出个数: 

这里写图片描述

之所以默认只考虑输入,我个人觉得是因为前向信息的传播更重要一些

转载于:https://www.cnblogs.com/guohaoyu110/p/7487290.html

你可能感兴趣的文章
第九章 前后查找
查看>>
Python学习资料
查看>>
jQuery 自定义函数
查看>>
jquery datagrid 后台获取datatable处理成正确的json字符串
查看>>
ActiveMQ与spring整合
查看>>
web服务器
查看>>
第一阶段冲刺06
查看>>
EOS生产区块:解析插件producer_plugin
查看>>
JS取得绝对路径
查看>>
排球积分程序(三)——模型类的设计
查看>>
HDU 4635 Strongly connected
查看>>
格式化输出数字和时间
查看>>
页面中公用的全选按钮,单选按钮组件的编写
查看>>
java笔记--用ThreadLocal管理线程,Callable<V>接口实现有返回值的线程
查看>>
(旧笔记搬家)struts.xml中单独页面跳转的配置
查看>>
不定期周末福利:数据结构与算法学习书单
查看>>
strlen函数
查看>>
关于TFS2010使用常见问题
查看>>
软件工程团队作业3
查看>>
火狐、谷歌、IE关于document.body.scrollTop和document.documentElement.scrollTop 以及值为0的问题...
查看>>