1#!/usr/bin/env python
2# coding=utf-8
3
4"""Tests for our CLI applications"""
5
6import os
7import re
8import contextlib
9from bob.extension import rc
10import pkg_resources
11
12from click.testing import CliRunner
13
14from . import mock_dataset
15
16# Download test data and get their location if needed
17montgomery_datadir = mock_dataset()
18
19_pasa_checkpoint_URL = "http://www.idiap.ch/software/bob/data/bob/bob.med.tb/master/_test_fpasa_checkpoint.pth"
20_signstotb_checkpoint_URL = "http://www.idiap.ch/software/bob/data/bob/bob.med.tb/master/_test_signstotb_checkpoint.pth"
21_logreg_checkpoint_URL = "http://www.idiap.ch/software/bob/data/bob/bob.med.tb/master/_test_logreg_checkpoint.pth"
22#_densenetrs_checkpoint_URL = "http://www.idiap.ch/software/bob/data/bob/bob.med.tb/master/_test_densenetrs_checkpoint.pth"
23
24
25@contextlib.contextmanager
26def rc_context(**new_config):
27 old_rc = rc.copy()
28 rc.update(new_config)
29 try:
30 yield
31 finally:
32 rc.clear()
33 rc.update(old_rc)
34
35
36@contextlib.contextmanager
37def stdout_logging():
38
39 ## copy logging messages to std out
40 import sys
41 import logging
42 import io
43
44 buf = io.StringIO()
45 ch = logging.StreamHandler(buf)
46 ch.setFormatter(logging.Formatter("%(message)s"))
47 ch.setLevel(logging.INFO)
48 logger = logging.getLogger("bob")
49 logger.addHandler(ch)
50 yield buf
51 logger.removeHandler(ch)
52
53
54def _assert_exit_0(result):
55
56 assert (
57 result.exit_code == 0
58 ), f"Exit code {result.exit_code} != 0 -- Output:\n{result.output}"
59
60
61def _data_file(f):
62 return pkg_resources.resource_filename(__name__, os.path.join("data", f))
63
64
65def _check_help(entry_point):
66
67 runner = CliRunner()
68 result = runner.invoke(entry_point, ["--help"])
69 _assert_exit_0(result)
70 assert result.output.startswith("Usage:")
71
72
73def test_config_help():
74 from ..scripts.config import config
75
76 _check_help(config)
77
78
79def test_config_list_help():
80 from ..scripts.config import list
81
82 _check_help(list)
83
84
85def test_config_list():
86 from ..scripts.config import list
87
88 runner = CliRunner()
89 result = runner.invoke(list)
90 _assert_exit_0(result)
91 assert "module: bob.med.tb.configs.datasets" in result.output
92 assert "module: bob.med.tb.configs.models" in result.output
93
94
95def test_config_list_v():
96 from ..scripts.config import list
97
98 result = CliRunner().invoke(list, ["--verbose"])
99 _assert_exit_0(result)
100 assert "module: bob.med.tb.configs.datasets" in result.output
101 assert "module: bob.med.tb.configs.models" in result.output
102
103
104def test_config_describe_help():
105 from ..scripts.config import describe
106
107 _check_help(describe)
108
109
110def test_config_describe_montgomery():
111 from ..scripts.config import describe
112
113 runner = CliRunner()
114 result = runner.invoke(describe, ["montgomery"])
115 _assert_exit_0(result)
116 assert "Montgomery dataset for TB detection" in result.output
117
118
119def test_dataset_help():
120 from ..scripts.dataset import dataset
121
122 _check_help(dataset)
123
124
125def test_dataset_list_help():
126 from ..scripts.dataset import list
127
128 _check_help(list)
129
130
131def test_dataset_list():
132 from ..scripts.dataset import list
133
134 runner = CliRunner()
135 result = runner.invoke(list)
136 _assert_exit_0(result)
137 assert result.output.startswith("Supported datasets:")
138
139
140def test_dataset_check_help():
141 from ..scripts.dataset import check
142
143 _check_help(check)
144
145
146def test_dataset_check():
147 from ..scripts.dataset import check
148
149 runner = CliRunner()
150 result = runner.invoke(check, ["--verbose", "--limit=2"])
151 _assert_exit_0(result)
152
153
154def test_main_help():
155 from ..scripts.tb import tb
156
157 _check_help(tb)
158
159
160def test_train_help():
161 from ..scripts.train import train
162
163 _check_help(train)
164
165
166def _str_counter(substr, s):
167 return sum(1 for _ in re.finditer(substr, s, re.MULTILINE))
168
169
170def test_predict_help():
171 from ..scripts.predict import predict
172
173 _check_help(predict)
174
175
176def test_predtojson_help():
177 from ..scripts.predtojson import predtojson
178
179 _check_help(predtojson)
180
181
182def test_aggregpred_help():
183 from ..scripts.aggregpred import aggregpred
184
185 _check_help(aggregpred)
186
187
188def test_evaluate_help():
189 from ..scripts.evaluate import evaluate
190
191 _check_help(evaluate)
192
193
194def test_compare_help():
195 from ..scripts.compare import compare
196
197 _check_help(compare)
198
199
200def test_train_pasa_montgomery():
201
202 # Temporarily modify Montgomery datadir
203 new_value = {"bob.med.tb.montgomery.datadir": montgomery_datadir}
204 with rc_context(**new_value):
205
206 from ..scripts.train import train
207
208 runner = CliRunner()
209
210 with stdout_logging() as buf:
211
212 output_folder = "results"
213 result = runner.invoke(
214 train,
215 [
216 "pasa",
217 "montgomery",
218 "-vv",
219 "--epochs=1",
220 "--batch-size=1",
221 "--normalization=current",
222 f"--output-folder={output_folder}",
223 ],
224 )
225 _assert_exit_0(result)
226
227 assert os.path.exists(
228 os.path.join(output_folder, "model_final.pth")
229 )
230 assert os.path.exists(
231 os.path.join(output_folder, "model_lowest_valid_loss.pth")
232 )
233 assert os.path.exists(
234 os.path.join(output_folder, "last_checkpoint")
235 )
236 assert os.path.exists(os.path.join(output_folder, "constants.csv"))
237 assert os.path.exists(os.path.join(output_folder, "trainlog.csv"))
238 assert os.path.exists(
239 os.path.join(output_folder, "model_summary.txt")
240 )
241
242 keywords = {
243 r"^Found \(dedicated\) '__train__' set for training$": 1,
244 r"^Found \(dedicated\) '__valid__' set for validation$": 1,
245 r"^Continuing from epoch 0$": 1,
246 r"^Saving model summary at.*$": 1,
247 r"^Model has.*$": 1,
248 r"^Saving checkpoint": 2,
249 r"^Total training time:": 1,
250 r"^Z-normalization with mean": 1,
251 }
252 buf.seek(0)
253 logging_output = buf.read()
254
255 for k, v in keywords.items():
256 assert _str_counter(k, logging_output) == v, (
257 f"Count for string '{k}' appeared "
258 f"({_str_counter(k, logging_output)}) "
259 f"instead of the expected {v}:\nOutput:\n{logging_output}"
260 )
261
262
263def test_predict_pasa_montgomery():
264
265 # Temporarily modify Montgomery datadir
266 new_value = {"bob.med.tb.montgomery.datadir": montgomery_datadir}
267 with rc_context(**new_value):
268
269 from ..scripts.predict import predict
270
271 runner = CliRunner()
272
273 with stdout_logging() as buf:
274
275 output_folder = "predictions"
276 result = runner.invoke(
277 predict,
278 [
279 "pasa",
280 "montgomery",
281 "-vv",
282 "--batch-size=1",
283 "--relevance-analysis",
284 f"--weight={_pasa_checkpoint_URL}",
285 f"--output-folder={output_folder}",
286 ],
287 )
288 _assert_exit_0(result)
289
290 # check predictions are there
291 predictions_file1 = os.path.join(
292 output_folder, "train/predictions.csv"
293 )
294 predictions_file2 = os.path.join(
295 output_folder, "validation/predictions.csv"
296 )
297 predictions_file3 = os.path.join(
298 output_folder, "test/predictions.csv"
299 )
300 assert os.path.exists(predictions_file1)
301 assert os.path.exists(predictions_file2)
302 assert os.path.exists(predictions_file3)
303
304 keywords = {
305 r"^Loading checkpoint from.*$": 1,
306 r"^Total time:.*$": 3,
307 r"^Relevance analysis.*$": 3,
308 }
309 buf.seek(0)
310 logging_output = buf.read()
311
312 for k, v in keywords.items():
313 assert _str_counter(k, logging_output) == v, (
314 f"Count for string '{k}' appeared "
315 f"({_str_counter(k, logging_output)}) "
316 f"instead of the expected {v}:\nOutput:\n{logging_output}"
317 )
318
319
320def test_predtojson():
321
322 # Temporarily modify Montgomery datadir
323 new_value = {"bob.med.tb.montgomery.datadir": montgomery_datadir}
324 with rc_context(**new_value):
325
326 from ..scripts.predtojson import predtojson
327
328 runner = CliRunner()
329
330 with stdout_logging() as buf:
331
332 predictions = _data_file("test_predictions.csv")
333 output_folder = "pred_to_json"
334 result = runner.invoke(
335 predtojson,
336 [
337 "-vv",
338 "train",
339 f"{predictions}",
340 "test",
341 f"{predictions}",
342 f"--output-folder={output_folder}",
343 ],
344 )
345 _assert_exit_0(result)
346
347 # check json file is there
348 assert os.path.exists(os.path.join(output_folder, "dataset.json"))
349
350 keywords = {
351 r"Output folder: pred_to_json": 1,
352 r"Saving JSON file...": 1,
353 r"^Loading predictions from.*$": 2,
354 }
355 buf.seek(0)
356 logging_output = buf.read()
357
358 for k, v in keywords.items():
359 assert _str_counter(k, logging_output) == v, (
360 f"Count for string '{k}' appeared "
361 f"({_str_counter(k, logging_output)}) "
362 f"instead of the expected {v}:\nOutput:\n{logging_output}"
363 )
364
365
366def test_evaluate_pasa_montgomery():
367
368 # Temporarily modify Montgomery datadir
369 new_value = {"bob.med.tb.montgomery.datadir": montgomery_datadir}
370 with rc_context(**new_value):
371
372 from ..scripts.evaluate import evaluate
373
374 runner = CliRunner()
375
376 with stdout_logging() as buf:
377
378 prediction_folder = "predictions"
379 output_folder = "evaluations"
380 result = runner.invoke(
381 evaluate,
382 [
383 "-vv",
384 "montgomery",
385 f"--predictions-folder={prediction_folder}",
386 f"--output-folder={output_folder}",
387 "--threshold=train",
388 "--steps=2000",
389 ],
390 )
391 _assert_exit_0(result)
392
393 # check evaluations are there
394 assert os.path.exists(os.path.join(output_folder, "test.csv"))
395 assert os.path.exists(os.path.join(output_folder, "train.csv"))
396 assert os.path.exists(
397 os.path.join(output_folder, "test_score_table.pdf")
398 )
399 assert os.path.exists(
400 os.path.join(output_folder, "train_score_table.pdf")
401 )
402
403 keywords = {
404 r"^Skipping dataset '__train__'": 1,
405 r"^Evaluating threshold on.*$": 1,
406 r"^Maximum F1-score of.*$": 4,
407 r"^Set --f1_threshold=.*$": 1,
408 r"^Set --eer_threshold=.*$": 1,
409 }
410 buf.seek(0)
411 logging_output = buf.read()
412
413 for k, v in keywords.items():
414 assert _str_counter(k, logging_output) == v, (
415 f"Count for string '{k}' appeared "
416 f"({_str_counter(k, logging_output)}) "
417 f"instead of the expected {v}:\nOutput:\n{logging_output}"
418 )
419
420
421def test_compare_pasa_montgomery():
422
423 # Temporarily modify Montgomery datadir
424 new_value = {"bob.med.tb.montgomery.datadir": montgomery_datadir}
425 with rc_context(**new_value):
426
427 from ..scripts.compare import compare
428
429 runner = CliRunner()
430
431 with stdout_logging() as buf:
432
433 predictions_folder = "predictions"
434 output_folder = "comparisons"
435 result = runner.invoke(
436 compare,
437 [
438 "-vv",
439 "train",
440 f"{predictions_folder}/train/predictions.csv",
441 "test",
442 f"{predictions_folder}/test/predictions.csv",
443 f"--output-figure={output_folder}/compare.pdf",
444 f"--output-table={output_folder}/table.txt",
445 "--threshold=0.5",
446 ],
447 )
448 _assert_exit_0(result)
449
450 # check comparisons are there
451 assert os.path.exists(os.path.join(output_folder, "compare.pdf"))
452 assert os.path.exists(os.path.join(output_folder, "table.txt"))
453
454 keywords = {
455 r"^Dataset '\*': threshold =.*$": 1,
456 r"^Loading predictions from.*$": 2,
457 r"^Tabulating performance summary...": 1,
458 }
459 buf.seek(0)
460 logging_output = buf.read()
461
462 for k, v in keywords.items():
463 assert _str_counter(k, logging_output) == v, (
464 f"Count for string '{k}' appeared "
465 f"({_str_counter(k, logging_output)}) "
466 f"instead of the expected {v}:\nOutput:\n{logging_output}"
467 )
468
469
470def test_train_signstotb_montgomery_rs():
471
472 from ..scripts.train import train
473
474 runner = CliRunner()
475
476 with stdout_logging() as buf:
477
478 output_folder = "results"
479 result = runner.invoke(
480 train,
481 [
482 "signs_to_tb",
483 "montgomery_rs",
484 "-vv",
485 "--epochs=1",
486 "--batch-size=1",
487 f"--weight={_signstotb_checkpoint_URL}",
488 f"--output-folder={output_folder}",
489 ],
490 )
491 _assert_exit_0(result)
492
493 assert os.path.exists(os.path.join(output_folder, "model_final.pth"))
494 assert os.path.exists(
495 os.path.join(output_folder, "model_lowest_valid_loss.pth")
496 )
497 assert os.path.exists(os.path.join(output_folder, "last_checkpoint"))
498 assert os.path.exists(os.path.join(output_folder, "constants.csv"))
499 assert os.path.exists(os.path.join(output_folder, "trainlog.csv"))
500 assert os.path.exists(os.path.join(output_folder, "model_summary.txt"))
501
502 keywords = {
503 r"^Found \(dedicated\) '__train__' set for training$": 1,
504 r"^Found \(dedicated\) '__valid__' set for validation$": 1,
505 r"^Continuing from epoch 0$": 1,
506 r"^Saving model summary at.*$": 1,
507 r"^Model has.*$": 1,
508 r"^Saving checkpoint": 2,
509 r"^Total training time:": 1,
510 }
511 buf.seek(0)
512 logging_output = buf.read()
513
514 for k, v in keywords.items():
515 assert _str_counter(k, logging_output) == v, (
516 f"Count for string '{k}' appeared "
517 f"({_str_counter(k, logging_output)}) "
518 f"instead of the expected {v}:\nOutput:\n{logging_output}"
519 )
520
521
522def test_predict_signstotb_montgomery_rs():
523
524 from ..scripts.predict import predict
525
526 runner = CliRunner()
527
528 with stdout_logging() as buf:
529
530 output_folder = "predictions"
531 result = runner.invoke(
532 predict,
533 [
534 "signs_to_tb",
535 "montgomery_rs",
536 "-vv",
537 "--batch-size=1",
538 "--relevance-analysis",
539 f"--weight={_signstotb_checkpoint_URL}",
540 f"--output-folder={output_folder}",
541 ],
542 )
543 _assert_exit_0(result)
544
545 # check predictions are there
546 predictions_file = os.path.join(output_folder, "train/predictions.csv")
547 RA1 = os.path.join(output_folder, "train_RA.pdf")
548 RA2 = os.path.join(output_folder, "validation_RA.pdf")
549 RA3 = os.path.join(output_folder, "test_RA.pdf")
550 assert os.path.exists(predictions_file)
551 assert os.path.exists(RA1)
552 assert os.path.exists(RA2)
553 assert os.path.exists(RA3)
554
555 keywords = {
556 r"^Loading checkpoint from.*$": 1,
557 r"^Total time:.*$": 3 * 15,
558 r"^Starting relevance analysis for subset.*$": 3,
559 r"^Creating and saving plot at.*$": 3,
560 }
561 buf.seek(0)
562 logging_output = buf.read()
563
564 for k, v in keywords.items():
565 assert _str_counter(k, logging_output) == v, (
566 f"Count for string '{k}' appeared "
567 f"({_str_counter(k, logging_output)}) "
568 f"instead of the expected {v}:\nOutput:\n{logging_output}"
569 )
570
571
572def test_train_logreg_montgomery_rs():
573
574 from ..scripts.train import train
575
576 runner = CliRunner()
577
578 with stdout_logging() as buf:
579
580 output_folder = "results"
581 result = runner.invoke(
582 train,
583 [
584 "logistic_regression",
585 "montgomery_rs",
586 "-vv",
587 "--epochs=1",
588 "--batch-size=1",
589 f"--weight={_logreg_checkpoint_URL}",
590 f"--output-folder={output_folder}",
591 ],
592 )
593 _assert_exit_0(result)
594
595 assert os.path.exists(os.path.join(output_folder, "model_final.pth"))
596 assert os.path.exists(
597 os.path.join(output_folder, "model_lowest_valid_loss.pth")
598 )
599 assert os.path.exists(os.path.join(output_folder, "last_checkpoint"))
600 assert os.path.exists(os.path.join(output_folder, "constants.csv"))
601 assert os.path.exists(os.path.join(output_folder, "trainlog.csv"))
602 assert os.path.exists(os.path.join(output_folder, "model_summary.txt"))
603
604 keywords = {
605 r"^Found \(dedicated\) '__train__' set for training$": 1,
606 r"^Found \(dedicated\) '__valid__' set for validation$": 1,
607 r"^Continuing from epoch 0$": 1,
608 r"^Saving model summary at.*$": 1,
609 r"^Model has.*$": 1,
610 r"^Saving checkpoint": 2,
611 r"^Total training time:": 1,
612 }
613 buf.seek(0)
614 logging_output = buf.read()
615
616 for k, v in keywords.items():
617 assert _str_counter(k, logging_output) == v, (
618 f"Count for string '{k}' appeared "
619 f"({_str_counter(k, logging_output)}) "
620 f"instead of the expected {v}:\nOutput:\n{logging_output}"
621 )
622
623
624def test_predict_logreg_montgomery_rs():
625
626 from ..scripts.predict import predict
627
628 runner = CliRunner()
629
630 with stdout_logging() as buf:
631
632 output_folder = "predictions"
633 result = runner.invoke(
634 predict,
635 [
636 "logistic_regression",
637 "montgomery_rs",
638 "-vv",
639 "--batch-size=1",
640 f"--weight={_logreg_checkpoint_URL}",
641 f"--output-folder={output_folder}",
642 ],
643 )
644 _assert_exit_0(result)
645
646 # check predictions are there
647 predictions_file = os.path.join(output_folder, "train/predictions.csv")
648 wfile = os.path.join(output_folder, "LogReg_Weights.pdf")
649 assert os.path.exists(predictions_file)
650 assert os.path.exists(wfile)
651
652 keywords = {
653 r"^Loading checkpoint from.*$": 1,
654 r"^Total time:.*$": 3,
655 r"^Logistic regression identified: saving model weights.*$": 1,
656 }
657 buf.seek(0)
658 logging_output = buf.read()
659
660 for k, v in keywords.items():
661 assert _str_counter(k, logging_output) == v, (
662 f"Count for string '{k}' appeared "
663 f"({_str_counter(k, logging_output)}) "
664 f"instead of the expected {v}:\nOutput:\n{logging_output}"
665 )
666
667
668def test_aggregpred():
669
670 # Temporarily modify Montgomery datadir
671 new_value = {"bob.med.tb.montgomery.datadir": montgomery_datadir}
672 with rc_context(**new_value):
673
674 from ..scripts.aggregpred import aggregpred
675
676 runner = CliRunner()
677
678 with stdout_logging() as buf:
679
680 predictions = "predictions/train/predictions.csv"
681 output_folder = "aggregpred"
682 result = runner.invoke(
683 aggregpred,
684 [
685 "-vv",
686 f"{predictions}",
687 f"{predictions}",
688 f"--output-folder={output_folder}",
689 ],
690 )
691 _assert_exit_0(result)
692
693 # check csv file is there
694 assert os.path.exists(os.path.join(output_folder, "aggregpred.csv"))
695
696 keywords = {
697 r"Output folder: aggregpred": 1,
698 r"Saving aggregated CSV file...": 1,
699 r"^Loading predictions from.*$": 2,
700 }
701 buf.seek(0)
702 logging_output = buf.read()
703
704 for k, v in keywords.items():
705 assert _str_counter(k, logging_output) == v, (
706 f"Count for string '{k}' appeared "
707 f"({_str_counter(k, logging_output)}) "
708 f"instead of the expected {v}:\nOutput:\n{logging_output}"
709 )
710
711
712# Not enough RAM available to do this test
713# def test_predict_densenetrs_montgomery():
714
715# # Temporarily modify Montgomery datadir
716# new_value = {"bob.med.tb.montgomery.datadir": montgomery_datadir}
717# with rc_context(**new_value):
718
719# from ..scripts.predict import predict
720
721# runner = CliRunner()
722
723# with stdout_logging() as buf:
724
725# output_folder = "predictions"
726# result = runner.invoke(
727# predict,
728# [
729# "densenet_rs",
730# "montgomery_f0_rgb",
731# "-vv",
732# "--batch-size=1",
733# f"--weight={_densenetrs_checkpoint_URL}",
734# f"--output-folder={output_folder}",
735# "--grad-cams"
736# ],
737# )
738# _assert_exit_0(result)
739
740# # check predictions are there
741# predictions_file1 = os.path.join(output_folder, "train/predictions.csv")
742# predictions_file2 = os.path.join(output_folder, "validation/predictions.csv")
743# predictions_file3 = os.path.join(output_folder, "test/predictions.csv")
744# assert os.path.exists(predictions_file1)
745# assert os.path.exists(predictions_file2)
746# assert os.path.exists(predictions_file3)
747# # check some grad cams are there
748# cam1 = os.path.join(output_folder, "train/cams/MCUCXR_0002_0_cam.png")
749# cam2 = os.path.join(output_folder, "train/cams/MCUCXR_0126_1_cam.png")
750# cam3 = os.path.join(output_folder, "train/cams/MCUCXR_0275_1_cam.png")
751# cam4 = os.path.join(output_folder, "validation/cams/MCUCXR_0399_1_cam.png")
752# cam5 = os.path.join(output_folder, "validation/cams/MCUCXR_0113_1_cam.png")
753# cam6 = os.path.join(output_folder, "validation/cams/MCUCXR_0013_0_cam.png")
754# cam7 = os.path.join(output_folder, "test/cams/MCUCXR_0027_0_cam.png")
755# cam8 = os.path.join(output_folder, "test/cams/MCUCXR_0094_0_cam.png")
756# cam9 = os.path.join(output_folder, "test/cams/MCUCXR_0375_1_cam.png")
757# assert os.path.exists(cam1)
758# assert os.path.exists(cam2)
759# assert os.path.exists(cam3)
760# assert os.path.exists(cam4)
761# assert os.path.exists(cam5)
762# assert os.path.exists(cam6)
763# assert os.path.exists(cam7)
764# assert os.path.exists(cam8)
765# assert os.path.exists(cam9)
766
767# keywords = {
768# r"^Loading checkpoint from.*$": 1,
769# r"^Total time:.*$": 3,
770# r"^Grad cams folder:.*$": 3,
771# }
772# buf.seek(0)
773# logging_output = buf.read()
774
775# for k, v in keywords.items():
776# assert _str_counter(k, logging_output) == v, (
777# f"Count for string '{k}' appeared "
778# f"({_str_counter(k, logging_output)}) "
779# f"instead of the expected {v}:\nOutput:\n{logging_output}"
780# )