Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1#!/usr/bin/env python 

2# coding=utf-8 

3 

4"""A network model that prefixes a z-normalization step to any other module""" 

5 

6 

7import torch 

8import torch.nn 

9 

10 

11class TorchVisionNormalizer(torch.nn.Module): 

12 """A simple normalizer that applies the standard torchvision normalization 

13 

14 This module does not learn. 

15 

16 The values applied in this "prefix" operator are defined at 

17 https://pytorch.org/docs/stable/torchvision/models.html, and are as 

18 follows: 

19 

20 * ``mean``: ``[0.485, 0.456, 0.406]``, 

21 * ``std``: ``[0.229, 0.224, 0.225]`` 

22 """ 

23 

24 def __init__(self): 

25 super(TorchVisionNormalizer, self).__init__() 

26 mean = torch.as_tensor([0.485, 0.456, 0.406])[None, :, None, None] 

27 std = torch.as_tensor([0.229, 0.224, 0.225])[None, :, None, None] 

28 self.register_buffer("mean", mean) 

29 self.register_buffer("std", std) 

30 self.name = "torchvision-normalizer" 

31 

32 def forward(self, inputs): 

33 return inputs.sub(self.mean).div(self.std)