Coverage for /scratch/builds/bob/bob.ip.binseg/miniconda/conda-bld/bob.ip.binseg_1673966692152/_test_env_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_p/lib/python3.10/site-packages/bob/ip/common/script/experiment.py: 81%

27 statements  

« prev     ^ index     » next       coverage.py v7.0.5, created at 2023-01-17 15:03 +0000

1#!/usr/bin/env python 

2# coding=utf-8 

3 

4import logging 

5import os 

6import shutil 

7 

8import click 

9 

10from .common import save_sh_command 

11 

12logger = logging.getLogger(__name__) 

13 

14 

15@click.pass_context 

16def base_experiment( 

17 ctx, 

18 model, 

19 optimizer, 

20 scheduler, 

21 output_folder, 

22 epochs, 

23 batch_size, 

24 batch_chunk_count, 

25 drop_incomplete_batch, 

26 criterion, 

27 dataset, 

28 second_annotator, 

29 checkpoint_period, 

30 device, 

31 seed, 

32 parallel, 

33 monitoring_interval, 

34 overlayed, 

35 steps, 

36 plot_limits, 

37 detection, 

38 verbose, 

39 **kwargs, 

40): 

41 """Create base experiment function for segmentation / detection tasks.""" 

42 command_sh = os.path.join(output_folder, "command.sh") 

43 if os.path.exists(command_sh): 

44 backup = command_sh + "~" 

45 if os.path.exists(backup): 

46 os.unlink(backup) 

47 shutil.move(command_sh, backup) 

48 save_sh_command(command_sh) 

49 

50 # training 

51 logger.info("Started training") 

52 

53 from .train import base_train 

54 

55 train_output_folder = os.path.join(output_folder, "model") 

56 ctx.invoke( 

57 base_train, 

58 model=model, 

59 optimizer=optimizer, 

60 scheduler=scheduler, 

61 output_folder=train_output_folder, 

62 epochs=epochs, 

63 batch_size=batch_size, 

64 batch_chunk_count=batch_chunk_count, 

65 drop_incomplete_batch=drop_incomplete_batch, 

66 criterion=criterion, 

67 dataset=dataset, 

68 checkpoint_period=checkpoint_period, 

69 device=device, 

70 seed=seed, 

71 parallel=parallel, 

72 monitoring_interval=monitoring_interval, 

73 detection=detection, 

74 verbose=verbose, 

75 ) 

76 logger.info("Ended training") 

77 

78 from .train_analysis import base_train_analysis 

79 

80 ctx.invoke( 

81 base_train_analysis, 

82 log=os.path.join(train_output_folder, "trainlog.csv"), 

83 constants=os.path.join(train_output_folder, "constants.csv"), 

84 output_pdf=os.path.join(train_output_folder, "trainlog.pdf"), 

85 verbose=verbose, 

86 ) 

87 

88 from .analyze import base_analyze 

89 

90 # preferably, we use the best model on the validation set 

91 # otherwise, we get the last saved model 

92 model_file = os.path.join( 

93 train_output_folder, "model_lowest_valid_loss.pth" 

94 ) 

95 if not os.path.exists(model_file): 

96 model_file = os.path.join(train_output_folder, "model_final_epoch.pth") 

97 

98 ctx.invoke( 

99 base_analyze, 

100 model=model, 

101 output_folder=output_folder, 

102 batch_size=batch_size, 

103 dataset=dataset, 

104 second_annotator=second_annotator, 

105 device=device, 

106 overlayed=overlayed, 

107 weight=model_file, 

108 steps=steps, 

109 parallel=parallel, 

110 plot_limits=plot_limits, 

111 detection=detection, 

112 verbose=verbose, 

113 )