- Python Reinforcement Learning Projects
- Sean Saito Yang Wenzhuo Rajalingappaa Shanmugamani
- 255字
- 2021-07-23 19:05:05
Building the network
Multiple deep learning frameworks have already implemented APIs for loading the F-MNIST dataset, including TensorFlow. For our implementation, we will be using Keras, another popular deep learning framework that is integrated with TensorFlow. The Keras datasets module provides a highly convenient interface for loading the datasets as numpy arrays.
Finally, we can start coding! For this exercise, we only need one Python module, which we will call cnn.py. Open up your favorite text editor or IDE, and let's get started.
Our first step is to declare the modules that we are going to use:
import logging
import os
import sys
logger = logging.getLogger(__name__)
import tensorflow as tf
import numpy as np
from keras.datasets import fashion_mnist
from keras.utils import np_utils
The following describes what each module is for and how we will use it:
We will implement our CNN as a class called SimpleCNN. The __init__ constructor takes a number of parameters:
class SimpleCNN(object):
def __init__(self, learning_rate, num_epochs, beta, batch_size):
self.learning_rate = learning_rate
self.num_epochs = num_epochs
self.beta = beta
self.batch_size = batch_size
self.save_dir = "saves"
self.logs_dir = "logs"
os.makedirs(self.save_dir, exist_ok=True)
os.makedirs(self.logs_dir, exist_ok=True)
self.save_path = os.path.join(self.save_dir, "simple_cnn")
self.logs_path = os.path.join(self.logs_dir, "simple_cnn")
The parameters our SimpleCNN is initialized with are described here:
Moreover, save_dir and save_path refer to the locations where we will store our network's parameters. logs_dir and logs_path refer to the locations where the statistics of the training run will be stored (we will show how we can retrieve these logs later).
- Google Cloud Platform Cookbook
- 手把手教你玩轉RPA:基于UiPath和Blue Prism
- 網絡組建與互聯
- 數據庫系統原理及應用教程(第5版)
- 電腦主板現場維修實錄
- Excel 2007技巧大全
- Blender 3D Printing by Example
- Excel 2007常見技法與行業應用實例精講
- Spatial Analytics with ArcGIS
- 電腦上網入門
- 重估:人工智能與賦能社會
- AMK伺服控制系統原理及應用
- Natural Language Processing and Computational Linguistics
- Keras Reinforcement Learning Projects
- Hands-On Generative Adversarial Networks with Keras