Coverage for /scratch/builds/bob/bob.med.tb/miniconda/conda-bld/bob.med.tb_1674079587905/_test_env_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placeho/lib/python3.10/site-packages/bob/med/tb/scripts/train_analysis.py: 26%

68 statements  

« prev     ^ index     » next       coverage.py v7.0.5, created at 2023-01-18 22:14 +0000

1#!/usr/bin/env python 

2# coding=utf-8 

3 

4import logging 

5import os 

6 

7import click 

8import matplotlib.pyplot as plt 

9import numpy 

10import pandas 

11 

12from matplotlib.backends.backend_pdf import PdfPages 

13 

14logger = logging.getLogger(__name__) 

15 

16 

17from bob.extension.scripts.click_helper import ( 

18 ConfigCommand, 

19 ResourceOption, 

20 verbosity_option, 

21) 

22 

23 

24def _loss_evolution(df): 

25 """Plots the loss evolution over time (epochs) 

26 

27 Parameters 

28 ---------- 

29 

30 df : pandas.DataFrame 

31 dataframe containing the training logs 

32 

33 

34 Returns 

35 ------- 

36 

37 figure : matplotlib.figure.Figure 

38 figure to be displayed or saved to file 

39 

40 """ 

41 

42 figure = plt.figure() 

43 axes = figure.gca() 

44 

45 axes.plot(df.epoch.values, df.loss.values, label="Training") 

46 if "validation_loss" in df.columns: 

47 axes.plot( 

48 df.epoch.values, df.validation_loss.values, label="Validation" 

49 ) 

50 # shows a red dot on the location with the minima on the validation set 

51 lowest_index = numpy.argmin(df["validation_loss"]) 

52 

53 axes.plot( 

54 df.epoch.values[lowest_index], 

55 df.validation_loss[lowest_index], 

56 "mo", 

57 label=f"Lowest validation ({df.validation_loss[lowest_index]:.3f}@{df.epoch[lowest_index]})", 

58 ) 

59 

60 if "extra_validation_losses" in df.columns: 

61 # These losses are in array format. So, we read all rows, then create a 

62 # 2d array. We transpose the array to iterate over each column and 

63 # plot the losses individually. They are numbered from 1. 

64 df["extra_validation_losses"] = df["extra_validation_losses"].apply( 

65 lambda x: numpy.fromstring(x.strip("[]"), sep=" ") 

66 ) 

67 losses = numpy.vstack(df.extra_validation_losses.values).T 

68 for n, k in enumerate(losses): 

69 axes.plot(df.epoch.values, k, label=f"Extra validation {n+1}") 

70 

71 axes.set_title("Loss over time") 

72 axes.set_xlabel("Epoch") 

73 axes.set_ylabel("Loss") 

74 

75 axes.legend(loc="best") 

76 axes.grid(alpha=0.3) 

77 figure.set_layout_engine("tight") 

78 

79 return figure 

80 

81 

82def _hardware_utilisation(df, const): 

83 """Plot the CPU utilisation over time (epochs). 

84 

85 Parameters 

86 ---------- 

87 

88 df : pandas.DataFrame 

89 dataframe containing the training logs 

90 

91 const : dict 

92 training and hardware constants 

93 

94 

95 Returns 

96 ------- 

97 

98 figure : matplotlib.figure.Figure 

99 figure to be displayed or saved to file 

100 

101 """ 

102 figure = plt.figure() 

103 axes = figure.gca() 

104 

105 cpu_percent = df.cpu_percent.values / const["cpu_count"] 

106 cpu_memory = 100 * df.cpu_rss / const["cpu_memory_total"] 

107 

108 axes.plot( 

109 df.epoch.values, 

110 cpu_percent, 

111 label=f"CPU usage (cores: {const['cpu_count']})", 

112 ) 

113 axes.plot( 

114 df.epoch.values, 

115 cpu_memory, 

116 label=f"CPU memory (total: {const['cpu_memory_total']:.1f} Gb)", 

117 ) 

118 if "gpu_percent" in df: 

119 axes.plot( 

120 df.epoch.values, 

121 df.gpu_percent.values, 

122 label=f"GPU usage (type: {const['gpu_name']})", 

123 ) 

124 if "gpu_memory_percent" in df: 

125 axes.plot( 

126 df.epoch.values, 

127 df.gpu_memory_percent.values, 

128 label=f"GPU memory (total: {const['gpu_memory_total']:.1f} Gb)", 

129 ) 

130 axes.set_title("Hardware utilisation over time") 

131 axes.set_xlabel("Epoch") 

132 axes.set_ylabel("Relative utilisation (%)") 

133 axes.set_ylim([0, 100]) 

134 

135 axes.legend(loc="best") 

136 axes.grid(alpha=0.3) 

137 figure.set_layout_engine("tight") 

138 

139 return figure 

140 

141 

142def base_analysis(log, constants, output_pdf, verbose, **kwargs): 

143 """Create base train_analysis function.""" 

144 

145 

146@click.command( 

147 entry_point_group="bob.med.tb.config", 

148 cls=ConfigCommand, 

149 epilog="""Examples: 

150 

151\b 

152 1. Analyzes a training log and produces various plots: 

153 

154 $ bob binseg train-analysis -vv log.csv constants.csv 

155 

156""", 

157) 

158@click.argument( 

159 "log", 

160 type=click.Path(dir_okay=False, exists=True), 

161) 

162@click.argument( 

163 "constants", 

164 type=click.Path(dir_okay=False, exists=True), 

165) 

166@click.option( 

167 "--output-pdf", 

168 "-o", 

169 help="Name of the output file to dump", 

170 required=True, 

171 show_default=True, 

172 default="trainlog.pdf", 

173) 

174@verbosity_option(cls=ResourceOption) 

175def train_analysis( 

176 log, 

177 constants, 

178 output_pdf, 

179 verbose, 

180 **kwargs, 

181): 

182 """Analyze the training logs for loss evolution and resource utilisation.""" 

183 

184 constants = pandas.read_csv(constants) 

185 constants = dict(zip(constants.keys(), constants.values[0])) 

186 data = pandas.read_csv(log) 

187 

188 # makes sure the directory to save the output PDF is there 

189 dirname = os.path.dirname(os.path.realpath(output_pdf)) 

190 if not os.path.exists(dirname): 

191 os.makedirs(dirname) 

192 

193 # now, do the analysis 

194 with PdfPages(output_pdf) as pdf: 

195 

196 figure = _loss_evolution(data) 

197 pdf.savefig(figure) 

198 plt.close(figure) 

199 

200 figure = _hardware_utilisation(data, constants) 

201 pdf.savefig(figure) 

202 plt.close(figure)