Coverage for src/deepdraw/data/cxr8/__init__.py: 67%
27 statements
« prev ^ index » next coverage.py v7.3.1, created at 2023-11-30 15:00 +0100
« prev ^ index » next coverage.py v7.3.1, created at 2023-11-30 15:00 +0100
1# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
2#
3# SPDX-License-Identifier: GPL-3.0-or-later
5"""ChestX-ray8: Hospital-scale Chest X-ray Database
7The database contains a total of 112120 images. Image size for each X-ray is
81024 x 1024. One set of mask annotations is available for all images.
10* Reference: [CXR8-2017]_
11* Original resolution (height x width): 1024 x 1024
12* Configuration resolution: 256 x 256 (after rescaling)
13* Split reference: [GAAL-2020]_
14* Protocol ``default``:
16 * Training samples: 78484 (including labels)
17 * Validation samples: 11212 (including labels)
18 * Test samples: 22424 (including labels)
20* Protocol ``idiap``:
22 * Exactly the same as ``default``, except it uses the file organisation
23 suitable to the Idiap Research Institute (where there is limit of 1k files
24 per directory)
26"""
28import os
30import numpy as np
31import pkg_resources
33from PIL import Image
35from ...data.dataset import JSONDataset
36from ...utils.rc import load_rc
37from ..loader import load_pil_rgb, make_delayed
39_protocols = [
40 pkg_resources.resource_filename(__name__, "default.json"),
41]
43_rc = load_rc()
44_root_path = _rc.get("datadir.cxr8", os.path.realpath(os.curdir))
45_idiap_organisation = True
46if os.path.exists(os.path.join(_root_path, "images", "00000001_000.png")):
47 _idiap_organisation = False
50def _raw_data_loader(sample):
51 sample_parts = sample["data"].split("/", 1)
52 sample_path = (
53 os.path.join(sample_parts[0], sample_parts[1][:5], sample_parts[1])
54 if _idiap_organisation
55 else sample["data"]
56 )
57 label_parts = sample["data"].split("/", 1)
58 label_path = (
59 os.path.join(label_parts[0], label_parts[1][:5], label_parts[1])
60 if _idiap_organisation
61 else sample["label"]
62 )
63 retval = dict(
64 data=load_pil_rgb(os.path.join(_root_path, sample_path)),
65 label=np.array(Image.open(os.path.join(_root_path, label_path))),
66 )
68 retval["label"] = np.where(retval["label"] == 255, 0, retval["label"])
69 retval["label"] = Image.fromarray(np.array(retval["label"] > 0))
70 return retval
73def _loader(context, sample):
74 # "context" is ignored in this case - database is homogeneous
75 # we returned delayed samples to avoid loading all images at once
76 return make_delayed(sample, _raw_data_loader)
79dataset = JSONDataset(
80 protocols=_protocols, fieldnames=("data", "label"), loader=_loader
81)
83"""CXR8 dataset object"""