1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3
4import os
5import sys
6import csv
7import time
8import shutil
9import datetime
10import contextlib
11import distutils.version
12
13import torch
14from tqdm import tqdm
15
16from ..utils.measure import SmoothedValue
17from ..utils.summary import summary
18from ..utils.resources import cpu_constants, gpu_constants, cpu_log, gpu_log
19
20import logging
21
22logger = logging.getLogger(__name__)
23
24PYTORCH_GE_110 = distutils.version.StrictVersion(torch.__version__) >= "1.1.0"
25
26
27@contextlib.contextmanager
28def torch_evaluation(model):
29 """Context manager to turn ON/OFF model evaluation
30
31 This context manager will turn evaluation mode ON on entry and turn it OFF
32 when exiting the ``with`` statement block.
33
34
35 Parameters
36 ----------
37
38 model : :py:class:`torch.nn.Module`
39 Network (e.g. driu, hed, unet)
40
41
42 Yields
43 ------
44
45 model : :py:class:`torch.nn.Module`
46 Network (e.g. driu, hed, unet)
47
48 """
49
50 model.eval()
51 yield model
52 model.train()
53
54
55def run(
56 model,
57 data_loader,
58 valid_loader,
59 optimizer,
60 criterion,
61 checkpointer,
62 checkpoint_period,
63 device,
64 arguments,
65 output_folder,
66 criterion_valid = None
67):
68 """
69 Fits a CNN model using supervised learning and save it to disk.
70
71 This method supports periodic checkpointing and the output of a
72 CSV-formatted log with the evolution of some figures during training.
73
74
75 Parameters
76 ----------
77
78 model : :py:class:`torch.nn.Module`
79 Network (e.g. pasa)
80
81 data_loader : :py:class:`torch.utils.data.DataLoader`
82 To be used to train the model
83
84 valid_loader : :py:class:`torch.utils.data.DataLoader`
85 To be used to validate the model and enable automatic checkpointing.
86 If set to ``None``, then do not validate it.
87
88 optimizer : :py:mod:`torch.optim`
89
90 criterion : :py:class:`torch.nn.modules.loss._Loss`
91 loss function
92
93 checkpointer : :py:class:`bob.med.tb.utils.checkpointer.Checkpointer`
94 checkpointer implementation
95
96 checkpoint_period : int
97 save a checkpoint every ``n`` epochs. If set to ``0`` (zero), then do
98 not save intermediary checkpoints
99
100 device : str
101 device to use ``'cpu'`` or ``cuda:0``
102
103 arguments : dict
104 start and end epochs
105
106 output_folder : str
107 output path
108
109 criterion_valid : :py:class:`torch.nn.modules.loss._Loss`
110 specific loss function for the validation set
111 """
112
113 start_epoch = arguments["epoch"]
114 max_epoch = arguments["max_epoch"]
115
116 if device != "cpu":
117 # asserts we do have a GPU
118 assert bool(gpu_constants()), (
119 f"Device set to '{device}', but cannot "
120 f"find a GPU (maybe nvidia-smi is not installed?)"
121 )
122
123 os.makedirs(output_folder, exist_ok=True)
124
125 # Save model summary
126 summary_path = os.path.join(output_folder, "model_summary.txt")
127 logger.info(f"Saving model summary at {summary_path}...")
128 with open(summary_path, "wt") as f:
129 r, n = summary(model)
130 logger.info(f"Model has {n} parameters...")
131 f.write(r)
132
133 # write static information to a CSV file
134 static_logfile_name = os.path.join(output_folder, "constants.csv")
135 if os.path.exists(static_logfile_name):
136 backup = static_logfile_name + "~"
137 if os.path.exists(backup):
138 os.unlink(backup)
139 shutil.move(static_logfile_name, backup)
140 with open(static_logfile_name, "w", newline="") as f:
141 logdata = cpu_constants()
142 if device != "cpu":
143 logdata += gpu_constants()
144 logdata += (("model_size", n),)
145 logwriter = csv.DictWriter(f, fieldnames=[k[0] for k in logdata])
146 logwriter.writeheader()
147 logwriter.writerow(dict(k for k in logdata))
148
149 # Log continous information to (another) file
150 logfile_name = os.path.join(output_folder, "trainlog.csv")
151
152 if arguments["epoch"] == 0 and os.path.exists(logfile_name):
153 backup = logfile_name + "~"
154 if os.path.exists(backup):
155 os.unlink(backup)
156 shutil.move(logfile_name, backup)
157
158 logfile_fields = (
159 "epoch",
160 "total_time",
161 "eta",
162 "average_loss",
163 "median_loss",
164 "learning_rate",
165 )
166 if valid_loader is not None:
167 logfile_fields += ("validation_average_loss", "validation_median_loss")
168 logfile_fields += tuple([k[0] for k in cpu_log()])
169 if device != "cpu":
170 logfile_fields += tuple([k[0] for k in gpu_log()])
171
172 # the lowest validation loss obtained so far - this value is updated only
173 # if a validation set is available
174 lowest_validation_loss = sys.float_info.max
175
176 with open(logfile_name, "a+", newline="") as logfile:
177 logwriter = csv.DictWriter(logfile, fieldnames=logfile_fields)
178
179 if arguments["epoch"] == 0:
180 logwriter.writeheader()
181
182 model.train() # set training mode
183
184 model.to(device) # set/cast parameters to device
185 for state in optimizer.state.values():
186 for k, v in state.items():
187 if isinstance(v, torch.Tensor):
188 state[k] = v.to(device)
189
190 # Total training timer
191 start_training_time = time.time()
192
193 for epoch in tqdm(
194 range(start_epoch, max_epoch),
195 desc="epoch",
196 leave=False,
197 disable=None,
198 ):
199 losses = SmoothedValue(len(data_loader))
200 epoch = epoch + 1
201 arguments["epoch"] = epoch
202
203 # Epoch time
204 start_epoch_time = time.time()
205
206 # progress bar only on interactive jobs
207 for samples in tqdm(
208 data_loader, desc="batch", leave=False, disable=None
209 ):
210
211 # data forwarding on the existing network
212 images = samples[1].to(
213 device=device, non_blocking=torch.cuda.is_available()
214 )
215 labels = samples[2].to(
216 device=device, non_blocking=torch.cuda.is_available()
217 )
218
219 # Increase labels dimension if too low
220 # Allows single and multiclass usage
221 if labels.ndim == 1:
222 labels = torch.reshape(labels, (labels.shape[0], 1))
223
224 outputs = model(images)
225
226 # loss evaluation and learning (backward step)
227 loss = criterion(outputs, labels.double())
228 optimizer.zero_grad()
229 loss.backward()
230 optimizer.step()
231
232 losses.update(loss)
233 logger.debug(f"batch loss: {loss.item()}")
234
235
236 # calculates the validation loss if necessary
237 valid_losses = None
238 if valid_loader is not None:
239
240 with torch.no_grad(), torch_evaluation(model):
241
242 valid_losses = SmoothedValue(len(valid_loader))
243 for samples in tqdm(
244 valid_loader, desc="valid", leave=False, disable=None
245 ):
246 # data forwarding on the existing network
247 images = samples[1].to(
248 device=device,
249 non_blocking=torch.cuda.is_available(),
250 )
251 labels = samples[2].to(
252 device=device,
253 non_blocking=torch.cuda.is_available(),
254 )
255
256 # Increase labels dimension if too low
257 # Allows single and multiclass usage
258 if labels.ndim == 1:
259 labels = torch.reshape(labels, (labels.shape[0], 1))
260
261 outputs = model(images)
262
263 if criterion_valid is not None:
264 loss = criterion_valid(outputs, labels.double())
265 else:
266 loss = criterion(outputs, labels.double())
267 valid_losses.update(loss)
268
269 if checkpoint_period and (epoch % checkpoint_period == 0):
270 checkpointer.save(f"model_{epoch:03d}", **arguments)
271
272 if (
273 valid_losses is not None
274 and valid_losses.avg < lowest_validation_loss
275 ):
276 lowest_validation_loss = valid_losses.avg
277 logger.info(
278 f"Found new low on validation set:"
279 f" {lowest_validation_loss:.6f}"
280 )
281 checkpointer.save(f"model_lowest_valid_loss", **arguments)
282
283 if epoch >= max_epoch:
284 checkpointer.save("model_final", **arguments)
285
286 # computes ETA (estimated time-of-arrival; end of training) taking
287 # into consideration previous epoch performance
288 epoch_time = time.time() - start_epoch_time
289 eta_seconds = epoch_time * (max_epoch - epoch)
290 current_time = time.time() - start_training_time
291
292 logdata = (
293 ("epoch", f"{epoch}"),
294 (
295 "total_time",
296 f"{datetime.timedelta(seconds=int(current_time))}",
297 ),
298 ("eta", f"{datetime.timedelta(seconds=int(eta_seconds))}"),
299 ("average_loss", f"{losses.avg:.6f}"),
300 ("median_loss", f"{losses.median:.6f}"),
301 ("learning_rate", f"{optimizer.param_groups[0]['lr']:.6f}"),
302 )
303 if valid_losses is not None:
304 logdata += (
305 ("validation_average_loss", f"{valid_losses.avg:.6f}"),
306 ("validation_median_loss", f"{valid_losses.median:.6f}"),
307 )
308 logdata += cpu_log()
309 if device != "cpu":
310 logdata += gpu_log()
311
312 logwriter.writerow(dict(k for k in logdata))
313 logfile.flush()
314 tqdm.write("|".join([f"{k}: {v}" for (k, v) in logdata[:4]]))
315
316 total_training_time = time.time() - start_training_time
317 logger.info(
318 f"Total training time: {datetime.timedelta(seconds=total_training_time)} ({(total_training_time/max_epoch):.4f}s in average per epoch)"
319 )