A Coding Tutorial on OpenMythos on Recurrent-Depth Transformers with Depth Extrapolation, Adaptive Computation, and Mixture-of-Experts Routing
訓練 3 個迴圈、推理跑 16 次準確率仍持續提升——OpenMythos 以可執行程式碼示範深度外推 Transformer 的四大核心機制
- MLA 注意力透過低秩壓縮讓 KV-cache 大幅小於 GQA,長上下文推理尤其省記憶體
- 遞迴矩陣 A 的特徵值被約束在 (0,1),即使 lr=1.0 極端訓練仍能保持跨深度穩定
- ACT 機制讓每個位置自主決定迭代輪數,避免為簡單 token 浪費多餘的遞迴算力
把模型只訓練在 3 個迴圈,推理時跑到 16 次準確率仍持續爬升——OpenMythos 的「深度外推」特性顛覆了「要更準就得加更多參數」的直覺。這份教學以可執行程式碼實作 Claude Mythos 架構的開源版本,驗證其 GQA/MLA 記憶體效率、遞迴穩定性、ACT 停機與 MoE 路由四大機制,是目前少數能把「以推理算力換準確率」這個想法做成可量測基準的公開實作。
GQA 對比 MLA:同等推理效果、KV-cache 佔用差距懸殊
OpenMythos 支援兩種注意力機制,由 attn_type 參數切換。GQA(分組查詢注意力,Grouped Query Attention) 讓多個查詢頭共用少數幾個 Key-Value 頭,本例設定查詢頭 4 個、KV 頭僅 2 個,以此壓縮快取大小。MLA(多頭潛在注意力,Multi-Head Latent Attention) 則是 DeepSeek 提出的更激進方案:把 Key 和 Value 壓縮為低秩潛在向量(kv_lora_rank=32, q_lora_rank=64),計算注意力時再動態還原,從架構上消除了 KV-cache 的主要開銷。
教學對兩個模型各執行一次序列長度 64、迴圈深度 4 的前向傳遞,逐一加總 kv_cache 字典內所有張量的位元組數。輸出直接列出兩者的快取大小(KB)與倍數比值,MLA 明顯佔優。對長上下文或大批次推理場景,這個差距直接轉換為可容納的序列長度或並發量——同樣顯存下,MLA 能服務更多請求。
MLA 雖多了幾個低秩投影矩陣,但因矩陣秩遠低於原始維度,整體參數量增幅有限。兩種機制的取捨清晰:GQA 架構精簡、易理解;MLA 推理時記憶體效率更優,長序列場景競爭力更強。
遞迴更新矩陣 A 的特徵值必須落在 (0, 1) 才能跨深度穩定
OpenMythos 的核心是一個迴圈遞迴模組,每次迴圈用同一組共享參數更新隱藏狀態,類似 RNN 的時間展開。穩定性的關鍵在更新矩陣 A 的特徵值(此處為對角元素):若有元素 ≥ 1,遞迴發散;若有元素 ≤ 0,則出現震盪;唯有全數落在開區間 (0, 1) 才能讓跨深度推理保持穩定。
教學用 show_stability() 函式讀取 model.recurrent.injection.get_A(),輸出 A 的最小值、最大值、均值與穩定性布林判斷。初始化後兩個模型均通過。為壓力測試,教學刻意用 lr=1.0(正常訓練的 3000 倍以上)對 MLA 做 30 步極端訓練,隨後再度查驗——特徵值依然全數合規。
這說明 OpenMythos 對 A 實施了硬約束,推測是 sigmoid 參數化讓輸出天然落在 (0, 1),而非讓 A 自由更新。對工程師而言意義重大:訓練迴圈深度和推理迴圈深度不同,若缺少這層保障,訓練 T=3 後在 T=16 推理很可能直接發散崩潰。
用奇偶校驗任務實測:訓練 T=3,推理 T=16 準確率仍持續攀升
教學選用「累積奇偶校驗(cumulative parity)」作示範:輸入序列中每個元素是 1 或 2(對應 bit 0/1),模型需在每個位置預測到目前為止 bit 的累積奇偶性。這個任務具有天然遞推結構,必須記住前面每一步,迴圈越多理論上狀態追蹤越準確。
訓練配置固定 n_loops=3,跑 600 步,batch size 64,序列長 24,AdamW 優化器,學習率 3e-4。訓練完成後,在 512 個測試樣本上掃描 T=[1, 2, 3, 4, 6, 8, 10, 12, 14, 16] 各自的準確率,輸出格式附帶條形圖與「← trained here」標記,一眼可見訓練基準點在哪裡。
結果呈現教科書式的外推曲線:T=1、T=2 迴圈不足,準確率明顯低落;T=3 達到訓練時的水準;T 繼續加大,準確率持續爬升到 T=16 的峰值,整個過程參數一個也沒動。這個現象在奇偶校驗任務上尤其顯著,因結構天然適合遞推;能否推廣到語言建模等更複雜任務,仍是開放問題。
ACT 自適應停機:讓每個位置按難度決定要算幾輪
傳統 Transformer 對所有 token 施以相同計算深度,簡單的功能詞和複雜的多義詞花費一樣算力。ACT(自適應計算時間,Adaptive Computation Time) 為每個位置在每個迴圈輸出一個停機概率,當累積值超過閾值(本例 act_threshold=0.99)時,該位置提前停止後續迭代,理論上讓簡單 token 早早退出。
教學以 monkey-patch 方式掛 hook 在 model.recurrent.act.forward,收集 16 次迴圈的停機概率張量後,輸出形狀 (loops, positions) 的矩陣,並以 viridis 熱力圖視覺化(x 軸序列位置、y 軸迴圈次數、顏色代表停機概率)。程式同時印出每個迴圈的平均停機概率,方便快速判斷是否出現「全部 token 都跑滿迴圈」的退化情況。
ACT 的實際效益取決於任務難度分布:若大多數 token 在前 3-4 輪就達到停機閾值,跑 16 輪的真實計算量遠小於表面的 16 倍,效率大幅提升。ACT 也為深度外推提供了額外保障——即便推理時給的迴圈數比訓練時多,簡單位置不會白跑多餘的輪次。
MoE 路由均衡性驗證與三種迴圈深度的生成對比
MoE(混合專家,Mixture of Experts) FFN 層讓每個 token 只激活少數幾個子網路(本例 n_experts=4, n_experts_per_tok=2),其餘專家不參與計算,在不增加推理成本的前提下擴大模型有效容量。然而 MoE 有個常見陷阱:router 可能發生「贏家通吃」,少數專家承攬絕大多數 token,其餘退化為無效死神經元。
教學 hook 追蹤每次 MoE 前向傳遞的 topk 索引,統計在 32 個 batch、T=3 推理中各 expert 被選中的比例。均衡的路由應讓四個 expert 各佔約 25%,偏差過大是訓練不穩定的信號。輸出格式「expert X: Y% of topk slots」讓工程師一眼判斷路由健康度。
生成實驗以長度 8 的奇偶序列作 prompt,用 T_gen=[1, 4, 12] 各自生成後續 8 個 token,溫度 0.1、top-k=2,直接比較三種深度的輸出序列是否更貼合累積奇偶規律。教學最後把訓練 loss 曲線、深度外推準確率(紅色虛線標記 T=3)與 ACT 停機熱力圖合併為一張三面板圖輸出,提供整合式的視覺化摘要。
訓練深度固定 3 個迴圈,推理增到 16 次準確率仍持續爬升——OpenMythos 把「以推理算力換準確率」做成可執行、可量測的程式碼基準,是值得追蹤的計算自適應推理新方向。