一、BN层作用

批量归一化(Batch Normalization,BN)在深度学习中常放在卷积层之后,BN层有以下优点:

  • 减少了人为选择参数。在某些情况下可以取消 dropout 和 L2 正则项参数,或者采取更小的 L2 正则项约束参数;
  • 减少了对学习率的要求。现在我们可以使用初始很大的学习率或者选择了较小的学习率,算法也能够快速训练收敛;
  • 可以不再使用局部响应归一化。BN 本身就是归一化网络(局部响应归一化在 AlexNet 网络中存在);
  • 破坏原来的数据分布,一定程度上缓解过拟合(防止每批训练中某一个样本经常被挑选到,文献说这个可以提高 1% 的精度);
  • 减少梯度消失,加快收敛速度,提高训练精度

二、 BN层算法流程

下面给出的是 BN 算法在训练时的过程

输入:上一层输出结果 $ X = {x_1, x_2, …, x_m} $,学习参数 $ \gamma, \beta $。

算法流程

  1. 计算上一层输出数据的均值
\[\mu_{\beta} = \frac{1}{m} \sum_{i=1}^m(x_i) \]

其中,$ m $ 是此次训练样本 batch 的大小。

  1. 计算上一层输出数据的标准差
\[\sigma_{\beta}^2 = \frac{1}{m} \sum_{i=1}^m (x_i – \mu_{\beta})^2 \]

  1. 归一化处理,得到
\[\hat x_i = \frac{x_i – \mu_{\beta}}{\sqrt{\sigma_{\beta}^2} + \epsilon} \]

其中 $ \epsilon $ 是为了避免分母为 0 而加进去的接近于 0 的很小值

  1. 重构,对经过上面归一化处理得到的数据进行重构,得到
\[y_i = \gamma \hat x_i + \beta \]

其中,$ \gamma, \beta $ 为可学习参数。

注:上述是 BN 训练时的过程,但是当在测试阶段时,往往只是输入一个样本,没有所谓的均值 $ \mu_{\beta} $ 和标准差 $ \sigma_{\beta}^2 $。此时,均值 $ \mu_{\beta} $ 是计算所有 batch $ \mu_{\beta} $ 值的平均值得到,标准差 $ \sigma_{\beta}^2 $ 采用每个batch $ \sigma_{\beta}^2 $ 的无偏估计得到

三、推理阶段合并BN和conv的原理

如果BN层在卷积层Conv之后,那卷积和BN层可以合并成如下式子。
卷积层

\[Z = W X + B \]

BN层

\[Y = \frac{Z – \mu_{\beta}}{\sqrt{\sigma_{\beta}^2} + \epsilon} \gamma + \beta \]

合并上面两个式子可得:

\[Y = \frac{W\gamma}{\sqrt{\sigma_{\beta}^2} + \epsilon} X + (\frac{B – \mu_{\beta} }{\sqrt{\sigma_{\beta}^2} + \epsilon} \gamma + \beta) \]

\[W^{‘} = \frac{W\gamma}{\sqrt{\sigma_{\beta}^2} + \epsilon} \]

\[B^{‘} = \frac{B – \mu_{\beta} }{\sqrt{\sigma_{\beta}^2} + \epsilon} \gamma + \beta \]

可得

\[Y = W^{‘} X + B^{‘} \]

因此只需要更新卷积层的权值和偏置就可以达到合并卷积和BN层的效果。

三、code

import torch
import torch.nn as nn
import torchvision as tv


class DummyModule(nn.Module):
    def __init__(self):
        super(DummyModule, self).__init__()

    def forward(self, x):
        # print("Dummy, Dummy.")
        return x


def fuse(conv, bn):
    w = conv.weight
    mean = bn.running_mean
    var_sqrt = torch.sqrt(bn.running_var + bn.eps)

    beta = bn.weight
    gamma = bn.bias

    if conv.bias is not None:
        b = conv.bias
    else:
        b = mean.new_zeros(mean.shape)

    w = w * (beta / var_sqrt).reshape([conv.out_channels, 1, 1, 1])
    b = (b - mean)/var_sqrt * beta + gamma
    fused_conv = nn.Conv2d(conv.in_channels,
                         conv.out_channels,
                         conv.kernel_size,
                         conv.stride,
                         conv.padding,
                         bias=True)
    fused_conv.weight = nn.Parameter(w)
    fused_conv.bias = nn.Parameter(b)
    return fused_conv


def fuse_conv_and_bn(conv, bn):
    # init
    fused_conv = torch.nn.Conv2d(
        conv.in_channels,
        conv.out_channels,
        kernel_size=conv.kernel_size,
        stride=conv.stride,
        padding=conv.padding,
        bias=True
    )
    # # prepare filters
    w_conv = conv.weight.clone().view(conv.out_channels, -1)
    w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps+bn.running_var)))
    fused_conv.weight = nn.Parameter(torch.mm(w_bn, w_conv).view(fused_conv.weight.size()))
    # # prepare spatial bias
    if conv.bias is not None:
        b_conv = conv.bias
    else:
        b_conv = torch.zeros(conv.weight.size(0))
    b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
    fused_conv.bias = nn.Parameter(torch.matmul(w_bn, b_conv) + b_bn)

    return fused_conv


def fuse_module(m):
    children = list(m.named_children())
    print("***********")
    print(children)
    print("***********")
    c = None
    cn = None

    for name, child in children:
        if isinstance(child, nn.BatchNorm2d):
            # bc = fuse(c, child)
            bc = fuse_conv_and_bn(c, child)
            m._modules[cn] = bc
            m._modules[name] = DummyModule()
            print("==> name: ", name)
            c = None
        elif isinstance(child, nn.Conv2d):
            c = child
            cn = name
        else:
            fuse_module(child)


def test_net(m):
    p = torch.randn([1, 3, 224, 224])
    import time
    s = time.time()
    o_output = m(p)
    print("Original time: ", time.time() - s)

    fuse_module(m)

    s = time.time()
    f_output = m(p)
    print("Fused time: ", time.time() - s)

    print("Max abs diff: ", (o_output - f_output).abs().max().item())
    assert(o_output.argmax() == f_output.argmax())
    # print(o_output[0][0].item(), f_output[0][0].item())
    print("MSE diff: ", nn.MSELoss()(o_output, f_output).item())


def test_layer():
    p = torch.randn([1, 3, 112, 112])
    conv1 = m.conv1
    bn1 = m.bn1
    o_output = bn1(conv1(p))
    fusion = fuse(conv1, bn1)
    f_output = fusion(p)
    print(o_output[0][0][0][0].item())
    print(f_output[0][0][0][0].item())
    print("Max abs diff: ", (o_output - f_output).abs().max().item())
    print("MSE diff: ", nn.MSELoss()(o_output, f_output).item())


if __name__ == "__main__":

    m = tv.models.resnet18(True)
    m.eval()
    print("Layer level test: ")
    test_layer()

    print("============================")
    print("Module level test: ")
    test_net(m)

参考链接

https://blog.csdn.net/wfei101/article/details/78635557

https://zhuanlan.zhihu.com/p/49329030

https://pytorch.org/tutorials/intermediate/custom_function_conv_bn_tutorial.html?highlight=batchnorm

https://pytorch.org/tutorials/intermediate/fx_conv_bn_fuser.html?highlight=batchnorm

原文地址:http://www.cnblogs.com/xiaxuexiaoab/p/16422640.html

1. 本站所有资源来源于用户上传和网络,如有侵权请邮件联系站长! 2. 分享目的仅供大家学习和交流,请务用于商业用途! 3. 如果你也有好源码或者教程,可以到用户中心发布,分享有积分奖励和额外收入! 4. 本站提供的源码、模板、插件等等其他资源,都不包含技术服务请大家谅解! 5. 如有链接无法下载、失效或广告,请联系管理员处理! 6. 本站资源售价只是赞助,收取费用仅维持本站的日常运营所需! 7. 如遇到加密压缩包,默认解压密码为"gltf",如遇到无法解压的请联系管理员! 8. 因为资源和程序源码均为可复制品,所以不支持任何理由的退款兑现,请斟酌后支付下载 声明:如果标题没有注明"已测试"或者"测试可用"等字样的资源源码均未经过站长测试.特别注意没有标注的源码不保证任何可用性