Source code for bob.db.mnist.query

#!/usr/bin/env python
# vim: set fileencoding=utf-8 :


import os
import shutil
import struct
import gzip
import numpy

from bob.db.base.utils import check_parameters_for_validity


class Database:
  """Wrapper class for the MNIST database of handwritten digits.

  The original database files are distributed over:
  http://yann.lecun.com/exdb/mnist/.
  """


  def __init__(self):

    from .driver import Interface
    f = Interface().files()

    self.train_images = f[0]
    self.train_labels = f[1]
    self.test_images  = f[2]
    self.test_labels  = f[3]

    self._labels = set(range(0,10))
    self._groups = ('train', 'test')


  def _read_labels(self, fname):
    """Reads the labels from the original MNIST label binary file"""

    with gzip.open(fname, 'rb') as f:

      # reads 2 big-ending integers
      magic_nr, n_examples = struct.unpack(">II", f.read(8))
      # reads the rest, using an uint8 dataformat (endian-less)

      labels = numpy.fromstring(f.read(), dtype='uint8')

      return labels


  def _read_images(self, fname):
    """Reads the images from the original MNIST label binary file"""

    with gzip.open(fname, 'rb') as f:

      # reads 4 big-ending integers
      magic_nr, n_examples, rows, cols = struct.unpack(">IIII", f.read(16))
      shape = (n_examples, rows*cols)

      # reads the rest, using an uint8 dataformat (endian-less)
      images = numpy.fromstring(f.read(), dtype='uint8').reshape(shape)

      return images


[docs] def labels(self): """Returns the vector of labels """ return self._labels
[docs] def groups(self): """Returns the vector of groups """ return self._groups
[docs] def data(self, groups=None, labels=None): """Loads the MNIST samples and labels and returns them in NumPy arrays Parameters: groups (:py:class:`str` or :py:class:`list`): One of the groups ``train`` or ``test``, or a list with both of them (which is the default) labels (:py:class:`int` or :py:class:`list`): A subset of the labels (digits 0 to 9) (everything is the default) Returns: numpy.ndarray: A 2D array representing the digit images, with as many rows as examples in the dataset, as many columns as pixels (actually, there are 28x28 = 784 rows). The pixels of each image are unrolled in C-scan order (i.e., first row 0, then row 1, etc.) numpy.ndarray: A 1D array with as many elements as examples in the dataset, containing the labels for each image returned above. """ # check if groups set are valid groups = check_parameters_for_validity(groups, "group", self._groups) vlabels = check_parameters_for_validity(labels, "label", self._labels) # Reads data from the groups if 'train' in groups and 'test' in groups: images1 = self._read_images(self.train_images) labels1 = self._read_labels(self.train_labels) images2 = self._read_images(self.test_images) labels2 = self._read_labels(self.test_labels) images = numpy.vstack([images1,images2]) labels = numpy.hstack([labels1,labels2]) elif 'train' in groups: images = self._read_images(self.train_images) labels = self._read_labels(self.train_labels) elif 'test' in groups: images = self._read_images(self.test_images) labels = self._read_labels(self.test_labels) else: images = numpy.ndarray(shape=(0,784), dtype=numpy.uint8) labels = numpy.ndarray(shape=(0,), dtype=numpy.uint8) # List of indices for which the labels are in the list of requested labels indices = numpy.where(numpy.array([v in vlabels for v in labels]))[0] images = images[indices,:] labels = labels[indices] return images, labels