Coverage for src/mednet/config/data/nih_cxr14/datamodule.py: 79%

39 statements  

« prev     ^ index     » next       coverage.py v7.4.2, created at 2024-04-30 11:44 +0200

1# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> 

2# 

3# SPDX-License-Identifier: GPL-3.0-or-later 

4"""NIH CXR14 (relabeled) DataModule for computer-aided diagnosis. 

5 

6Database reference: [NIH-CXR14-2017]_ 

7""" 

8 

9import importlib.resources 

10import os 

11import pathlib 

12 

13import PIL.Image 

14from torchvision.transforms.functional import to_tensor 

15 

16from ....data.datamodule import CachingDataModule 

17from ....data.split import JSONDatabaseSplit 

18from ....data.typing import DatabaseSplit, Sample 

19from ....data.typing import RawDataLoader as _BaseRawDataLoader 

20from ....utils.rc import load_rc 

21 

22CONFIGURATION_KEY_DATADIR = "datadir." + (__name__.rsplit(".", 2)[-2]) 

23"""Key to search for in the configuration file for the root directory of this 

24database. 

25""" 

26 

27CONFIGURATION_KEY_IDIAP_FILESTRUCTURE = ( 

28 (__name__.rsplit(".", 2)[-2]) + ".idiap_folder_structure" 

29) 

30"""Key to search for in the configuration file indicating if the loader should 

31use standard or idiap-based file organisation structure. 

32 

33It causes the internal loader to search for files in a slightly 

34different folder structure, that was adapted to Idiap's requirements 

35(number of files per folder to be less than 10k). 

36""" 

37 

38 

39class RawDataLoader(_BaseRawDataLoader): 

40 """A specialized raw-data-loader for the NIH CXR-14 dataset.""" 

41 

42 datadir: pathlib.Path 

43 """This variable contains the base directory where the database raw data is 

44 stored.""" 

45 

46 idiap_file_organisation: bool 

47 """If should use the Idiap's filesystem organisation when looking up data. 

48 

49 This variable will be ``True``, if the user has set the configuration 

50 parameter ``nih_cxr14.idiap_file_organisation`` in the global configuration 

51 file. It will cause internal loader to search for files in a slightly 

52 different folder structure, that was adapted to Idiap's requirements 

53 (number of files per folder to be less than 10k). 

54 """ 

55 

56 def __init__(self): 

57 rc = load_rc() 

58 self.datadir = pathlib.Path( 

59 rc.get(CONFIGURATION_KEY_DATADIR, os.path.realpath(os.curdir)), 

60 ) 

61 self.idiap_file_organisation = rc.get( 

62 CONFIGURATION_KEY_IDIAP_FILESTRUCTURE, 

63 False, 

64 ) 

65 

66 def sample(self, sample: tuple[str, list[int]]) -> Sample: 

67 """Load a single image sample from the disk. 

68 

69 Parameters 

70 ---------- 

71 sample 

72 A tuple containing the path suffix, within the dataset root folder, 

73 where to find the image to be loaded, and an integer, representing 

74 the sample label. 

75 

76 Returns 

77 ------- 

78 The sample representation. 

79 """ 

80 

81 file_path = pathlib.Path(sample[0]) # default 

82 if self.idiap_file_organisation: 

83 # for folder lookup efficiency, data is split into subfolders 

84 # each original file is on the subfolder `f[:5]/f`, where f 

85 # is the original file basename 

86 file_path = file_path.parent / file_path.name[:5] / file_path.name 

87 

88 # N.B.: some NIH CXR-14 images are encoded as color PNGs with an alpha 

89 # channel. Most, are grayscale PNGs 

90 image = PIL.Image.open(self.datadir / file_path) 

91 image = image.convert("L") # required for some images 

92 tensor = to_tensor(image) 

93 

94 # use the code below to view generated images 

95 # from torchvision.transforms.functional import to_pil_image 

96 # to_pil_image(tensor).show() 

97 # __import__("pdb").set_trace() 

98 

99 return tensor, dict(label=sample[1], name=sample[0]) # type: ignore[arg-type] 

100 

101 def label(self, sample: tuple[str, list[int]]) -> list[int]: 

102 """Load a single image sample label from the disk. 

103 

104 Parameters 

105 ---------- 

106 sample 

107 A tuple containing the path suffix, within the dataset root folder, 

108 where to find the image to be loaded, and an integer, representing the 

109 sample label. 

110 

111 Returns 

112 ------- 

113 list[int] 

114 The integer labels associated with the sample. 

115 """ 

116 

117 return sample[1] 

118 

119 

120def make_split(basename: str) -> DatabaseSplit: 

121 """Return a database split for the NIH CXR-14 database. 

122 

123 Parameters 

124 ---------- 

125 basename 

126 Name of the .json file containing the split to load. 

127 

128 Returns 

129 ------- 

130 An instance of DatabaseSplit. 

131 """ 

132 

133 return JSONDatabaseSplit( 

134 importlib.resources.files(__name__.rsplit(".", 1)[0]).joinpath( 

135 basename, 

136 ), 

137 ) 

138 

139 

140class DataModule(CachingDataModule): 

141 """NIH CXR14 (relabeled) DataModule for computer-aided diagnosis. 

142 

143 This dataset was extracted from the clinical PACS database at the National 

144 Institutes of Health Clinical Center (USA) and represents 60% of all their 

145 radiographs. It contains labels for 14 common radiological signs in this 

146 order: cardiomegaly, emphysema, effusion, hernia, infiltration, mass, 

147 nodule, atelectasis, pneumothorax, pleural thickening, pneumonia, fibrosis, 

148 edema and consolidation. This is the relabeled version created in the 

149 CheXNeXt study. 

150 

151 * Reference: [NIH-CXR14-2017]_ 

152 * Raw data input (on disk): 

153 

154 * PNG RGB 8-bit depth images 

155 * Resolution: 1024 x 1024 pixels 

156 

157 * Labels: [CHEXNEXT-2018]_ 

158 * Split reference: [CHEXNEXT-2018]_ 

159 * Output image: 

160 

161 * Transforms: 

162 

163 * Load raw PNG with :py:mod:`PIL` 

164 * Convert to torch tensor 

165 

166 * Final specifications: 

167 

168 * RGB, encoded as a 3-plane tensor, 32-bit floats, square 

169 (1024x1024 px) 

170 * Labels in order: 

171 * cardiomegaly 

172 * emphysema 

173 * effusion 

174 * hernia 

175 * infiltration 

176 * mass 

177 * nodule 

178 * atelectasis 

179 * pneumothorax 

180 * pleural thickening 

181 * pneumonia 

182 * fibrosis 

183 * edema 

184 * consolidation 

185 

186 Parameters 

187 ---------- 

188 split_filename 

189 Name of the .json file containing the split to load. 

190 """ 

191 

192 def __init__(self, split_filename: str): 

193 super().__init__( 

194 database_split=make_split(split_filename), 

195 raw_data_loader=RawDataLoader(), 

196 database_name=__package__.split(".")[-1], 

197 split_name=pathlib.Path(split_filename).stem, 

198 )