1#!/usr/bin/env python
2# coding=utf-8
3
4import logging
5import os
6
7import click
8
9from .common import save_sh_command
10
11logger = logging.getLogger(__name__)
12
13
14@click.pass_context
15def base_analyze(
16 ctx,
17 model,
18 output_folder,
19 batch_size,
20 dataset,
21 second_annotator,
22 device,
23 overlayed,
24 weight,
25 steps,
26 parallel,
27 plot_limits,
28 verbose,
29 detection,
30 **kwargs,
31):
32 """Create base analyze function for segmentation / detection tasks."""
33 command_sh = os.path.join(output_folder, "command.sh")
34 if not os.path.exists(command_sh):
35 # only save if experiment has not saved yet something similar
36 save_sh_command(command_sh)
37
38 # Prediction
39 logger.info("Started prediction")
40
41 from .predict import base_predict
42
43 predictions_folder = os.path.join(output_folder, "predictions")
44 overlayed_folder = (
45 os.path.join(output_folder, "overlayed", "predictions")
46 if overlayed
47 else None
48 )
49
50 ctx.invoke(
51 base_predict,
52 output_folder=predictions_folder,
53 model=model,
54 dataset=dataset,
55 batch_size=batch_size,
56 device=device,
57 weight=weight,
58 overlayed=overlayed_folder,
59 parallel=parallel,
60 detection=detection,
61 verbose=verbose,
62 )
63 logger.info("Ended prediction")
64
65 # Evaluation
66 logger.info("Started evaluation")
67
68 from .evaluate import base_evaluate
69
70 overlayed_folder = (
71 os.path.join(output_folder, "overlayed", "analysis")
72 if overlayed
73 else None
74 )
75
76 # choosing the overlayed_threshold
77 if "validation" in dataset:
78 threshold = "validation"
79 elif "train" in dataset:
80 threshold = "train"
81 else:
82 threshold = 0.5
83 logger.info(f"Setting --threshold={threshold}...")
84
85 analysis_folder = os.path.join(output_folder, "analysis")
86 ctx.invoke(
87 base_evaluate,
88 output_folder=analysis_folder,
89 predictions_folder=predictions_folder,
90 dataset=dataset,
91 second_annotator=second_annotator,
92 overlayed=overlayed_folder,
93 threshold=threshold,
94 steps=steps,
95 parallel=parallel,
96 detection=detection,
97 verbose=verbose,
98 )
99
100 logger.info("Ended evaluation")
101
102 # Comparison
103 logger.info("Started comparison")
104
105 # compare performances on the various sets
106 from .compare import base_compare
107
108 systems = []
109 for k, v in dataset.items():
110 if k.startswith("_"):
111 logger.info(f"Skipping dataset '{k}' (not to be compared)")
112 continue
113 candidate = os.path.join(analysis_folder, f"{k}.csv")
114 if not os.path.exists(candidate):
115 logger.error(
116 f"Skipping dataset '{k}' "
117 f"(candidate CSV file `{candidate}` does not exist!)"
118 )
119 continue
120 systems += [k, os.path.join(analysis_folder, f"{k}.csv")]
121 if second_annotator is not None:
122 for k, v in second_annotator.items():
123 if k.startswith("_"):
124 logger.info(
125 f"Skipping second-annotator '{k}' " f"(not to be compared)"
126 )
127 continue
128 if k not in dataset:
129 logger.info(
130 f"Skipping second-annotator '{k}' "
131 f"(no equivalent `dataset[{k}]`)"
132 )
133 continue
134 if not dataset[k].all_keys_match(v):
135 logger.warning(
136 f"Skipping second-annotator '{k}' "
137 f"(keys do not match `dataset[{k}]`?)"
138 )
139 continue
140 candidate = os.path.join(
141 analysis_folder, "second-annotator", f"{k}.csv"
142 )
143 if not os.path.exists(candidate):
144 logger.error(
145 f"Skipping second-annotator '{k}' "
146 f"(candidate CSV file `{candidate}` does not exist!)"
147 )
148 continue
149 systems += [f"{k} (2nd. annot.)", candidate]
150
151 output_figure = os.path.join(output_folder, "comparison.pdf")
152 output_table = os.path.join(output_folder, "comparison.rst")
153
154 ctx.invoke(
155 base_compare,
156 label_path=systems,
157 output_figure=output_figure,
158 output_table=output_table,
159 threshold=threshold,
160 plot_limits=plot_limits,
161 detection=detection,
162 verbose=verbose,
163 )
164
165 logger.info("Ended comparison")