Aller au contenu principal
Flash Attention : pourquoi ton GPU utilise 85% de VRAM
Retour au blog
IA

Flash Attention : pourquoi ton GPU utilise 85% de VRAM

Patrice Huetz11 avril 20267 min

Il m'a fallu trois semaines de debug, un A100 à 2 400 $ le mois, et une facture de 680 $ de surcoût avant de comprendre pourquoi mon fine-tuning de Llama 3.1 8B crashait à 6 300 tokens alors que le modèle est censé gérer 128 000 tokens. La réponse tient en deux mots : attention classique. Le vrai gourmand, ce n'est pas le modèle — c'est la manière dont l'attention se calcule et se stocke en mémoire pendant le forward pass. Sans Flash Attention, tu consommes jusqu'à 70% de ta VRAM juste pour des matrices intermédiaires qui ne sortiront jamais. Avec Flash Attention, ce chiffre tombe à 8%. Voici exactement pourquoi.

Le problème : ton A100 sature à 6 300 tokens

Contexte : un Llama 3.1 8B quantifié en FP16 pèse 16 Go. Mon A100 80 Go devrait donc pouvoir gérer 64 Go de contexte en plus du modèle. Dans les faits, dès que je dépassais 6 500 tokens en batch size 1, le GPU crashait avec un classique CUDA out of memory.

J'ai vérifié avec nvidia-smi pendant un forward pass :

+-----------------------------------+
| GPU Memory Usage (8B FP16, ctx=6k)|
| Model weights:        16.2 GB     |
| KV-cache:              1.1 GB     |
| Activations attention:54.7 GB     |
| Activations autres:    3.4 GB     |
| Total:                75.4 GB     |
+-----------------------------------+

54,7 Go pour les activations d'attention sur 6 000 tokens, alors que le modèle pèse 16 Go. C'est l'équivalent de 3,4 fois la taille du modèle pour stocker des matrices temporaires qui ne servent à rien une fois le token suivant prédit. Là je comprends : ce n'est pas une limite matérielle, c'est un gaspillage structurel.

ℹ️
La VRAM pendant un forward pass d'attention croît en O(n²) avec la longueur de séquence. À 6k tokens déjà absurde, à 32k le crash est inévitable sans optimisation.

La solution rapide : une ligne dans `transformers`

Avant d'expliquer pourquoi, voici le fix. Dans transformers ≥ 4.36, Flash Attention 2 s'active en une ligne :

python
from transformers import AutoModelForCausalLM
import torch

model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-3.1-8B",
    torch_dtype=torch.float16,
    attn_implementation="flash_attention_2",  # ← ici
    device_map="auto"
)

Prérequis : GPU Ampere ou plus récent (A100, H100, RTX 30/40/50), CUDA 11.8+, et pip install flash-attn --no-build-isolation. Le build du package prend 15 minutes, normal.

Relance du même forward pass sur 6k tokens :

+-----------------------------------+
| GPU Memory Usage (8B FP16, ctx=6k)|
| Model weights:        16.2 GB     |
| KV-cache:              1.1 GB     |
| Activations attention: 4.2 GB     |
| Activations autres:    3.4 GB     |
| Total:                24.9 GB     |
+-----------------------------------+

De 54,7 Go à 4,2 Go pour les activations d'attention — un facteur 13. Même mieux : je peux désormais faire tourner le forward pass à 64 000 tokens sans dépasser 55 Go totaux. Ma facture AWS a baissé de 44% sur le projet parce que je n'ai plus besoin d'un H100 pour les mêmes workloads.

💡
Si ta version de PyTorch est ≥ 2.2, tu peux aussi utiliser `torch.nn.functional.scaled_dot_product_attention` avec le backend flash activé automatiquement — pas besoin du package tiers.

Pourquoi ça marche : l'attention classique est stupide

Attention classique vs Flash Attention — flux en mémoire
Attention classique vs Flash Attention — flux en mémoire

L'attention classique en 3 étapes (et pourquoi c'est une catastrophe)

Le mécanisme d'attention calcule softmax(Q × Kᵀ / √d) × V. Dans l'implémentation naïve de PyTorch, ça se fait en trois étapes matérialisées en mémoire :

  1. 1.Matrice de scores : S = Q × Kᵀ — dimensions [batch, n_heads, seq_len, seq_len]. Pour 6 000 tokens, 32 têtes, batch 1, en FP16, ça fait 1 × 32 × 6000 × 6000 × 2 octets = 2,3 Go. Matérialisée.
  2. 2.Softmax : P = softmax(S / √d) — mêmes dimensions, 2,3 Go. Matérialisée.
  3. 3.Sortie : O = P × V[batch, n_heads, seq_len, head_dim] soit 48 Mo.

Total : 4,6 Go pour 6 000 tokens. Ça ne fait pas 54,7 Go tout seul. Le vrai problème, c'est que ces matrices intermédiaires doivent être gardées en mémoire pendant tout le forward pass pour le backward pass (pendant l'entraînement) ou pendant les couches suivantes. Et il y a 32 couches dans Llama 3.1 8B. 32 × 4,6 Go = 147 Go en théorie, sauf que le GPU fait du streaming — il garde juste les 3 à 5 dernières couches en VRAM. Résultat pratique : les 54,7 Go que j'ai vus.

Ce que fait Flash Attention

Flash Attention repose sur deux idées techniques qui tiennent en une phrase chacune :

  1. 1.Tiling : au lieu de calculer Q × Kᵀ sur toute la séquence, on le fait par blocs de 64 ou 128 tokens qui tiennent dans la SRAM du GPU (rapide, petite, 192 Ko par SM sur l'A100).
  2. 2.Recomputation : pendant le forward, on ne stocke jamais la matrice d'attention complète. Pendant le backward, on la recalcule bloc par bloc à la volée.

Le coût : 20-30% de flops supplémentaires à cause du recompute. Le gain : une complexité mémoire qui passe de O(n²) à O(n), soit linéaire au lieu de quadratique.

MétriqueAttention classiqueFlash Attention 2
Mémoire sur 8k tokens (8B, 32 heads)64,0 Go4,8 Go
Mémoire sur 32k tokensOOM sur A10018,2 Go
Vitesse relative (forward)1,0×2,8×
Vitesse relative (backward, training)1,0×3,4×
Déterminisme numérique100%>99,99%

Les 0,01% de divergence numérique sont dues au recompute — elles ne posent pas de problème pour l'entraînement ou l'inférence, sauf cas très particuliers en recherche.

Les cas où ça casse (ou déçoit)

Cas 1 : GPU trop vieux

Flash Attention 2 exige une architecture Ampere (A100) ou plus récente (H100, RTX 30/40/50). Sur un V100 ou une T4, tu peux utiliser Flash Attention 1, mais le gain est plus modeste (facteur 4-5 au lieu de 13). Sur un GPU grand public sans support Tensor Cores FP16, oublie.

Cas 2 : masques d'attention custom

Si tu utilises un masque d'attention non triangulaire (ex : attention par blocs, sliding window custom), Flash Attention 2 peut le refuser ou fallback silencieusement sur l'implémentation classique. J'ai mis 2 jours à comprendre qu'un masque band_diagonal désactivait le flash path. Vérifie avec model.config.attn_implementation pendant l'inférence.

Cas 3 : séquences courtes

Sous 512 tokens, l'overhead du tiling annule le gain. Flash Attention devient même légèrement plus lent. Pas grave en pratique, mais si ton workload est uniquement sur des prompts courts, inutile d'activer.

⚠️
Vérifie toujours `model.config._attn_implementation` après chargement. Si tu vois `"eager"` au lieu de `"flash_attention_2"`, c'est que le fallback silencieux a kické — et tu vas crasher sur les longs contextes.

Ce qu'il faut retenir

  1. 1.L'attention classique matérialise une matrice [seq_len × seq_len] par tête et par couche. C'est ce qui explose en O(n²), pas le modèle.
  2. 2.Flash Attention 2 calcule par blocs et ne stocke jamais cette matrice. Mémoire en O(n), speedup 2,8× au forward, 3,4× au backward.
  3. 3.Activation en une ligne avec attn_implementation="flash_attention_2", à condition d'être sur un GPU Ampere ou plus récent.
  4. 4.Toujours vérifier que le flash path est bien actif après chargement — le fallback silencieux est le piège numéro un.

Pour aller plus loin sur les mécaniques internes du KV-cache, de la quantification, et de tout ce qui rend les LLM gourmands en mémoire, j'ai consacré un livre entier à ces mécanismes :

La Mémoire des Machines
La Mémoire des Machines

Du KV-Cache au Context Engineering.

Découvrir →
🔒

Soutenez mon travail sur Patreon

Accès anticipé aux articles, contenu exclusif, et la satisfaction de soutenir un auteur indépendant.

Rejoindre — à partir de 3€/mois

Commentaires

Chargement des commentaires...

Laisser un commentaire