#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
from PIL import Image
import json
from pathlib import Path
import pkg_resources
import itertools
from .models import Sample, FundusImage, GroundTruth
import logging
[docs]class FileList:
"""
FileList object that loads the protocol as defined in a json file.
Provides a ``__getitem__`` interface.
"""
def __init__(self, db_json):
with open(db_json,'r') as in_file:
self._filelist = json.load(in_file)
def __getitem__(self, key):
# if no valid split is passed, return all paths of train and test
return self._filelist.get(key, (self._filelist['train']+self._filelist['test']))
[docs]class Database:
"""
A low level database interface to be used with PyTorch or other deep learning frameworks.
Attributes
----------
protocol : str
protocol defining the train-test split.
"""
def __init__(self, protocol = 'default_od'):
self.protocol = protocol
root = Path(pkg_resources.resource_filename(__name__, ''))
db_json = root.joinpath('refuge_db_'+self.protocol+'.json')
# initialize filelist
self._filelist = FileList(db_json)
# set threshold for gt
if self.protocol.split('_')[1] == 'od':
self.threshold = 1
elif self.protocol.split('_')[1] == 'cup':
self.threshold = 0.5
else:
self.threshold = 1
logging.warning(f'Unknown protocol type "{self.protocol.split("_")[1]}", setting threshold to 1')
@property
def paths(self):
"""
Returns
-------
paths : list
list of all paths of all samples
"""
return list(itertools.chain(*(self._filelist[None]) ))
def _make_sample(self, img_path,gt_path):
"""
Make a single sample object
Parameters
----------
img_path : str
relative path to image
gt_path : str
relative path to ground truth
Returns
-------
sample : Sample
"""
img = FundusImage(img_path)
# threshold to get binary od or cup mask
gt = GroundTruth(gt_path, threshold=self.threshold)
return Sample(img, gt)
[docs] def samples(self, split=None):
"""
Given a split, returns a list of Sample objects.
Parameters
----------
split : str
'train', 'test' or None (returns all samples)
Returns
-------
samples : list
list of Sample objects
"""
samples = []
for s in self._filelist[split]:
sample_obj = self._make_sample(s[0], s[1])
samples.append(sample_obj)
return samples