Coverage for src/deepdraw/models/normalizer.py: 92%

13 statements  

« prev     ^ index     » next       coverage.py v7.3.1, created at 2023-11-30 15:00 +0100

1# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> 

2# 

3# SPDX-License-Identifier: GPL-3.0-or-later 

4 

5"""A network model that prefixes a z-normalization step to any other module.""" 

6 

7 

8import torch 

9import torch.nn 

10 

11 

12class TorchVisionNormalizer(torch.nn.Module): 

13 """A simple normalizer that applies the standard torchvision normalization. 

14 

15 This module does not learn. 

16 

17 The values applied in this "prefix" operator are defined at 

18 https://pytorch.org/docs/stable/torchvision/models.html, and are as 

19 follows: 

20 

21 * ``mean``: ``[0.485, 0.456, 0.406]``, 

22 * ``std``: ``[0.229, 0.224, 0.225]`` 

23 """ 

24 

25 def __init__(self): 

26 super().__init__() 

27 mean = torch.as_tensor([0.485, 0.456, 0.406])[None, :, None, None] 

28 std = torch.as_tensor([0.229, 0.224, 0.225])[None, :, None, None] 

29 self.register_buffer("mean", mean) 

30 self.register_buffer("std", std) 

31 self.name = "torchvision-normalizer" 

32 

33 def forward(self, inputs): 

34 return inputs.sub(self.mean).div(self.std)