官术网_书友最值得收藏!

3.8 PyTorch Ignite

PyTorch是一個優雅而靈活的庫,因此它成為成千上萬的研究人員、DL愛好者、行業開發人員和其他人員的首選。但是靈活性有其自身的代價:需要寫太多的代碼來解決問題。有時,這是非常有益的,例如,實現一些尚未包含在標準庫中的新優化方法或DL技巧時。只需使用Python實現公式,PyTorch將神奇地完成所有梯度計算和反向傳播機制。另一個證明這種方法有益的場景是,當你必須關注底層原理時,比如調整梯度、了解優化器詳細信息以及NN轉換數據的方式。

但是,在完成日常任務(例如圖像分類器的簡單監督訓練)時,并不需要這種靈活性。對于此類任務,標準PyTorch可能太過底層,所以你需要一遍又一遍地處理相同的代碼。以下是DL訓練過程中主要部分的詳盡列表,但需要編寫一些代碼:

  • 數據準備和轉換以及批次的生成。
  • 計算訓練指標,例如損失值、精度和F1分數。
  • 在測試和驗證數據集中對模型進行周期性測試。
  • 經過一定數量的迭代或達到新的最佳度量標準后的模型的檢查點。
  • 將指標輸入到TensorBoard等監控工具中。
  • 超參隨著時間而變化,例如學習率的降低或增加。
  • 在控制臺上輸出有關訓練進度的消息。

當然,它們都能使用PyTorch來實現,但是可能需要編寫大量的代碼。這些任務在任何DL項目中都存在,一遍又一遍地編寫相同的代碼很快變得麻煩。解決此問題的常規方法是編寫函數,將其包裝到庫中,然后重復使用。如果該庫是開源的且質量很高(易于使用,提供了一定程度的靈活性,可以正確編寫等),那么隨著越來越多的人在其項目中使用它,該庫將變得流行。該過程不只發生在DL領域,它在軟件行業中無處不在。

PyTorch有多個庫可簡化常見任務,如ptlearnfastaiignite等。“PyTorch生態系統項目”參見https://pytorch.org/ecosystem。

開始就使用這些高級庫可能會很有吸引力,因為使用它們可以僅用幾行代碼即可解決常見問題,但是這里也存在一些危險。如果只知道如何使用高級庫而不了解底層細節,那么可能會陷入無法僅由標準方法解決問題的困境。在ML的動態領域中,這種情況經常發生。

本書的重點是確保你理解RL方法、它的實現及其適用性。因此,我們將使用遞進的方法。首先,僅使用PyTorch代碼來實現,但是隨著學習的推進,將使用高級庫來實現示例。對于RL,將使用由我編寫的小型庫:PTAN(https://github.com/Shmuma/ptan/)。PTAN將在第7章進行介紹。

為了減少DL樣板代碼的數量,我們將使用一個稱為PyTorch Ignite(https://pytorch.org/ignite/)的庫。本節將簡要介紹Ignite,然后使用Ignite重寫Atari GAN示例,并對其進行檢查。

Ignite概念

從高層次上講,Ignite簡化了PyTorch DL中訓練循環的編寫。在本章前面的“優化器”部分,可以看到最小的訓練循環包括:

  • 采樣一批訓練數據。
  • 將NN應用于這批數據,計算損失函數(要最小化的單個值)。
  • 對損失進行反向傳播,以獲取與損失有關的網絡參數梯度。
  • 使優化器將梯度應用于網絡。
  • 重復,直到滿意或不想再等待。

Ignite的核心部分是Engine類,該類遍歷數據源,并將處理函數應用于數據批。除此之外,Ignite還提供了在訓練循環的特定條件下,調用某函數的功能。這些特定條件稱為Event,可能在以下位置:

  • 整個訓練過程的開始或結束位置。
  • 訓練epoch(使用數據進行迭代)的開始或結束位置。
  • 單個批處理的開始或結束位置。

除此之外,還存在自定義事件,并且允許指定每N個事件調用一次函數,例如,每100個批次或每隔一個epoch進行一次計算。

以下代碼塊顯示了一個非常簡單的Ignite示例:

076-01

該代碼不可運行,因為它缺少很多內容,例如數據源、模型和優化器創建,但它展示了Ignite基本概念。Ignite的主要優勢在于它能夠利用現有功能擴展訓練模型。你希望平滑損失值并且每100批次將其寫入TensorBoard中嗎?沒問題!加兩行代碼即可完成。你想每10個epoch運行一次模型驗證嗎?寫一個函數來運行測試,并將其加入engine中,然后它將被如期調用。

關于Ignite功能的完整描述不在本書的討論范圍,可以閱讀官方網站(https://pytorch.org/ignite)的文檔來查看。

為了演示Ignite,我們更改一下用GAN訓練Atari圖像的例子。完整的示例代碼見Chapter03/04_atari_gan_ignite.py,以下代碼段將僅顯示有改動的部分。

076-02

首先,導入幾個Ignite類:EngineEventsignite.metrics包含與訓練過程的性能指標有關的類,例如混淆矩陣、精度和召回率。在本示例中,將使用RunningAverage類,該類提供一種平滑時間序列值的方法。在前面的示例中,我們通過對一系列損失值調用np.mean()來完成此操作,但是RunningAverage提供了一種更方便(并且在數學上更正確)的方法。此外,Ignite的contrib包中導入TensorBoard記錄器(該功能由其他人貢獻)。

077-01

下一步,我們需要定義處理函數,該函數將獲取批數據,并用該批數據對判別器和生成器模型進行更新。此函數可以返回訓練過程中要跟蹤的任何數據,在本示例中為兩個模型各自的損失值。這個函數還可以保存要在TensorBoard中顯示的圖像。

完成此操作后,我們要做的就是創建一個Engine實例,加上所需的處理程序,然后運行訓練過程。

077-02

在前面的代碼中,我們創建了engine,傳遞了處理函數,并為兩個損失值附加了RunningAverage轉換。每個RunningAverage都會產生一個所謂的“指標”,即在訓練過程中保持的派生值。平滑指標avg_loss_gen表示來自生成器的平滑損失,avg_loss_dis表示來自判別器的平滑損失。這兩個值在每次迭代后寫入TensorBoard中。

078-01

最后一段代碼附加了另一個事件處理程序,并且在每次迭代完成時由Engine調用。它會寫一行日志,其索引是迭代數,值是平滑后的度量值。最后一行啟動Engine,將已定義的函數作為數據源傳入(函數iterate_batches是一個生成器,分批返回迭代器,因此,將其輸出作為data參數傳遞是很好的)。

這就是Ignite的全部內容。如果運行示例Chapter03/04_atari_gan_ignite.py,它與前面示例的運行方式相同,這樣的小例子可能并不會令人印象深刻,但是在實際項目中,Ignite的使用通常會使代碼更簡潔、更可擴展。

主站蜘蛛池模板: 平和县| 朔州市| 石门县| 三都| 西平县| 砀山县| 赤水市| 景宁| 凉山| 绵阳市| 拜泉县| 大埔县| 崇信县| 武功县| 金堂县| 桦川县| 瓮安县| 贺兰县| 和林格尔县| 池州市| 济宁市| 雅江县| 区。| 鹤庆县| 介休市| 津南区| 禄劝| 精河县| 西吉县| 资阳市| 徐闻县| 桂阳县| 酒泉市| 马山县| 英德市| 张家港市| 宿迁市| 兴安盟| 阳原县| 莎车县| 日照市|