官术网_书友最值得收藏!

Playing Rock, Paper, Scissors with LSTMs

Remembering sequences of data have huge applications in many areas, not the least of which includes gaming. Of course, producing a simple, clean example is another matter. Fortunately, examples abound on the internet and Chapter_2_5.py shows an example of using an LSTM to play Rock, Paper, Scissors.

Open up that sample file and follow these steps:

This example was pulled from https://github.com/hjpulkki/RPS, but the code needed to be tweaked in several places to get it to work for us.  
  1. Let's start as we normally do with the imports. For this sample, be sure to have Keras installed as we did for the last set of exercises:
import numpy as np
from keras.utils import np_utils
from keras.models import Sequential
from keras.layers import Dense, LSTM
  1. Then, we set some constants as shown:
EPOCH_NP = 100
INPUT_SHAPE = (1, -1, 1)
OUTPUT_SHAPE = (1, -1, 3)
DATA_FILE = "data.txt"
MODEL_FILE = "RPS_model.h5"
  1. Then, we build the model, this time with three LSTM layers, one for each element in our sequence (rock, paper and scissors), like so:
def simple_model(): 
new_model = Sequential()
new_model.add(LSTM(output_dim=64, input_dim=1, return_sequences=True, activation='sigmoid'))
new_model.add(LSTM(output_dim=64, return_sequences=True, activation='sigmoid'))
new_model.add(LSTM(output_dim=64, return_sequences=True, activation='sigmoid'))
new_model.add(Dense(64, activation='relu'))
new_model.add(Dense(64, activation='relu'))
new_model.add(Dense(3, activation='softmax'))
new_model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy', 'categorical_crossentropy'])
return new_model

  1. Then we create a function to extract our data from the data.txt file. This file holds the sequences of training data using the following code:
def batch_generator(filename): 
with open('data.txt', 'r') as data_file:
for line in data_file:
data_vector = np.array(list(line[:-1]))
input_data = data_vector[np.newaxis, :-1, np.newaxis]
temp = np_utils.to_categorical(data_vector, num_classes=3)
output_data = temp[np.newaxis, 1:]
yield (input_data, output_data)
  1. In this example, we are training each block of training through 100 epochs in the same order as they are in the file. A better method would be to train each training sequence in a random order.
  2. Then we create the model:
# Create model
np.random.seed(7)
model = simple_model()
  1. Train the data using a loop, with each iteration pulling a batch from the data.txt file:
for (input_data, output_data) in batch_generator('data.txt'):
try:
model.fit(input_data, output_data, epochs=100, batch_size=100)
except:
print("error")
  1. Finally, we evaluate the results with a validation sequence as shown in this code:
print("evaluating")
validation = '100101000110221110101002201101101101002201011012222210221011011101011122110010101010101'
input_validation = np.array(list(validation[:-1])).reshape(INPUT_SHAPE)
output_validation = np_utils.to_categorical(np.array(list(validation[1:]))).reshape(OUTPUT_SHAPE)
loss_and_metrics = model.evaluate(input_validation, output_validation, batch_size=100)

print("\n Evaluation results")

for i in range(len(loss_and_metrics)):
print(model.metrics_names[i], loss_and_metrics[i])

input_test = np.array([0, 0, 0, 1, 1, 1, 2, 2, 2]).reshape(INPUT_SHAPE)
res = model.predict(input_test)
prediction = np.argmax(res[0], axis=1)
print(res, prediction)

model.save(MODEL_FILE)
del model
  1. Run the sample as you normally would. Check the results at the end and note how accurate the model gets at predicting the sequence.

Be sure to run through this simple example a few times and understand how the LSTM layers are set up. Pay special attention to the parameters and how they are set.

That concludes our quick look at understanding how to use recurrent aka LSTM blocks for recognizing and predicting sequences of data. We will of course use this versatile layer type many more times throughout the course of this book.

In the final section of this chapter, we again showcase a number of exercises you are encouraged to undertake for your own benefit.

主站蜘蛛池模板: 永定县| 航空| 乌拉特后旗| 斗六市| 永善县| 荣昌县| 西乌珠穆沁旗| 房山区| 大英县| 成安县| 互助| 三亚市| 门头沟区| 泗洪县| 武穴市| 曲松县| 天津市| 锡林浩特市| 上饶市| 合川市| 高雄市| 长寿区| 越西县| 比如县| 威宁| 庆元县| 滨州市| 夏邑县| 六安市| 分宜县| 夏津县| 古交市| 丰县| 高平市| 甘洛县| 渑池县| 宁河县| 拉孜县| 九江县| 喀什市| 三亚市|