import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
def test_batchnorm():
x = torch.randn(10, 20, 30)
C = 20
eps = 1e-5
weight = torch.ones(C)
bias = torch.zeros(C)
running_mean = torch.zeros(C)
running_var = torch.ones(C)
batch_normed = F.batch_norm(
x, running_mean, running_var, weight, bias, training=True, momentum=0.1, eps=eps
)
# Manual batch normalization
# x shape: (b, c, t) - normalize over b and t for each c
mean = x.mean(dim=(0, 2), keepdim=True) # (1, c, 1)
var = x.var(dim=(0, 2), keepdim=True, unbiased=False) # (1, c, 1)
batch_normed_manual = rearrange(weight, 'c -> 1 c 1') * (x - mean) / torch.sqrt(var + eps) + rearrange(bias, 'c -> 1 c 1')
assert torch.allclose(batch_normed, batch_normed_manual, atol=1e-5), "Batch normalization test failed!"
def test_layernorm():
x = torch.randn(10, 20, 30)
eps = 1e-5
weight = torch.ones(20, 30)
bias = torch.zeros(20, 30)
layer_normed = F.layer_norm(x, [20, 30], weight=None, bias=None, eps=eps)
# Manual layer normalization
# x shape: (b, c, t) - normalize over c and t for each b
mean = x.mean(dim=(1, 2), keepdim=True) # (b, 1, 1)
var = x.var(dim=(1, 2), keepdim=True, unbiased=False) # (b, 1, 1)
layer_normed_manual = rearrange(weight, 'c t -> 1 c t') * (x - mean) / torch.sqrt(var + eps) + rearrange(bias, 'c t -> 1 c t')
assert torch.allclose(layer_normed, layer_normed_manual, atol=1e-5), "Layer normalization test failed!"
def test_instancenorm():
x = torch.randn(10, 20, 30)
eps = 1e-5
weight = torch.ones(20)
bias = torch.zeros(20)
instance_normed = F.instance_norm(x, weight=weight, bias=bias, eps=eps)
# Manual instance normalization
# x shape: (b, c, t) - normalize over t for each (b, c)
mean = x.mean(dim=2, keepdim=True) # (b, c, 1)
var = x.var(dim=2, keepdim=True, unbiased=False) # (b, c, 1)
instance_normed_manual = rearrange(weight, 'c -> 1 c 1') * (x - mean) / torch.sqrt(var + eps) + rearrange(bias, 'c -> 1 c 1')
assert torch.allclose(instance_normed, instance_normed_manual, atol=1e-5), "Instance normalization test failed!"
if __name__ == '__main__':
test_batchnorm()
test_layernorm()
test_instancenorm()
print("All normalization tests passed!")