官术网_书友最值得收藏!

8.1 基礎(chǔ)DQN

首先,我們將實(shí)現(xiàn)與第6章中一樣的DQN方法,但要使用第7章中介紹的高級(jí)庫(kù)來實(shí)現(xiàn)。這會(huì)使代碼更加緊湊,這點(diǎn)很重要,因?yàn)楹头椒ㄟ壿嫴幌嚓P(guān)的細(xì)節(jié)不會(huì)使我們分心。

同時(shí),本書的目的不是教你如何使用現(xiàn)有的庫(kù),而是開發(fā)你對(duì)RL方法的直覺,并在必要時(shí)從頭實(shí)現(xiàn)一切。從我的角度來看,這是更有價(jià)值的技能,因?yàn)閹?kù)有興起衰落,而對(duì)于領(lǐng)域的真正理解將使你能夠快速理解別人的代碼并有意識(shí)地使用它。

在基礎(chǔ)DQN的實(shí)現(xiàn)中,有三個(gè)模塊:

  • Chapter08/lib/dqn_model.py:DQN神經(jīng)網(wǎng)絡(luò),代碼和第6章中的一樣,所以這里不再贅述。
  • Chapter08/lib/common.py:本章其他代碼會(huì)用到的通用函數(shù)和聲明。
  • Chapter08/01_dqn_basic.py:60行使用了PTAN和Ignite庫(kù)的代碼,實(shí)現(xiàn)了基礎(chǔ)DQN方法。

8.1.1 通用庫(kù)

我們從lib/common.py的內(nèi)容開始。首先,我們需要一些上一章的Pong環(huán)境的超參數(shù)。超參數(shù)保存在SimpleNamespace對(duì)象中,該對(duì)象是Python標(biāo)準(zhǔn)庫(kù)中的類,它提供了對(duì)一組鍵值對(duì)的簡(jiǎn)單訪問。這使得我們可以輕松地為更復(fù)雜的各種Atari游戲添加一份配置,并嘗試使用超參數(shù):

160-01

SimpleNamespace類的實(shí)例為值提供一個(gè)通用的容器。例如,對(duì)于前面的超參數(shù),你可以這么用:

160-02

lib/common.py中的下一個(gè)函數(shù)叫unpack_batch,它將一批狀態(tài)轉(zhuǎn)移轉(zhuǎn)換成適合訓(xùn)練的NumPy數(shù)組。每一個(gè)來自ExperienceSourceFirstLast的狀態(tài)轉(zhuǎn)移的類型都是ExperienceFirstLast,底層類型是namedtuple,包含下列字段:

  • state:來自環(huán)境的觀察。
  • action:智能體執(zhí)行的整型動(dòng)作。
  • reward:如果使用steps_count=1來創(chuàng)建ExperienceSourceFirstLast,它就是立即獎(jiǎng)勵(lì)。對(duì)于更大的步數(shù),它包含這么多步的折扣累積獎(jiǎng)勵(lì)。
  • last_state:如果狀態(tài)轉(zhuǎn)移對(duì)應(yīng)于環(huán)境的最后一步,則這個(gè)字段是None;否則,它包含經(jīng)驗(yàn)鏈的最后一個(gè)觀察。

unpack_batch的代碼如下:

161-01

注意我們是如何處理批中的最后一個(gè)狀態(tài)轉(zhuǎn)移的。為了避免進(jìn)行這種特殊處理,對(duì)于終結(jié)狀態(tài)轉(zhuǎn)移,我們?cè)?code class="kindle-cn-computer-code">last_states中存了初始狀態(tài)。為了使對(duì)Bellman更新的計(jì)算正確,我們可以在損失計(jì)算的時(shí)候用dones數(shù)組對(duì)批進(jìn)行mask操作。另一個(gè)解決方案是只對(duì)非終結(jié)狀態(tài)轉(zhuǎn)移的最后一個(gè)狀態(tài)進(jìn)行計(jì)算,但是這會(huì)使損失函數(shù)變得有點(diǎn)復(fù)雜。

DQN損失函數(shù)的計(jì)算由calc_loss_dqn函數(shù)提供,代碼和第6章中的幾乎一樣。一個(gè)小的改動(dòng)是增加了torch.no_grad(),它可以停止記錄PyTorch計(jì)算圖。

161-02

除了這些核心的DQN函數(shù)之外,common.py還提供了幾個(gè)和訓(xùn)練循環(huán)、數(shù)據(jù)生成以及TensorBoard相關(guān)的工具。第一個(gè)工具是一個(gè)實(shí)現(xiàn)了在訓(xùn)練時(shí)衰減epsilon的小類。epsilon定義了智能體采取隨機(jī)動(dòng)作的概率。它應(yīng)該從一開始的1.0(完全隨機(jī)的智能體)衰減到一個(gè)比較小的值,例如0.02或0.01。代碼很簡(jiǎn)單,但幾乎在所有DQN中都需要,所以用下面這個(gè)小類實(shí)現(xiàn):

162-01

另外一個(gè)小函數(shù)是batch_generator,它使用ExperienceReplayBuffer(第7章中描述的PTAN類)作為參數(shù),并從緩沖區(qū)中無(wú)限地生成采樣得到的訓(xùn)練批。一開始函數(shù)會(huì)確保緩沖區(qū)中已經(jīng)包含了所需數(shù)量的樣本。

162-02

最后,一個(gè)名為setup_ignite的冗長(zhǎng)卻非常有用的函數(shù)會(huì)掛載所需的Ignite處理器,以顯示訓(xùn)練進(jìn)度并將評(píng)估指標(biāo)寫入TensorBoard。我們來逐一查看此函數(shù)。

162-03

首先,setup_ignite掛載了兩個(gè)由PTAN提供的處理器:

  • EndOfEpisodeHandler:每當(dāng)游戲片段結(jié)束的時(shí)候,它會(huì)發(fā)布一個(gè)Ignite事件。當(dāng)片段的平均獎(jiǎng)勵(lì)超過界限的時(shí)候,它也會(huì)發(fā)布一個(gè)事件。用它可以檢測(cè)游戲是否被解決。
  • EpisodeFPSHandler:記錄片段花費(fèi)的時(shí)間以及已經(jīng)和環(huán)境產(chǎn)生的交互數(shù)量的小類。用它可以計(jì)算每秒處理的幀數(shù),這是一個(gè)非常重要的性能評(píng)估指標(biāo)。
163-01

然后我們創(chuàng)建兩個(gè)事件處理器,一個(gè)會(huì)在片段結(jié)束時(shí)被調(diào)用,它會(huì)在控制臺(tái)顯示已完成片段的相關(guān)信息。另一個(gè)在平均獎(jiǎng)勵(lì)超過超參數(shù)中定義的界限(Pong示例中是18.0)時(shí)被調(diào)用,展示游戲被解決并停止訓(xùn)練的消息。

函數(shù)的剩下部分和我們想記錄的TensorBoard數(shù)據(jù)相關(guān):

163-02

首先,創(chuàng)建一個(gè)TensorboardLogger,它是Ignite提供的向TensorBoard寫數(shù)據(jù)的一個(gè)特殊類。處理函數(shù)會(huì)返回?fù)p失值,所以我們掛載一個(gè)RunningAverage轉(zhuǎn)換(也是由Ignite提供的)來獲得比較平滑的隨時(shí)間推移計(jì)算的損失。

163-03

TensorboardLogger可以記錄來自Ignite的兩組數(shù)據(jù):輸出(由轉(zhuǎn)換函數(shù)返回的值)和評(píng)估指標(biāo)(在訓(xùn)練過程中被計(jì)算出來并保存在engine的狀態(tài)中)。EndOfEpisodeHandlerEpisodeFPSHandler提供了評(píng)估指標(biāo),會(huì)在每個(gè)游戲片段結(jié)束時(shí)更新。所以,我們掛載OutputHandler來將每次片段結(jié)束時(shí)的相關(guān)信息寫入TensorBoard。

163-04

另外一組我們想記錄的值是訓(xùn)練過程中的評(píng)估指標(biāo):損失、FPS、以及可能的用戶自定義評(píng)估指標(biāo)。這些值在每次訓(xùn)練迭代都會(huì)更新,但是我們會(huì)執(zhí)行成千上萬(wàn)次迭代,所以每100次訓(xùn)練迭代才向TensorBoard保存一次數(shù)據(jù);否則的話,數(shù)據(jù)文件會(huì)變得特別大。所有這類功能可能看起來都太復(fù)雜了,但是它提供了在訓(xùn)練過程中獲取的統(tǒng)一評(píng)估指標(biāo)集。實(shí)際上,Ignite不是很復(fù)雜,它提供了一個(gè)非常靈活的框架。以上就是common.py的內(nèi)容。

8.1.2 實(shí)現(xiàn)

現(xiàn)在,我們來看一下01_dqn_basic.py,它創(chuàng)建所需類并開始訓(xùn)練。這里將省略無(wú)關(guān)代碼,只關(guān)注重要的部分。完整的版本可以在GitHub倉(cāng)庫(kù)找到。

164-01

首先,創(chuàng)建環(huán)境并應(yīng)用一組標(biāo)準(zhǔn)的包裝器。第6章已經(jīng)討論過它們了,并且在下一章優(yōu)化Pong解決方案的性能時(shí)還會(huì)討論它們。然后,創(chuàng)建DQN模型和目標(biāo)神經(jīng)網(wǎng)絡(luò)。

164-02

接著,創(chuàng)建智能體,并傳入一個(gè)ε-greedy動(dòng)作選擇器。在訓(xùn)練過程中,由已經(jīng)討論過的EpsilonTracker類降低ε值。它會(huì)降低隨機(jī)選擇的動(dòng)作數(shù)量并將更多控制權(quán)交給NN。

164-03

接下來的兩個(gè)非常重要的對(duì)象是ExperienceSourceExperienceReplayBuffer。第一個(gè)對(duì)象接受智能體和環(huán)境作為參數(shù)并提供游戲片段的狀態(tài)轉(zhuǎn)移。這些狀態(tài)轉(zhuǎn)移會(huì)被保存在經(jīng)驗(yàn)回放緩沖區(qū)。

165-01

然后,創(chuàng)建一個(gè)優(yōu)化器并定義處理函數(shù),每批狀態(tài)轉(zhuǎn)移都會(huì)調(diào)用該函數(shù)來訓(xùn)練模型。為了訓(xùn)練,我們調(diào)用common.calc_loss_dqn函數(shù)并反向傳播它的結(jié)果。

這個(gè)函數(shù)也要求EpsilonTracker降低epsilon,并周期性地同步目標(biāo)神經(jīng)網(wǎng)絡(luò)。

165-02

最后,創(chuàng)建Ignite的Engine對(duì)象,用common.py中的一個(gè)函數(shù)來配置它,并啟動(dòng)訓(xùn)練進(jìn)程。

8.1.3 結(jié)果

好了,我們開始訓(xùn)練吧!

165-03

控制臺(tái)中的每一行都是片段結(jié)束時(shí)輸出的,展示了片段的獎(jiǎng)勵(lì)、步數(shù)、速度以及總訓(xùn)練時(shí)長(zhǎng)。基礎(chǔ)DQN版本通常需要100萬(wàn)幀才能達(dá)到18的平均獎(jiǎng)勵(lì),所以耐心一點(diǎn)。訓(xùn)練過程中,我們可以在TensorBoard檢查訓(xùn)練過程的動(dòng)態(tài)情況,它會(huì)展示epsilon變化圖、原始獎(jiǎng)勵(lì)值、平均獎(jiǎng)勵(lì)以及速度。圖8.1和圖8.2展示了獎(jiǎng)勵(lì)和片段步數(shù)(底部的x軸表示經(jīng)過的時(shí)間,頂部的則是片段數(shù))。

166-01

圖8.1 訓(xùn)練過程中片段的相關(guān)信息

166-02

圖8.2 訓(xùn)練的評(píng)估指標(biāo):速度和損失

主站蜘蛛池模板: 崇左市| 潮安县| 永宁县| 和龙市| 九江市| 杨浦区| 阿拉善盟| 喜德县| 永川市| 蒙阴县| 东光县| 会同县| 珲春市| 麻栗坡县| 北碚区| 赤城县| 股票| 浮山县| 荔波县| 清流县| 新田县| 岚皋县| 巴彦淖尔市| 无极县| 济源市| 资兴市| 海门市| 万盛区| 石嘴山市| 库伦旗| 建始县| 谷城县| 清镇市| 禄丰县| 鸡泽县| 伊通| 东至县| 阳东县| 阿克陶县| 隆化县| 祥云县|