- 深度強化學習實踐(原書第2版)
- (俄)馬克西姆·拉潘
- 533字
- 2021-08-18 17:39:20
3.3 NN構建塊
torch.nn
包中有大量預定義的類,可以提供基本的功能。這些類在設計時就考慮了實用性(例如,它們支持mini-batch處理,設置了合理的默認值,并且權重也經過了合理的初始化)。所有模塊都遵循callable的約定,這意味著任何類的實例在應用于其參數時都可以充當函數。例如,Linear
類實現了帶有可選偏差的前饋層:

上述代碼創建了一個隨機初始化的前饋層,包含兩個輸入和五個輸出,并將其應用于浮點張量。torch.nn
包中的所有類均繼承自nn.Module
基類,可以通過該基類構建更高級別的NN模塊。下一節將介紹如何自己構建,但是現在,我們先看一下所有nn.Module
子類提供的方法。如下:
parameters()
:此函數返回所有需要進行梯度計算的變量的迭代器(即模塊權重)。zero_grad()
:此函數將所有參數的梯度初始化為零。to(device)
:此函數將所有模塊參數移至給定的設備(CPU或GPU)。state_dict()
:此函數返回一個包含所有模塊參數的字典,對于模型序列化很有用。load_state_dict()
:此函數使用狀態字典來初始化模塊。
所有的類都可在文檔(http://pytorch.org/docs)中找到。
現在,我將要提到一個非常方便的類,即Sequential
,它可以將不同的層串起來。演示Sequential
的最佳方法是通過一個示例:

上面的代碼定義了一個三層的NN,輸出層是softmax,softmax應用于第一維度(第零維度是批樣本),還包括整流線性函數(Rectified Linear Unit,ReLU)非線性層和dropout。我們給這個模型輸入一些數據:

mini-batch就是一個成功地遍歷了網絡的例子。