Coverage for src/deepdraw/utils/summary.py: 100%
28 statements
« prev ^ index » next coverage.py v7.3.1, created at 2023-11-30 15:00 +0100
« prev ^ index » next coverage.py v7.3.1, created at 2023-11-30 15:00 +0100
1# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
2#
3# SPDX-License-Identifier: GPL-3.0-or-later
5# Adapted from https://github.com/pytorch/pytorch/issues/2001#issuecomment-405675488
6from functools import reduce
8from torch.nn.modules.module import _addindent
11def summary(model):
12 """Counts the number of parameters in each model layer.
14 Parameters
15 ----------
17 model : :py:class:`torch.nn.Module`
18 model to summarize
20 Returns
21 -------
23 repr : str
24 a multiline string representation of the network
26 nparam : int
27 number of parameters
28 """
30 def repr(model):
31 # We treat the extra repr like the sub-module, one item per line
32 extra_lines = []
33 extra_repr = model.extra_repr()
34 # empty string will be split into list ['']
35 if extra_repr:
36 extra_lines = extra_repr.split("\n")
37 child_lines = []
38 total_params = 0
39 for key, module in model._modules.items():
40 mod_str, num_params = repr(module)
41 mod_str = _addindent(mod_str, 2)
42 child_lines.append("(" + key + "): " + mod_str)
43 total_params += num_params
44 lines = extra_lines + child_lines
46 for name, p in model._parameters.items():
47 if hasattr(p, "dtype"):
48 total_params += reduce(lambda x, y: x * y, p.shape)
50 main_str = model._get_name() + "("
51 if lines:
52 # simple one-liner info, which most builtin Modules will use
53 if len(extra_lines) == 1 and not child_lines:
54 main_str += extra_lines[0]
55 else:
56 main_str += "\n " + "\n ".join(lines) + "\n"
58 main_str += ")"
59 main_str += f", {total_params:,} params"
60 return main_str, total_params
62 return repr(model)