A Coding Tutorial on OpenMythos on Recurrent-Depth Transformers with Depth Extrapolation, Adaptive Computation, and Mixture-of-Experts Routing

Asif Razzaq

View Original ↗
AI 導讀 technology AI 重要性 3/5

訓練 3 個迴圈、推理跑 16 次準確率仍持續提升——OpenMythos 以可執行程式碼示範深度外推 Transformer 的四大核心機制

  • MLA 注意力透過低秩壓縮讓 KV-cache 大幅小於 GQA,長上下文推理尤其省記憶體
  • 遞迴矩陣 A 的特徵值被約束在 (0,1),即使 lr=1.0 極端訓練仍能保持跨深度穩定
  • ACT 機制讓每個位置自主決定迭代輪數,避免為簡單 token 浪費多餘的遞迴算力

把模型只訓練在 3 個迴圈,推理時跑到 16 次準確率仍持續爬升——OpenMythos 的「深度外推」特性顛覆了「要更準就得加更多參數」的直覺。這份教學以可執行程式碼實作 Claude Mythos 架構的開源版本,驗證其 GQA/MLA 記憶體效率、遞迴穩定性、ACT 停機與 MoE 路由四大機制,是目前少數能把「以推理算力換準確率」這個想法做成可量測基準的公開實作。

GQA 對比 MLA:同等推理效果、KV-cache 佔用差距懸殊

OpenMythos 支援兩種注意力機制,由 attn_type 參數切換。GQA(分組查詢注意力,Grouped Query Attention) 讓多個查詢頭共用少數幾個 Key-Value 頭,本例設定查詢頭 4 個、KV 頭僅 2 個,以此壓縮快取大小。MLA(多頭潛在注意力,Multi-Head Latent Attention) 則是 DeepSeek 提出的更激進方案:把 Key 和 Value 壓縮為低秩潛在向量(kv_lora_rank=32, q_lora_rank=64),計算注意力時再動態還原,從架構上消除了 KV-cache 的主要開銷。

教學對兩個模型各執行一次序列長度 64、迴圈深度 4 的前向傳遞,逐一加總 kv_cache 字典內所有張量的位元組數。輸出直接列出兩者的快取大小(KB)與倍數比值,MLA 明顯佔優。對長上下文或大批次推理場景,這個差距直接轉換為可容納的序列長度或並發量——同樣顯存下,MLA 能服務更多請求。

MLA 雖多了幾個低秩投影矩陣,但因矩陣秩遠低於原始維度,整體參數量增幅有限。兩種機制的取捨清晰:GQA 架構精簡、易理解;MLA 推理時記憶體效率更優,長序列場景競爭力更強。

遞迴更新矩陣 A 的特徵值必須落在 (0, 1) 才能跨深度穩定

OpenMythos 的核心是一個迴圈遞迴模組,每次迴圈用同一組共享參數更新隱藏狀態,類似 RNN 的時間展開。穩定性的關鍵在更新矩陣 A 的特徵值(此處為對角元素):若有元素 ≥ 1,遞迴發散;若有元素 ≤ 0,則出現震盪;唯有全數落在開區間 (0, 1) 才能讓跨深度推理保持穩定。

教學用 show_stability() 函式讀取 model.recurrent.injection.get_A(),輸出 A 的最小值、最大值、均值與穩定性布林判斷。初始化後兩個模型均通過。為壓力測試,教學刻意用 lr=1.0(正常訓練的 3000 倍以上)對 MLA 做 30 步極端訓練,隨後再度查驗——特徵值依然全數合規。

這說明 OpenMythos 對 A 實施了硬約束,推測是 sigmoid 參數化讓輸出天然落在 (0, 1),而非讓 A 自由更新。對工程師而言意義重大:訓練迴圈深度和推理迴圈深度不同,若缺少這層保障,訓練 T=3 後在 T=16 推理很可能直接發散崩潰。

用奇偶校驗任務實測:訓練 T=3,推理 T=16 準確率仍持續攀升

教學選用「累積奇偶校驗(cumulative parity)」作示範:輸入序列中每個元素是 1 或 2(對應 bit 0/1),模型需在每個位置預測到目前為止 bit 的累積奇偶性。這個任務具有天然遞推結構,必須記住前面每一步,迴圈越多理論上狀態追蹤越準確。

訓練配置固定 n_loops=3,跑 600 步,batch size 64,序列長 24,AdamW 優化器,學習率 3e-4。訓練完成後,在 512 個測試樣本上掃描 T=[1, 2, 3, 4, 6, 8, 10, 12, 14, 16] 各自的準確率,輸出格式附帶條形圖與「← trained here」標記,一眼可見訓練基準點在哪裡。

結果呈現教科書式的外推曲線:T=1、T=2 迴圈不足,準確率明顯低落;T=3 達到訓練時的水準;T 繼續加大,準確率持續爬升到 T=16 的峰值,整個過程參數一個也沒動。這個現象在奇偶校驗任務上尤其顯著,因結構天然適合遞推;能否推廣到語言建模等更複雜任務,仍是開放問題。

ACT 自適應停機:讓每個位置按難度決定要算幾輪

傳統 Transformer 對所有 token 施以相同計算深度,簡單的功能詞和複雜的多義詞花費一樣算力。ACT(自適應計算時間,Adaptive Computation Time) 為每個位置在每個迴圈輸出一個停機概率,當累積值超過閾值(本例 act_threshold=0.99)時,該位置提前停止後續迭代,理論上讓簡單 token 早早退出。

教學以 monkey-patch 方式掛 hook 在 model.recurrent.act.forward,收集 16 次迴圈的停機概率張量後,輸出形狀 (loops, positions) 的矩陣,並以 viridis 熱力圖視覺化(x 軸序列位置、y 軸迴圈次數、顏色代表停機概率)。程式同時印出每個迴圈的平均停機概率,方便快速判斷是否出現「全部 token 都跑滿迴圈」的退化情況。

ACT 的實際效益取決於任務難度分布:若大多數 token 在前 3-4 輪就達到停機閾值,跑 16 輪的真實計算量遠小於表面的 16 倍,效率大幅提升。ACT 也為深度外推提供了額外保障——即便推理時給的迴圈數比訓練時多,簡單位置不會白跑多餘的輪次。

MoE 路由均衡性驗證與三種迴圈深度的生成對比

MoE(混合專家,Mixture of Experts) FFN 層讓每個 token 只激活少數幾個子網路(本例 n_experts=4, n_experts_per_tok=2),其餘專家不參與計算,在不增加推理成本的前提下擴大模型有效容量。然而 MoE 有個常見陷阱:router 可能發生「贏家通吃」,少數專家承攬絕大多數 token,其餘退化為無效死神經元。

教學 hook 追蹤每次 MoE 前向傳遞的 topk 索引,統計在 32 個 batch、T=3 推理中各 expert 被選中的比例。均衡的路由應讓四個 expert 各佔約 25%,偏差過大是訓練不穩定的信號。輸出格式「expert X: Y% of topk slots」讓工程師一眼判斷路由健康度。

生成實驗以長度 8 的奇偶序列作 prompt,用 T_gen=[1, 4, 12] 各自生成後續 8 個 token,溫度 0.1、top-k=2,直接比較三種深度的輸出序列是否更貼合累積奇偶規律。教學最後把訓練 loss 曲線、深度外推準確率(紅色虛線標記 T=3)與 ACT 停機熱力圖合併為一張三面板圖輸出,提供整合式的視覺化摘要。

訓練深度固定 3 個迴圈,推理增到 16 次準確率仍持續爬升——OpenMythos 把「以推理算力換準確率」做成可執行、可量測的程式碼基準,是值得追蹤的計算自適應推理新方向。

Abstract

In this tutorial, we explore the implementation of OpenMythos, a theoretical reconstruction of the Claude Mythos architecture that enables deeper reasoning through iterative computation rather than increased parameter size. We build and analyze models using both GQA and MLA attention mechanisms, examine memory efficiency through KV-cache comparisons, and validate stability via the spectral properties of […] The post A Coding Tutorial on OpenMythos on Recurrent-Depth Transformers with Depth Extrapolation, Adaptive Computation, and Mixture-of-Experts Routing appeared first on MarkTechPost.