Source code for bob.ip.common.utils.summary

#!/usr/bin/env python
# -*- coding: utf-8 -*-

# Adapted from https://github.com/pytorch/pytorch/issues/2001#issuecomment-405675488
from functools import reduce

from torch.nn.modules.module import _addindent


[docs]def summary(model): """Counts the number of parameters in each model layer Parameters ---------- model : :py:class:`torch.nn.Module` model to summarize Returns ------- repr : str a multiline string representation of the network nparam : int number of parameters """ def repr(model): # We treat the extra repr like the sub-module, one item per line extra_lines = [] extra_repr = model.extra_repr() # empty string will be split into list [''] if extra_repr: extra_lines = extra_repr.split("\n") child_lines = [] total_params = 0 for key, module in model._modules.items(): mod_str, num_params = repr(module) mod_str = _addindent(mod_str, 2) child_lines.append("(" + key + "): " + mod_str) total_params += num_params lines = extra_lines + child_lines for name, p in model._parameters.items(): if hasattr(p, "dtype"): total_params += reduce(lambda x, y: x * y, p.shape) main_str = model._get_name() + "(" if lines: # simple one-liner info, which most builtin Modules will use if len(extra_lines) == 1 and not child_lines: main_str += extra_lines[0] else: main_str += "\n " + "\n ".join(lines) + "\n" main_str += ")" main_str += ", {:,} params".format(total_params) return main_str, total_params return repr(model)