Coverage for /scratch/builds/bob/bob.med.tb/miniconda/conda-bld/bob.med.tb_1637571489937/_test_env_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placeho/lib/python3.8/site-packages/bob/med/tb/models/normalizer.py: 100%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

18 statements  

1#!/usr/bin/env python 

2# coding=utf-8 

3 

4"""A network model that prefixes a z-normalization step to any other module""" 

5 

6 

7import torch 

8import torch.nn 

9 

10 

11class TorchVisionNormalizer(torch.nn.Module): 

12 """A simple normalizer that applies the standard torchvision normalization 

13 

14 This module does not learn. 

15 

16 Parameters 

17 ---------- 

18 

19 nb_channels : :py:class:`int`, Optional 

20 Number of images channels fed to the model 

21 """ 

22 

23 def __init__(self, nb_channels=3): 

24 super(TorchVisionNormalizer, self).__init__() 

25 mean = torch.zeros(nb_channels)[None, :, None, None] 

26 std = torch.ones(nb_channels)[None, :, None, None] 

27 self.register_buffer('mean', mean) 

28 self.register_buffer('std', std) 

29 self.name = "torchvision-normalizer" 

30 

31 def set_mean_std(self, mean, std): 

32 mean = torch.as_tensor(mean)[None, :, None, None] 

33 std = torch.as_tensor(std)[None, :, None, None] 

34 self.register_buffer('mean', mean) 

35 self.register_buffer('std', std) 

36 

37 def forward(self, inputs): 

38 return inputs.sub(self.mean).div(self.std)