1#!/usr/bin/env python
2# coding=utf-8
3
4import logging
5import multiprocessing
6import sys
7
8import torch
9
10from torch.utils.data import DataLoader
11
12from ..utils.checkpointer import Checkpointer
13from .common import set_seeds, setup_pytorch_device
14
15logger = logging.getLogger(__name__)
16
17
18def base_train(
19 model,
20 optimizer,
21 scheduler,
22 output_folder,
23 epochs,
24 batch_size,
25 batch_chunk_count,
26 drop_incomplete_batch,
27 criterion,
28 dataset,
29 checkpoint_period,
30 device,
31 seed,
32 parallel,
33 monitoring_interval,
34 detection,
35 verbose,
36 **kwargs,
37):
38 """Create base function for training segmentation / detection task."""
39
40 def _collate_fn(batch):
41 return tuple(zip(*batch))
42
43 device = setup_pytorch_device(device)
44
45 set_seeds(seed, all_gpus=False)
46
47 use_dataset = dataset
48 validation_dataset = None
49 extra_validation_datasets = []
50 if isinstance(dataset, dict):
51 if "__train__" in dataset:
52 logger.info("Found (dedicated) '__train__' set for training")
53 use_dataset = dataset["__train__"]
54 else:
55 use_dataset = dataset["train"]
56
57 if "__valid__" in dataset:
58 logger.info("Found (dedicated) '__valid__' set for validation")
59 logger.info("Will checkpoint lowest loss model on validation set")
60 validation_dataset = dataset["__valid__"]
61
62 if "__extra_valid__" in dataset:
63 if not isinstance(dataset["__extra_valid__"], list):
64 raise RuntimeError(
65 f"If present, dataset['__extra_valid__'] must be a list, "
66 f"but you passed a {type(dataset['__extra_valid__'])}, "
67 f"which is invalid."
68 )
69 logger.info(
70 f"Found {len(dataset['__extra_valid__'])} extra validation "
71 f"set(s) to be tracked during training"
72 )
73 logger.info(
74 "Extra validation sets are NOT used for model checkpointing!"
75 )
76 extra_validation_datasets = dataset["__extra_valid__"]
77
78 # PyTorch dataloader
79 multiproc_kwargs = dict()
80 if parallel < 0:
81 multiproc_kwargs["num_workers"] = 0
82 else:
83 multiproc_kwargs["num_workers"] = (
84 parallel or multiprocessing.cpu_count()
85 )
86
87 if multiproc_kwargs["num_workers"] > 0 and sys.platform == "darwin":
88 multiproc_kwargs[
89 "multiprocessing_context"
90 ] = multiprocessing.get_context("spawn")
91
92 batch_chunk_size = batch_size
93 if batch_size % batch_chunk_count != 0:
94 # batch_size must be divisible by batch_chunk_count.
95 raise RuntimeError(
96 f"--batch-size ({batch_size}) must be divisible by "
97 f"--batch-chunk-size ({batch_chunk_count})."
98 )
99 else:
100 batch_chunk_size = batch_size // batch_chunk_count
101
102 if detection:
103 from ...detect.engine.trainer import run
104
105 data_loader = DataLoader(
106 dataset=use_dataset,
107 batch_size=batch_chunk_size,
108 shuffle=True,
109 drop_last=drop_incomplete_batch,
110 pin_memory=torch.cuda.is_available(),
111 collate_fn=_collate_fn,
112 **multiproc_kwargs,
113 )
114
115 valid_loader = None
116 if validation_dataset is not None:
117 valid_loader = DataLoader(
118 dataset=validation_dataset,
119 batch_size=batch_chunk_size,
120 shuffle=False,
121 drop_last=False,
122 pin_memory=torch.cuda.is_available(),
123 collate_fn=_collate_fn,
124 **multiproc_kwargs,
125 )
126
127 extra_valid_loaders = [
128 DataLoader(
129 dataset=k,
130 batch_size=batch_chunk_size,
131 shuffle=False,
132 drop_last=False,
133 pin_memory=torch.cuda.is_available(),
134 collate_fn=_collate_fn,
135 **multiproc_kwargs,
136 )
137 for k in extra_validation_datasets
138 ]
139 else:
140 from ...binseg.engine.trainer import run
141
142 data_loader = DataLoader(
143 dataset=use_dataset,
144 batch_size=batch_chunk_size,
145 shuffle=True,
146 drop_last=drop_incomplete_batch,
147 pin_memory=torch.cuda.is_available(),
148 **multiproc_kwargs,
149 )
150
151 valid_loader = None
152 if validation_dataset is not None:
153 valid_loader = DataLoader(
154 dataset=validation_dataset,
155 batch_size=batch_chunk_size,
156 shuffle=False,
157 drop_last=False,
158 pin_memory=torch.cuda.is_available(),
159 **multiproc_kwargs,
160 )
161
162 extra_valid_loaders = [
163 DataLoader(
164 dataset=k,
165 batch_size=batch_chunk_size,
166 shuffle=False,
167 drop_last=False,
168 pin_memory=torch.cuda.is_available(),
169 **multiproc_kwargs,
170 )
171 for k in extra_validation_datasets
172 ]
173
174 checkpointer = Checkpointer(model, optimizer, scheduler, path=output_folder)
175
176 arguments = {}
177 arguments["epoch"] = 0
178 extra_checkpoint_data = checkpointer.load()
179 arguments.update(extra_checkpoint_data)
180 arguments["max_epoch"] = epochs
181
182 logger.info("Training for {} epochs".format(arguments["max_epoch"]))
183 logger.info("Continuing from epoch {}".format(arguments["epoch"]))
184
185 run(
186 model,
187 data_loader,
188 valid_loader,
189 extra_valid_loaders,
190 optimizer,
191 criterion,
192 scheduler,
193 checkpointer,
194 checkpoint_period,
195 device,
196 arguments,
197 output_folder,
198 monitoring_interval,
199 batch_chunk_count,
200 )