1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3
4# Adapted from https://github.com/pytorch/pytorch/issues/2001#issuecomment-405675488
5import sys
6from functools import reduce
7
8from torch.nn.modules.module import _addindent
9
10
11def summary(model):
12 """Counts the number of parameters in each model layer
13
14 Parameters
15 ----------
16
17 model : :py:class:`torch.nn.Module`
18 model to summarize
19
20 Returns
21 -------
22
23 repr : str
24 a multiline string representation of the network
25
26 nparam : int
27 number of parameters
28
29 """
30
31 def repr(model):
32
33 # We treat the extra repr like the sub-module, one item per line
34 extra_lines = []
35 extra_repr = model.extra_repr()
36 # empty string will be split into list ['']
37 if extra_repr:
38 extra_lines = extra_repr.split("\n")
39 child_lines = []
40 total_params = 0
41 for key, module in model._modules.items():
42 mod_str, num_params = repr(module)
43 mod_str = _addindent(mod_str, 2)
44 child_lines.append("(" + key + "): " + mod_str)
45 total_params += num_params
46 lines = extra_lines + child_lines
47
48 for name, p in model._parameters.items():
49 if hasattr(p, "dtype"):
50 total_params += reduce(lambda x, y: x * y, p.shape)
51
52 main_str = model._get_name() + "("
53 if lines:
54 # simple one-liner info, which most builtin Modules will use
55 if len(extra_lines) == 1 and not child_lines:
56 main_str += extra_lines[0]
57 else:
58 main_str += "\n " + "\n ".join(lines) + "\n"
59
60 main_str += ")"
61 main_str += ", {:,} params".format(total_params)
62 return main_str, total_params
63
64 return repr(model)