- Hands-On Deep Learning for Games
- Micheal Lanham
- 482字
- 2021-06-24 15:48:00
Training a GAN
Training a GAN requires a fair bit more attention to detail and an understanding of more advanced optimization techniques. We will walk through each section of this function in detail in order to understand the intricacies of training. Let's open up Chapter_3_1.py and look at the train function and follow these steps:
- At the start of the train function, you will see the following code:
def train(self, epochs, batch_size=128, save_interval=50):
(X_train, _), (_, _) = mnist.load_data()
X_train = X_train / 127.5 - 1.
X_train = np.expand_dims(X_train, axis=3)
valid = np.ones((batch_size, 1))
fake = np.zeros((batch_size, 1))
- The data is first loaded from the MNIST training set and then rescaled to the range of -1 to 1. We do this in order to better center that data around 0 and to accommodate our activation function, tanh. If you go back to the generator function, you will see that the bottom activation is tanh.
- Next, we build a for loop to iterate through the epochs like so:
for epoch in range(epochs):
- Then we randomly select half of the real training images, using this code:
idx = np.random.randint(0, X_train.shape[0], batch_size)
imgs = X_train[idx]
- After that, we sample noise and generate a set of forged images with the following code:
noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
gen_imgs = self.generator.predict(noise)
- Now, half of the images are real and the other half are faked by our generator.
- Next, the discriminator is trained against the images generating a loss for incorrectly predicted fakes and correctly identified real images as shown:
d_loss_real = self.discriminator.train_on_batch(imgs, valid)
d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake)
d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
- Remember, this block of code is running across a set or batch. This is why we use the numpy np.add function to add the d_loss_real, and d_loss_fake. numpy is a library we will often use to work on sets or tensors of data.
- Finally, we train the generator using the following code:
g_loss = self.combined.train_on_batch(noise, valid)
print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))
if epoch % save_interval == 0:
self.save_imgs(epoch)
- Note how the g_loss is calculated based on training the combined model. As you may recall, the combined model takes the input from real and fake images and backpropagates the training back through the entire model. This allows us to train both the generator and discriminator together as a combined model. An example of how this looks is shown next, but just note that the image sizes are a little different than ours:

Layer architecture diagram of DCGAN
Now that we have a better understanding of the architecture, we need to go back and understand some details about the new layer types and the optimization of the combined model. We will look at how we can optimize a joined model such as our GAN in the next section.
推薦閱讀
- 數(shù)據(jù)庫基礎(chǔ)教程(SQL Server平臺)
- 企業(yè)數(shù)字化創(chuàng)新引擎:企業(yè)級PaaS平臺HZERO
- 數(shù)據(jù)庫原理及應(yīng)用教程(第4版)(微課版)
- Greenplum:從大數(shù)據(jù)戰(zhàn)略到實(shí)現(xiàn)
- Java Data Science Cookbook
- InfluxDB原理與實(shí)戰(zhàn)
- 大數(shù)據(jù)導(dǎo)論
- 數(shù)據(jù)驅(qū)動(dòng)設(shè)計(jì):A/B測試提升用戶體驗(yàn)
- Power BI商業(yè)數(shù)據(jù)分析完全自學(xué)教程
- 深入淺出 Hyperscan:高性能正則表達(dá)式算法原理與設(shè)計(jì)
- AI時(shí)代的數(shù)據(jù)價(jià)值創(chuàng)造:從數(shù)據(jù)底座到大模型應(yīng)用落地
- 計(jì)算機(jī)視覺
- 云計(jì)算
- 數(shù)據(jù)中臺實(shí)戰(zhàn):手把手教你搭建數(shù)據(jù)中臺
- 碼上行動(dòng):利用Python與ChatGPT高效搞定Excel數(shù)據(jù)分析