官术网_书友最值得收藏!

The MNIST dataset

The Modified National Institute of Standards and Technology (MNIST) dataset is a very well-known dataset of handwritten digits {0,1,2,3,4,5,6,7,8,9} used to train and test classification models.

A classification model is a model that predicts the probabilities of observing a class, given an input.

Training is the task of learning the parameters to fit the model to the data as well as we can so that for any input image, the correct label is predicted. For this training task, the MNIST dataset contains 60,000 images with a target label (a number between 0 and 9) for each example.

To validate that the training is efficient and to decide when to stop the training, we usually split the training dataset into two datasets: 80% to 90% of the images are used for training, while the remaining 10-20% of images will not be presented to the algorithm for training but to validate that the model generalizes well on unobserved data.

There is a separate dataset that the algorithm should never see during training, named the test set, which consists of 10,000 images in the MNIST dataset.

In the MNIST dataset, the input data of each example is a 28x28 normalized monochrome image and a label, represented as a simple integer between 0 and 9 for each example. Let's display some of them:

  1. First, download a pre-packaged version of the dataset that makes it easier to load from Python:
    wget http://www.iro.umontreal.ca/~lisa/deep/data/mnist/mnist.pkl.gz -P /sharedfiles
  2. Then load the data into a Python session:
    import pickle, gzip
    with gzip.open("/sharedfiles/mnist.pkl.gz", 'rb') as f:
       train_set, valid_set, test_set = pickle.load(f)

    For Python3, we need pickle.load(f, encoding='latin1') due to the way it was serialized.

    train_set[0].shape
    (50000, 784)
    
    train_set[1].shape
    (50000,)
    
    import matplotlib
    
    import numpy 
    
    import matplotlib.pyplot as plt
    
    plt.rcParams['figure.figsize'] = (10, 10)
    
    plt.rcParams['image.cmap'] = 'gray'
    
    for i in range(9):
        plt.subplot(1,10,i+1)
        plt.imshow(train_set[0][i].reshape(28,28))
        plt.axis('off')
        plt.title(str(train_set[1][i]))
    
    plt.show()

The first nine samples from the dataset are displayed with the corresponding label (the ground truth, that is, the correct answer expected by the classification algorithm) on top of them:

In order to avoid too many transfers to the GPU, and since the complete dataset is small enough to fit in the memory of the GPU, we usually place the full training set in shared variables:

import theano
train_set_x = theano.shared(numpy.asarray(train_set[0], dtype=theano.config.floatX))
train_set_y = theano.shared(numpy.asarray(train_set[1], dtype='int32'))

Avoiding these data transfers allows us to train faster on the GPU, despite recent GPU and fast PCIe connections.

More information on the dataset is available at http://yann.lecun.com/exdb/mnist/.

主站蜘蛛池模板: 司法| 黄山市| 兴化市| 鲁甸县| 德令哈市| 南皮县| 略阳县| 饶河县| 承德县| 大姚县| 旬邑县| 海门市| 利辛县| 凤凰县| 滦南县| 江门市| 松原市| 金乡县| 牙克石市| 华安县| 巴林右旗| 美姑县| 紫阳县| 东平县| 乌兰浩特市| 大丰市| 湘潭市| 金堂县| 类乌齐县| 蕲春县| 沙坪坝区| 万安县| 安陆市| 来凤县| 平远县| 龙山县| 平武县| 库尔勒市| 闵行区| 临沂市| 若羌县|