Coverage for src/deepdraw/configs/datasets/csv.py: 64%

11 statements  

« prev     ^ index     » next       coverage.py v7.3.1, created at 2023-11-30 15:00 +0100

1# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> 

2# 

3# SPDX-License-Identifier: GPL-3.0-or-later 

4 

5"""Example CSV-based custom filelist dataset. 

6 

7In case you have your own dataset that is organized on your filesystem (or 

8elsewhere), this configuration shows an example setup so you can feed such data 

9(potentially including any ground-truth you may have) to train, predict or 

10evaluate one of the available network models. 

11 

12You must write CSV based file (e.g. using comma as separator) that describes 

13the data (and ground-truth) locations for each sample on your dataset. So, for 

14example, if you have a file structure like this: 

15 

16.. code-block:: text 

17 

18 ├── images 

19 ├── image_1.png 

20 ├── ... 

21 └── image_n.png 

22 └── ground-truth 

23 ├── gt_1.png 

24 ├── ... 

25 └── gt_n.png 

26 

27Then create one or more files, each containing a subset of your dataset: 

28 

29.. code-block:: text 

30 

31 images/image_1.png,ground-truth/gt_1.png 

32 ...,... 

33 images/image_n.png,ground-truth/gt_n.png 

34 

35To create a subset without ground-truth (e.g., for prediction purposes), then 

36omit the second column on the CSV file. 

37 

38Use the path leading to the CSV file and carefully read the comments in this 

39configuration. **Copy it locally to make changes**: 

40 

41.. code-block:: sh 

42 

43 $ deepdraw config copy csv-dataset-example mydataset.py 

44 # edit mydataset.py as explained here, follow the comments 

45 

46Finally, the only object this file needs to provide is one named ``dataset``, 

47and it should contain a dictionary mapping a name, such as ``train``, ``dev``, 

48or ``test``, to objects of type :py:class:`torch.utils.data.Dataset`. As you 

49will see in this example, we provide boilerplate code to do so. 

50 

51More information: 

52 

53* :py:class:`deepdraw.data.dataset.CSVDataset` for operational details. 

54* :py:class:`deepdraw.data.dataset.JSONDataset` for an alternative for 

55 multi-protocol datasets (all of our supported raw datasets are implemented 

56 using this) 

57* :py:func:`deepdraw.configs.datasets.__init__.make_dataset` for extra 

58 information on the sample list to pytorch connector. 

59""" 

60 

61import os 

62 

63from deepdraw.data.dataset import CSVDataset 

64from deepdraw.data.loader import load_pil_1, load_pil_rgb 

65from deepdraw.data.sample import Sample 

66 

67# How we use the loaders - "sample" is a dictionary where keys are defined 

68# below and map to the columns of the CSV files you input. This one is 

69# configured to load images and labels using PIL. 

70 

71 

72def _loader(context, sample): 

73 # "context" is ignored in this case - database is homogeneous 

74 # it is a dictionary that passes e.g., the name of the subset 

75 # being loaded, so you can take contextual decisions on the loading 

76 

77 # Using the path leading to the various data files stored in disk allows 

78 # the CSV file to contain only relative paths and is, therefore, more 

79 # compact. Of course, you can make those paths absolute and then simplify 

80 # it here. 

81 root_path = "/path/where/raw/files/sit" 

82 

83 data = load_pil_rgb(os.path.join(root_path, sample["data"])) 

84 label = load_pil_1(os.path.join(root_path, sample["label"])) 

85 

86 # You may also return DelayedSample to avoid data loading to take place 

87 # as the sample object itself is created. Take a look at our own datasets 

88 # for examples. 

89 return Sample( 

90 key=os.path.splitext(sample["data"])[0], 

91 data=dict(data=data, label=label), 

92 ) 

93 

94 

95# This is just a class that puts everything together: the CSV file, how to load 

96# each sample defined in the dataset, and names for the various columns of the 

97# CSV file. Once created, this object can be called to generate sample lists. 

98 

99_raw_dataset = CSVDataset( 

100 # path to the CSV file(s) - you may add as many subsets as you want: 

101 # * "__train__" is used for training a model (stock data augmentation is 

102 # applied via our "make_dataset()" connector) 

103 # * anything else can be used for prediction and/or evaluation (if labels 

104 # are also provided in such a set). Data augmentation is NOT applied 

105 # using our "make_dataset()" connector. 

106 subsets={ 

107 "__train__": "<path/to/train.csv>", # applies data augmentation 

108 "train": "<path/to/train.csv>", # no data augmentation, evaluate it 

109 "test": "<path/to/test.csv>", # no data augmentation, evaluate it 

110 }, 

111 fieldnames=("data", "label"), # these are the column names 

112 loader=_loader, 

113) 

114 

115# Finally, we build a connector to passes our dataset to the pytorch framework 

116# so we can, for example, train and evaluate a pytorch model. The connector 

117# only converts the sample lists into a standard tuple (data[, label[, mask]]) 

118# that is expected by our engines, after applying the (optional) 

119# transformations you define. 

120 

121# from deepdraw.configs.datasets import make_dataset as _maker 

122 

123# Add/tune your (optional) transforms below - these are just examples 

124# compatible with a model that requires image inputs of 544 x 544 pixels. 

125# from deepdraw.config.data.transforms import CenterCrop 

126 

127# dataset = _maker(_raw_dataset.subsets(), [CenterCrop((544, 544))])