Coverage for /scratch/builds/bob/bob.ip.binseg/miniconda/conda-bld/bob.ip.binseg_1673966692152/_test_env_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_p/lib/python3.10/site-packages/bob/ip/common/utils/summary.py: 100%

28 statements  

« prev     ^ index     » next       coverage.py v7.0.5, created at 2023-01-17 15:03 +0000

1#!/usr/bin/env python 

2# -*- coding: utf-8 -*- 

3 

4# Adapted from https://github.com/pytorch/pytorch/issues/2001#issuecomment-405675488 

5from functools import reduce 

6 

7from torch.nn.modules.module import _addindent 

8 

9 

10def summary(model): 

11 """Counts the number of parameters in each model layer 

12 

13 Parameters 

14 ---------- 

15 

16 model : :py:class:`torch.nn.Module` 

17 model to summarize 

18 

19 Returns 

20 ------- 

21 

22 repr : str 

23 a multiline string representation of the network 

24 

25 nparam : int 

26 number of parameters 

27 

28 """ 

29 

30 def repr(model): 

31 

32 # We treat the extra repr like the sub-module, one item per line 

33 extra_lines = [] 

34 extra_repr = model.extra_repr() 

35 # empty string will be split into list [''] 

36 if extra_repr: 

37 extra_lines = extra_repr.split("\n") 

38 child_lines = [] 

39 total_params = 0 

40 for key, module in model._modules.items(): 

41 mod_str, num_params = repr(module) 

42 mod_str = _addindent(mod_str, 2) 

43 child_lines.append("(" + key + "): " + mod_str) 

44 total_params += num_params 

45 lines = extra_lines + child_lines 

46 

47 for name, p in model._parameters.items(): 

48 if hasattr(p, "dtype"): 

49 total_params += reduce(lambda x, y: x * y, p.shape) 

50 

51 main_str = model._get_name() + "(" 

52 if lines: 

53 # simple one-liner info, which most builtin Modules will use 

54 if len(extra_lines) == 1 and not child_lines: 

55 main_str += extra_lines[0] 

56 else: 

57 main_str += "\n " + "\n ".join(lines) + "\n" 

58 

59 main_str += ")" 

60 main_str += ", {:,} params".format(total_params) 

61 return main_str, total_params 

62 

63 return repr(model)