Coverage for src/deepdraw/data/cxr8/__init__.py: 67%

27 statements  

« prev     ^ index     » next       coverage.py v7.3.1, created at 2023-11-30 15:00 +0100

1# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> 

2# 

3# SPDX-License-Identifier: GPL-3.0-or-later 

4 

5"""ChestX-ray8: Hospital-scale Chest X-ray Database 

6 

7The database contains a total of 112120 images. Image size for each X-ray is 

81024 x 1024. One set of mask annotations is available for all images. 

9 

10* Reference: [CXR8-2017]_ 

11* Original resolution (height x width): 1024 x 1024 

12* Configuration resolution: 256 x 256 (after rescaling) 

13* Split reference: [GAAL-2020]_ 

14* Protocol ``default``: 

15 

16 * Training samples: 78484 (including labels) 

17 * Validation samples: 11212 (including labels) 

18 * Test samples: 22424 (including labels) 

19 

20* Protocol ``idiap``: 

21 

22 * Exactly the same as ``default``, except it uses the file organisation 

23 suitable to the Idiap Research Institute (where there is limit of 1k files 

24 per directory) 

25 

26""" 

27 

28import os 

29 

30import numpy as np 

31import pkg_resources 

32 

33from PIL import Image 

34 

35from ...data.dataset import JSONDataset 

36from ...utils.rc import load_rc 

37from ..loader import load_pil_rgb, make_delayed 

38 

39_protocols = [ 

40 pkg_resources.resource_filename(__name__, "default.json"), 

41] 

42 

43_rc = load_rc() 

44_root_path = _rc.get("datadir.cxr8", os.path.realpath(os.curdir)) 

45_idiap_organisation = True 

46if os.path.exists(os.path.join(_root_path, "images", "00000001_000.png")): 

47 _idiap_organisation = False 

48 

49 

50def _raw_data_loader(sample): 

51 sample_parts = sample["data"].split("/", 1) 

52 sample_path = ( 

53 os.path.join(sample_parts[0], sample_parts[1][:5], sample_parts[1]) 

54 if _idiap_organisation 

55 else sample["data"] 

56 ) 

57 label_parts = sample["data"].split("/", 1) 

58 label_path = ( 

59 os.path.join(label_parts[0], label_parts[1][:5], label_parts[1]) 

60 if _idiap_organisation 

61 else sample["label"] 

62 ) 

63 retval = dict( 

64 data=load_pil_rgb(os.path.join(_root_path, sample_path)), 

65 label=np.array(Image.open(os.path.join(_root_path, label_path))), 

66 ) 

67 

68 retval["label"] = np.where(retval["label"] == 255, 0, retval["label"]) 

69 retval["label"] = Image.fromarray(np.array(retval["label"] > 0)) 

70 return retval 

71 

72 

73def _loader(context, sample): 

74 # "context" is ignored in this case - database is homogeneous 

75 # we returned delayed samples to avoid loading all images at once 

76 return make_delayed(sample, _raw_data_loader) 

77 

78 

79dataset = JSONDataset( 

80 protocols=_protocols, fieldnames=("data", "label"), loader=_loader 

81) 

82 

83"""CXR8 dataset object"""