1#!/usr/bin/env python
2# coding=utf-8
3
4import logging
5import os
6import sys
7
8import numpy
9
10from ...binseg.engine.evaluator import run as run_evaluation
11from ...binseg.engine.significance import (
12 PERFORMANCE_FIGURES,
13 index_of_outliers,
14 sliding_window_performances,
15 visual_performances,
16 write_analysis_figures,
17 write_analysis_text,
18)
19
20logger = logging.getLogger(__name__)
21
22
23def base_significance(
24 names,
25 predictions,
26 dataset,
27 threshold,
28 evaluate,
29 steps,
30 size,
31 stride,
32 figure,
33 output_folder,
34 remove_outliers,
35 remove_zeros,
36 parallel,
37 checkpoint_folder,
38 **kwargs,
39):
40 """Evaluates how significantly different are two models on the same dataset
41
42 This application calculates the significance of results of two models
43 operating on the same dataset, and subject to a priori threshold tunning.
44 """
45
46 def _validate_threshold(t, dataset):
47 """Validate the user threshold selection. Returns parsed threshold."""
48 if t is None:
49 return 0.5
50
51 try:
52 # we try to convert it to float first
53 t = float(t)
54 if t < 0.0 or t > 1.0:
55 raise ValueError(
56 "Float thresholds must be within range [0.0, 1.0]"
57 )
58 except ValueError:
59 # it is a bit of text - assert dataset with name is available
60 if not isinstance(dataset, dict):
61 raise ValueError(
62 "Threshold should be a floating-point number "
63 "if your provide only a single dataset for evaluation"
64 )
65 if t not in dataset:
66 raise ValueError(
67 f"Text thresholds should match dataset names, "
68 f"but {t} is not available among the datasets provided ("
69 f"({', '.join(dataset.keys())})"
70 )
71
72 return t
73
74 def _eval_sliding_windows(
75 system_name,
76 threshold,
77 evaluate,
78 preddir,
79 dataset,
80 steps,
81 size,
82 stride,
83 outdir,
84 figure,
85 nproc,
86 checkpointdir,
87 ):
88 """Calculates the sliding window performances on a dataset
89
90
91 Parameters
92 ==========
93
94 system_name : str
95 The name of the current system being analyzed
96
97 threshold : :py:class:`float`, :py:class:`str`
98 This number is used to define positives and negatives from probability
99 maps, and report F1-scores (a priori). By default, we expect a set
100 named 'validation' to be available at the input data. If that is not
101 the case, we use 'train', if available. You may provide the name of
102 another dataset to be used for threshold tunning otherwise. If not
103 set, or a string is input, threshold tunning is done per system,
104 individually. Optionally, you may also provide a floating-point number
105 between [0.0, 1.0] as the threshold to use for both systems.
106
107 evaluate : str
108 Name of the dataset key to use from ``dataset`` to evaluate (typically,
109 ``test``)
110
111 preddir : str
112 Root path to the predictions generated by system ``system_name``. The
113 final subpath inside ``preddir`` that will be used will have the value
114 of this variable suffixed with the value of ``evaluate``. We will
115 search for ``<preddir>/<evaluate>/<stems>.hdf5``.
116
117 dataset : dict
118 A dictionary mapping string keys to
119 :py:class:`torch.utils.data.dataset.Dataset` instances
120
121 steps : int
122 The number of threshold steps to consider when evaluating the highest
123 possible F1-score on train/test data.
124
125 size : tuple
126 Two values indicating the size of windows to be used for the sliding
127 window analysis. The values represent height and width respectively
128
129 stride : tuple
130 Two values indicating the stride of windows to be used for the sliding
131 window analysis. The values represent height and width respectively
132
133 outdir : str
134 Path where to store visualizations. If set to ``None``, then do not
135 store performance visualizations.
136
137 figure : str
138 The name of a performance figure (e.g. ``f1_score``, ``jaccard``, or
139 ``accuracy``) to use when comparing performances
140
141 nproc : int
142 Sets the number of parallel processes to use when running using
143 multiprocessing. A value of zero uses all reported cores. A value of
144 ``1`` avoids completely the use of multiprocessing and runs all chores
145 in the current processing context.
146
147 checkpointdir : str
148 If set to a string (instead of ``None``), then stores a cached version
149 of the sliding window performances on disk, for a particular system.
150
151
152 Returns
153 =======
154
155 d : dict
156 A dictionary in which keys are filename stems and values are
157 dictionaries with the following contents:
158
159 ``winperf``: numpy.ndarray
160 A dataframe with all the sliding window performances aggregated,
161 for all input images.
162
163 ``n`` : numpy.ndarray
164 A 2D numpy array containing the number of performance scores for
165 every pixel in the original image
166
167 ``avg`` : numpy.ndarray
168 A 2D numpy array containing the average performances for every
169 pixel on the input image considering the sliding window sizes and
170 strides applied to the image
171
172 ``std`` : numpy.ndarray
173 A 2D numpy array containing the (unbiased) standard deviations for
174 the provided performance figure, for every pixel on the input image
175 considering the sliding window sizes and strides applied to the
176 image
177
178 """
179
180 if checkpointdir is not None:
181 chkpt_fname = os.path.join(
182 checkpointdir,
183 f"{system_name}-{evaluate}-{threshold}-"
184 f"{size[0]}x{size[1]}+{stride[0]}x{stride[1]}-{figure}.pkl.gz",
185 )
186 os.makedirs(os.path.dirname(chkpt_fname), exist_ok=True)
187 if os.path.exists(chkpt_fname):
188 logger.info(f"Loading checkpoint from {chkpt_fname}...")
189 # loads and returns checkpoint from file
190 try:
191 with __import__("gzip").GzipFile(chkpt_fname, "r") as f:
192 return __import__("pickle").load(f)
193 except EOFError as e:
194 logger.warning(
195 f"Could not load sliding window performance "
196 f"from {chkpt_fname}: {e}. Calculating..."
197 )
198 else:
199 logger.debug(
200 f"Checkpoint not available at {chkpt_fname}. "
201 f"Calculating..."
202 )
203 else:
204 chkpt_fname = None
205
206 if not isinstance(threshold, float):
207
208 assert threshold in dataset, f"No dataset named '{threshold}'"
209
210 logger.info(
211 f"Evaluating threshold on '{threshold}' set for "
212 f"'{system_name}' using {steps} steps"
213 )
214 threshold = run_evaluation(
215 dataset[threshold], threshold, preddir, steps=steps
216 )
217 logger.info(f"Set --threshold={threshold:.5f} for '{system_name}'")
218
219 # for a given threshold on each system, calculate sliding window performances
220 logger.info(
221 f"Evaluating sliding window '{figure}' on '{evaluate}' set for "
222 f"'{system_name}' using windows of size {size} and stride {stride}"
223 )
224
225 retval = sliding_window_performances(
226 dataset,
227 evaluate,
228 preddir,
229 threshold,
230 size,
231 stride,
232 figure,
233 nproc,
234 outdir,
235 )
236
237 # cache sliding window performance for later use, if necessary
238 if chkpt_fname is not None:
239 logger.debug(f"Storing checkpoint at {chkpt_fname}...")
240 with __import__("gzip").GzipFile(chkpt_fname, "w") as f:
241 __import__("pickle").dump(retval, f)
242
243 return retval
244
245 def _eval_differences(
246 names,
247 perfs,
248 evaluate,
249 dataset,
250 size,
251 stride,
252 outdir,
253 figure,
254 nproc,
255 checkpointdir,
256 ):
257 """Evaluate differences in the performance sliding windows between two systems
258
259 Parameters
260 ----------
261
262 names : :py:class:`tuple` of :py:class:`str`
263 Names of the first and second systems
264
265 perfs : :py:class:`tuple` of :py:class:`dict`
266 Dictionaries for the sliding window performances of each system, as
267 returned by :py:func:`_eval_sliding_windows`
268
269 evaluate : str
270 Name of the dataset key to use from ``dataset`` to evaluate (typically,
271 ``test``)
272
273 dataset : dict
274 A dictionary mapping string keys to
275 :py:class:`torch.utils.data.dataset.Dataset` instances
276
277 size : tuple
278 Two values indicating the size of windows to be used for sliding window
279 analysis. The values represent height and width respectively
280
281 stride : tuple
282 Two values indicating the stride of windows to be used for sliding
283 window analysis. The values represent height and width respectively
284
285 outdir : str
286 If set to ``None``, then do not output performance visualizations.
287 Otherwise, in directory ``outdir``, dumps the visualizations for the
288 performance differences between both systems.
289
290 figure : str
291 The name of a performance figure (e.g. ``f1_score``, or ``jaccard``) to
292 use when comparing performances
293
294 nproc : int
295 Sets the number of parallel processes to use when running using
296 multiprocessing. A value of zero uses all reported cores. A value of
297 ``1`` avoids completely the use of multiprocessing and runs all chores
298 in the current processing context.
299
300 checkpointdir : str
301 If set to a string (instead of ``None``), then stores a cached version
302 of the sliding window performances on disk, for a particular difference
303 between systems.
304
305
306 Returns
307 -------
308
309 d : dict
310 A dictionary representing sliding window performance differences across
311 all files and sliding windows. The format of this is similar to the
312 individual inputs ``perf1`` and ``perf2``.
313
314 """
315
316 if checkpointdir is not None:
317 chkpt_fname = os.path.join(
318 checkpointdir,
319 f"{names[0]}-{names[1]}-{evaluate}-"
320 f"{size[0]}x{size[1]}+{stride[0]}x{stride[1]}-{figure}.pkl.gz",
321 )
322 os.makedirs(os.path.dirname(chkpt_fname), exist_ok=True)
323 if os.path.exists(chkpt_fname):
324 logger.info(f"Loading checkpoint from {chkpt_fname}...")
325 # loads and returns checkpoint from file
326 try:
327 with __import__("gzip").GzipFile(chkpt_fname, "r") as f:
328 return __import__("pickle").load(f)
329 except EOFError as e:
330 logger.warning(
331 f"Could not load sliding window performance "
332 f"from {chkpt_fname}: {e}. Calculating..."
333 )
334 else:
335 logger.debug(
336 f"Checkpoint not available at {chkpt_fname}. "
337 f"Calculating..."
338 )
339 else:
340 chkpt_fname = None
341
342 perf_diff = dict(
343 [
344 (k, perfs[0][k]["winperf"] - perfs[1][k]["winperf"])
345 for k in perfs[0]
346 ]
347 )
348
349 # for a given threshold on each system, calculate sliding window performances
350 logger.info(
351 f"Evaluating sliding window '{figure}' differences on '{evaluate}' "
352 f"set on '{names[0]}-{names[1]}' using windows of size {size} and "
353 f"stride {stride}"
354 )
355
356 retval = visual_performances(
357 dataset,
358 evaluate,
359 perf_diff,
360 size,
361 stride,
362 figure,
363 nproc,
364 outdir,
365 )
366
367 # cache sliding window performance for later use, if necessary
368 if chkpt_fname is not None:
369 logger.debug(f"Storing checkpoint at {chkpt_fname}...")
370 with __import__("gzip").GzipFile(chkpt_fname, "w") as f:
371 __import__("pickle").dump(retval, f)
372
373 return retval
374
375 # minimal validation to startup
376 threshold = _validate_threshold(threshold, dataset)
377 assert evaluate in dataset, f"No dataset named '{evaluate}'"
378
379 perf1 = _eval_sliding_windows(
380 names[0],
381 threshold,
382 evaluate,
383 predictions[0],
384 dataset,
385 steps,
386 size,
387 stride,
388 (
389 output_folder
390 if output_folder is None
391 else os.path.join(output_folder, names[0])
392 ),
393 figure,
394 parallel,
395 checkpoint_folder,
396 )
397
398 perf2 = _eval_sliding_windows(
399 names[1],
400 threshold,
401 evaluate,
402 predictions[1],
403 dataset,
404 steps,
405 size,
406 stride,
407 (
408 output_folder
409 if output_folder is None
410 else os.path.join(output_folder, names[1])
411 ),
412 figure,
413 parallel,
414 checkpoint_folder,
415 )
416
417 # perf_diff = _eval_differences(
418 # names,
419 # (perf1, perf2),
420 # evaluate,
421 # dataset,
422 # size,
423 # stride,
424 # (
425 # output_folder
426 # if output_folder is None
427 # else os.path.join(output_folder, "diff")
428 # ),
429 # figure,
430 # parallel,
431 # checkpoint_folder,
432 # )
433
434 # loads all figures for the given threshold
435 stems = list(perf1.keys())
436 figindex = PERFORMANCE_FIGURES.index(figure)
437 da = numpy.array([perf1[k]["winperf"][figindex] for k in stems]).flatten()
438 db = numpy.array([perf2[k]["winperf"][figindex] for k in stems]).flatten()
439 diff = da - db
440
441 while remove_outliers:
442 outliers_diff = index_of_outliers(diff)
443 if sum(outliers_diff) == 0:
444 break
445 diff = diff[~outliers_diff]
446 da = da[~outliers_diff]
447 db = db[~outliers_diff]
448
449 if remove_zeros:
450 remove_zeros = (da == 0) & (db == 0)
451 diff = diff[~remove_zeros]
452 da = da[~remove_zeros]
453 db = db[~remove_zeros]
454
455 if output_folder is not None:
456 fname = os.path.join(output_folder, "analysis.pdf")
457 os.makedirs(os.path.dirname(fname), exist_ok=True)
458 logger.info(f"Writing analysis figures to {fname} (multipage PDF)...")
459 write_analysis_figures(names, da, db, fname)
460
461 if output_folder is not None:
462 fname = os.path.join(output_folder, "analysis.txt")
463 os.makedirs(os.path.dirname(fname), exist_ok=True)
464 logger.info(f"Writing analysis summary to {fname}...")
465 with open(fname, "wt") as f:
466 write_analysis_text(names, da, db, f)
467 write_analysis_text(names, da, db, sys.stdout)