Constraint-based Pre-training: From Structured Constraints to Scalable Model Initialization
WeiT 將知識與尺寸解耦,只需不到一分鐘微調,即可完成任意大小模型的高效初始化。
- 常規預訓練模型受限於固定尺寸,更改大小常面臨極高重訓成本。
- WeiT 透過克羅內克約束將通用知識封裝於模板,再由輕量縮放器重組參數。
- 只需幾百步梯度優化,新尺寸模型便能無縫繼承原有知識,實現高效部署。
常規的預訓練模型多綁定單一尺寸,若要適應邊緣硬體限制,往往得耗費數百小時重新訓練。東南大學團隊提出的 WeiT 架構,僅需短短幾分鐘與幾百步梯度更新,就能讓新尺寸模型繼承原始知識,打破重新蒸餾的算力魔咒。
ViT-B等固定配置的侷限與模型轉換的昂貴成本
微調預訓練模型已成為機器學習的主流路徑,特別在數據稀缺場景下,從零訓練像是 ViT (Vision Transformer) 等現代架構幾乎不可行。然而,現實環境的邊緣設備部署受到記憶體、運算資源與延遲的嚴格限制,迫使開發者必須採用各種不同大小的模型。市面上的開源模型大多只提供有限的配置,例如常見的 ViT-B (12層深)。這導致不符合預設尺寸的目標模型陷入必須重新大規模預訓練的困境。過去的解決方案如層級剪枝會直接破壞模型原有的結構化知識。而知識蒸餾雖較靈活,卻必須為每一種新尺寸重新進行繁重的運算,產生龐大的計算開銷與時間成本。
WeiT運用克羅內克約束分離模板與輕量縮放器
為了讓與尺寸無關的知識能夠被抽離並封裝,研究團隊提出了基於約束的預訓練 (Constraint-based Pre-training) 範式。這項技術將可變尺寸模型的初始化視為多任務適應問題,並據此開發出 WeiT 演算法。該方法將原本複雜的神經網路參數拼接成單一矩陣,並施加 Kronecker-based (克羅內克) 約束進行正則化。WeiT 將權重拆解為兩個部分:編碼了跨尺寸通用知識的權重模板 (Weight Templates),以及負責拼接聚合的輕量級權重縮放器 (Weight Scalers)。研究團隊同時引入低秩瓶頸,迫使模板在深度與寬度上最大程度地共享,實現了跨規模的高效知識傳遞。
模板縮放機制透過幾百步優化完成目標尺寸適應
早期針對多尺寸初始化的研究通常只支援深度擴展,在調整寬度時往往顯得捉襟見肘。WeiT 為了補足此缺陷,特別引入了模板縮放機制 (Template Scaling Mechanism)。該機制在預訓練階段對權重模板應用結構化的 Dropout 技術,讓模板在每次前向傳播時隨機調整有效寬度。這項作法避免了模型過度擬合於單一維度,讓其能靈活應對下游目標模型變窄或變寬的複雜需求。進入實際初始化階段時,工程師只要凍結這些充滿結構化知識的模板,並針對新模型實例化輕量級縮放器即可。由於負責適應尺寸的縮放器僅包含幾千個參數,系統只需使用極少量數據進行約 0.16 個 Epoch(數百步的梯度更新),短短不到一分鐘即可收斂完畢,達成了近乎零成本的參數高效初始化。
於DeiT分類與DiT擴散模型中展現SOTA性能
在多樣化的視覺基準測試中,WeiT 展現出極強的尺寸適應能力。研究團隊首先以 DeiT-B (12層深、12個注意力頭) 進行預訓練,隨後將其擴展與縮小至不同的架構尺寸。實驗數據證實,相比於傳統啟發式方法,WeiT 在初始化後成功保留了更高的表徵保真度。而在圖像生成 (Image Generation) 領域,團隊選用 DiT-L (Diffusion Transformers) 擴散模型作為骨幹進行嚴格測試。由於擴散模型對權重初始化極度敏感,以往方法在調整寬度時經常破壞原有的層級去噪結構。WeiT 克服了這項缺陷,在跨尺寸轉換時實現了顯著更低的 FID (Fréchet Inception Distance,用以衡量生成影像品質) 分數。即使面臨與訓練集差異極大的下游任務,其生成品質甚至超越了耗時費力的全參數微調。
突破架構限制並在具身控制任務提升數據效率
這套基於約束的預訓練範式不僅適用於 Transformer 架構,也能無縫延伸至卷積神經網路。團隊將 WeiT 應用於現代卷積架構 ConvNeXt-v2 上,證明了克羅內克約束同樣能捕捉卷積參數中的可轉移先驗知識。此外,在重視動態穩定性的具身控制 (Embodied Control) 任務中,WeiT 也展現了處理複雜強化學習環境的潛力。測試顯示,由 WeiT 初始化的小型策略網路,即使僅有 1 層深與 1 個注意力頭,也能在容量受限狀況下獲取具競爭力的累積獎勵。這種結構化初始化降低了對機器人形態變化的敏感度,在未知地形中實現了高數據效率的穩定學習。
WeiT 透過約束預訓練將知識與尺寸解耦,僅需分鐘級微調即可完成新模型初始化,大幅削減邊緣部署的算力成本。