1#!/usr/bin/env python
2# coding=utf-8
3
4"""A network model that prefixes a z-normalization step to any other module"""
5
6
7import torch
8import torch.nn
9
10
11class TorchVisionNormalizer(torch.nn.Module):
12 """A simple normalizer that applies the standard torchvision normalization
13
14 This module does not learn.
15
16 Parameters
17 ----------
18
19 nb_channels : :py:class:`int`, Optional
20 Number of images channels fed to the model
21 """
22
23 def __init__(self, nb_channels=3):
24 super(TorchVisionNormalizer, self).__init__()
25 mean = torch.zeros(nb_channels)[None, :, None, None]
26 std = torch.ones(nb_channels)[None, :, None, None]
27 self.register_buffer('mean', mean)
28 self.register_buffer('std', std)
29 self.name = "torchvision-normalizer"
30
31 def set_mean_std(self, mean, std):
32 mean = torch.as_tensor(mean)[None, :, None, None]
33 std = torch.as_tensor(std)[None, :, None, None]
34 self.register_buffer('mean', mean)
35 self.register_buffer('std', std)
36
37 def forward(self, inputs):
38 return inputs.sub(self.mean).div(self.std)