플래시 어텐션: 작동 방식과 중요성

발행: (2026년 6월 10일 PM 08:20 GMT+9)
10 분 소요
원문: Dev.to

Flash Attention: 무엇을 하고 왜 중요한가

당신의 학습 작업은 A100을 시간당 $3에 사용하고 있습니다. 손실이 감소하고, 그래디언트가 흐르고, 모델의 손실 곡선은 교과서적인 로그 형태를 보이고 있습니다. 하지만 스텝 시간을 프로파일링하고 GPU가 실제로 무엇을 하고 있는지 보면 충격적인 사실을 발견합니다: GPU 연산 유닛이 4060% 정도 유휴 상태라는 것입니다. 병목 현상이 연산이 아니라 메모리 대역폭입니다. GPU의 HBM(고대역폭 메모리, A100에서는 1.52 TB/s)은 연산 유닛이 데이터를 소비하려는 속도를 따라가지 못합니다. 그리고 모든 트랜스포머 학습·추론 실행에서 가장 큰 메모리 트래픽은 바로 어텐션 연산이며, 이는 매 포워드 패스마다 전체 N × N 어텐션 행렬을 HBM에 읽고 쓰는 방식으로 이루어집니다.

Flash Attention은 바로 이 문제를 해결하기 위해 등장했습니다: 어텐션 연산을 GPU의 SRAM(칩 내 고속 메모리, A100에서는 약 20 MB) 안에서 완전히 처리하도록 타일링함으로써 중복된 HBM 트래픽을 없앱니다. 그 결과 어텐션에 의해 제한되는 워크로드에서 전체 2~4배의 속도 향상을 달성하면서도 정밀도 손실이 없으며, 모델을 변경할 필요도 없습니다.

표준 셀프 어텐션 레이어

단일 헤드의 셀프 어텐션 레이어는 Q, K, V라는 세 개의 행렬을 사용합니다. 각각의 형태는 (N, d)이며, 여기서 N은 시퀀스 길이, d는 헤드 차원입니다. 가장 단순한 구현은 다음과 같습니다.

Compute S = Q @ K^T          -- shape (N, N)
Compute P = softmax(S, dim=-1) -- shape (N, N)
Compute O = P @ V            -- shape (N, d)

핵심 비용은 SP가 각각 N × N 원소를 차지한다는 점입니다. 예를 들어 시퀀스 길이가 4096이고 d=128인 경우, 헤드당 1,600만 개의 원소가 필요합니다. FP16이면 헤드당 32 MB가 됩니다. 32개의 헤드가 있으면 전체 N × N 행렬은 1 GB에 달하는데, 이는 단일 A100 GPU의 약 20 MB SRAM보다 훨씬 큽니다. 표준 구현은 이 1 GB를 HBM에 쓰고(느리게), softmax를 위해 다시 읽고(HBM 읽기), 결과를 다시 쓰고(HBM 쓰기), 마지막으로 V와 곱하기 위해 또 다시 읽습니다.

Flash Attention은 이 N × N 행렬을 전혀 물리적으로 만들지 않고, SRAM에 들어갈 만큼 작은 블록으로 softmax 연산을 타일링함으로써 이를 회피합니다.

핵심 아이디어

Tri Dao와 스탠포드 그룹(2022)이 제시한 핵심 통찰은 어텐션 연산이 IO‑바운드이며 컴퓨트‑바운드가 아니라는 점입니다. 주요 비용은 HBM과 SRAM 사이의 데이터 이동입니다. A100에서는 SRAM 대역폭이 약 20 TB/s(연산 유닛 → SRAM)인 반면 HBM 대역폭은 약 2 TB/s에 불과합니다. 10배 차이죠. 연산을 SRAM 안에서만 수행하도록 구조화하면 이 차이를 활용할 수 있습니다.

알고리즘은 매우 직관적입니다.

  1. Q, K, V 행렬을 SRAM에 들어갈 만큼 작은 타일로 나눕니다.
  2. 각 타일에 대해 온라인 softmax 알고리즘(점진적으로 업데이트 가능한 안전한 softmax)을 사용해 부분 softmax를 계산합니다.
  3. 부분 결과를 출력에 누적하면서, 각 타일별 재스케일링 통계는 레지스터에 보관합니다.
  4. 레이어당 한 번만 최종 출력을 HBM에 기록합니다(헤드당 여러 번 읽고 쓰는 대신).

이것은 고전적인 타일링 기법이지만, softmax가 전역 정규화(전체 행에 대한 분모)를 필요로 하기 때문에 단순히 타일을 더하는 것만으로는 안 됩니다. 논문의 핵심 알고리즘 기여는 온라인‑세이프 softmax로, 각 타일이 로컬 softmax를 계산한 뒤 새로운 타일이 들어올 때마다 실행 중인 출력을 보정할 수 있게 합니다.

Flash Attention 한 단계 전방 패스 블록의 의사코드

def flash_attention_block(Q_block, K_block, V_block):
    # Q_block: (B_r, d), K_block: (B_c, d), V_block: (B_c, d)
    # B_r와 B_c는 SRAM에 맞게 선택된 타일 크기

    # 실행 중인 최대값과 정규화 분모 초기화
    m = -inf   # 행별 최대값 (수치 안정성을 위해)
    l = 0.0    # exp(x - m) 합계 (정규화 분모)
    O = zeros(B_r, d)

    for each K, V tile:
        S = Q_block @ K_tile.T        # 로컬 어텐션 스코어 (B_r, B_c)
        m_new = max(m, rowmax(S))     # 실행 중인 최대값 업데이트
        l_new = exp(m - m_new) * l + rowsum(exp(S - m_new))
        P = exp(S - m_new) / l_new    # 로컬 softmax
        O = (l * exp(m - m_new) / l_new) * O + P @ V_tile
        m, l = m_new, l_new

    return O

알고리즘은 Q, K, V를 HBM에서 한 번 읽고, 타일 단위로 SRAM에서 처리한 뒤, 최종 O를 HBM에 한 번 기록합니다. 반면 순수 구현은 시퀀스 길이 N에 대해 N × N 어텐션 행렬을 HBM에 읽고 쓰므로 O(N² d) 규모의 HBM 트래픽을 발생시킵니다. Flash Attention은 이를 O(N² d / M)(M은 SRAM 크기)로 감소시켜, SRAM 용량에 비례한 트래픽 감소를 이룹니다.

타일링이 전체 어텐션 행렬을 만들지 않는 과정 (다이어그램)

flowchart TB
    subgraph SRAM["GPU SRAM (~20 MB)"]
        QB[Q 타일\n(B_r x d)]
        KB[K 타일\n(B_c x d)]
        VB[V 타일\n(B_c x d)]
        ST[부분 S = QB @ KB^T\n(B_r x B_c)]
        OT[부분 O 누산기\n(B_r x d)]
    end
    subgraph HBM["GPU HBM (~40-80 GB)"]
        QF[전체 Q\n(N x d)]
        KF[전체 K\n(N x d)]
        VF[전체 V\n(N x d)]
        OF[전체 O\n(N x d)]
    end

    QF -->|한 번 읽음| QB
    KF -->|타일 단위로 읽음| KB
    VF -->|한 번 읽음| VB
    KB --> ST
    VB -->|부분 곱| OT
    OT -->|한 번 기록| OF

    style SRAM fill:#1e293b,stroke:#38bdf8,color:#e2e8f0
    style HBM fill:#0f172a,stroke:#64748b,color:#94a3b8

HBM → SRAM 간의 각 화살표는 느린 DMA 전송을 의미합니다. 순수 구현은 행당·헤드당 O(N)개의 전송을 수행하지만, Flash Attention은 K와 V를 두 번(읽기와 타일링)만 순회하고 O를 한 번만 기록합니다.

버전별 주요 개선 사항

VersionYear핵심 개선점순수 구현 대비 속도 향상주 대상 GPU
v12022타일링 + 온라인 softmax, O(N²) 회피A100 (Ampere)
v22023비‑행렬 연산 감소, 병렬성 향상, 2ⁿ이 아닌 길이 지원2~3.5×A100, H100
v32024‑2025H100 Tensor Core용 WGMMA, 비동기 파이프라인, FP8 지원3~7×H100/B200 (Hopper)

Flash Attention v2는 마스크 생성·스케일링에 필요했던 다수의 비‑행렬 연산을 크게 줄였습니다. 이는 Tensor Core가 순수 행렬 곱셈일 때 가장 효율적이기 때문에, 추가적인 원소별 연산이 있을 경우 활용도가 떨어지기 때문입니다. v2 논문에서는 65 M 파라미터 모델의 단일 포워드 패스가 기존 PyTorch 구현에서는 6.5 ms였던 것이 Flash Attention v2에서는 2.6 ms로 단축됐다고 보고했습니다.

Flash Attention v3(2024)는 H100 Hopper 아키텍처를 목표로 합니다. WGMMA(워프‑그룹 MMA) 명령을 사용해 타일링된 softmax 단계에서 데이터 이동과 연산을 겹치게 합니다. v1/v2의 동기식 SRAM 읽기는 비동기 복사로 대체돼 레이턴시를 숨깁니다.

0 조회
Back to Blog

관련 글

더 보기 »

Eidentic 소개

Today we're releasing Eidentic, an open-source TypeScript SDK for building AI agents with self-improving memory and the production fundamentals built in — not b...

Typescript의 타입

Introdução Tipos são uma forma de definir a “forma” ou o contrato dos dados que estamos usando no código. Pensando em Javascript puro, ele é dinâmico: você pode...