手写 CUDA FlashAttention(一):朴素实现

| 博客 | 14901 | 38分钟 | AIAI InfraFlashAttentionCUDA

1 朴素的Attention

1.1 Attention简介

关于 Attention 的介绍强烈推荐李宏毅老师的教学视频:自注意力机制和Transformer详细解析

1.1.1 从输入 XXQQKKVV

记模型的输入是长度 NN 的序列 XX,其中每个元素 xix_i 都是一个 dmodeld_{model} 维的向量:

X=[x1x2xN]xiRdmodelX = \begin{bmatrix} x_1 \\ x_2 \\ \vdots \\ x_N \end{bmatrix} \\ x_i \in \mathbb{R}^{d_{model}}

从输入prompt到得到 XX ,分为以下几个步骤:

① 输入句子
1/3
Attention is all you need
Embedding Layer
x1= [+0.50-0.05-0.49···+0.16-0.48-0.12]Attention
x2= [-0.05+0.09-0.14···-0.31+0.27-0.23]is
x3= [-0.49-0.14+0.46···+0.42+0.32-0.33]all
x4= [+0.09+0.18+0.27···-0.49-0.46-0.41]you
x5= [+0.49-0.23-0.38···+0.50-0.06-0.47]need
  1. 输入的句子经过分词器(tokenizer)处理,转换成一系列离散的token。
  2. 每一个token被映射到一个高维空间中,形成一个向量,这个过程叫做嵌入(embedding)。嵌入层将每个token转换成一个 dmodeld_{model} 维的向量。

NN 个 token 转换成向量后,纵向排列成一个 (N,dmodel)(N, d_{model}) 的矩阵 XX

然后将每一个 token 的向量 xix_i ,映射到查询(Query)、键(Key)和值(Value)三个不同的空间中:

qi=xiWQ,ki=xiWK,vi=xiWVq_i = x_i W^Q , \quad k_i = x_i W^K, \quad v_i = x_i W^V

从线性代数的角度看,WQRd×dmodelW^Q \in \mathbb{R}^{d \times d_{model}} 本质上定义了一个从原始表示空间 Rdmodel\mathbb{R}^{d_{model}} 到查询空间 Rd\mathbb{R}^{d} 的线性映射。对每个 token 向量 xix_i,乘以 WQW^Q 相当于将它投影到一个专门用于“提问”的子空间中。

将所有token看做一个整体,一起做映射,将公式简化为:

Q=XWQ,K=XWK,V=XWVQ = X W^Q , \quad K = X W^K, \quad V = X W^V
① 几何直觉:W^Q 对空间的线性变换
1/3
xᵢ 在原始空间 ℝᵈ 中的位置

1.1.2 Attention公式拆解

Attention计算公式如下:

Attention(Q,K,V)=softmax(QKTd)V\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d}}\right)V

我们逐步拆解这个公式。

1. QKTQK^T

QQ 的第 iiqiq_ixix_i 在查询空间中的表示,KK 的第 jjkjk_jxjx_j 在键空间中的表示.

向量的点积衡量了两个向量的相似度/相关性。直观的讲,qikjq_i \cdot k_j 衡量了 xjx_jxix_i 的重要程度.

① 交互:拖动 kⱼ 观察 qᵢ · kⱼ
1/3
qi·kj=2.5×1.0 + 1.5×2.5=+6.25
方向相近 → 点积为正 → 相关性高

2. QKTd\frac{QK^T}{\sqrt{d}}

为什么这里要除以 d\sqrt{d} 呢?

假设:qiq_ikjk_j 的每个分量都是独立的、均值为 00、方差为 11 的随机变量。

点积展开:

qikj=l=1dqilkjlq_i \cdot k_j = \sum_{l=1}^{d} q_{il} \cdot k_{jl}

分析每一项:每个 qilkjlq_{il} \cdot k_{jl} 是两个独立的零均值、单位方差随机变量的乘积:

每个 qilkjlq_{il} \cdot k_{jl} 是两个独立的零均值、单位方差随机变量的乘积:

  • E[qilkjl]=E[qil]E[kjl]=00=0E[q_{il} \cdot k_{jl}] = E[q_{il}] \cdot E[k_{jl}] = 0 \cdot 0 = 0
  • Var(qilkjl)=E[qil2]E[kjl2](E[qil]E[kjl])2=110=1\text{Var}(q_{il} \cdot k_{jl}) = E[q_{il}^2] \cdot E[k_{jl}^2] - (E[q_{il}] \cdot E[k_{jl}])^2 = 1 \cdot 1 - 0 = 1

对 d 项求和,由于各分量独立,方差可加:

E[qikj]=0,Var(qikj)=dE[q_i \cdot k_j] = 0, \quad \text{Var}(q_i \cdot k_j) = d

所以点积的标准差为 d\sqrt{d}。当 dd 很大时(如 64、128),点积的值会很大,导致 softmax 输出趋近 one-hot 向量(梯度接近 0,训练困难)。

除以d\sqrt{d} 后:

Var(qikjd)=Var(qikj)d=dd=1\text{Var}\left(\frac{q_i \cdot k_j}{\sqrt{d}}\right) = \frac{\text{Var}(q_i \cdot k_j)}{d} = \frac{d}{d} = 1

3. softmax(QKTd)\text{softmax}(\frac{QK^T}{\sqrt{d}})

对上一步得到的缩放点积结果,沿最后一个维度(即 Key 的序列维度)求 Softmax,将注意力分数归一化为概率分布。

直接计算 exp(sij)\exp(s_{ij}) 时,如果 sijs_{ij} 很大,会导致浮点溢出。实践中使用 safe softmax:先减去每行的最大值 mi=maxjsijm_i = \max_j s_{ij},再取指数:

αij=exp(sijmi)k=1nexp(sikmi),其中  sij=qikjd\alpha_{ij} = \frac{\exp(s_{ij} - m_i)}{\sum_{k=1}^{n} \exp(s_{ik} - m_i)}, \quad \text{其中} \; s_{ij} = \frac{q_i \cdot k_j}{\sqrt{d}}

减去 mim_i 不改变 softmax 的结果(分子分母同乘 emie^{-m_i} 可以约掉),但保证了指数的输入 0\leq 0,避免数值溢出。

归一化后,每一行的注意力权重之和为 1,表示当前 Query 对所有 Key 的关注程度分布。

Softmax: S → P (逐行归一化)
S
+0.25+0.74-0.78
+0.48-0.01+0.09
-0.33+0.19-0.28
N × N
softmaxper row
P
???
???
???
N × N

4. softmax(QKTd)V\text{softmax}(\frac{QK^T}{\sqrt{d}}) V

将注意力权重与 VV 矩阵相乘,得到最终的加权输出:

O=αVO = \alpha V

其中输出矩阵 OO 的每一行 oi=j=1nαijvjo_i = \sum_{j=1}^{n} \alpha_{ij} v_j,即对所有 Value 向量按注意力权重进行加权求和。直觉上,模型通过注意力分数决定”关注哪些位置的信息”,再从这些位置提取对应的 Value 进行聚合。

O = P · V (加权求和)
P
0.220.350.08
0.280.170.19
0.150.260.16
N × N
·
V
+0.42+0.35+0.17
-0.03-0.49+0.04
-0.36+0.39-0.34
N × d
=
O
???
???
???
N × d

2 编写朴素Attention的 cuda 算子

首先定义好基础结构,分配 d_S d_P 作为中间变量:

#include <cmath>
#include <cfloat>
#include <cstdio>
#include <cuda_runtime.h>

#define TILE 16
#define BLOCK_SIZE 256
#define cdiv(a, b) (((a) + (b) - 1) / (b))

void launch_naive_attention(
    const float* d_Q, const float* d_K, const float* d_V, float* d_O,
    int N, int d
) {
    float* d_S, *d_P;
    cudaMalloc(&d_S, sizeof(float) * N * N);
    cudaMalloc(&d_P, sizeof(float) * N * N);
    
    // 1. S = Q K^T / sqrt(d)
    dim3 block1(TILE, TILE);
    dim3 grid1(cdiv(N, TILE), cdiv(N, TILE));
    ScaledDotProductKernel<<<grid1, block1>>>(d_Q, d_K, d_S, N, d);

    // 2. P = softmax(S); per row
    dim3 block2(BLOCK_SIZE);
    dim3 grid2(cdiv(N, BLOCK_SIZE));
    SoftmaxKernel<<<grid2, block2>>>(d_S, d_P, N);

    // 3. O = P V
    dim3 block3(TILE, TILE);
    dim3 grid3(cdiv(N, TILE), cdiv(d, TILE));
    PVMultiplyKernel<<<grid3, block3>>>(d_P, d_V, d_O, N, d);

    cudaFree(d_S);
    cudaFree(d_P);
}

2.1 ScaledDotProductKernel

第一阶段,计算 S=QKT/dS = QK^T / \sqrt{d},本质是一个标准的矩阵乘法。输出矩阵 SSN×NN \times N,每个线程负责计算一个输出元素 SijS_{ij},因此按 TILE × TILE 划分 thread block,再根据输出矩阵的大小确定 grid 维度:

dim3 block1(TILE, TILE);
// 输出矩阵 S 是 N×N,每个线程计算一个 S_ij,grid 按 TILE 分块覆盖整个输出矩阵
dim3 grid1(cdiv(N, TILE), cdiv(N, TILE));
ScaledDotProductKernel<<<grid1, block1>>>(
    d_Q, d_K, d_S, N, d
);

device 侧,注意 d_Qd_K 在显存中的布局都是 N×dN \times d(行优先),我们并没有真正构造 KTK^T。计算 Sij=qikjS_{ij} = q_i \cdot k_j 时,只需让 Q 按第 row 行、K 按第 col 行去取同一维度 i 的元素做内积即可:

__global__ void ScaledDotProductKernel(const float* d_Q, const float* d_K, float* d_S, int N, int d) {
    int row = threadIdx.x + blockIdx.x * blockDim.x;
    int col = threadIdx.y + blockIdx.y * blockDim.y;
    
    if (row < N && col < N) {
        float sum = 0.0f;
        for (int i = 0; i < d; i++) {
            // Q[row, i] * K[col, i]  (K[col] 即 K^T 的第 col 列)
            sum += d_Q[row * d + i] * d_K[col * d + i];
        }
        d_S[row * N + col] = sum / sqrtf((float)d);
    }
}

2.2 SoftmaxKernel

第二阶段对 S 逐行求 softmax。Softmax 是按行独立的操作,所以每个线程处理一整行,不需要二维 block:

dim3 block2(BLOCK_SIZE);
dim3 grid2(cdiv(N, BLOCK_SIZE));
SoftmaxKernel<<<grid2, block2>>>(
    d_S, d_P, N
);

device 侧,每个线程负责 S 的一行,内部分三步:求最大值 m、求指数和、归一化写回 P。注意 expf 中要减去 m 防止溢出:

__global__ void SoftmaxKernel(const float* d_S, float* d_P, int N) {
    int row = blockIdx.x * blockDim.x + threadIdx.x;

    if (row < N) {
        // 1. 求本行最大值
        float m = -FLT_MAX;
        for (int i = 0; i < N; i++) {
            m = fmaxf(m, d_S[row * N + i]);
        }

        // 2. 求指数和 (减去 m 防止溢出)
        float sum = 0.0f;
        for (int i = 0; i < N; i++) {
            sum += expf(d_S[row * N + i] - m);
        }

        // 3. 归一化,写回 P
        for (int i = 0; i < N; i++) {
            d_P[row * N + i] = expf(d_S[row * N + i] - m) / sum;
        }
    }
}

2.3 PVMultiplyKernel

第三阶段计算 O=PVO = PV,又是一个标准矩阵乘法。PPN×NN \times NVVN×dN \times d,输出 OON×dN \times d。与第一阶段类似,按输出矩阵的形状配置 grid:

dim3 block3(TILE, TILE);
dim3 grid3(cdiv(N, TILE), cdiv(d, TILE));
PVMultiplyKernel<<<grid3, block3>>>(
    d_P, d_V, d_O, N, d
);

device 侧,每个线程计算 Oij=kPikVkjO_{ij} = \sum_{k} P_{ik} \cdot V_{kj}

__global__ void PVMultiplyKernel(const float* d_P, const float* d_V, float* d_O, int N, int d) {
    int row = threadIdx.x + blockIdx.x * blockDim.x;
    int col = threadIdx.y + blockIdx.y * blockDim.y;

    if (row < N && col < d) {
        float sum = 0.0f;
        for (int i = 0; i < N; i++) {
            sum += d_P[row * N + i] * d_V[i * d + col];
        }
        d_O[row * d + col] = sum;
    }
}

3 通用的优化算法

我们先不使用 FlashAttention 用到的算法,从更常用的cuda优化方案思考上述代码的优化空间。

3.1 分块矩阵乘法

dim3 block1(TILE, TILE);
dim3 grid1(cdiv(N, TILE), cdiv(N, TILE));
ScaledDotProductKernel<<<grid1, block1>>>(
    d_Q, d_K, d_S, N, d
);

__global__ void ScaledDotProductKernel(const float* d_Q, const float* d_K, float* d_S, int N, int d) {
    int row = threadIdx.x + blockIdx.x * blockDim.x;
    int col = threadIdx.y + blockIdx.y * blockDim.y;
    
    if (row < N && col < N) {
        float sum = 0.0f;
        for (int i = 0; i < d; i++) {
            sum += d_Q[row * d + i] * d_K[col * d + i];
        }
        d_S[row * N + col] = sum / sqrtf((float)d);
    }
}

重新审视这段代码的访存模式。一个 block 有 TILE × TILE 个线程,它们共同计算输出矩阵 S 中一个 TILE × TILE 的子块。为了计算这个子块,同一行的 TILE 个线程都需要读取 Q 的同一行,同一列的 TILE 个线程都需要读取 K 的同一行——也就是说,每一行 Q 数据被同 block 内的 TILE 个线程重复读取,每一行 K 数据同理

朴素实现中,这些重复读取全部走 HBM(显存),而 HBM 的带宽是 GPU 计算的主要瓶颈。解决方案是利用每个 SM 上的 shared memory(共享内存/SRAM):block 内的线程先协作地将所需数据从 HBM 搬运到 shared memory,之后所有线程都从 shared memory 读取,避免重复的 HBM 访问。

但 shared memory 容量有限(通常 48~164 KB),无法一次装下 Q 和 K 沿 d 维度的完整行。因此需要分块(tiling):将 d 维度切成若干个大小为 TILE 的片段,每次只加载一小块到 shared memory,计算出部分内积并累加,循环直到遍历完整个 d 维度。每个 block 需要分配 2 块 TILE × TILE 的 shared memory 空间,分别缓存 Q 和 K 的当前 tile。

分块矩阵乘法
0/24
Q
30-313-1-3101-11-12-22-3-132-2-223112223333-1-221-303-12-23-33-32-3-2230-3-2112333210
8×8
×KT
Kᵀ
1-23-33-211-2-3-21331-13-2-230-322-313-1-222-3330-2-3023-23-320-13-3112223331-12-33-33-3
8×8
=
S
································································
8×8
(0,0)
(0,1)
(1,0)
(1,1)
点击 ▶ 开始 — S 被分成 2×2 = 4 个 block,每个 block 做 2 次 tile 迭代
__global__ void ScaledDotProductKernel(const float* d_Q, const float* d_K, float* d_S, int N, int d) {
    __shared__ float s_Q[TILE][TILE];
    __shared__ float s_K[TILE][TILE];
    
    int tx = threadIdx.x;
    int ty = threadIdx.y;
    int row = tx + blockIdx.x * TILE;
    int col = ty + blockIdx.y * TILE;
    
    float sum = 0.0f;
    
    for (int i = 0; i < cdiv(d, TILE); i++) {
        // 第一步,视角切换,block内每个线程管理一个s_Q和s_K的线程
        if (row < N && i * TILE + ty < d) {
            // s_Q[tx][ty] = Q[row][i*TILE + ty]
            s_Q[tx][ty] = d_Q[row * d + i * TILE + ty];
        } else {
            s_Q[tx][ty] = 0.0f;
        }
        if (col < N && i * TILE + tx < d) {
            // s_K[tx][ty] = K^T[i*TILE + tx][col] = K[col][i*TILE + tx]
            s_K[tx][ty] = d_K[col * d + i * TILE + tx];
        } else {
            s_K[tx][ty] = 0.0f;
        }
        __syncthreads(); // 同步,所有线程现在的状态都一样,才可以重新分工。
        // 视角切换。一个线程用于计算 S[row][col] 的值
        for (int k = 0; k < TILE; k++) {
            sum += s_Q[tx][k] * s_K[k][ty];       
        }
        __syncthreads(); // 同步,线程马上要重新分工去读数据。
    }
    
    if (row < N && col < N) {
        d_S[row * N + col] = sum / sqrtf((float)d);
    }
}

同理,修改第三步 PVMultiplyKernel

__global__ void PVMultiplyKernel(const float* d_P, const float* d_V, float* d_O, int N, int d) {
    int tx = threadIdx.x;
    int ty = threadIdx.y;
    int row = tx + blockIdx.x * blockDim.x;
    int col = ty + blockIdx.y * blockDim.y;
    
    __shared__ float s_P[TILE][TILE];
    __shared__ float s_V[TILE][TILE];
    
    float sum = 0.0f;
    for (int i = 0; i < cdiv(N, TILE); i++) {
        if (row < N && i * TILE + ty < N) {
            s_P[tx][ty] = d_P[row * N + i * TILE + ty];
        } else {
            s_P[tx][ty] = 0.0f;
        }
        if (i * TILE + tx < N && col < d) {
            s_V[tx][ty] = d_V[(i * TILE + tx) * d + col];
        } else {
            s_V[tx][ty] = 0.0f;
        }
        __syncthreads();
        
        for (int k = 0; k < TILE; k++) {
            sum += s_P[tx][k] * s_V[k][ty];
        }
        __syncthreads();
    }
    
    if (row < N && col < d) {
        d_O[row * d + col] = sum;
    }
}

3.2 并行规约

SoftmaxKernel 中,现在每一个线程负责一行的 S ,需要 O(N)O(N) 的遍历每一行 3 次,并发性太低了。

重新分析 Softmax 我们需要做什么:

  1. 获取每一行的最大值

    之前的做法是,每个Block计算一行,让1个线程遍历这一行。这一个线程就需要递增 NN 次。

    如果可以同时让 BLOCK_SIZE 个线程一起计算最大值,递增的次数就减少了 BLOCK_SIZE 倍。

    一个线程负责 N / BLOCK_SIZE 个数据,统计一个局部最大值。

    所有线程执行完成后,得到 BLOCK_SIZE 个局部最大值。之后再对这 BLOCK_SIZE 个局部最大值求整体最大值

    因此我们需要 BLOCK_SIZE * sizeof(float) 的空间存局部最大值,N个BLOCK每个BLOCK负责一行,每块 BLOCK_SIZE 个线程。

    // 2. P = softmax(S);
    dim3 block2(BLOCK_SIZE);
    dim3 grid2(N);
    SoftmaxKernel<<<grid2, block2>>>(
        d_S, d_P, N
    );
    
    __global__ void SoftmaxKernel(const float* d_S, float* d_P, int N) {
        int row = blockIdx.x;
        int tid = threadIdx.x;
        
        __shared__ float sdata[BLOCK_SIZE];
        
        // 1. 求局部最大值
        sdata[tid] = -FLT_MAX;
        for (int i = tid; i < N; i += BLOCK_SIZE) {
            sdata[tid] = fmax(sdata[tid], d_S[row * N + i]);
        }
    }

    这里每个线程反复对 sdata[tid] 进行读写,虽然 SRAM 的读取速度已经很快了,但是寄存器的读写速度更高。
    可以做一个简单的优化:

    // 1. 求局部最大值
    float local_max = -FLT_MAX;
    for (int i = tid; i < N; i += BLOCK_SIZE) {
        local_max = fmax(local_max, d_S[row * N + i]);
    }
    sdata[tid] = local_max;

    之后对 sdataBLOCK_SIZE 个局部最大值进行归约。此时每个线程的职责要发生变换,所以需要先同步。

    __syncthreads();
    // 归约,求整行的最大值
    for (int stride = BLOCK_SIZE / 2; stride >= 1; stride >>= 1) {
        if (tid < stride) {
            sdata[tid] = fmax(sdata[tid], sdata[tid + stride]);
        }
        __syncthreads();
    }
    float m = sdata[0];
并行归约求最大值
0/7
Thread 0
Thread 1
Thread 2
Thread 3
S[row] (HBM)
3[0]
1[1]
7[2]
2[3]
5[4]
9[5]
0[6]
4[7]
8[8]
6[9]
1[10]
3[11]
2[12]
7[13]
4[14]
5[15]
点击 ▶ 开始 — 4 个线程并行处理 16 个元素,每个线程负责 4 个
  1. 指数求和

    同样的两步走,先求局部和,再用归约的算法求和。

    float local_sum = 0.0f;
    for (int i = tid; i < N; i += BLOCK_SIZE) {
        local_sum += expf(d_S[row * N + i] - m);
    }
    sdata[tid] = local_sum;
    __syncthreads();
    
    for (int stride = BLOCK_SIZE / 2; stride >= 1; stride >>= 1) {
        if (tid < stride) {
            sdata[tid] += sdata[tid + stride];
        }
        __syncthreads();
    }
    float sum = sdata[0];
  2. 逐个元素的自然指数除以指数和

    for (int i = tid; i < N; i += BLOCK_SIZE) {
        d_P[row * N + i] = expf(d_S[row * N + i] - m) / sum;
    }

3.3 线程粗化(COARSE)

下面回过头,继续优化分块矩阵乘法。

参考3.1的动画,计算 C=Q×KTC = Q \times K^T 时,分块矩阵乘法的核心公式为:

Cij=kQikKkjTC_{ij} = \sum_{k} Q_{ik} \cdot K^T_{kj}

C00C_{00} 为例,展开得:

C00=Q00K00T+Q01K10TC_{00} = Q_{00} \cdot K^T_{00} + Q_{01} \cdot K^T_{10}

可以看到,Q00Q_{00} 在计算 C00C_{00}C01C_{01} 时都会被用到:

C00=Q00K00T+Q01K10TC_{00} = \mathbf{Q_{00}} \cdot K^T_{00} + Q_{01} \cdot K^T_{10}
C01=Q00K01T+Q01K11TC_{01} = \mathbf{Q_{00}} \cdot K^T_{01} + Q_{01} \cdot K^T_{11}

也就是说,Q00Q_{00} 这个 TILE 被两组不同的线程分别从全局内存读取了两次。

线程粗化(Thread Coarsening)的思路是:既然 Q00Q_{00} 反正要被读进共享内存,那就让同一组线程一次性把 C00C_{00}C01C_{01} 都算完,而不是分给两组线程各算一个。

具体来说,原本两个线程块分别计算:

  • 线程块 A:读取 Q00Q_{00},计算 C00C_{00}
  • 线程块 B:读取 Q00Q_{00},计算 C01C_{01}

粗化之后,一个线程块同时完成:

  • 读取 Q00Q_{00}(只读一次)
  • 计算 C00C_{00}C01C_{01}

这样,Q00Q_{00} 从全局内存到共享内存的读取次数从 2 次降为 1 次,减少了冗余的全局内存访问。

线程粗化对比
0/16
Q 读取0
K^T 读取0
总 HBM 读取0
Q
2130031212033121
4×4
×K^T
K^T
1320201302313102
4×4
=
C
················
4×4
线程块分配 (4 个 Block)
C00Blk(0,0)
C01Blk(0,1)
C10Blk(1,0)
C11Blk(1,1)
点击 ▶ 开始 — Q(4×4) × K^T(4×4), TILE=2, 沿 d 做 2 次 tile 迭代
dim3 block1(TILE, TILE);
dim3 grid1(cdiv(N, TILE), cdiv(N, TILE * COARSE));

__global__ void ScaledDotProductKernel(const float* d_Q, const float* d_K, float* d_S, int N, int d) {
    int tx = threadIdx.x;
    int ty = threadIdx.y;
    int row = threadIdx.x + blockIdx.x * TILE;
    int col = threadIdx.y + blockIdx.y * TILE * COARSE;
    
    __shared__ float s_Q[TILE][TILE];
    __shared__ float s_K[TILE][TILE];
    
    float sum[COARSE] = {};
    
    for (int i = 0; i < cdiv(d, TILE); i++) {
        if (row < N && i * TILE + ty < d)
            s_Q[tx][ty] = d_Q[row * d + i * TILE + ty];
        else
            s_Q[tx][ty] = 0.0f;
        
        #pragma unroll
        for (int c = 0; c < COARSE; c++) {
            int cur_col = col + c * TILE;
            
            if (cur_col < N && i * TILE + tx < d)
                s_K[tx][ty] = d_K[cur_col * d + i * TILE + tx];
            else
                s_K[tx][ty] = 0.0f;
            __syncthreads();
            
            for (int k = 0; k < TILE; k++) {
                sum[c] += s_Q[tx][k] * s_K[k][ty];    
            }
            __syncthreads();
        }
    }
    
    float scale = 1.0f / sqrtf((float)d);
    for (int c = 0; c < COARSE; c++) {
        int cur_col = col + c * TILE;
        if (row < N && cur_col < N)
            d_S[row * N + cur_col] = sum[c] * scale;
    }
}

同理,对 PVMultiplyKernel 也可以做线程粗化。

性能对比

到此为止,我们对比一下原版和四个优化的性能差异

N朴素实现分块矩阵乘法并行归约线程粗化
2560.134 ms0.106 ms0.062 ms0.122 ms
5120.533 ms0.418 ms0.318 ms0.410 ms
10241.453 ms1.063 ms0.836 ms0.929 ms
16384160.555 ms89.030 ms63.558 ms56.797 ms

注意到引入了线程粗化后,N 较小时(256~1024),性能反而比并行归约差。这是因为线程粗化用 COARSE=4 将每个线程块的工作量扩大了 4 倍,同时线程块总数减少为原来的 1/4。当 N 较小时,线程块本来就不多,再减少 4 倍后 GPU 的 SM 无法被充分占满,并行度不足反而拖慢了速度。到 N=16384 时,线程块数量足够多,粗化减少的冗余读取才真正体现出优势。

很有成就感,性能逐步提升了接近3倍。接下来,我们继续。

3.4 合并访存(COALESCING)

其实前面的代码,都有一个关键的性能问题。

CUDA 线程线性化顺序是 tid = tx + ty * blockDim.x,所以,同一个 warp 里,tx = 0..15,ty 固定。

而前面的代码中:

row = tx + blockIdx.x * TILE
s_Q[tx][ty] = d_Q[row * d + i*TILE + ty]

相邻线程的地址间隔是 d,不连续,非合并访问。

所以概念中我们应该记住,tx一般是列,ty是行,如果有第三个维度,tz是深度。应该让 tx 永远保持在最外层,代码从 tx 的角度看是连续的。

修复方法是:交换 txty 的角色,让 tx 对应列,ty 对应行。这样同一 warp 内相邻线程(tx 连续)访问的是连续的内存地址。

row = ty + blockIdx.y * TILE    
col = tx + blockIdx.x * TILE    
s_Q[ty][tx] = d_Q[row * d + i*TILE + tx]

相邻线程(tx 连续)访问 d_Q[row * d + ...] 的连续偏移,实现合并访存。

相应地,host 侧也要调换 grid 的 x、y 维度:

dim3 block1(TILE, TILE);
dim3 grid1(cdiv(N, TILE * COARSE), cdiv(N, TILE));

K 的转置访问方式变化

Q 的修改很直观,但 K 需要额外思考一下。我们要加载的是 KTK^T 的一个 TILE,忽略块内偏移,只看块号:

k_row = i * TILE
k_col = blockIdx.x * TILE * COARSE + c * TILE

KTK^T 的视角加载:

s_K[ty][tx] = d_KT[k_row + ty][k_col + tx]

转置回 KK

s_K[ty][tx] = d_K[k_col + tx][k_row + ty]   // 这里展开就是 3.3 的写法

但这里 tx 在行序(外层索引)上,同一 warp 内 tx 连续会导致不同行的访问,不是合并访存。

方法一:所以我们交换 d_K 下标中的 tx 和 ty

s_K[ty][tx] = d_K[k_col + ty][k_row + tx]   // tx 到了内层,合并访存

这个交换不会改变取的是哪个 TILE,只是让 TILE 内部多了一次转置——存进共享内存的数据相比 3.3 是转置的。

没关系,后续计算时取 s_KTs\_K^T 就转回来了:

// 3.3: s_Q[tx][k] * s_K[k][ty]    — 直接用 s_K
// 3.4: s_Q[ty][k] * s_K[tx][k]    — 用 s_K 的转置,抵消加载时多出的转置

两次转置抵消,最终结果正确。

方法二:或者,我们可以在读取的时候,对s_K做一次转置:

s_K[ty][tx] = d_K[k_col + ty][k_row + tx]   
s_K[tx][ty] = d_K[k_col + ty][k_row + tx] // 写入s_K时转置

后续计算时就不用再次转置了,还是:

s_Q[tx][k] * s_K[k][ty]

3.5 Bank Confict

3.4 最后提到的对 K 的两种处理方式,乍一看感觉本质上完全一样吧?我们实际测试一下。

N线程粗化方法一:计算时再转置方法二:先转置
2560.122 ms0.081 ms0.077 ms
5120.407 ms0.324 ms0.310 ms
10240.929 ms0.756 ms0.708 ms
1638457.096 ms33.187 ms24.386 ms

非常的 Amazing 啊,这一个微小的改动竟然有这么明显地性能提升。

要理解方法二为什么更快,需要先了解 Bank Conflict

Shared Memory 的物理结构

Shared memory 被切成 32 个 bank,每个 bank 是独立的内存模块,可以同时服务一个请求。

地址分配规则很简单:第 n 个 float 在第 n % 32 号 bank

地址:  0   1   2   3  ...  31  32  33 ...
bank:  0   1   2   3  ...  31   0   1 ...

理想情况:一个 warp 32 个线程同时读 shared memory,如果每个线程访问不同的 bank,32 个 bank 并行服务,一个周期搞定。

Bank Conflict:一个 bank 在同一周期只能服务一个线程。如果多个线程访问同一个 bank,硬件会串行化这些请求。N-way conflict = 该 bank 同时被 N 个线程访问 = 需要 N 个周期。

Shared Memory Bank Conflict
0/16
k=3 固定,tx=0..15 访问 s_K[tx][3] — stride=1616%32=16,每 2 个线程撞同一 bank
32 Banks
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
Thread Access Pattern (k=3)
t0off=3B3
t1off=19B19
t2off=35B3
t3off=51B19
t4off=67B3
t5off=83B19
t6off=99B3
t7off=115B19
t8off=131B3
t9off=147B19
t10off=163B3
t11off=179B19
t12off=195B3
t13off=211B19
t14off=227B3
t15off=243B19
点击 ▶ 逐线程查看 s_K[tx][3] 的 bank 访问模式

具体分析 s_K 的两种读取方式

方法一 s_K[tx][k],k=3 固定,tx=0..15(同一个 warp 的前半部分):

s_K[0][3]  → offset = 0*16+3  = 3   → bank 3
s_K[1][3]  → offset = 1*16+3  = 19  → bank 19
s_K[2][3]  → offset = 2*16+3  = 35  → bank 3   ← 撞 tx=0
s_K[3][3]  → offset = 3*16+3  = 51  → bank 19  ← 撞 tx=1
s_K[4][3]  → offset = 4*16+3  = 67  → bank 3   ← 撞
...
s_K[7][3]  → offset = 7*16+3  = 115 → bank 19  ← 撞

bank 3 被 8 个线程同时访问 → 硬件串行执行 8 次 → 8-way conflict,慢 8 倍

根本原因:tx 每增加 1,offset 增加 TILE=16。而 16 % 32 = 1632 % 32 = 0,每隔两个线程就回到同一 bank,16 个线程只踩 2 个 bank。

方法二 s_K[k][tx],k=3 固定,tx=0..15:

s_K[3][0]  → offset = 3*16+0  = 48  → bank 16
s_K[3][1]  → offset = 3*16+1  = 49  → bank 17
s_K[3][2]  → offset = 3*16+2  = 50  → bank 18
...
s_K[3][15] → offset = 3*16+15 = 63  → bank 31

tx 每增加 1,offset 增加 1,bank 也增加 1 → 16 个线程访问 16 个不同 bank,无冲突

Bank conflict 的根源是步长。步长是 32 的因数时出问题;步长是 1(连续访问)时最好。

  • s_K[tx][k]:tx 走在”行”方向,步长=TILE=16 → 冲突
  • s_K[k][tx]:tx 走在”列”方向,步长=1 → 无冲突

Padding 消除 Bank Conflict

还有一个更简单的方法:把 s_K[TILE][TILE] 改成 s_K[TILE][TILE+1]

读取 s_K[tx][k] 的问题本质是步长 16 是 32 的因数。加了 +1 之后,每行占 17 个 float:

s_K[tx][k], k=3 固定, tx=0..15:

offset = tx * 17 + 3

tx=0:  3   → bank 3
tx=1:  20  → bank 20
tx=2:  37  → bank 5
tx=3:  54  → bank 22
tx=4:  71  → bank 7
tx=5:  88  → bank 24
...

步长从 16 变成 17,而 17 和 32 互质(gcd(17,32)=1),所以 16 个线程一定映射到 16 个不同 bank,冲突消失。

修改之后方法一和方法二的性能基本一致了:

N线程粗化方法一:计算时再转置方法二:先转置
2560.119 ms0.074 ms0.073 ms
5120.373 ms0.271 ms0.272 ms
10240.892 ms0.669 ms0.672 ms
1638458.079 ms23.597 ms25.350 ms