Dispatch-Aware Ragged Attention for Pruned Vision Transformers
ViT 剪枝後運算量減少 96%,延遲卻未改善,研究揭露 62 微秒的 API 調度開銷才是真正瓶頸。
- FlashAttention-2 在短序列運算面臨 62 微秒的調度開銷極限,導致 ViT 剪枝省下的算力無法反映在推論時間上。
- 基於 Triton 開發的精簡注意力核心去除了 Python 參數驗證與 C++ 綁定,將調度地板降至 40 微秒。
- 結合打包管線,該架構能在 25% 剪枝下達到每秒 4162 張影像的吞吐量,比未剪枝模型更快且準確度幾乎無損。
視覺變換器(Vision Transformer)在剪枝掉 80% 的標記(Token)後,雖然注意力機制的理論運算量大幅減少 96%,但若使用目前最先進的 FlashAttention-2 可變長度 API,執行延遲卻僅下降不到 1%。這個極端反直覺的現象,源自於一個隱形的效能瓶頸:在 ViT 常見的短序列(小於 197 個標記)運算中,實際的矩陣運算不到 1 微秒即可完成,但系統的調度開銷(Dispatch Overhead)卻固定吃掉約 62 微秒,讓原先預期的算力節省完全無法反映在真實的推論時間上。
標記剪枝與填充對齊讓 PyTorch 效能倒退
視覺變換器(ViT)已經成為影像分類、物件偵測與分割技術的主流骨幹網路。然而,其核心的「多頭自注意力機制」(Multi-Head Self-Attention, MHSA)的運算量會隨著標記(Token)數量的平方增長。為了減少運算成本,學界發展出豐富的「標記剪枝(Token Pruning)」方法(例如 DynamicViT、EViT、ATS 等),透過丟棄影像中缺乏資訊的區塊來降低運算負載。
然而,標準的深度學習框架(如 PyTorch)將批次資料表示為固定形狀的密集張量。當每張影像經過剪枝後剩餘的標記數量不一致時,現有的實作方式必須對每一個序列進行填充(Padding),使其長度對齊批次中最長的存活序列,並套用注意力遮罩。在 GPU 上,這種填充操作完全抵銷了剪枝省下的浮點運算量(FLOPs):硬體依然需要讀寫完整的填充張量,記憶體頻寬被浪費在遮罩位置上,且算術吞吐量被大量乘以零的無效運算消耗殆盡。
實測數據證實了這個困境:在每一種剪枝比例下,使用填充對齊的 PyTorch SDPA(scaled_dot_product_attention)吞吐量僅有未剪枝狀態的 0.62 倍。這意味著在標準框架下,進行剪枝並加上填充,執行速度反而比完全不剪枝還要慢。
FlashAttention-2 面臨的 62 微秒調度天花板
為了解決填充造成的浪費,最自然的作法是將存活的標記壓縮成連續的「參差不齊(Ragged)」緩衝區,並使用支援可變長度的注意力核心來處理。FlashAttention-2(FA2)提供了 flash_attn_varlen_func API,而 PyTorch 原生的 NestedTensor 也提供了類似的路徑。
但實際在 NVIDIA A100 GPU 上的測試卻揭示了一個重大盲區。研究團隊觀察到,在 DeiT-Base 模型上設定 80% 的標記剪枝(每張影像從 197 個標記減少至約 39 個)時,在批次大小(Batch Size)為 32 的情況下,FA2 varlen 的注意力延遲減少了不到 1%。主要原因在於,FA2 varlen 的設計目標是針對大型語言模型(LLM)的長文本場景(4K 到 128K 個標記),在該情境下調度成本相對運算時間微乎其微。
然而,當任務轉移到 ViT 剪枝後的短序列時,整個運算架構變成了「調度開銷受限(Dispatch-overhead bound)」而非運算受限。無論工作負載多小,Python 參數驗證、pybind11 C++ 綁定跨越、輸出張量分配以及 CUDA 啟動等主機端(Host-side)的調度路徑,都會無條件消耗約 62 微秒。這個固定下限徹底掩蓋了剪枝帶來的微秒級運算節省。
繞過綁定:基於 Triton 的 40 微秒雙向核心
為突破此瓶頸,研究團隊採用 OpenAI Triton 開發了一個極簡的雙向參差注意力核心(Bidirectional Ragged Attention Kernel)。這個核心刻意捨棄了 LLM 必備的因果遮罩(Causal Mask)與 KV-Cache 機制,專注於 ViT 推論所需的純雙向注意力運算。
透過 Triton 的 JIT 編譯器,系統能夠生成直接將參數寫入核心的啟動存根(Launch Stub),完全繞過 pybind11 的邊界與 C++ 的參數轉換,同時省去 FA2 必須配置的 softmax_lse 工作空間分配。這項精簡設計成功將調度開銷的地板從 62 微秒壓低至約 40 微秒(降低約 1.55 倍)。
在零剪枝、大批次(BS=64)的極端運算受限情境下,FA2 憑藉高度手工最佳化的 CUDA 核心(113 微秒)仍勝過 Triton 編譯器生成的程式碼(207 微秒)。但在剪枝後的小工作負載區間內,這 22 微秒的「調度開銷落差」成為決定整體效能的唯一關鍵,讓 Triton 核心在 ViT 推論中取得壓倒性優勢。
整合打包管線實現 2.24 倍端到端吞吐量提升
單純最佳化注意力核心並不夠,團隊進一步建構了完整的「打包–注意力–解包(pack–attend–unpack)」端到端管線。該管線首先計算累積序列長度(cu_seqlens),接著透過 Triton 核心將存活的標記透過記憶體合併(Coalesced)技術複製到連續的扁平張量中,執行注意力機制與 MLP 後,再提取分類標記(CLS Token)。
在 ImageNet-1K 驗證集的測試中,這個 Triton 參差管線展現了優異的擴展性。相較於填充對齊的 PyTorch SDPA,Triton 方案在批次大小 32 時達到了 2.04 倍 的吞吐量,在批次大小 512 時更提升至 2.24 倍。且由於這套架構獨立於剪枝演算法之外,它可以無縫支援 Threshold-ℓ2、DynamicViT、EViT 等多種剪枝方法。
在準確度與吞吐量的取捨邊界(Pareto Front)上,該管線取得了突破性進展:在 DeiT-S 模型上套用 25% 剪枝,整體吞吐量高達每秒 4,162 張影像(Top-1 準確度 81.7%),直接超越了完全未剪枝的 DeiT-S 基準線(每秒 3,780 張影像,準確度 82.2%),且最大的對數機率值(Logit)差異小於 0.007,確保了數值層級的等效性。
對注意力機制運算與模型優化的後續影響
這項研究重新定義了視覺模型優化的評估基準。過去的模型剪枝論文通常將端到端的效能提升,直接歸功於注意力機制浮點運算量(FLOPs)的減少。然而實測數據證明,在現代 GPU 的注意力核心架構下,真正從剪枝中省下執行時間的,其實是與標記數量呈線性相關的 MLP 層運算,而非早已被調度開銷卡死的注意力層。
這對於未來的軟硬體設計提出了明確的指引:在影像處理或其他積極採用標記壓縮技術的領域,序列長度正逐漸縮短。在這些情境下,底層框架開發者(如 PyTorch、FlashAttention 團隊)必須將「最小化可變長度 API 的調度開銷」視為首要任務。同時,未來的剪枝演算法研究,也應當摒棄以「填充對齊基準(Padded Baselines)」進行效能對比的作法,改為在真實的參差執行環境中進行基準測試,才能反映模型在真實世界部署的效能潛力。
邊緣推論與短序列模型的加速關鍵已不再單純是算力的堆疊,而是如何削減極限微秒級的系統調度開銷,這將徹底改變視覺 Transformer 的底層優化方向。