使用统计方法在AMD GPU上使用JAX Profiler可靠地比较大型生成AI模型中的算法性能
本文提供了一份详细的指南,介绍如何在JAX实现的生成AI模型中测量和比较各种算法的性能。利用JAX Profiler和统计分析,本文展示了如何可靠地评估关键步骤并比较AMD GPU上算法的性能。
JAX是谷歌的一款开源数值计算库(尽管不是官方的谷歌产品),由于其能够利用硬件加速器和自动微分的能力,正在生成AI领域引起广泛关注。最初用于高性能机器学习研究,JAX的函数式编程方法和对GPU及TPU的支持使其成为构建和部署大型语言模型(LLMs)和其他前沿生成AI应用的首选。值得注意的是,像 X.AI这样的公司利用JAX开发开源模型如Grok-1,进一步推动了该库在生成AI领域的流行。凭借其性能、灵活性及其适合先进AI模型开发和部署的特点,JAX继续在受欢迎程度上不断攀升。
ROCm博客系列此前已探索过各种性能分析工具,如
可以用于在AMD GPU上分析模型性能,还有针对TensorFlow和PyTorch的框架特定性能分析工具。尽管JAX的官方页面涵盖了其性能分析工具的基本用法,本教程深入探讨了更高级的技术。例如,它解释了在评估算法时,如何在考虑到大量随机噪声的情况下确定一种算法是否显著优于另一种算法。本文通过统计分析和假设检验,展示了如何可靠地测量和比较在大型语言模型中执行相同步骤的不同算法的性能。具体而言,它比较了在JAX-based生成预训练变换器(GPT)模型的`CausalSelfAttention`组件中,使用`einsum`与`matmul`实现两个矩阵乘法步骤的性能。(参见博客中关于在JAX中实现GPT模型的文章)。要了解更多关于`einsum`的信息,请访问这篇博客。
要实现此代码示例,请首先设置ROCm环境,并安装必要的软件包和Python脚本。值得注意的是,该代码示例是平台无关的,这意味着只要加速计算平台和Python包配置正确,它就兼容AMD GPU以及其他GPU或TPU。
1. 在Linux shell中使用下面的代码拉取并运行docker容器:
docker run -it --ipc=host --network=host --device=/dev/kfd --device=/dev/dri \
--group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined \
--name=nanogpt rocm/pytorch:rocm6.1_ubuntu22.04_py3.10_pytorch_2.1.2 /bin/bash
2. 在docker容器内运行以下代码,以安装必要的Python包并配置XLA环境变量:
python3 -m pip install --upgrade pip
pip install optax==0.2.2 flax==0.8.2 transformers==4.38.2 tiktoken==0.6.0 datasets==2.17.1 perfetto==0.7.0 matplotlib==3.8.4 scipy==1.13.0
python3 -m pip install https://github.com/ROCmSoftwarePlatform/jax/releases/download/jaxlib-v0.4.26/jaxlib-0.4.26+rocm610-cp310-cp310-manylinux2014_x86_64.whl
python3 -m pip install https://github.com/ROCmSoftwarePlatform/jax/archive/refs/tags/jaxlib-v0.4.26.tar.gz
pip install numpy==1.22.0
export XLA_FLAGS="--xla_gpu_autotune_level=0"
使用以下命令从 ROCm/rocm-blogs
GitHub 存储库下载用于该博客的文件。
git clone https://github.com/ROCm/rocm-blogs.git
cd rocm-blogs/blogs/artificial-intelligence/nanoGPT-JAX
文件夹中的`model.py`和`sample.py`脚本替换为当前博客在GitHub上
class CausalSelfAttention(nn.Module):
config: GPTConfig
def __call__(self, x, train=False, rng1=None, rng2=None):
assert self.config.n_embd % self.config.n_head == 0
B, T, C = x.shape # batch size, sequence length, embedding dimensionality (n_embd)
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
q, k, v = jnp.split(nn.Dense(self.config.n_embd * 3, name="c_attn")(x), 3, axis=-1)
k = k.reshape(B, T, self.config.n_head, C // self.config.n_head).swapaxes(1, 2) # (B, nh, T, hs)
q = q.reshape(B, T, self.config.n_head, C // self.config.n_head).swapaxes(1, 2) # (B, nh, T, hs)
v = v.reshape(B, T, self.config.n_head, C // self.config.n_head).swapaxes(1, 2) # (B, nh, T, hs)
+ with jax.named_scope("attn_q_k"):
+ att = (jnp.einsum('bhts,bhqs->bhtq', q, k, optimize=True) if self.config.use_einsum else jnp.matmul(q, k.swapaxes(-2, -1))) * (1.0 / jnp.sqrt(k.shape[-1]))
- att = (jnp.einsum('bhts,bhqs->bhtq', q, k, optimize=True) if self.config.use_einsum else jnp.matmul(q, k.swapaxes(-2, -1))) * (1.0 / jnp.sqrt(k.shape[-1]))
mask = jnp.tril(jnp.ones((T, T))).reshape((1, 1, T, T))
att = jnp.where(mask == 0, float('-inf'), att)
att = nn.softmax(att, axis=-1)
att = nn.Dropout(self.config.dropout, name='attn_dropout', deterministic=not train)(att, rng=rng1)
+ with jax.named_scope("attn_att_v"):
+ y = jnp.einsum('bhts,bhsq->bhtq', att, v, optimize=True) if self.config.use_einsum else jnp.matmul(att, v) # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
- y = jnp.einsum('bhts,bhsq->bhtq', att, v, optimize=True) if self.config.use_einsum else jnp.matmul(att, v) # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
y = y.swapaxes(1, 2).reshape(B, T, C) # re-assemble all head outputs side by side
# output projection
y = nn.Dense(self.config.n_embd, name='c_proj')(y)
y = nn.Dropout(self.config.dropout, name='resid_dropout', deterministic=not train)(y, rng=rng2)
return y
for i in range(num_samples):
+ jax.profiler.start_trace(profile_dir+f'_{i}')
output = generate([jnp.array(start_ids)], seed+i)
+ jax.profiler.stop_trace()
print(f'\nGenerated output __{i}__: \n__________________________________\n{decode(output[0].tolist())}\n__________________________________')
# Generate profiling output using matmul
python sample.py --init_from='gpt2' --max_new_tokens=50 --start="The weather today is" --num_samples=10 --profile_dir="trace_file_matmul"
# Generate profiling output using einsum
python sample.py --init_from='gpt2' --max_new_tokens=50 --start="The weather today is" --num_samples=10 --profile_dir="trace_file_einsum" --override_args="{'use_einsum':True}"
for i in {0..9}; do
gzip -d trace_file_einsum_$i/plugins/profile/202*/*.json.gz
gzip -d trace_file_matmul_$i/plugins/profile/202*/*.json.gz
现在,你可以读取剖析数据并进行统计分析。对于每个迭代(对应于每种算法生成的一个样本),程序比较两种算法在矩阵乘法执行时间(以纳秒为单位)分布上的差异。可以使用箱线图来直观地检查差异。Wilcoxon秩和检验(Mann-Whitney U检验)用来确定位置参数(如均值和中位数)是否显著不同。较短的执行时间表示更好的性能。
import glob
from perfetto.trace_processor import TraceProcessor
from scipy.stats import ranksums
import matplotlib.pyplot as plt
def plot_boxplot(df1, df2, columns1, columns2=None, df1_lab='matmul', df2_lab='einsum'):
Plot boxplots for specified columns in two DataFrames. This function will
be used to compare the distribution of running time for the two algorithms
we profiled.
df1 (pandas.DataFrame): First DataFrame.
df2 (pandas.DataFrame): Second DataFrame.
columns1 (list): List of column names from the first DataFrame to plot.
columns2 (list): List of column names from the second DataFrame to plot.
df1_lab (string): Label for df1 in the plot.
df2_lab (string): Label for df2 in the plot.
if columns2 is None:
columns2 = columns1
# Combine data from both DataFrames
data = [df1[col] for col in columns1] + [df2[col] for col in columns2]
# Create labels for boxplots
labels = [df1_lab + '_' + col for col in columns1] + [df2_lab + '_' + col for col in columns2]
# Plot boxplots
plt.figure(figsize=(10, 6))
plt.boxplot(data, labels=labels)
plt.ylabel('Time in nanoseconds')
plt.title('Performance comparison on the scale of nanoseconds')
程序随后比较了每次样本生成迭代中两种算法的执行时间。它在SQL查询中使用`where display_value like "%attn_q_k%"来过滤在第一个
for i in range(1, 10):
# Process the profiling data for matmul
tp = TraceProcessor(trace=glob.glob(f'trace_file_matmul_{i}/plugins/profile/202*/*.json'))
# SQL query to get the operations enclosed by the named_scope
query_text='''INCLUDE PERFETTO MODULE slices.slices;
WITH arg_sets_0 AS (
SELECT DISTINCT arg_set_id, display_value
FROM args
WHERE key = 'args.name'
SELECT name, display_value, dur
FROM _slice_with_thread_and_process_info
INNER JOIN arg_sets_0 ON arg_sets_0.arg_set_id = _slice_with_thread_and_process_info.arg_set_id
where display_value like "%attn_q_k%"
# Query the profiling data and convert to dataframe
qr_matmul = tp.query(query_text).as_pandas_dataframe()
# Process the profiling data for einsum
tp = TraceProcessor(trace=glob.glob(f'trace_file_einsum_{i}/plugins/profile/202*/*.json'))
# Query the profiling data and convert to dataframe
qr_einsum = tp.query(query_text).as_pandas_dataframe()
# Print out the mean, standard dev. and shape for each algorithm
print(f'Matmul: Mean={qr_matmul.dur.mean()}, std. dev.={qr_matmul.dur.std()}, shape of df:{qr_matmul.shape}')
print(f'Einsum: Mean={qr_einsum.dur.mean()}, std. dev.={qr_einsum.dur.std()}, shape of df:{qr_einsum.shape}')
plot_boxplot(qr_matmul, qr_einsum, ['dur'])
stat, p = ranksums(qr_matmul['dur'], qr_einsum['dur'])
print(f'Test statistic={stat}, p_val={p}')
Matmul: Mean=6461.875, std. dev.=504.8818364954699, shape of df:(600, 3)
Einsum: Mean=5813.346666666666, std. dev.=455.80420754410954, shape of df:(600, 3)
Test statistic=20.22982266255362, p_val=5.349499343834845e-91
Matmul: Mean=6293.076666666667, std. dev.=514.1309448993132, shape of df:(600, 3)
Einsum: Mean=5797.615, std. dev.=397.86885546863283, shape of df:(600, 3)
Test statistic=16.932946075063718, p_val=2.5717953759559878e-64
for i in range(1, 10):
# Process the profiling data for matmul
tp = TraceProcessor(trace=glob.glob(f'trace_file_matmul_{i}/plugins/profile/202*/*.json'))
# SQL query to get the operations enclosed by the named_scope
query_text='''INCLUDE PERFETTO MODULE slices.slices;
WITH arg_sets_0 AS (
SELECT DISTINCT arg_set_id, display_value
FROM args
WHERE key = 'args.name'
SELECT name, display_value,dur
FROM _slice_with_thread_and_process_info
INNER JOIN arg_sets_0 ON arg_sets_0.arg_set_id = _slice_with_thread_and_process_info.arg_set_id
where display_value like "%attn_att_v%"
# Query the profiling data and convert to dataframe
qr_matmul = tp.query(query_text).as_pandas_dataframe()
# Process the profiling data for einsum
tp = TraceProcessor(trace=glob.glob(f'trace_file_einsum_{i}/plugins/profile/202*/*.json'))
# Query the profiling data and convert to dataframe
qr_einsum = tp.query(query_text).as_pandas_dataframe()
# Print out the mean, standard dev. and shape for each algorithm
print(f'Matmul: Mean={qr_matmul.dur.mean()}, std. dev.={qr_matmul.dur.std()}, shape of df:{qr_matmul.shape}')
print(f'Einsum: Mean={qr_einsum.dur.mean()}, std. dev.={qr_einsum.dur.std()}, shape of df:{qr_einsum.shape}')
plot_boxplot(qr_matmul, qr_einsum, ['dur'])
stat, p = ranksums(qr_matmul['dur'], qr_einsum['dur'])
print(f'Test statistic={stat}, p_val={p}')
Matmul: Mean=5204.543333333333, std. dev.=882.6151202759834, shape of df:(600, 3)
Einsum: Mean=6360.556666666666, std. dev.=373.461514250933, shape of df:(600, 3)
Test statistic=-21.986424230986046, p_val=3.884153635651101e-107
Matmul: Mean=5145.61, std. dev.=876.5247080600369, shape of df:(600, 3)
Einsum: Mean=6396.01, std. dev.=381.7892458942073, shape of df:(600, 3)
Test statistic=-22.450480914300588, p_val=1.2659476932444539e-111