1#!/usr/bin/env python
2# coding=utf-8
3
4"""Test code for datasets"""
5
6import os
7
8import pkg_resources
9
10from ..data.dataset import CSVDataset, JSONDataset
11from ..data.sample import Sample
12
13
14def _data_file(f):
15 return pkg_resources.resource_filename(__name__, os.path.join("data", f))
16
17
18def _raw_data_loader(context, d):
19 return Sample(
20 data=[
21 float(d["sepal_length"]),
22 float(d["sepal_width"]),
23 float(d["petal_length"]),
24 float(d["petal_width"]),
25 d["species"][5:],
26 ],
27 key=(context["subset"] + str(context["order"])),
28 )
29
30
31def test_csv_loading():
32
33 # tests if we can build a simple CSV loader for the Iris Flower dataset
34 subsets = {
35 "train": _data_file("iris-train.csv"),
36 "test": _data_file("iris-train.csv"),
37 }
38
39 fieldnames = (
40 "sepal_length",
41 "sepal_width",
42 "petal_length",
43 "petal_width",
44 "species",
45 )
46
47 dataset = CSVDataset(subsets, fieldnames, _raw_data_loader)
48 dataset.check()
49
50 data = dataset.subsets()
51
52 assert len(data["train"]) == 75
53 for k in data["train"]:
54 for f in range(4):
55 assert type(k.data[f]) == float
56 assert type(k.data[4]) == str
57 assert type(k.key) == str
58
59 assert len(data["test"]) == 75
60 for k in data["test"]:
61 for f in range(4):
62 assert type(k.data[f]) == float
63 assert type(k.data[4]) == str
64 assert k.data[4] in ("setosa", "versicolor", "virginica")
65 assert type(k.key) == str
66
67
68def test_json_loading():
69
70 # tests if we can build a simple JSON loader for the Iris Flower dataset
71 protocols = {"default": _data_file("iris.json")}
72
73 fieldnames = (
74 "sepal_length",
75 "sepal_width",
76 "petal_length",
77 "petal_width",
78 "species",
79 )
80
81 dataset = JSONDataset(protocols, fieldnames, _raw_data_loader)
82 dataset.check()
83
84 data = dataset.subsets("default")
85
86 assert len(data["train"]) == 75
87 for k in data["train"]:
88 for f in range(4):
89 assert type(k.data[f]) == float
90 assert type(k.data[4]) == str
91 assert type(k.key) == str
92
93 assert len(data["test"]) == 75
94 for k in data["test"]:
95 for f in range(4):
96 assert type(k.data[f]) == float
97 assert type(k.data[4]) == str
98 assert type(k.key) == str