1#!/usr/bin/env python
2# coding=utf-8
3
4import logging
5import os
6import shutil
7
8import click
9
10from .common import save_sh_command
11
12logger = logging.getLogger(__name__)
13
14
15@click.pass_context
16def base_experiment(
17 ctx,
18 model,
19 optimizer,
20 scheduler,
21 output_folder,
22 epochs,
23 batch_size,
24 batch_chunk_count,
25 drop_incomplete_batch,
26 criterion,
27 dataset,
28 second_annotator,
29 checkpoint_period,
30 device,
31 seed,
32 parallel,
33 monitoring_interval,
34 overlayed,
35 steps,
36 plot_limits,
37 detection,
38 verbose,
39 **kwargs,
40):
41 """Create base experiment function for segmentation / detection tasks."""
42 command_sh = os.path.join(output_folder, "command.sh")
43 if os.path.exists(command_sh):
44 backup = command_sh + "~"
45 if os.path.exists(backup):
46 os.unlink(backup)
47 shutil.move(command_sh, backup)
48 save_sh_command(command_sh)
49
50 # training
51 logger.info("Started training")
52
53 from .train import base_train
54
55 train_output_folder = os.path.join(output_folder, "model")
56 ctx.invoke(
57 base_train,
58 model=model,
59 optimizer=optimizer,
60 scheduler=scheduler,
61 output_folder=train_output_folder,
62 epochs=epochs,
63 batch_size=batch_size,
64 batch_chunk_count=batch_chunk_count,
65 drop_incomplete_batch=drop_incomplete_batch,
66 criterion=criterion,
67 dataset=dataset,
68 checkpoint_period=checkpoint_period,
69 device=device,
70 seed=seed,
71 parallel=parallel,
72 monitoring_interval=monitoring_interval,
73 detection=detection,
74 verbose=verbose,
75 )
76 logger.info("Ended training")
77
78 from .train_analysis import base_train_analysis
79
80 ctx.invoke(
81 base_train_analysis,
82 log=os.path.join(train_output_folder, "trainlog.csv"),
83 constants=os.path.join(train_output_folder, "constants.csv"),
84 output_pdf=os.path.join(train_output_folder, "trainlog.pdf"),
85 verbose=verbose,
86 )
87
88 from .analyze import base_analyze
89
90 # preferably, we use the best model on the validation set
91 # otherwise, we get the last saved model
92 model_file = os.path.join(
93 train_output_folder, "model_lowest_valid_loss.pth"
94 )
95 if not os.path.exists(model_file):
96 model_file = os.path.join(train_output_folder, "model_final_epoch.pth")
97
98 ctx.invoke(
99 base_analyze,
100 model=model,
101 output_folder=output_folder,
102 batch_size=batch_size,
103 dataset=dataset,
104 second_annotator=second_annotator,
105 device=device,
106 overlayed=overlayed,
107 weight=model_file,
108 steps=steps,
109 parallel=parallel,
110 plot_limits=plot_limits,
111 detection=detection,
112 verbose=verbose,
113 )