1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3
4import csv
5import datetime
6import logging
7import os
8import shutil
9import sys
10import time
11
12import numpy
13import torch
14
15from tqdm import tqdm
16
17from ..utils.measure import SmoothedValue
18from ..utils.resources import cpu_constants, cpu_log, gpu_constants, gpu_log
19from ..utils.summary import summary
20from .trainer import PYTORCH_GE_110, torch_evaluation
21
22logger = logging.getLogger(__name__)
23
24
25def sharpen(x, T):
26 temp = x ** (1 / T)
27 return temp / temp.sum(dim=1, keepdim=True)
28
29
30def mix_up(alpha, input, target, unlabelled_input, unlabled_target):
31 """Applies mix up as described in [MIXMATCH_19].
32
33 Parameters
34 ----------
35 alpha : float
36
37 input : :py:class:`torch.Tensor`
38
39 target : :py:class:`torch.Tensor`
40
41 unlabelled_input : :py:class:`torch.Tensor`
42
43 unlabled_target : :py:class:`torch.Tensor`
44
45
46 Returns
47 -------
48
49 list
50
51 """
52
53 with torch.no_grad():
54 l = numpy.random.beta(alpha, alpha) # Eq (8)
55 l = max(l, 1 - l) # Eq (9)
56 # Shuffle and concat. Alg. 1 Line: 12
57 w_inputs = torch.cat([input, unlabelled_input], 0)
58 w_targets = torch.cat([target, unlabled_target], 0)
59 idx = torch.randperm(w_inputs.size(0)) # get random index
60
61 # Apply MixUp to labelled data and entries from W. Alg. 1 Line: 13
62 input_mixedup = l * input + (1 - l) * w_inputs[idx[len(input) :]]
63 target_mixedup = l * target + (1 - l) * w_targets[idx[len(target) :]]
64
65 # Apply MixUp to unlabelled data and entries from W. Alg. 1 Line: 14
66 unlabelled_input_mixedup = (
67 l * unlabelled_input
68 + (1 - l) * w_inputs[idx[: len(unlabelled_input)]]
69 )
70 unlabled_target_mixedup = (
71 l * unlabled_target
72 + (1 - l) * w_targets[idx[: len(unlabled_target)]]
73 )
74 return (
75 input_mixedup,
76 target_mixedup,
77 unlabelled_input_mixedup,
78 unlabled_target_mixedup,
79 )
80
81
82def square_rampup(current, rampup_length=16):
83 """slowly ramp-up ``lambda_u``
84
85 Parameters
86 ----------
87
88 current : int
89 current epoch
90
91 rampup_length : :obj:`int`, optional
92 how long to ramp up, by default 16
93
94 Returns
95 -------
96
97 factor : float
98 ramp up factor
99 """
100
101 if rampup_length == 0:
102 return 1.0
103 else:
104 current = numpy.clip((current / float(rampup_length)) ** 2, 0.0, 1.0)
105 return float(current)
106
107
108def linear_rampup(current, rampup_length=16):
109 """slowly ramp-up ``lambda_u``
110
111 Parameters
112 ----------
113 current : int
114 current epoch
115
116 rampup_length : :obj:`int`, optional
117 how long to ramp up, by default 16
118
119 Returns
120 -------
121
122 factor: float
123 ramp up factor
124
125 """
126 if rampup_length == 0:
127 return 1.0
128 else:
129 current = numpy.clip(current / rampup_length, 0.0, 1.0)
130 return float(current)
131
132
133def guess_labels(unlabelled_images, model):
134 """
135 Calculate the average predictions by 2 augmentations: horizontal and vertical flips
136
137 Parameters
138 ----------
139
140 unlabelled_images : :py:class:`torch.Tensor`
141 ``[n,c,h,w]``
142
143 target : :py:class:`torch.Tensor`
144
145 Returns
146 -------
147
148 shape : :py:class:`torch.Tensor`
149 ``[n,c,h,w]``
150
151 """
152 with torch.no_grad():
153 guess1 = torch.sigmoid(model(unlabelled_images)).unsqueeze(0)
154 # Horizontal flip and unsqueeze to work with batches (increase flip dimension by 1)
155 hflip = torch.sigmoid(model(unlabelled_images.flip(2))).unsqueeze(0)
156 guess2 = hflip.flip(3)
157 # Vertical flip and unsqueeze to work with batches (increase flip dimension by 1)
158 vflip = torch.sigmoid(model(unlabelled_images.flip(3))).unsqueeze(0)
159 guess3 = vflip.flip(4)
160 # Concat
161 concat = torch.cat([guess1, guess2, guess3], 0)
162 avg_guess = torch.mean(concat, 0)
163 return avg_guess
164
165
166def run(
167 model,
168 data_loader,
169 valid_loader,
170 optimizer,
171 criterion,
172 scheduler,
173 checkpointer,
174 checkpoint_period,
175 device,
176 arguments,
177 output_folder,
178 rampup_length,
179):
180 """
181 Fits an FCN model using semi-supervised learning and saves it to disk.
182
183
184 This method supports periodic checkpointing and the output of a
185 CSV-formatted log with the evolution of some figures during training.
186
187
188 Parameters
189 ----------
190
191 model : :py:class:`torch.nn.Module`
192 Network (e.g. driu, hed, unet)
193
194 data_loader : :py:class:`torch.utils.data.DataLoader`
195 To be used to train the model
196
197 valid_loader : :py:class:`torch.utils.data.DataLoader`
198 To be used to validate the model and enable automatic checkpointing.
199 If set to ``None``, then do not validate it.
200
201 optimizer : :py:mod:`torch.optim`
202
203 criterion : :py:class:`torch.nn.modules.loss._Loss`
204 loss function
205
206 scheduler : :py:mod:`torch.optim`
207 learning rate scheduler
208
209 checkpointer : :py:class:`bob.ip.binseg.utils.checkpointer.Checkpointer`
210 checkpointer implementation
211
212 checkpoint_period : int
213 save a checkpoint every ``n`` epochs. If set to ``0`` (zero), then do
214 not save intermediary checkpoints
215
216 device : str
217 device to use ``'cpu'`` or ``cuda:0``
218
219 arguments : dict
220 start and end epochs
221
222 output_folder : str
223 output path
224
225 rampup_length : int
226 rampup epochs
227
228 """
229
230 start_epoch = arguments["epoch"]
231 max_epoch = arguments["max_epoch"]
232
233 if device != "cpu":
234 # asserts we do have a GPU
235 assert bool(gpu_constants()), (
236 f"Device set to '{device}', but cannot "
237 f"find a GPU (maybe nvidia-smi is not installed?)"
238 )
239
240 os.makedirs(output_folder, exist_ok=True)
241
242 # Save model summary
243 summary_path = os.path.join(output_folder, "model_summary.txt")
244 logger.info(f"Saving model summary at {summary_path}...")
245 with open(summary_path, "wt") as f:
246 r, n = summary(model)
247 logger.info(f"Model has {n} parameters...")
248 f.write(r)
249
250 # write static information to a CSV file
251 static_logfile_name = os.path.join(output_folder, "constants.csv")
252 if os.path.exists(static_logfile_name):
253 backup = static_logfile_name + "~"
254 if os.path.exists(backup):
255 os.unlink(backup)
256 shutil.move(static_logfile_name, backup)
257 with open(static_logfile_name, "w", newline="") as f:
258 logdata = cpu_constants()
259 if device != "cpu":
260 logdata += gpu_constants()
261 logdata += (("model_size", n),)
262 logwriter = csv.DictWriter(f, fieldnames=[k[0] for k in logdata])
263 logwriter.writeheader()
264 logwriter.writerow(dict(k for k in logdata))
265
266 # Log continous information to (another) file
267 logfile_name = os.path.join(output_folder, "trainlog.csv")
268
269 if arguments["epoch"] == 0 and os.path.exists(logfile_name):
270 backup = logfile_name + "~"
271 if os.path.exists(backup):
272 os.unlink(backup)
273 shutil.move(logfile_name, backup)
274
275 logfile_fields = (
276 "epoch",
277 "total_time",
278 "eta",
279 "average_loss",
280 "median_loss",
281 "labelled_median_loss",
282 "unlabelled_median_loss",
283 "learning_rate",
284 )
285 if valid_loader is not None:
286 logfile_fields += ("validation_average_loss", "validation_median_loss")
287 logfile_fields += tuple([k[0] for k in cpu_log()])
288 if device != "cpu":
289 logfile_fields += tuple([k[0] for k in gpu_log()])
290
291 # the lowest validation loss obtained so far - this value is updated only
292 # if a validation set is available
293 lowest_validation_loss = sys.float_info.max
294
295 with open(logfile_name, "a+", newline="") as logfile:
296 logwriter = csv.DictWriter(logfile, fieldnames=logfile_fields)
297
298 if arguments["epoch"] == 0:
299 logwriter.writeheader()
300
301 model.train() # set training mode
302
303 model.to(device) # set/cast parameters to device
304 for state in optimizer.state.values():
305 for k, v in state.items():
306 if isinstance(v, torch.Tensor):
307 state[k] = v.to(device)
308
309 # Total training timer
310 start_training_time = time.time()
311
312 for epoch in tqdm(
313 range(start_epoch, max_epoch),
314 desc="epoch",
315 leave=False,
316 disable=None,
317 ):
318 if not PYTORCH_GE_110:
319 scheduler.step()
320 losses = SmoothedValue(len(data_loader))
321 labelled_loss = SmoothedValue(len(data_loader))
322 unlabelled_loss = SmoothedValue(len(data_loader))
323 epoch = epoch + 1
324 arguments["epoch"] = epoch
325
326 # Epoch time
327 start_epoch_time = time.time()
328
329 # progress bar only on interactive jobs
330 for samples in tqdm(
331 data_loader, desc="batch", leave=False, disable=None
332 ):
333
334 # data forwarding on the existing network
335
336 # labelled
337 images = samples[1].to(
338 device=device, non_blocking=torch.cuda.is_available()
339 )
340 ground_truths = samples[2].to(
341 device=device, non_blocking=torch.cuda.is_available()
342 )
343 unlabelled_images = samples[4].to(
344 device=device, non_blocking=torch.cuda.is_available()
345 )
346 # labelled outputs
347 outputs = model(images)
348 unlabelled_outputs = model(unlabelled_images)
349 # guessed unlabelled outputs
350 unlabelled_ground_truths = guess_labels(
351 unlabelled_images, model
352 )
353
354 # loss evaluation and learning (backward step)
355 ramp_up_factor = square_rampup(
356 epoch, rampup_length=rampup_length
357 )
358
359 # note: no support for masks...
360 loss, ll, ul = criterion(
361 outputs,
362 ground_truths,
363 unlabelled_outputs,
364 unlabelled_ground_truths,
365 ramp_up_factor,
366 )
367 optimizer.zero_grad()
368 loss.backward()
369 optimizer.step()
370 losses.update(loss)
371 labelled_loss.update(ll)
372 unlabelled_loss.update(ul)
373 logger.debug(f"batch loss: {loss.item()}")
374
375 if PYTORCH_GE_110:
376 scheduler.step()
377
378 # calculates the validation loss if necessary
379 # note: validation does not comprise "unlabelled" losses
380 valid_losses = None
381 if valid_loader is not None:
382
383 with torch.no_grad(), torch_evaluation(model):
384
385 valid_losses = SmoothedValue(len(valid_loader))
386 for samples in tqdm(
387 valid_loader, desc="valid", leave=False, disable=None
388 ):
389 # data forwarding on the existing network
390 images = samples[1].to(
391 device=device,
392 non_blocking=torch.cuda.is_available(),
393 )
394 ground_truths = samples[2].to(
395 device=device,
396 non_blocking=torch.cuda.is_available(),
397 )
398 masks = (
399 torch.ones_like(ground_truths)
400 if len(samples) < 4
401 else samples[3].to(
402 device=device,
403 non_blocking=torch.cuda.is_available(),
404 )
405 )
406
407 outputs = model(images)
408 loss = criterion(outputs, ground_truths, masks)
409 valid_losses.update(loss)
410
411 if checkpoint_period and (epoch % checkpoint_period == 0):
412 checkpointer.save(f"model_{epoch:03d}", **arguments)
413
414 if (
415 valid_losses is not None
416 and valid_losses.avg < lowest_validation_loss
417 ):
418 lowest_validation_loss = valid_losses.avg
419 logger.info(
420 f"Found new low on validation set:"
421 f" {lowest_validation_loss:.6f}"
422 )
423 checkpointer.save("model_lowest_valid_loss", **arguments)
424
425 if epoch >= max_epoch:
426 checkpointer.save("model_final", **arguments)
427
428 # computes ETA (estimated time-of-arrival; end of training) taking
429 # into consideration previous epoch performance
430 epoch_time = time.time() - start_epoch_time
431 eta_seconds = epoch_time * (max_epoch - epoch)
432 current_time = time.time() - start_training_time
433
434 logdata = (
435 ("epoch", f"{epoch}"),
436 (
437 "total_time",
438 f"{datetime.timedelta(seconds=int(current_time))}",
439 ),
440 ("eta", f"{datetime.timedelta(seconds=int(eta_seconds))}"),
441 ("average_loss", f"{losses.avg:.6f}"),
442 ("median_loss", f"{losses.median:.6f}"),
443 ("labelled_median_loss", f"{labelled_loss.median:.6f}"),
444 ("unlabelled_median_loss", f"{unlabelled_loss.median:.6f}"),
445 ("learning_rate", f"{optimizer.param_groups[0]['lr']:.6f}"),
446 )
447 if valid_losses is not None:
448 logdata += (
449 ("validation_average_loss", f"{valid_losses.avg:.6f}"),
450 ("validation_median_loss", f"{valid_losses.median:.6f}"),
451 )
452 logdata += cpu_log()
453 if device != "cpu":
454 logdata += gpu_log()
455
456 if device != "cpu":
457 logdata += gpu_log()
458
459 logwriter.writerow(dict(k for k in logdata))
460 logfile.flush()
461 tqdm.write("|".join([f"{k}: {v}" for (k, v) in logdata[:4]]))
462
463 total_training_time = time.time() - start_training_time
464 logger.info(
465 f"Total training time: {datetime.timedelta(seconds=total_training_time)} ({(total_training_time/max_epoch):.4f}s in average per epoch)"
466 )