官术网_书友最值得收藏!

3.7 示例:將GAN應用于Atari圖像

幾乎每本有關DL的書都使用MNIST數據集來展示DL功能,多年來,該數據集都變得無聊了,就像遺傳研究人員的果蠅一樣。為了打破這一傳統,并添加更多樂趣,我嘗試避免沿用以前的方法,而使用其他方法說明PyTorch。本章前面簡要提到了GAN,它們是由伊恩·古德費洛(Ian Goodfellow)發明和推廣的。本示例中將訓練GAN生成各種Atari游戲的屏幕截圖。

最簡單的GAN架構有兩個網絡,第一個網絡充當“欺騙者”(也稱為生成器),另一個網絡充當“偵探”(另一個名稱是判別器)。兩個網絡相互競爭,生成器試圖生成偽造的數據,這些數據使判別器也難以將它與原數據集區分開,判別器試圖檢測生成的數據樣本。隨著時間的流逝,兩個網絡都提高了技能,生成器生成越來越多的真實數據樣本,而判別器發明了更復雜的方法來區分偽造的數據。

GAN的實際應用包括改善圖像質量、逼真圖像生成和特征學習。在本示例中,實用性幾乎為零,但這將是一個很好的示例,可以說明對于相當復雜的模型而言,PyTorch代碼可以很簡潔。

整個示例代碼在文件Chapter03/03_atari_gan.py中。這里將給出一些重要的代碼,不包括import部分和常量聲明:

070-02

此類是Gym游戲的包裝器,其中包括以下幾種轉換:

  • 將輸入圖像的尺寸從210×160(標準Atari分辨率)調整為正方形尺寸64×64。
  • 將圖像的顏色平面從最后一個位置移到第一個位置,以滿足PyTorch卷積層的約定,該卷積層輸入包含形狀為通道、高度和寬度的張量。
  • 將圖像從bytes轉換為float

然后,定義兩個nn.Module類:DiscriminatorGenerator。第一種將經過縮放的彩色圖像作為輸入,并通過應用五層卷積,再使用Sigmoid進行非線性變換將數據轉換為數字。Sigmoid的輸出被解釋為:判別器認為輸入圖像來自真實數據集的概率。

Generator將隨機數向量(隱向量)作為輸入,并使用“轉置卷積”操作(也稱為deconvolution)將該向量轉換為原始分辨率的彩色圖像。這里不會介紹這些類,因為它們很冗長且與示例無關,你可以在完整的示例文件中找到它們。

我們讓幾個隨機智能體同時玩Atari游戲,并將游戲截圖作為輸入。圖3.6是輸入數據的示例,它是由以下函數生成的:

071-01
072-01

圖3.6 三種Atari游戲的屏幕截圖示例

從提供的數組中對環境進行無限采樣,發出隨機動作,并在batch列表中記錄觀察結果。當批滿足所需大小時,將圖像歸一化,將其轉換為張量,然后從生成器中yield出來。由于其中一個游戲存在問題,因此需要檢查觀察值均值非零,以防止圖像閃爍。

現在,我們看一下主函數,它包括準備模型并運行訓練循環。

072-02

在此,我們處理命令行參數(只有一個可選參數--cuda,啟用GPU計算模式),創建環境池并用包裝器包裝。該環境數組將傳遞給iterate_batches函數以生成訓練數據。

072-03

上面的代碼創建了幾個類:一個Summary Writer、兩個網絡、一個損失函數和兩個優化器。為什么是兩個?因為這就是GAN訓練的方式:要訓練判別器,需要用適當的標簽(1代表真實的,0代表偽造的)來向它展示真實和偽造的數據樣本。在此過程中,僅更新判別器的參數。

此后,再次將真實和偽造樣本都通過判別器,但是這次,所有樣本的標簽均為1,并且僅更新生成器的權重。第二遍告訴生成器如何欺騙判別器,并將真實樣本與生成的樣本混淆起來。

073-01

這段代碼定義了數組(用于累積損失)、迭代器計數器以及帶有真假標簽的變量。

073-02

在訓練循環開始前,生成一個隨機向量并將其傳遞給Generator網絡。

073-03

首先,通過兩批數據來訓練判別器,即分別應用于真實數據樣本和生成的樣本。我們需要在生成器的輸出上調用detach()函數,以防止此次訓練的梯度流入生成器(detach()tensor的方法,該方法可以復制張量而不與原始張量的操作關聯)。

073-04

以上代碼用于生成器的訓練。將生成器的輸出傳遞給判別器,但是現在不停止梯度。相反,我們將目標函數與True標簽一起應用。它將使生成器向生成可欺騙判別器的樣本的方向發展。

那是與訓練相關的代碼,接下來的兩行代碼會上報損失,并將圖像樣本輸入給TensorBoard:

074-01

這個例子的訓練是一個漫長的過程。在GTX 1080 GPU上,100次迭代大約需要40秒。最初,生成的圖像完全是隨機噪聲,但是在經過1萬~2萬次迭代后,生成器變得越來越熟練,并且生成次圖像越來越類似于真實游戲的屏幕截圖。

經過4萬~5萬次訓練迭代后(在GPU上幾個小時),實驗給出了以下圖像(見圖3.7)。

074-02

圖3.7 生成器網絡產生的樣例圖片

主站蜘蛛池模板: 贵南县| 邓州市| 泉州市| 龙陵县| 静宁县| 渑池县| 贡嘎县| 天长市| 读书| 西贡区| 雷山县| 上虞市| 民权县| 宜春市| 湘西| 盱眙县| 黎城县| 黄陵县| 且末县| 五峰| 瓮安县| 仁布县| 墨玉县| 汾阳市| 高邑县| 田阳县| 六安市| 通海县| 宣恩县| 闵行区| 石首市| 和静县| 镇江市| 通许县| 赫章县| 渭源县| 贵阳市| 张北县| 新疆| 泰兴市| 伊川县|