Coverage for src/deepdraw/utils/summary.py: 100%

28 statements  

« prev     ^ index     » next       coverage.py v7.3.1, created at 2023-11-30 15:00 +0100

1# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> 

2# 

3# SPDX-License-Identifier: GPL-3.0-or-later 

4 

5# Adapted from https://github.com/pytorch/pytorch/issues/2001#issuecomment-405675488 

6from functools import reduce 

7 

8from torch.nn.modules.module import _addindent 

9 

10 

11def summary(model): 

12 """Counts the number of parameters in each model layer. 

13 

14 Parameters 

15 ---------- 

16 

17 model : :py:class:`torch.nn.Module` 

18 model to summarize 

19 

20 Returns 

21 ------- 

22 

23 repr : str 

24 a multiline string representation of the network 

25 

26 nparam : int 

27 number of parameters 

28 """ 

29 

30 def repr(model): 

31 # We treat the extra repr like the sub-module, one item per line 

32 extra_lines = [] 

33 extra_repr = model.extra_repr() 

34 # empty string will be split into list [''] 

35 if extra_repr: 

36 extra_lines = extra_repr.split("\n") 

37 child_lines = [] 

38 total_params = 0 

39 for key, module in model._modules.items(): 

40 mod_str, num_params = repr(module) 

41 mod_str = _addindent(mod_str, 2) 

42 child_lines.append("(" + key + "): " + mod_str) 

43 total_params += num_params 

44 lines = extra_lines + child_lines 

45 

46 for name, p in model._parameters.items(): 

47 if hasattr(p, "dtype"): 

48 total_params += reduce(lambda x, y: x * y, p.shape) 

49 

50 main_str = model._get_name() + "(" 

51 if lines: 

52 # simple one-liner info, which most builtin Modules will use 

53 if len(extra_lines) == 1 and not child_lines: 

54 main_str += extra_lines[0] 

55 else: 

56 main_str += "\n " + "\n ".join(lines) + "\n" 

57 

58 main_str += ")" 

59 main_str += f", {total_params:,} params" 

60 return main_str, total_params 

61 

62 return repr(model)