Coverage for /scratch/builds/bob/bob.ip.binseg/miniconda/conda-bld/bob.ip.binseg_1673966692152/_test_env_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_p/lib/python3.10/site-packages/bob/ip/common/script/train.py: 92%

63 statements  

« prev     ^ index     » next       coverage.py v7.0.5, created at 2023-01-17 15:03 +0000

1#!/usr/bin/env python 

2# coding=utf-8 

3 

4import logging 

5import multiprocessing 

6import sys 

7 

8import torch 

9 

10from torch.utils.data import DataLoader 

11 

12from ..utils.checkpointer import Checkpointer 

13from .common import set_seeds, setup_pytorch_device 

14 

15logger = logging.getLogger(__name__) 

16 

17 

18def base_train( 

19 model, 

20 optimizer, 

21 scheduler, 

22 output_folder, 

23 epochs, 

24 batch_size, 

25 batch_chunk_count, 

26 drop_incomplete_batch, 

27 criterion, 

28 dataset, 

29 checkpoint_period, 

30 device, 

31 seed, 

32 parallel, 

33 monitoring_interval, 

34 detection, 

35 verbose, 

36 **kwargs, 

37): 

38 """Create base function for training segmentation / detection task.""" 

39 

40 def _collate_fn(batch): 

41 return tuple(zip(*batch)) 

42 

43 device = setup_pytorch_device(device) 

44 

45 set_seeds(seed, all_gpus=False) 

46 

47 use_dataset = dataset 

48 validation_dataset = None 

49 extra_validation_datasets = [] 

50 if isinstance(dataset, dict): 

51 if "__train__" in dataset: 

52 logger.info("Found (dedicated) '__train__' set for training") 

53 use_dataset = dataset["__train__"] 

54 else: 

55 use_dataset = dataset["train"] 

56 

57 if "__valid__" in dataset: 

58 logger.info("Found (dedicated) '__valid__' set for validation") 

59 logger.info("Will checkpoint lowest loss model on validation set") 

60 validation_dataset = dataset["__valid__"] 

61 

62 if "__extra_valid__" in dataset: 

63 if not isinstance(dataset["__extra_valid__"], list): 

64 raise RuntimeError( 

65 f"If present, dataset['__extra_valid__'] must be a list, " 

66 f"but you passed a {type(dataset['__extra_valid__'])}, " 

67 f"which is invalid." 

68 ) 

69 logger.info( 

70 f"Found {len(dataset['__extra_valid__'])} extra validation " 

71 f"set(s) to be tracked during training" 

72 ) 

73 logger.info( 

74 "Extra validation sets are NOT used for model checkpointing!" 

75 ) 

76 extra_validation_datasets = dataset["__extra_valid__"] 

77 

78 # PyTorch dataloader 

79 multiproc_kwargs = dict() 

80 if parallel < 0: 

81 multiproc_kwargs["num_workers"] = 0 

82 else: 

83 multiproc_kwargs["num_workers"] = ( 

84 parallel or multiprocessing.cpu_count() 

85 ) 

86 

87 if multiproc_kwargs["num_workers"] > 0 and sys.platform == "darwin": 

88 multiproc_kwargs[ 

89 "multiprocessing_context" 

90 ] = multiprocessing.get_context("spawn") 

91 

92 batch_chunk_size = batch_size 

93 if batch_size % batch_chunk_count != 0: 

94 # batch_size must be divisible by batch_chunk_count. 

95 raise RuntimeError( 

96 f"--batch-size ({batch_size}) must be divisible by " 

97 f"--batch-chunk-size ({batch_chunk_count})." 

98 ) 

99 else: 

100 batch_chunk_size = batch_size // batch_chunk_count 

101 

102 if detection: 

103 from ...detect.engine.trainer import run 

104 

105 data_loader = DataLoader( 

106 dataset=use_dataset, 

107 batch_size=batch_chunk_size, 

108 shuffle=True, 

109 drop_last=drop_incomplete_batch, 

110 pin_memory=torch.cuda.is_available(), 

111 collate_fn=_collate_fn, 

112 **multiproc_kwargs, 

113 ) 

114 

115 valid_loader = None 

116 if validation_dataset is not None: 

117 valid_loader = DataLoader( 

118 dataset=validation_dataset, 

119 batch_size=batch_chunk_size, 

120 shuffle=False, 

121 drop_last=False, 

122 pin_memory=torch.cuda.is_available(), 

123 collate_fn=_collate_fn, 

124 **multiproc_kwargs, 

125 ) 

126 

127 extra_valid_loaders = [ 

128 DataLoader( 

129 dataset=k, 

130 batch_size=batch_chunk_size, 

131 shuffle=False, 

132 drop_last=False, 

133 pin_memory=torch.cuda.is_available(), 

134 collate_fn=_collate_fn, 

135 **multiproc_kwargs, 

136 ) 

137 for k in extra_validation_datasets 

138 ] 

139 else: 

140 from ...binseg.engine.trainer import run 

141 

142 data_loader = DataLoader( 

143 dataset=use_dataset, 

144 batch_size=batch_chunk_size, 

145 shuffle=True, 

146 drop_last=drop_incomplete_batch, 

147 pin_memory=torch.cuda.is_available(), 

148 **multiproc_kwargs, 

149 ) 

150 

151 valid_loader = None 

152 if validation_dataset is not None: 

153 valid_loader = DataLoader( 

154 dataset=validation_dataset, 

155 batch_size=batch_chunk_size, 

156 shuffle=False, 

157 drop_last=False, 

158 pin_memory=torch.cuda.is_available(), 

159 **multiproc_kwargs, 

160 ) 

161 

162 extra_valid_loaders = [ 

163 DataLoader( 

164 dataset=k, 

165 batch_size=batch_chunk_size, 

166 shuffle=False, 

167 drop_last=False, 

168 pin_memory=torch.cuda.is_available(), 

169 **multiproc_kwargs, 

170 ) 

171 for k in extra_validation_datasets 

172 ] 

173 

174 checkpointer = Checkpointer(model, optimizer, scheduler, path=output_folder) 

175 

176 arguments = {} 

177 arguments["epoch"] = 0 

178 extra_checkpoint_data = checkpointer.load() 

179 arguments.update(extra_checkpoint_data) 

180 arguments["max_epoch"] = epochs 

181 

182 logger.info("Training for {} epochs".format(arguments["max_epoch"])) 

183 logger.info("Continuing from epoch {}".format(arguments["epoch"])) 

184 

185 run( 

186 model, 

187 data_loader, 

188 valid_loader, 

189 extra_valid_loaders, 

190 optimizer, 

191 criterion, 

192 scheduler, 

193 checkpointer, 

194 checkpoint_period, 

195 device, 

196 arguments, 

197 output_folder, 

198 monitoring_interval, 

199 batch_chunk_count, 

200 )