1#!/usr/bin/env python
2# coding=utf-8
3
4"""Test code for datasets"""
5
6import os
7import pkg_resources
8
9from ..data.dataset import CSVDataset, JSONDataset
10from ..data.sample import Sample
11
12
13def _data_file(f):
14 return pkg_resources.resource_filename(__name__, os.path.join("data", f))
15
16
17def _raw_data_loader(context, d):
18 return Sample(
19 data=[
20 float(d["sepal_length"]),
21 float(d["sepal_width"]),
22 float(d["petal_length"]),
23 float(d["petal_width"]),
24 d["species"][5:],
25 ],
26 key=(context["subset"] + str(context["order"]))
27 )
28
29
30def test_csv_loading():
31
32 # tests if we can build a simple CSV loader for the Iris Flower dataset
33 subsets = {
34 "train": _data_file("iris-train.csv"),
35 "test": _data_file("iris-train.csv")
36 }
37
38 fieldnames = (
39 "sepal_length",
40 "sepal_width",
41 "petal_length",
42 "petal_width",
43 "species",
44 )
45
46 dataset = CSVDataset(subsets, fieldnames, _raw_data_loader)
47 dataset.check()
48
49 data = dataset.subsets()
50
51 assert len(data["train"]) == 75
52 for k in data["train"]:
53 for f in range(4):
54 assert type(k.data[f]) == float
55 assert type(k.data[4]) == str
56 assert type(k.key) == str
57
58 assert len(data["test"]) == 75
59 for k in data["test"]:
60 for f in range(4):
61 assert type(k.data[f]) == float
62 assert type(k.data[4]) == str
63 assert k.data[4] in ("setosa", "versicolor", "virginica")
64 assert type(k.key) == str
65
66
67def test_json_loading():
68
69 # tests if we can build a simple JSON loader for the Iris Flower dataset
70 protocols = {"default": _data_file("iris.json")}
71
72 fieldnames = (
73 "sepal_length",
74 "sepal_width",
75 "petal_length",
76 "petal_width",
77 "species",
78 )
79
80 dataset = JSONDataset(protocols, fieldnames, _raw_data_loader)
81 dataset.check()
82
83 data = dataset.subsets("default")
84
85 assert len(data["train"]) == 75
86 for k in data["train"]:
87 for f in range(4):
88 assert type(k.data[f]) == float
89 assert type(k.data[4]) == str
90 assert type(k.key) == str
91
92 assert len(data["test"]) == 75
93 for k in data["test"]:
94 for f in range(4):
95 assert type(k.data[f]) == float
96 assert type(k.data[4]) == str
97 assert type(k.key) == str