1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3
4# Adapted from https://github.com/pytorch/pytorch/issues/2001#issuecomment-405675488
5from functools import reduce
6
7from torch.nn.modules.module import _addindent
8
9
10def summary(model):
11 """Counts the number of parameters in each model layer
12
13 Parameters
14 ----------
15
16 model : :py:class:`torch.nn.Module`
17 model to summarize
18
19 Returns
20 -------
21
22 repr : str
23 a multiline string representation of the network
24
25 nparam : int
26 number of parameters
27
28 """
29
30 def repr(model):
31
32 # We treat the extra repr like the sub-module, one item per line
33 extra_lines = []
34 extra_repr = model.extra_repr()
35 # empty string will be split into list ['']
36 if extra_repr:
37 extra_lines = extra_repr.split("\n")
38 child_lines = []
39 total_params = 0
40 for key, module in model._modules.items():
41 mod_str, num_params = repr(module)
42 mod_str = _addindent(mod_str, 2)
43 child_lines.append("(" + key + "): " + mod_str)
44 total_params += num_params
45 lines = extra_lines + child_lines
46
47 for name, p in model._parameters.items():
48 if hasattr(p, "dtype"):
49 total_params += reduce(lambda x, y: x * y, p.shape)
50
51 main_str = model._get_name() + "("
52 if lines:
53 # simple one-liner info, which most builtin Modules will use
54 if len(extra_lines) == 1 and not child_lines:
55 main_str += extra_lines[0]
56 else:
57 main_str += "\n " + "\n ".join(lines) + "\n"
58
59 main_str += ")"
60 main_str += ", {:,} params".format(total_params)
61 return main_str, total_params
62
63 return repr(model)