尋夢新聞LINE@每日推播熱門推薦文章,趣聞不漏接❤️
機器之心報導
項目作者:thu-ml
參與:思、肖清
訓練模型的極速,與 1500 行程式的精簡,清華大學新開源強化學習平臺「天授」。值得注意的是,該項目的兩位主要作者目前都是清華大學的本科生。
是否你也有這樣的感覺,成熟 ML 工具的源碼很難懂,各種繼承與處理關係需要花很多時間一點點理清。在清華大學開源的「天授」項目中,它以極簡的代碼實現了很多極速的強化學習算法。重點是,天授框架的源碼很容易懂,不會有太龐雜的邏輯關係。
項目地址:https://github.com/thu-ml/tianshou
天授(Tianshou)是純 基於 PyTorch 代碼的強化學習框架,與目前現有基於 TensorFlow 的強化學習庫不同,天授的類繼承並不龐雜,API 也不是很繁瑣。最重要的是,天授的訓練速度非常快,我們試用 Pythonic 的 API 就能快速構建與訓練 RL 智能體。
目前天授支持的 RL 算法有如下幾種:
-
Policy Gradient (PG)
-
Deep Q-Network (DQN)
-
Double DQN (DDQN) with n-step returns
-
Advantage Actor-Critic (A2C)
-
Deep Deterministic Policy Gradient (DDPG)
-
Proximal Policy Optimization (PPO)
-
Twin Delayed DDPG (TD3)
-
Soft Actor-Critic (SAC)
另外,對於以上代碼天授還支持並行收集樣本,並且所有算法均統一改寫為基於 replay-buffer 的形式。
速度與輕量:「天授」的靈魂
天授旨在提供一個高速、輕量化的 RL 開源平臺。下圖為天授與各大知名 RL 開源平臺在 CartPole 與 Pendulum 環境下的速度對比。所有代碼均在配置為 i7-8750H + GTX1060 的同一臺筆記本電腦長進行測試。值得注意的是,天授實現的 VPG(vanilla policy gradient)算法在 CartPole-v0 任務中,訓練用時僅為 3 秒。
以上測試使用了 10 個不同的 seed。CartPole 和 Pendulum 任務中的累積獎賞閾值分別設置為 195.0 與-250.0。可能會有讀者感覺這兩個任務比較簡單,不太能突出框架的優勢。該項目也表示,在這幾天內,他們會更新天授在 Atari Pong / Mujoco 任務上的性能。
天授,只需 1500 行代碼
非常令人驚訝的是,天授平臺整體代碼量不到 1500 行,其實現的 RL 算法大多數都少於百行代碼。單從數量上來說,這樣的代碼量已經非常精簡了,各種類與函數之間的關係應該也容易把握住。
項目表示,天授雖然代碼量少,但可讀性並不會有損失。我們可以快速瀏覽整個框架,並理解運行的流程與策略到底是什麼樣的。該項目提供了很多靈活的 API,例如可以便捷地使用如下代碼令策略與環境交互 n 步:
result=collector.collect(n_step=n)
或者,如果你想通過采樣的批量數據訓練給定的策略,可以這樣寫:
result=policy.learn(collector.sample(batch_size))
正是通過大量精簡的 API 構造 RL 模型,天授才能保持在 1500 行代碼內。例如我們可以看看 DQN 的模型代碼,它是非常流行的一種強化學習模型,在天授內部,DQN 模型真的只用了 99 行代碼就完成了。當然,這 99 行代碼是不包含其它公用代碼塊的。
如下為 DQN 的主要代碼結構,我們省略了部分具體代碼,各個 RL 策略都會繼承基本類的結構,然後重寫就夠了。可以發現,在常規地定義好模型後,傳入這個類就能創建策略。DQN 策略的各種操作都會寫在一起,後續配置 Collector 後就能直接訓練。
項目作者把所有策略算法都模塊化為 4 部分:
-
__init__:初始化策略
-
process_fn:從 replay buffer 中處理數據
-
__call__:給定環境觀察結果計算對應行動
-
learn:給定批量數據學習策略
實際體驗
天授很容易安裝,直接運行「pip install tianshou」就可以。下面我們將該項目克隆到本地,實際測試一下。
!gitclonehttps://github.com/thu-ml/tianshou!pip3installtianshouimportosos.chdir('tianshou')
該項目在 test 文件夾下提供了諸多算法的測試範例,下面我們在 CartPole 任務下逐個測試一番。
!pythontest/discrete/test_pg.py
!pythontest/discrete/test_ppo.py
!pythontest/discrete/test_a2c.py
!pythontest/discrete/test_dqn.py
以上分別為 VPG、PPO、A2C 與 DQN 在 P100 GPU 上的訓練結果。可以看到,我們的測試結果與項目提供的結果出入不大。
由於 CartPole 任務在強化學習中相對簡單,相當於圖像識別中的 MNIST。為更進一步測試該 RL 框架的性能,我們也在 MinitaurBulletEnv-v0 任務中對其進行了測試。
Minitaur 是 PyBullet 環境中一個四足機器人運動控制任務,其觀測值為該機器人的位置、姿態等 28 個狀態資訊,控制輸入為電機的轉矩(每條腿 2 個電機,總共 8 個電機),策略優化的目標為最大化機器人移動速度的同時最小化能量消耗。也就是說,agent 需要根據獎賞值自主地學習到由 28 個狀態資訊到 8 個控制輸入的映射關係。
使用 SAC 算法在 Minitaur 任務中的訓練結果如下圖所示:
需要注意的是,天授的 SAC 實現在 Minitaur 任務中僅訓練了不到 200k 步即能獲得以上控制策略,效果可以說是很不錯的。
項目作者,清華本科生
在 GitHub 中,其展示了該項目的主要作者是 Jiayi Weng 與 Minghao Zhang,他們都是清華的本科生。其中 Jiayi Weng 今年 6 月份本科畢業,在此之前作為本科研究者與清華大學蘇航、朱軍等老師開展強化學習領域的相幹研究。Minghao Zhang 目前是清華大學軟體學院的本科二年級學生,同時還修了數學專業。
作為本科生,該項目的兩位作者已經有了非常豐富的研究經驗,Jiayi Weng 去年夏季就作為拜訪學生到訪 MILA 實驗室,並與 Yoshua Bengio 開展了關於意識先驗相幹的研究。在 Jiayi Weng 的主頁中,我們可以看到在本科期間已經發了 IJCAI 的 Oral 論文。
Minghao Zhang 也有豐富的研究經驗,之前他在軟體學院 iMoon Lab 做關於 3D 視覺相幹的研究,而後目前在清華交叉資訊學院做研究助理,從事強化學習方面的研究。盡管離畢業還有不短的時間,Minghao Zhang 已經做出了自己的研究成果。
所以綜合來看,因為在本科已經有了豐富的科研經驗,並且做過多個項目,那麼在這個階段能做一個非常不錯的強化學習開源項目也就理所當然了。
接下來的工作
天授目前還處於初期開發階段,尚有一些未實現的功能或有待完善的地方。項目作者表示今後主要在以下幾個方面來完善該 RL 框架:
-
Prioritized replay buffer
-
RNN support
-
Imitation Learning
-
Multi-agent
-
Distributed training
它們分別是提供更多 RL 環境的 benchmark、優先經驗回放、循環神經網路支持、模仿學習、多智能體學習以及分布式訓練。
本文為機器之心報導,轉載請聯繫本公眾號獲得授權。
✄————————————————
加入機器之心(全職記者 / 實習生):[email protected]
投稿或尋求報導:content@jiqizhixin.com
廣告 & 商務合作:[email protected]