Coverage for /scratch/builds/bob/bob.med.tb/miniconda/conda-bld/bob.med.tb_1637571489937/_test_env_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placeho/lib/python3.8/site-packages/bob/med/tb/utils/summary.py: 100%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

29 statements  

1#!/usr/bin/env python 

2# -*- coding: utf-8 -*- 

3 

4# Adapted from https://github.com/pytorch/pytorch/issues/2001#issuecomment-405675488 

5import sys 

6from functools import reduce 

7 

8from torch.nn.modules.module import _addindent 

9 

10 

11def summary(model): 

12 """Counts the number of parameters in each model layer 

13 

14 Parameters 

15 ---------- 

16 

17 model : :py:class:`torch.nn.Module` 

18 model to summarize 

19 

20 Returns 

21 ------- 

22 

23 repr : str 

24 a multiline string representation of the network 

25 

26 nparam : int 

27 number of parameters 

28 

29 """ 

30 

31 def repr(model): 

32 

33 # We treat the extra repr like the sub-module, one item per line 

34 extra_lines = [] 

35 extra_repr = model.extra_repr() 

36 # empty string will be split into list [''] 

37 if extra_repr: 

38 extra_lines = extra_repr.split("\n") 

39 child_lines = [] 

40 total_params = 0 

41 for key, module in model._modules.items(): 

42 mod_str, num_params = repr(module) 

43 mod_str = _addindent(mod_str, 2) 

44 child_lines.append("(" + key + "): " + mod_str) 

45 total_params += num_params 

46 lines = extra_lines + child_lines 

47 

48 for name, p in model._parameters.items(): 

49 if hasattr(p, "dtype"): 

50 total_params += reduce(lambda x, y: x * y, p.shape) 

51 

52 main_str = model._get_name() + "(" 

53 if lines: 

54 # simple one-liner info, which most builtin Modules will use 

55 if len(extra_lines) == 1 and not child_lines: 

56 main_str += extra_lines[0] 

57 else: 

58 main_str += "\n " + "\n ".join(lines) + "\n" 

59 

60 main_str += ")" 

61 main_str += ", {:,} params".format(total_params) 

62 return main_str, total_params 

63 

64 return repr(model)