- 深度強化學習實踐(原書第2版)
- (俄)馬克西姆·拉潘
- 1714字
- 2021-08-18 17:39:21
3.5 最終黏合劑:損失函數和優化器
將輸入數據轉換為輸出的網絡并不是訓練唯一需要的東西。我們還需要定義學習目標,即要有一個接受兩個參數(網絡輸出和預期輸出)的函數。它的責任是返回一個表示網絡預測結果與預期結果之間的差距的數字。此函數稱為損失函數,其輸出為損失值。使用損失值,可以計算網絡參數的梯度,并對其進行調整以減小損失值,以便優化模型的結果。損失函數和通過梯度調整網絡參數的方法非常普遍,并且以多種形式存在,以至于它們構成了PyTorch庫的重要組成部分。我們從損失函數開始介紹。
3.5.1 損失函數
損失函數在nn
包中,并實現為nn.Module
的子類。通常,它們接受兩個參數:網絡輸出(預測)和預期輸出(真實數據,也稱為數據樣本的標簽)。在撰寫本書時,PyTorch 1.3.0包含20個不同的損失函數,當然,你也可以顯式地自定義要優化的函數。
最常用的標準損失函數是:
nn.MSELoss
:參數之間的均方誤差,是回歸問題的標準損失。nn.BCELoss
和nn.BCEWithLogits
:二分類交叉熵損失。前者期望輸入是一個概率值(通常是Sigmoid
層的輸出),而后者則假定原始分數為輸入并應用Sigmoid
本身。第二種方法通常在數值上更穩定、更有效。這些損失(顧名思義)經常用于分類問題。nn.CrossEntropyLoss
和nn.NLLLoss
:著名的“最大似然”標準,用于多類分類問題。前者期望的輸入是每個類的原始分數,并在內部應用LogSoftmax
,而后者期望將對數概率作為輸入。
還有一些其他的損失函數可供使用,當然你也可以自己寫Module
子類來比較輸出值和目標值。現在,來看下關于優化過程的部分。
3.5.2 優化器
基本優化器的職責是獲取模型參數的梯度,并更改這些參數來降低損失值。通過降低損失值,使模型向期望的輸出靠攏,使得模型性能越來越好。更改參數聽起來很簡單,但是有很多細節要處理,優化器仍是一個熱門的研究主題。在torch.optim
包中,PyTorch提供了許多流行的優化器實現,其中最廣為人知的是:
SGD
:具有可選動量的普通隨機梯度下降算法。RMSprop
:Geoffrey Hinton提出的優化器。Adagrad
:自適應梯度優化器。Adam
:一種非常成功且流行的優化器,是RMSprop
和Adagrad
的組合。
所有優化器都公開了統一的接口,因而可以輕松地嘗試使用不同的優化方法(有時,優化方法可以在動態收斂和最終結果上表現優秀)。在構造時,需要傳遞可迭代的張量,該張量在優化過程中會被修改。通常的做法是傳遞上層nn.Module
實例的params()
調用的結果,結果將返回所有具有梯度的可迭代葉張量。
現在,我們來討論訓練循環的常見藍圖。

通常,需要一遍又一遍地遍歷數據(所有數據運行一個迭代稱為一個epoch)。數據通常太大而無法立即放入CPU或GPU內存中,因此將其分成大小相同的批次進行處理。每一批數據都包含數據樣本和目標標簽,并且它們都必須是張量(第2行和第3行代碼)。
將數據樣本傳遞給網絡(第4行),并將其輸出值和目標標簽提供給損失函數(第5行),損失函數的結果顯示了網絡結果和目標標簽的差距。網絡的輸入和網絡的權重都是張量,所以網絡的所有轉換只不過是中間張量實例的操作圖。損失函數也是如此——它的結果也是一個只有一個損失值的張量。
計算圖中的每一個張量都記得其來源,因此要對整個網絡計算梯度,只需要在損失函數的返回結果上調用backward()
函數(第6行)即可。調用結果是展開已執行計算的圖和計算requires_grad = True
的葉張量的梯度。通常,這些張量是模型的參數,比如前饋網絡的權重和偏差,以及卷積濾波器。每次計算梯度時,都會在tensor.grad
字段中累加梯度,所以一個張量可以參與多次轉換,梯度會相加。例如,循環神經網絡(Recurrent Neural Network,RNN)的一個單元可以應用于多個輸入項。
在調用loss.backwards()
后,我們已經累加了梯度,現在是優化器執行其任務的時候了——它獲取傳遞給它的參數的所有梯度并應用它們。所有這些都是使用step()
完成的(第7行)。
訓練循環最后且重要的部分是對參數梯度置零的處理。可以在網絡上調用zero_grad()
來實現,但是為了方便,優化器還公開了這樣一個調用(第8行)。有時候zero_grad()
被放在訓練循環的開頭,但這并沒有什么影響。
上述方案是一種非常靈活的優化方法,即使在復雜的研究中也可以滿足要求。例如,可以用兩個優化器在同一份數據上調整不同模型的選項(這是一個來自生成對抗網絡(Generative Adversarial Network,GAN)訓練的真實場景)。
我們已經介紹完了訓練NN所需的PyTorch的基本功能。本章以一個實際的場景結束,演示涵蓋的所有概念,但在開始之前,我們需要討論一個重要的主題——監控學習過程——這對NN從業人員來說是必不可少的。