叶子🍃
叶子🍃
发布于 2026-05-19 / 11 阅读
0

PyTorch 各种 Normalization

import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange

def test_batchnorm():
    x = torch.randn(10, 20, 30)
    C = 20
    eps = 1e-5
    weight = torch.ones(C)
    bias = torch.zeros(C)
    running_mean = torch.zeros(C)
    running_var = torch.ones(C)

    batch_normed = F.batch_norm(
        x, running_mean, running_var, weight, bias, training=True, momentum=0.1, eps=eps
    )

    # Manual batch normalization
    # x shape: (b, c, t) - normalize over b and t for each c
    mean = x.mean(dim=(0, 2), keepdim=True)  # (1, c, 1)
    var = x.var(dim=(0, 2), keepdim=True, unbiased=False)  # (1, c, 1)
    batch_normed_manual = rearrange(weight, 'c -> 1 c 1') * (x - mean) / torch.sqrt(var + eps) + rearrange(bias, 'c -> 1 c 1')

    assert torch.allclose(batch_normed, batch_normed_manual, atol=1e-5), "Batch normalization test failed!"

def test_layernorm():
    x = torch.randn(10, 20, 30)
    eps = 1e-5
    weight = torch.ones(20, 30)
    bias = torch.zeros(20, 30)

    layer_normed = F.layer_norm(x, [20, 30], weight=None, bias=None, eps=eps)

    # Manual layer normalization
    # x shape: (b, c, t) - normalize over c and t for each b
    mean = x.mean(dim=(1, 2), keepdim=True)  # (b, 1, 1)
    var = x.var(dim=(1, 2), keepdim=True, unbiased=False)  # (b, 1, 1)
    layer_normed_manual = rearrange(weight, 'c t -> 1 c t') * (x - mean) / torch.sqrt(var + eps) + rearrange(bias, 'c t -> 1 c t')

    assert torch.allclose(layer_normed, layer_normed_manual, atol=1e-5), "Layer normalization test failed!"

def test_instancenorm():
    x = torch.randn(10, 20, 30)
    eps = 1e-5
    weight = torch.ones(20)
    bias = torch.zeros(20)

    instance_normed = F.instance_norm(x, weight=weight, bias=bias, eps=eps)

    # Manual instance normalization
    # x shape: (b, c, t) - normalize over t for each (b, c)
    mean = x.mean(dim=2, keepdim=True)  # (b, c, 1)
    var = x.var(dim=2, keepdim=True, unbiased=False)  # (b, c, 1)
    instance_normed_manual = rearrange(weight, 'c -> 1 c 1') * (x - mean) / torch.sqrt(var + eps) + rearrange(bias, 'c -> 1 c 1')

    assert torch.allclose(instance_normed, instance_normed_manual, atol=1e-5), "Instance normalization test failed!"

if __name__ == '__main__':
    test_batchnorm()
    test_layernorm()
    test_instancenorm()
    print("All normalization tests passed!")