Coverage for src/deepdraw/models/normalizer.py: 92%
13 statements
« prev ^ index » next coverage.py v7.3.1, created at 2023-11-30 15:00 +0100
« prev ^ index » next coverage.py v7.3.1, created at 2023-11-30 15:00 +0100
1# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
2#
3# SPDX-License-Identifier: GPL-3.0-or-later
5"""A network model that prefixes a z-normalization step to any other module."""
8import torch
9import torch.nn
12class TorchVisionNormalizer(torch.nn.Module):
13 """A simple normalizer that applies the standard torchvision normalization.
15 This module does not learn.
17 The values applied in this "prefix" operator are defined at
18 https://pytorch.org/docs/stable/torchvision/models.html, and are as
19 follows:
21 * ``mean``: ``[0.485, 0.456, 0.406]``,
22 * ``std``: ``[0.229, 0.224, 0.225]``
23 """
25 def __init__(self):
26 super().__init__()
27 mean = torch.as_tensor([0.485, 0.456, 0.406])[None, :, None, None]
28 std = torch.as_tensor([0.229, 0.224, 0.225])[None, :, None, None]
29 self.register_buffer("mean", mean)
30 self.register_buffer("std", std)
31 self.name = "torchvision-normalizer"
33 def forward(self, inputs):
34 return inputs.sub(self.mean).div(self.std)