官术网_书友最值得收藏!

The MNIST dataset

The Modified National Institute of Standards and Technology (MNIST) dataset is a very well-known dataset of handwritten digits {0,1,2,3,4,5,6,7,8,9} used to train and test classification models.

A classification model is a model that predicts the probabilities of observing a class, given an input.

Training is the task of learning the parameters to fit the model to the data as well as we can so that for any input image, the correct label is predicted. For this training task, the MNIST dataset contains 60,000 images with a target label (a number between 0 and 9) for each example.

To validate that the training is efficient and to decide when to stop the training, we usually split the training dataset into two datasets: 80% to 90% of the images are used for training, while the remaining 10-20% of images will not be presented to the algorithm for training but to validate that the model generalizes well on unobserved data.

There is a separate dataset that the algorithm should never see during training, named the test set, which consists of 10,000 images in the MNIST dataset.

In the MNIST dataset, the input data of each example is a 28x28 normalized monochrome image and a label, represented as a simple integer between 0 and 9 for each example. Let's display some of them:

  1. First, download a pre-packaged version of the dataset that makes it easier to load from Python:
    wget http://www.iro.umontreal.ca/~lisa/deep/data/mnist/mnist.pkl.gz -P /sharedfiles
  2. Then load the data into a Python session:
    import pickle, gzip
    with gzip.open("/sharedfiles/mnist.pkl.gz", 'rb') as f:
       train_set, valid_set, test_set = pickle.load(f)

    For Python3, we need pickle.load(f, encoding='latin1') due to the way it was serialized.

    train_set[0].shape
    (50000, 784)
    
    train_set[1].shape
    (50000,)
    
    import matplotlib
    
    import numpy 
    
    import matplotlib.pyplot as plt
    
    plt.rcParams['figure.figsize'] = (10, 10)
    
    plt.rcParams['image.cmap'] = 'gray'
    
    for i in range(9):
        plt.subplot(1,10,i+1)
        plt.imshow(train_set[0][i].reshape(28,28))
        plt.axis('off')
        plt.title(str(train_set[1][i]))
    
    plt.show()

The first nine samples from the dataset are displayed with the corresponding label (the ground truth, that is, the correct answer expected by the classification algorithm) on top of them:

In order to avoid too many transfers to the GPU, and since the complete dataset is small enough to fit in the memory of the GPU, we usually place the full training set in shared variables:

import theano
train_set_x = theano.shared(numpy.asarray(train_set[0], dtype=theano.config.floatX))
train_set_y = theano.shared(numpy.asarray(train_set[1], dtype='int32'))

Avoiding these data transfers allows us to train faster on the GPU, despite recent GPU and fast PCIe connections.

More information on the dataset is available at http://yann.lecun.com/exdb/mnist/.

主站蜘蛛池模板: 姜堰市| 邵武市| 渝北区| 台前县| 印江| 盐山县| 仙游县| 雷州市| 望江县| 韶关市| 登封市| 延津县| 武鸣县| 项城市| 梅州市| 上饶市| 鄂州市| 友谊县| 瓦房店市| 德兴市| 澳门| 阿拉善右旗| 左贡县| 全州县| 调兵山市| 疏勒县| 阿拉善右旗| 北辰区| 新田县| 台北县| 东山县| 庆城县| 九江市| 甘肃省| 绍兴县| 贵州省| 专栏| 姜堰市| 重庆市| 峨山| 青州市|