Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1#!/usr/bin/env python
2# coding=utf-8
4"""A network model that prefixes a z-normalization step to any other module"""
7import torch
8import torch.nn
11class TorchVisionNormalizer(torch.nn.Module):
12 """A simple normalizer that applies the standard torchvision normalization
14 This module does not learn.
16 The values applied in this "prefix" operator are defined at
17 https://pytorch.org/docs/stable/torchvision/models.html, and are as
18 follows:
20 * ``mean``: ``[0.485, 0.456, 0.406]``,
21 * ``std``: ``[0.229, 0.224, 0.225]``
22 """
24 def __init__(self):
25 super(TorchVisionNormalizer, self).__init__()
26 mean = torch.as_tensor([0.485, 0.456, 0.406])[None, :, None, None]
27 std = torch.as_tensor([0.229, 0.224, 0.225])[None, :, None, None]
28 self.register_buffer("mean", mean)
29 self.register_buffer("std", std)
30 self.name = "torchvision-normalizer"
32 def forward(self, inputs):
33 return inputs.sub(self.mean).div(self.std)