官术网_书友最值得收藏!

  • Deep Learning with Theano
  • Christopher Bourez
  • 409字
  • 2021-07-15 17:16:59

The MNIST dataset

The Modified National Institute of Standards and Technology (MNIST) dataset is a very well-known dataset of handwritten digits {0,1,2,3,4,5,6,7,8,9} used to train and test classification models.

A classification model is a model that predicts the probabilities of observing a class, given an input.

Training is the task of learning the parameters to fit the model to the data as well as we can so that for any input image, the correct label is predicted. For this training task, the MNIST dataset contains 60,000 images with a target label (a number between 0 and 9) for each example.

To validate that the training is efficient and to decide when to stop the training, we usually split the training dataset into two datasets: 80% to 90% of the images are used for training, while the remaining 10-20% of images will not be presented to the algorithm for training but to validate that the model generalizes well on unobserved data.

There is a separate dataset that the algorithm should never see during training, named the test set, which consists of 10,000 images in the MNIST dataset.

In the MNIST dataset, the input data of each example is a 28x28 normalized monochrome image and a label, represented as a simple integer between 0 and 9 for each example. Let's display some of them:

  1. First, download a pre-packaged version of the dataset that makes it easier to load from Python:
    wget http://www.iro.umontreal.ca/~lisa/deep/data/mnist/mnist.pkl.gz -P /sharedfiles
  2. Then load the data into a Python session:
    import pickle, gzip
    with gzip.open("/sharedfiles/mnist.pkl.gz", 'rb') as f:
       train_set, valid_set, test_set = pickle.load(f)

    For Python3, we need pickle.load(f, encoding='latin1') due to the way it was serialized.

    train_set[0].shape
    (50000, 784)
    
    train_set[1].shape
    (50000,)
    
    import matplotlib
    
    import numpy 
    
    import matplotlib.pyplot as plt
    
    plt.rcParams['figure.figsize'] = (10, 10)
    
    plt.rcParams['image.cmap'] = 'gray'
    
    for i in range(9):
        plt.subplot(1,10,i+1)
        plt.imshow(train_set[0][i].reshape(28,28))
        plt.axis('off')
        plt.title(str(train_set[1][i]))
    
    plt.show()

The first nine samples from the dataset are displayed with the corresponding label (the ground truth, that is, the correct answer expected by the classification algorithm) on top of them:

In order to avoid too many transfers to the GPU, and since the complete dataset is small enough to fit in the memory of the GPU, we usually place the full training set in shared variables:

import theano
train_set_x = theano.shared(numpy.asarray(train_set[0], dtype=theano.config.floatX))
train_set_y = theano.shared(numpy.asarray(train_set[1], dtype='int32'))

Avoiding these data transfers allows us to train faster on the GPU, despite recent GPU and fast PCIe connections.

More information on the dataset is available at http://yann.lecun.com/exdb/mnist/.

主站蜘蛛池模板: 墨竹工卡县| 太湖县| 农安县| 祁连县| 山丹县| 乌恰县| 本溪市| 偏关县| 库伦旗| 阳谷县| 烟台市| 东阳市| 秦皇岛市| 霸州市| 嘉黎县| 六盘水市| 桃源县| 大丰市| 崇信县| 嘉善县| 新巴尔虎右旗| 沁水县| 武汉市| 密云县| 桓台县| 洛隆县| 扶余县| 天门市| 柞水县| 吴川市| 建水县| 岢岚县| 溧阳市| 湘阴县| 洛阳市| 凤城市| 政和县| 射洪县| 元谋县| 建始县| 万载县|