RMSNorm算子的CUDA实现
目录
- 前言
- 1. 算子的计算公式
- 2. 算子的调用接口
- 3. CPU算子的实现
- 4. CUDA上的算子实现
- 5. 补充
- 总结
- 参考
前言
KuiperInfer 作者推出的《自制大模型推理框架课程》,链接。记录下个人学习笔记,仅供自己参考😄
video:自制Cuda大模型推理框架-RMSNorm的Cuda实现
repo:https://github.com/zjhellofss/KuiperLLama
本次课程我们来学习第五次课程—RMSNorm 算子的 CUDA 实现
课程大纲可以看下面的思维导图
1. 算子的计算公式
Transformer 中的 Normalization 层一般都是采用 LayerNorm 来对 Tensor 进行归一化,LayerNorm 的公式如下:
L a y e r N o r m : y = x − E [ x ] V a r [ x ] + ϵ ∗ γ + β E [ x ] = 1 N ∑ i = 1 N x i V a r [ x ] = 1 N ∑ i = 1 N ( x i − E [ x ] ) 2 \begin{aligned} LayerNorm:y & =\frac{x-E[x]}{\sqrt{Var[x]+\epsilon}}*\gamma+\beta \\ E[x] & =\frac{1}{N}\sum_{i=1}^Nx_i \\ Var[x] & =\frac{1}{N}\sum_{i=1}^N(x_i-E[x])^2 \end{aligned} LayerNorm:yE[x]Var[x]=Var[x]+ϵx−E[x]∗γ+β=N1i=1∑Nxi=N1i=1∑N(xi−E[x])2
而 RMSNorm 就是 LayerNorm 的变体,RMSNorm 省去了求均值的过程,也没有了偏置 β \beta β,即:
R M S N o r m : y = x M e a n ( x 2 ) + ϵ ∗ γ M e a n ( x 2 ) = 1 N ∑ i = 1 N x i 2 \begin{aligned} RMSNorm:y & =\frac{x}{\sqrt{Mean(x^{2})+\epsilon}}*\gamma \\ Mean(x^{2}) & =\frac{1}{N}\sum_{i=1}^Nx_i^2 \end{aligned} RMSNorm:yMean(x2)=Mean(x2)+ϵx∗γ=N1i=1∑Nxi2
其中 γ \gamma γ 和 β \beta β 为可学习的参数
2. 算子的调用接口
KuiperLLama 中 RMSNorm 算子的调用接口代码在 kernels/kernels_interface.h#L30 中,具体内容如下:
typedef void (*RMSNormKernel)(const tensor::Tensor& input, const tensor::Tensor& weight,
const tensor::Tensor& output, void* stream);
上述代码定义了一个函数指针类型,RMSNormKernel
是定义的一个类型别名,它代表一个特定的函数指针类型,其函数签名含义如下:(from ChatGPT)
void
表示该函数没有返回值(*RMSNormKernel)
表示RMSNormKernel
是一个指向函数的指针,函数的参数为:input
:一个常量引用类型的tensor::Tensor
对象,输入张量即 x x xweight
:权重张量即 γ \gamma γoutput
:输出张量即 y y ystream
: 一个void
指针,即 CUDA Stream
使用示例如下,假设有一个具体的函数实现:
void myRMSNormKernel(const tensor::Tensor& input, const tensor::Tensor& weight,
const tensor::Tensor& output, void* stream) {
// 实现 RMSNorm 操作的具体逻辑
}
现在可以通过如下方式定义一个 RMSNormKernel
类型的指针,并将其指向 myRMSNormKernel
:
RMSNormKernel kernel = &myRMSNormKernel; // 将函数指针指向 myRMSNormKernel
然后通过 kerenl
指针调用该函数:
kernel(inputTensor, weightTensor, outputTensor, stream);
在 KuiperLLama 中调用代码在 kernels/kernels_interfaces.cpp#L114,具体内容如下:
RMSNormKernel get_rmsnorm_kernel(base::DeviceType device_type) {
if (device_type == base::DeviceType::kDeviceCPU) {
return rmsnorm_kernel_cpu;
} else if (device_type == base::DeviceType::kDeviceCUDA) {
return rmsnorm_kernel_cu;
} else {
LOG(FATAL) << "Unknown device type for get an rmsnorm kernel.";
return nullptr;
}
}
KuiperLLama 中针对不同的硬件平台(CPU 和 GPU)提供了不同的实现,主要是通过 RMSNormKernel 这个函数指针来动态选择对应的实现,其中 rmsnorm_kernel_cpu
是 CPU 后端下 RMSNorm 的实现,rmsnorm_kernel_cu
是 CUDA 后端下 RMSNorm 核函数实现的启动函数
Note:在 CUDA 中,一个流(Stream)是一个序列化的任务队列,它允许开发者将一系列的 CUDA 操作(如内存复制、内核执行等)放入一个队列中,然后按顺序执行这些操作,而不需要等待前面的操作完全完成
3. CPU算子的实现
void rmsnorm_kernel_cpu(const tensor::Tensor& input, const tensor::Tensor& weight,
const tensor::Tensor& output, void* stream = nullptr);
可以看到我们在 CPU 上的算子定义符合前面所讲的函数指针类型 RMSNormKernel 的要求,也就是具有一个输入、输出以及权重,值得注意的是 CPU 上的 stream 指针参数恒为空,这是因为 stream 流是只有在 GPU 上才有的机制
CPU 上 RMSNorm 算子的具体实现如下:
void rmsnorm_kernel_cpu(const tensor::Tensor& input, const tensor::Tensor& weight,
const tensor::Tensor& output, void* stream) {
UNUSED(stream);
CHECK(!input.is_empty());
CHECK(!weight.is_empty());
CHECK(!output.is_empty());
CHECK(input.device_type() == base::DeviceType::kDeviceCPU &&
weight.device_type() == base::DeviceType::kDeviceCPU &&
output.device_type() == base::DeviceType::kDeviceCPU);
const float* in_ptr = input.ptr<float>();
const float* wei_ptr = weight.ptr<float>();
const float* out_ptr = output.ptr<float>();
const int32_t dim = static_cast<int32_t>(input.size());
arma::fvec in_tensor(const_cast<float*>(in_ptr), dim, false, true);
arma::fvec out_tensor(const_cast<float*>(out_ptr), dim, false, true);
arma::fvec wei_tensor(const_cast<float*>(wei_ptr), dim, false, true);
#ifdef QWEN2_SUPPORT
const float eps = 1e-6f;
#else
const float eps = 1e-5f;
#endif
const float mean = arma::as_scalar(arma::mean(arma::pow(in_tensor, 2))) + eps;
const float rsqrt = 1.f / std::sqrt(mean);
out_tensor = wei_tensor % (rsqrt * in_tensor);
}
Note:具体代码在 kuiper/source/op/kernels/cpu/rmsnorm_kernel.cpp
首先先检查输入张量的有效性以及设备类型是否存在问题,接着获取输入张量、权重以及输出张量的指针并计算输入张量的维度大小,然后使用 Armadillo 库处理张量:
arma::fvec in_tensor(const_cast<float*>(in_ptr), dim, false, true);
arma::fvec out_tensor(const_cast<float*>(out_ptr), dim, false, true);
arma::fvec wei_tensor(const_cast<float*>(wei_ptr), dim, false, true);
arma::fvec
是 Armadillo 库中的向量类型,它提供了许多高效的数学操作如向量加法、点积、均值计算等
#ifdef QWEN2_SUPPORT
const float eps = 1e-6f;
#else
const float eps = 1e-5f;
#endif
const float mean = arma::as_scalar(arma::mean(arma::pow(in_tensor, 2))) + eps;
const float rsqrt = 1.f / std::sqrt(mean);
out_tensor = wei_tensor % (rsqrt * in_tensor);
按照公式先计算输入张量的平方均值 mean
来获取归一化因子 rsqrt
,接着对输入张量 in_tensor
中的每个元素乘以归一化因子 rsqrt
即代码中的 (rsqrt * in_tensor)
最后将其与权重张量逐元素相乘,得到最终的归一化输出张量即代码中的 wei_tensor % (rsqrt * in_tensor)
,%
运算符表示元素级乘法(类似于向量的点乘)
4. CUDA上的算子实现
我们知道 CUDA 编程模型 SIMT(Single Instruction Multiple Threads)包含以下几个核心组成部分:
- 线程(Thread)是最小执行单元
- 线程块(Block)是线程的集合,多个线程一起执行任务
- 网格(Grid)是线程块的集合,多个线程块一起执行任务
- 流式多处理器(SM)是负责调度和执行线程块的硬件单元
- Warp 是 SM 内部调度的基本单元,由 32 个线程组成
值得注意的是 Thread、Block 和 Grid 是逻辑上的抽象概念,物理层面硬件上并不存在,而 SM、Warp 是物理上实际存在的硬件
CUDA 上实现 RMSNorm 算子需要大家对 CUDA 编程有一些基本的了解,大家感兴趣的话可以看看:CUDA编程入门极简教程
CUDA 实现 RMSNorm 算子的核函数如下:
static __global__ void row_rmsnorm_f32(float* in, float* wei, float* out, int size, float eps) {
const int tid = threadIdx.x;
constexpr int pack_size = 4;
const int pack_num = size / pack_size;
const int pack_off = pack_size * pack_num;
float sum = 0.0f;
float4* in_pack = reinterpret_cast<float4*>(in);
for (int i = tid; i < pack_num; i += blockDim.x) {
float4 in_float4 = *(in_pack + i);
sum += in_float4.x * in_float4.x;
sum += in_float4.y * in_float4.y;
sum += in_float4.z * in_float4.z;
sum += in_float4.w * in_float4.w;
}
for (int i = pack_off + tid; i < size; i += blockDim.x) {
sum += in[i] * in[i];
}
using BlockReduce = cub::BlockReduce<float, BLOCK_DIM>;
__shared__ typename BlockReduce::TempStorage temp;
__shared__ float shared_val;
sum = BlockReduce(temp).Sum(sum);
if (threadIdx.x == 0) {
shared_val = sum;
}
__syncthreads();
sum = shared_val;
const float scale = rsqrtf(sum / static_cast<float>(size) + eps);
float4* wei_pack = reinterpret_cast<float4*>(wei);
float4* out_pack = reinterpret_cast<float4*>(out);
for (int i = tid; i < pack_num; i += blockDim.x) {
float4 in_float4 = *(in_pack + i);
float4 wei_float4 = *(wei_pack + i);
*(out_pack + i) =
make_float4(scale * in_float4.x * wei_float4.x, scale * in_float4.y * wei_float4.y,
scale * in_float4.z * wei_float4.z, scale * in_float4.w * wei_float4.w);
}
for (int i = pack_off + tid; i < size; i += blockDim.x) {
out[i] = wei[i] * in[i] * scale;
}
}
Note:具体代码在 kuiper/source/op/kernels/cuda/rmsnorm_kernel.cu#L5
核函数的输入参数如下:
in
:输入数据数组,类型为float*
wei
:权重数组,类型为float*
out
:输出数据数组,类型为float*
size
:输入数据的大小eps
:一个小的常数,用于防止除零错误
Note:核函数的启动参数为 <<<1, threads_num, 0, stream_>>>
其中 threads_num
为 128。因此该核函数使用的是 1D 线程块布局,即整个 kernel 只有一个线程块(Block),且线程块内只有 128 个线程(Thread),全部在线程块的 x 维度中分布
核心计算步骤如下:(from DeepSeek)
1. 模板参数与线程索引
template <int32_t BLOCK_DIM>
static __global__ void row_rmsnorm_f32(float* in, float* wei, float* out, int size, float eps) {
const int tid = threadIdx.x;
- 模板参数
BLOCK_DIM
:指定线程块的大小(如 128),用于优化共享内存和归约操作 - 线程索引
tid
:当前线程在块内的局部索引(0 到BLOCKDIM-1
)
2. 向量化内存访问
constexpr int pack_size = 4;
const int pack_num = size / pack_size;
const int pack_off = pack_size * pack_num;
pack_size = 4
:使用float4
类型进行向量化访问,每个线程一次读取/写入 4 个浮点数,提高内存吞吐量pack_num
:完整float4
分组的数量pack_off
:处理完完整float4
分组后的起始偏移(用于处理剩余元素)
3. 计算平方和(向量化部分)
float sum = 0.0f;
float4* in_pack = reinterpret_cast<float4*>(in);
for (int i = tid; i < pack_num; i += blockDim.x) {
float4 in_float4 = *(in_pack + i);
sum += in_float4.x * in_float4.x;
sum += in_float4.y * in_float4.y;
sum += in_float4.z * in_float4.z;
sum += in_float4.w * in_float4.w;
}
- 向量化读取输入:将输入指针
in
转换为float4
类型,每个线程处理pack_num
个float4
元素,步长为blockDim.x
- 平方和累加:对每个
float4
的四个分量计算平方并累加到sum
4. 处理剩余元素(非向量化部分)
for (int i = pack_off + tid; i < size; i += blockDim.x) {
sum += in[i] * in[i];
}
- 剩余元素处理:处理无法组成完整
float4
的剩余元素,每个线程处理一个元素,步长为blockDim.x
5. 块内归约求和
using BlockReduce = cub::BlockReduce<float, BLOCK_DIM>;
__shared__ typename BlockReduce::TempStorage temp;
__shared__ float shared_val;
sum = BlockReduce(temp).Sum(sum);
if (threadIdx.x == 0) {
shared_val = sum;
}
__syncthreads();
sum = shared_val;
- CUB 块内归约:使用 CUB 库的
BlockReduce
对所有线程的sum
求和,结果存储在shared_val
- 共享内存同步:通过
__syncthreads()
确保所有线程读取到最终的sum
6. 计算归一化比例因子
const float scale = rsqrtf(sum / static_cast<float>(size) + eps);
- 归一化公式:计算 RMS 归一化回的比例因子
scale
,公式为: s c a l e = 1 1 N ∑ x i 2 + ϵ scale = \frac{1}{\sqrt{\frac{1}{N}\sum x_i^2+\epsilon}} scale=N1∑xi2+ϵ1 rsqrtf
:CUDA 内置函数,可用来快速计算平方根倒数
7. 应用归一化和权重(向量化部分)
float4* wei_pack = reinterpret_cast<float4*>(wei);
float4* out_pack = reinterpret_cast<float4*>(out);
for (int i = tid; i < pack_num; i += blockDim.x) {
float4 in_float4 = *(in_pack + i);
float4 wei_float4 = *(wei_pack + i);
*(out_pack + i) =
make_float4(scale * in_float4.x * wei_float4.x, scale * in_float4.y * wei_float4.y,
scale * in_float4.z * wei_float4.z, scale * in_float4.w * wei_float4.w);
}
- 向量化写入输出:将权重
wei
和输出out
转换为float4
,每个线程将输入乘以scale
和权重,结果写入输出
8. 处理剩余元素(非向量化部分)
for (int i = pack_off + tid; i < size; i += blockDim.x) {
out[i] = wei[i] * in[i] * scale;
}
- 剩余元素处理:对无法组成
float4
的剩余元素逐个处理,应用相同的缩放和权重
上述 RMSNorm 的核函数实现代码通过高效的向量化和线程归约计算输入张量的均方根值,并应用标准化因子对输入进行归一化,同时乘以权重张量。核心步骤包括:首先计算输入张量的平方和,使用 cub::BlockReduce
进行线程间归约,然后计算标准化因子,并将其应用于输入张量和权重的逐元素乘积,最后输出标准化后的结果。该实现通过使用 float4
向量化和共享内存优化了性能,适用于大规模数据并提高了 GPU 计算效率。
5. 补充
llama.cpp 中 RMSNorm 的具体实现如下:
template <int block_size>
static __global__ void rms_norm_f32(
const float * x, float * dst, const int ncols, const int64_t stride_row, const int64_t stride_channel,
const int64_t stride_sample, const float eps) {
const int nrows = gridDim.x;
const int nchannels = gridDim.y;
const int row = blockIdx.x;
const int channel = blockIdx.y;
const int sample = blockIdx.z;
const int tid = threadIdx.x;
x += sample*stride_sample + channel*stride_channel + row*stride_row;
dst += ((sample*nchannels + channel)*nrows + row)*ncols;
float tmp = 0.0f; // partial sum for thread in warp
for (int col = tid; col < ncols; col += block_size) {
const float xi = x[col];
tmp += xi * xi;
}
// sum up partial sums
tmp = warp_reduce_sum(tmp);
if constexpr (block_size > WARP_SIZE) {
static_assert(block_size == 1024, "unexpected block_size");
__shared__ float s_sum[32];
const int warp_id = threadIdx.x / WARP_SIZE;
const int lane_id = threadIdx.x % WARP_SIZE;
if (lane_id == 0) {
s_sum[warp_id] = tmp;
}
__syncthreads();
tmp = s_sum[lane_id];
tmp = warp_reduce_sum(tmp);
}
const float mean = tmp / ncols;
const float scale = rsqrtf(mean + eps);
for (int col = tid; col < ncols; col += block_size) {
dst[col] = scale * x[col];
}
}
static void rms_norm_f32_cuda(
const float * x, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, cudaStream_t stream) {
const dim3 blocks_num(nrows, nchannels, nsamples);
if (ncols < 1024) {
const dim3 block_dims(WARP_SIZE, 1, 1);
rms_norm_f32<WARP_SIZE><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
} else {
const dim3 block_dims(1024, 1, 1);
rms_norm_f32<1024><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
}
}
Note:具体代码在 llama.cpp/ggml/src/ggml-cuda/norm.cu#L108
核心计算步骤如下:(from ChatGPT)
1. CUDA 核函数模板定义
template <int block_size>
static __global__ void rms_norm_f32(
const float * x, float * dst, const int ncols, const int64_t stride_row, const int64_t stride_channel,
const int64_t stride_sample, const float eps) {
block_size
:模板参数x
:输入数据数组的起始指针dst
:输出数据数组的起始指针,用于存储归一化后的结果ncols
:当前向量的列数,即需要归一化的元素个数stride_row, stride_channel, stride_sample
:步长,用于在多维数据中定位正确元素eps
:防止除 0 的小常量,保证数值稳定性
2. 根据 CUDA grid 确定数据维度和 block 内线程索引
const int nrows = gridDim.x;
const int nchannels = gridDim.y;
const int row = blockIdx.x;
const int channel = blockIdx.y;
const int sample = blockIdx.z;
const int tid = threadIdx.x;
- 获取 grid 维度
gridDim.x
被定义为总行数(nrows)gridDim.y
为通道数(nchannels)
- 确定当前 block 的坐标
- 每个 block 对应一个特定的
(row, channel, sample)
,由blockIdx.x/y/z
分别确定 threadIdx.x
(赋值给tid
)标识 block 内当前线程的编号,用于后续分配工作
- 每个 block 对应一个特定的
3. 指针偏移:定位当前 block 要处理的数据
x += sample*stride_sample + channel*stride_channel + row*stride_row;
dst += ((sample*nchannels + channel)*nrows + row)*ncols;
- 输入数据指针
x
的偏移:根据sample
、channel
和row
以及对应的步长,将指针移到当前 block 所对应的向量起始位置 - 输出数据指针
dst
的偏移:按照数据排列顺序(先 sample,再 channel,最后 row)计算偏移,使得归一化后的结果写入正确的位置
4. 计算平方和:每个线程的局部累加
float tmp = 0.0f; // partial sum for thread in warp
for (int col = tid; col < ncols; col += block_size) {
const float xi = x[col];
tmp += xi * xi;
}
- 局部累加变量
tmp
:初始化为 0,用于累加每个线程所负责元素的平方 - 循环遍历列元素:循环变量
col
从当前线程索引tid
开始,每次增加步长block_size
,确保整个向量被所有线程共同覆盖 - 对于每个元素,计算其平方
xi * xi
并累加到tmp
中
5. 在 Warp 内进行归约求和
tmp = warp_reduce_sum(tmp);
- warp 内归约:调用
warp_reduce_sum
(通常基于 CUDA 的 warp shuffle 操作实现)对同一个 warp 内所有线程的tmp
值进行求和,得到该 warp 部分的平方和
warp_redece_sum
的实现代码在 ggml-cuda/common.cuh#L224,如下所示:
static __device__ __forceinline__ float warp_reduce_sum(float x) {
#pragma unroll
for (int offset = 16; offset > 0; offset >>= 1) {
x += __shfl_xor_sync(0xffffffff, x, offset, 32);
}
return x;
}
warp_reduce_sum
函数的目的是对一个 warp 内的 32 个线程进行并行归约(求和),并利用 CUDA 的 __shfl_xor_sync
操作来高效地进行线程间数据交换。下面是该函数代码的具体解析:(from ChatGPT)
static __device__ __forceinline__ float warp_reduce_sum(float x) {
__device__
:函数修饰符表示该函数只能在 device(GPU)上调用__forceinline__
:强制内联,提示编译器尽量将函数展开到调用点,以减少函数调用的开销- 参数
x
:表示传入的线程的局部值
#pragma unroll
for (int offset = 16; offset > 0; offset >>= 1) {
x += __shfl_xor_sync(0xffffffff, x, offset, 32);
}
#pragma unroll
:这个编译指令提示编译器尽可能将循环展开(unroll),这意味着循环的每一轮不会真正有跳转,而是会将每次迭代展开为不同的语句。offset = 16
:offset
从 16 开始,逐步右移(offset >>= 1
),直到为 0,每次循环将offset
减半。具体地:- 第一次循环,
offset = 16
- 第二次循环,
offset = 8
- 第三次循环,
offset = 4
- 第四次循环,
offset = 2
- 第五次循环,
offset = 1
- 这样做的目的是实现逐步的并行归约,将每次迭代操作范围缩小一半。
- 第一次循环,
__shfl_xor_sync
操作:利用__shfl_xor_sync
来交换线程间的数据。__shfl_xor_sync
是 CUDA 的一个低级函数,用于在同一个 warp 中进行线程间的数据传递
x += __shfl_xor_sync(0xffffffff, x, offset, 32);
0xffffffff
:这是掩码,表示与所有线程通信(即 warp 中的所有线程)x
:当前线程持有的数据offset
:指示与哪些线程交换数据。具体来说,offset
控制着线程之间的数据交换步长。它的值从 16 开始,逐渐变为 8、4、2、1。每次通过offset
来确定与哪些线程交换数据32
:表示 warp 的线程数为 32。__shfl_xor_sync
会在 32 个线程之间进行数据交换
__shfl_xor_sync
函数通过异或操作来交换数据,它基于每个线程的 lane_id
来确定与哪些线程交换数据。
假设线程编号为 0 到 31(lane_id
范围 0~31),那么 __shfl_xor_sync
会在每个线程间以步长为 offset
的方式进行数据交换:
- 第一次:
offset = 16
,线程 0 和 16、线程 1 和 17、线程 2 和 18 … 线程 15 和 31 进行数据交换 - 第二次:
offset = 8
,线程 0 和 8、线程 1 和 9、线程 2 和 10 … 进行数据交换 - 第三次:
offset = 4
,线程 0 和 4、线程 1 和 5、线程 2 和 6 … 进行数据交换 - 第四次:
offset = 2
,线程 0 和 2、线程 1 和 3、线程 4 和 6 … 进行数据交换 - 第五次:
offset = 1
,线程 0 和 1、线程 2 和 3、线程 4 和 5 … 进行数据交换
每次交换数据时,线程会将当前持有的值和通过异或交换得到的线程的数据加起来,实现归约求和。
return x;
- 返回的
x
是所有线程在 warp 内的部分和,归约后的最终结果
对于利用
__shfl_xor_sync
实现 warp 归约和计算的代码感觉似懂非懂,一个 Warp 内所有线程都会得到相同的归约结果吗?🤔
我们以一个较小的例子(比如 4 个线程)来说明这种归约方式的效果:
假设 warp 内有 4 个线程(编号 0、1、2、3),初始时每个线程的值分别为 a 0 , a 1 , a 2 , a 3 a_0,a_1,a_2,a_3 a0,a1,a2,a3
第一轮(设 offset = 2):
- 线程 0:读取线程
0
⊕
2
=
2
0\oplus2=2
0⊕2=2 的值,计算:
- x 0 = a 0 + a 2 x_0 = a_0+a_2 x0=a0+a2
- 线程 1:读取线程
1
⊕
2
=
3
1\oplus2=3
1⊕2=3 的值,计算:
- x 1 = a 1 + a 3 x_1 = a_1+a_3 x1=a1+a3
- 线程 2:读取线程
2
⊕
2
=
0
2\oplus2=0
2⊕2=0 的值,计算:
- x 2 = a 2 + a 0 x_2 = a_2+a_0 x2=a2+a0
- 线程 3:读取线程
3
⊕
2
=
1
3\oplus2=1
3⊕2=1 的值,计算:
- x 3 = a 3 + a 1 x_3 = a_3+a_1 x3=a3+a1
此时,线程 0 和线程 2 均拥有 a 0 + a 2 a_0+a_2 a0+a2;线程 1 和线程 3 均拥有 a 1 + a 3 a_1+a_3 a1+a3
第二轮(设 offset = 1):
- 线程 0:读取线程
0
⊕
1
=
1
0\oplus1=1
0⊕1=1 的值,即
x
1
=
a
1
+
a
3
x_1 = a_1+a_3
x1=a1+a3,更新:
- x 0 = ( a 0 + a 2 ) + ( a 1 + a 3 ) x_0=(a_0+a_2)+(a_1+a_3) x0=(a0+a2)+(a1+a3)
- 线程 1:读取线程
1
⊕
1
=
0
1\oplus1=0
1⊕1=0 的值,即
x
0
=
a
0
+
a
2
x_0 = a_0+a_2
x0=a0+a2,更新:
- x 1 = ( a 1 + a 3 ) + ( a 0 + a 2 ) x_1=(a_1+a_3)+(a_0+a_2) x1=(a1+a3)+(a0+a2)
- 线程 2:读取线程
2
⊕
1
=
3
2\oplus1=3
2⊕1=3 的值,即
x
3
=
a
3
+
a
1
x_3 = a_3+a_1
x3=a3+a1,更新:
- x 2 = ( a 0 + a 2 ) + ( a 3 + a 1 ) x_2=(a_0+a_2)+(a_3+a_1) x2=(a0+a2)+(a3+a1)
- 线程 3:读取线程
3
⊕
1
=
2
3\oplus1=2
3⊕1=2 的值,即
x
2
=
a
2
+
a
0
x2 = a_2+a_0
x2=a2+a0,更新:
- x 3 = ( a 1 + a 3 ) + ( a 2 + a 0 ) x_3 = (a_1+a_3)+(a_2+a_0) x3=(a1+a3)+(a2+a0)
最后,每个线程都得到了 a 0 + a 1 + a 2 + a 3 a_0+a_1+a_2+a_3 a0+a1+a2+a3
6. 跨 Warp 汇总
if constexpr (block_size > WARP_SIZE) {
static_assert(block_size == 1024, "unexpected block_size");
__shared__ float s_sum[32];
const int warp_id = threadIdx.x / WARP_SIZE;
const int lane_id = threadIdx.x % WARP_SIZE;
if (lane_id == 0) {
s_sum[warp_id] = tmp;
}
__syncthreads();
tmp = s_sum[lane_id];
tmp = warp_reduce_sum(tmp);
}
在前面步骤 5 中我们已经对同一 block 内各个线程利用 warp_reduce_sum
完成 warp 内部分归约求和了,但对于一个较大的 block(例如 1024 个线程,即 32 个 warp),我们需要将各个 warp 的部分和进一步合并,得到整个 block 的总和。由于 warp 内的归约只能在单个 warp 内完成,因此需要借助共享内存把每个 warp 的结果收集起来,再通过一次 warp 内的归约计算全局和。
下面是代码逐行解析:
if constexpr (block_size > WARP_SIZE) {
- 判断条件:该条件保证只有当线程块大小大于单个 warp 的大小(通常 32)时才执行以下代码。也就是说,只有当 block 包含多个 warp 时才需要进行跨 warp 的汇总。
static_assert(block_size == 1024, "unexpected block_size");
- 静态断言
__shared__ float s_sum[32];
- 共享内存数组:声明一个大小为 32 的共享内存数组,用来存储每个 warp 的局部归约结果。因为 1024 线程中有 32 个 warp,所以每个 warp 的结果将存放在
s_sum
的一个元素中
const int warp_id = threadIdx.x / WARP_SIZE;
const int lane_id = threadIdx.x % WARP_SIZE;
- 计算 warp 内部索引
warp_id
:将当前线程的索引除以 32 得到该线程所在 warp 的编号。比如,线程编号 45 则属于第 1 个 warp(45/32 = 1)lane_id
:当前线程在所在 warp 内的位置,取余后得到的值在 0~31 之间。这个值用于在 warp 内做归约或决定特定线程执行某些操作。
if (lane_id == 0) {
s_sum[warp_id] = tmp;
}
- 每个 warp 写入共享内存:每个 warp 在前面已经通过
warp_reduce_sum
完成了本 warp 的局部归约,归约值取在 warp 内的第 0 号线程(lane_id == 0
)的结果 if (lane_id == 0)
判断表示:每个 warp 仅由其第 0 个线程将本 warp 计算出的部分和tmp
写入共享内存数组对应的位置s_sum[warp_id]
- 这样,32 个 warp 的结果就被收集到了
s_sum[0]
到s_sum[31]
中
__syncthreads();
- 线程同步:使用
__syncthreads()
确保所有线程都完成了对共享内存的写入操作
tmp = s_sum[lane_id];
- 从共享内存读取归约值:每个线程都根据自己的
lane_id
从共享内存中读取一个值 - 每个线程的
lane_id
都在 0~31 内,而共享内存数组s_sum
正好有 32 个元素,对应各个 warp 的结果 - 因此,不同的线程(无论属于哪个 warp)都会根据自己的
lane_id
读取同一位置的数据。例如,所有lane_id == 2
的线程会读取s_sum[2]
tmp = warp_reduce_sum(tmp);
- 跨 Warp 归约:虽然 block 内有 1024 个线程,但经过上一行的赋值,每个线程持有的是来自共享内存、与自己
lane_id
对应的部分结果 - 因为每个 warp 内所有线程的
lane_id
分布从 0 到 31,一共有 32 个唯一值 - 接下来,在每个 warp 内调用
warp_reduce_sum
会对这 32 个值进行归约(求和),最终每个 warp 内所有线程得到相同的全局总和,也就是所有 warp 部分和的和 - 虽然每个 warp 都会独立完成这一步,但实际上只需要其中一个 warp 的结果即可
7. 计算归一化因子并应用归一化
const float mean = tmp / ncols;
const float scale = rsqrtf(mean + eps);
for (int col = tid; col < ncols; col += block_size) {
dst[col] = scale * x[col];
}
}
- 计算均值:将归约得到的总平方和除以列数
ncols
,得到所有元素平方的均值,即 M e a n ( x 2 ) = 1 N ∑ i = 1 N x i 2 Mean(x^{2}) =\frac{1}{N}\sum_{i=1}^Nx_i^2 Mean(x2)=N1∑i=1Nxi2 - 计算缩放因子
scale
:利用rsqrtf
计算(mean + eps)
的倒数平方根,即 s c a l e = 1 1 N ∑ x i 2 + ϵ scale = \frac{1}{\sqrt{\frac{1}{N}\sum x_i^2+\epsilon}} scale=N1∑xi2+ϵ1 - 归一化每个元素:再次利用循环(与之前遍历列的方式相同),每个线程处理分配的列,将对应的输入值乘以
scale
后写入输出数组dst
值得注意的是 rms_norm_f32
这个 kernel 只是在计算 RMSNorm 的前一部分,之后还需要通过如下 kernel 即 scale_f32
乘上
γ
\gamma
γ,就得到了完整的 RMSNorm
static __global__ void scale_f32(const float * x, float * dst, const float scale, const int k) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;
if (i >= k) {
return;
}
dst[i] = scale * x[i];
}
static void scale_f32_cuda(const float * x, float * dst, const float scale, const int k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_SCALE_BLOCK_SIZE - 1) / CUDA_SCALE_BLOCK_SIZE;
scale_f32<<<num_blocks, CUDA_SCALE_BLOCK_SIZE, 0, stream>>>(x, dst, scale, k);
}
scale_f32
算子比较简单就是挨个元素乘上对应的
γ
\gamma
γ,其实 rms_norm_f32
和 scale_f32
可以合并成一个 kernel
Note:llama.cpp 是一个大模型推理框架,大家如果想了解更多可以参考:llama.cpp源码解析–CUDA流程版本
总结
这节课我们学习了 RMSNorm 算子的实现,KuiperLLama 中 RMSNorm CPU 后端的实现是通过 Armadillo 库来实现的,CUDA 后端的实现是使用
float4
减少内存访存,并使用 CUB 库的BlockReduce
高效实现求和。llama.cpp 中 RMSNorm CUDA 后端的实现则是通过warp_reduce_sum
函数先在 warp 内归约计算平方和,接着通过 shared memory 和跨 warp 归约得到整个向量的总平方和OK,以上就是第 5 次课程的全部内容了,下节我们将去学习使用 Nsight compute 对 CUDA 算子进行调优 ,敬请期待😄
参考
- https://github.com/zjhellofss/KuiperLLama
- CUDA编程入门极简教程
- https://github.com/ggerganov/llama.cpp
- llama.cpp源码解析–CUDA流程版本