- 深度強化學習實踐(原書第2版)
- (俄)馬克西姆·拉潘
- 1309字
- 2021-08-18 17:39:21
3.4 自定義層
上一節簡要地提到了nn.Module
在PyTorch中是所有NN構建塊的基礎父類。它不僅僅是現存層的統一父類,它遠不止于此。通過將nn.Module
子類化,可以創建自己的構建塊,它們可以組合在一起,后續可以復用,并且可以完美地集成到PyTorch框架中。
作為核心,nn.Module
為其子類提供了相當豐富的功能:
- 它記錄當前模塊的所有子模塊。例如,構建塊可以具有兩個前饋層,可以以某種方式使用它們來執行代碼塊的轉換。
- 提供處理已注冊子模塊的所有參數的函數。可以獲取模塊參數的完整列表(
parameters()
方法)將其梯度置零(zero_grads()
方法),將其移至CPU或GPU(to(device)
方法),序列化和反序列化模塊(state_dict()
和load_state_dict()
),甚至可以用自己的callable執行通用的轉換邏輯(apply()
方法)。 - 建立了
Module
針對數據的約定。每個模塊都需要覆蓋forward()
方法來執行數據的轉換。 - 還有更多的函數,例如注冊鉤子函數以調整模塊轉換邏輯或梯度流,它們更加適合高級的使用場景。
這些功能允許我們通過統一的方式將子模型嵌套到更高層次的模型中,在處理復雜的情況時非常有用。它可以是簡單的單層線性變換,也可以是1001層的residual NN(ResNet)
,但是如果它們遵循nn.Module
的約定,則可以用相同的方式處理它們。這對于代碼的簡潔性和可重用性非常有幫助。
為了簡化工作,PyTorch的作者遵循上述約定,通過精心設計和大量Python魔術方法簡化了模塊的創建。因此,要創建自定義模塊,通常只需要做兩件事——注冊子模塊并實現forward()
方法。
我們來看上一節中Sequential
的例子是如何使用更加通用和可復用的方式做到這一點的(完整的示例見Chapter03/01_modules.py
):

這是繼承了nn.Module
的模塊。在構造函數中,我們傳遞了三個參數:輸入大小、輸出大小和可選的dropout概率。我們要做的第一件事就是調用父類的構造函數來初始化。
第二步,我們需要創建一個已經熟悉的nn.Sequential
,包含一些不同的層,并將其賦給類中名為pipe
的字段。通過為字段分配一個Sequential
實例,自動注冊該模塊(nn.Sequential
繼承自nn.Module
,與nn
包中的其他類一樣)。注冊它不需要任何調用,只需將子模塊分配給字段即可。構造函數完成后,所有字段會被自動注冊(如果確實想要手動注冊,nn.Module
中也有函數可用)。

在這里,我們必須覆寫forward
函數并實現自己的數據轉換邏輯。由于模塊是對其他層的非常簡單的包裝,因此只需讓它們轉換數據即可。請注意,要將模塊應用于數據,我們需要調用該模塊(即假設模塊實例為一個函數并使用參數調用它)而不使用nn.Module
類的forward()
方法。這是因為nn.Module
會覆蓋__call__()
方法(將實例視為可調用實例時,會使用該方法)。該方法執行了nn.Module
中的一些神奇的操作,并調用forward()
方法。如果直接調用forward()
,則將干預nn.Module
的職責,這可能會導致錯誤的結果。
因此,這就是定義自己的模塊所需要做的。現在,我們來使用它:

我們創建模塊,為輸入和輸出賦值,然后創建張量,讓模塊對其進行轉換(遵守約定,將其視為callable)。之后,打印網絡結構(nn.Module
覆寫了__str__()
和__repr__()
方法),以更好的方式來展示內部結構。最后,展示運行的結果。
代碼輸出應如下所示:

當然,之前說了PyTorch支持動態特性。每一批數據都會調用forward()
方法,因此如果要根據所需處理的數據進行一些復雜的轉換,例如分層softmax或要應用網絡隨機選擇,那么你也可以這樣做。模塊參數的數量也不只限于一個。因此,如果需要,可以編寫一個帶有多個必需參數和幾十個可選參數的模塊,這都是可以的。
接下來,我們需要熟悉PyTorch庫的兩個重要部分(損失函數和優化器),它們將簡化我們的生活。