1#!/usr/bin/env python
2# coding=utf-8
3
4from torch.utils.data.dataset import ConcatDataset
5
6def _maker(protocol):
7
8 if protocol == "default":
9 from ..montgomery_RS import default as mc
10 from ..shenzhen_RS import default as ch
11 elif protocol == "fold_0":
12 from ..montgomery_RS import fold_0 as mc
13 from ..shenzhen_RS import fold_0 as ch
14 elif protocol == "fold_1":
15 from ..montgomery_RS import fold_1 as mc
16 from ..shenzhen_RS import fold_1 as ch
17 elif protocol == "fold_2":
18 from ..montgomery_RS import fold_2 as mc
19 from ..shenzhen_RS import fold_2 as ch
20 elif protocol == "fold_3":
21 from ..montgomery_RS import fold_3 as mc
22 from ..shenzhen_RS import fold_3 as ch
23 elif protocol == "fold_4":
24 from ..montgomery_RS import fold_4 as mc
25 from ..shenzhen_RS import fold_4 as ch
26 elif protocol == "fold_5":
27 from ..montgomery_RS import fold_5 as mc
28 from ..shenzhen_RS import fold_5 as ch
29 elif protocol == "fold_6":
30 from ..montgomery_RS import fold_6 as mc
31 from ..shenzhen_RS import fold_6 as ch
32 elif protocol == "fold_7":
33 from ..montgomery_RS import fold_7 as mc
34 from ..shenzhen_RS import fold_7 as ch
35 elif protocol == "fold_8":
36 from ..montgomery_RS import fold_8 as mc
37 from ..shenzhen_RS import fold_8 as ch
38 elif protocol == "fold_9":
39 from ..montgomery_RS import fold_9 as mc
40 from ..shenzhen_RS import fold_9 as ch
41
42 mc = mc.dataset
43 ch = ch.dataset
44
45 dataset = {}
46 dataset['__train__'] = ConcatDataset([mc["__train__"], ch["__train__"]])
47 dataset['train'] = ConcatDataset([mc["train"], ch["train"]])
48 dataset['__valid__'] = ConcatDataset([mc["__valid__"], ch["__valid__"]])
49 dataset['validation'] = ConcatDataset([mc["validation"], ch["validation"]])
50 dataset['test'] = ConcatDataset([mc["test"], ch["test"]])
51
52 return dataset