当前位置: 首页 > article >正文

【Andrej Karpathy 神经网络从Zero到Hero】--2.语言模型的两种实现方式 (Bigram 和 神经网络)

目录

  • 统计 Bigram 语言模型
    • 质量评价方法
  • 神经网络语言模型

【系列笔记】
【Andrej Karpathy 神经网络从Zero到Hero】–1. 自动微分autograd实践要点

本文主要参考 大神Andrej Karpathy 大模型讲座 | 构建makemore 系列之一:讲解语言建模的明确入门,演示

  1. 如何利用统计数值构建一个简单的 Bigram 语言模型
  2. 如何用一个神经网络来复现前面 Bigram 语言模型的结果,以此来展示神经网络相对于传统 n-gram 模型的拓展性。

统计 Bigram 语言模型

首先给定一批数据,每个数据是一个英文名字,例如:

['emma',
 'olivia',
 'ava',
 'isabella',
 'sophia',
 'charlotte',
 'mia',
 'amelia',
 'harper',
 'evelyn']

Bigram语言模型的做法很简单,首先将数据中的英文名字都做成一个个bigram的数据

其中每个格子中是对应的二元组,eg: “rh” ,在所有数据中出现的次数。那么一个自然的想法是对于给定的字母,取其对应的行,将次数归一化转成概率值,然后根据概率分布抽取下一个可能的字母:

g = torch.Generator().manual_seed(2147483647)
P = N.float() # N 即为上述 counts 矩阵
P = P / P.sum(1, keepdims=True) # P是每行归一化后的概率值

for i in range(5):
  
  out = []
  ix = 0  ## start符和end符都用 id=0 表示,这里是start
  while True:
    p = P[ix] # 当前字符为 ix 时,预测下一个字符的概率分布,实质是一个多项分布(即可能抽到的值有多个,eg: 掷色子是六项分布)
    ix = torch.multinomial(p, num_samples=1, replacement=True, generator=g).item()
    out.append(itos[ix])
    if ix == 0: ## 当运行到end符,停止生成
      break
  print(''.join(out))

输出类似于:

mor.
axx.
minaymoryles.
kondlaisah.
anchshizarie.

质量评价方法

我们还需要方法来评估语言模型的质量,一个直观的想法是:
P ( s 1 s 2 . . . s n ) = P ( s 1 ) P ( s 2 ∣ s 1 ) ⋯ P ( s n ∣ s n − 1 ) P(s_1s_2...s_n) = P(s_1)P(s_2|s_1)\cdots P(s_n|s_{n-1}) P(s1s2...sn)=P(s1)P(s2s1)P(snsn1)
但上述计算方式有一个问题,概率值都是小于1的,当序列的长度比较长时,上述数值会趋于0,计算时容易下溢。因此实践中往往使用 l o g ( P ) log(P) log(P)来代替,为了可以对比不同长度的序列的预测效果,再进一步使用 l o g ( P ) / n log(P)/n log(P)/n 表示一个序列平均的质量

上述统计 Bigram 模型在训练数据上的平均质量为:

log_likelihood = 0.0
n = 0

for w in words: # 所有word里的二元组概率叠加
  chs = ['.'] + list(w) + ['.']
  for ch1, ch2 in zip(chs, chs[1:]):
    ix1 = stoi[ch1]
    ix2 = stoi[ch2]
    prob = P[ix1, ix2]
    logprob = torch.log(prob)
    log_likelihood += logprob
    n += 1 # 所有word里的二元组数量之和

nll = -log_likelihood
print(f'{nll/n}') ## 值为 2.4764,表示前面做的bigram模型,对现有训练数据的置信度
                  ## 这个值越低表示当前模型越认可训练数据的质量,而由于训练数据是我们认为“好”的数据,因此反过来就说明这个模型好

但这里有一个问题是,例如:

log_likelihood = 0.0
n = 0

#for w in words:
for w in ["andrejz"]:
  chs = ['.'] + list(w) + ['.']
  for ch1, ch2 in zip(chs, chs[1:]):
    ix1 = stoi[ch1]
    ix2 = stoi[ch2]
    prob = P[ix1, ix2]
    logprob = torch.log(prob)
    log_likelihood += logprob
    n += 1
    print(f'{ch1}{ch2}: {prob:.4f} {logprob:.4f}')

print(f'{log_likelihood=}')
nll = -log_likelihood
print(f'{nll=}')
print(f'{nll/n}')

输出是

.a: 0.1377 -1.9829
an: 0.1605 -1.8296
nd: 0.0384 -3.2594
dr: 0.0771 -2.5620
re: 0.1336 -2.0127
ej: 0.0027 -5.9171
jz: 0.0000 -inf
z.: 0.0667 -2.7072
log_likelihood=tensor(-inf)
nll=tensor(inf)
inf

可以发现由于,jz 在计数矩阵 N 中为0,即数据中没有出现过,导致 log(loss) 变成了负无穷,这里为了避免这样的情况,需要做 平滑处理,即 P = N.float() 改成 P = (N+1).float(),这样上述代码输出变成:

.a: 0.1376 -1.9835
an: 0.1604 -1.8302
nd: 0.0384 -3.2594
dr: 0.0770 -2.5646
re: 0.1334 -2.0143
ej: 0.0027 -5.9004
jz: 0.0003 -7.9817
z.: 0.0664 -2.7122
log_likelihood=tensor(-28.2463)
nll=tensor(28.2463)
3.5307815074920654

避免了出现 inf 这种数据溢出问题。


神经网络语言模型

接下来尝试用神经网络的方式构建上述bigram语言模型:

# 构建训练数据
xs, ys = [], [] # 分别是前一个字符和要预测的下一个字符的id
for w in words[:5]:
  chs = ['.'] + list(w) + ['.']
  for ch1, ch2 in zip(chs, chs[1:]):
    ix1 = stoi[ch1]
    ix2 = stoi[ch2]
    print(ch1, ch2)
    xs.append(ix1)
    ys.append(ix2)    
    
xs = torch.tensor(xs)
ys = torch.tensor(ys)
# 输出示例:. e
#          e m
#          m m
#          m a
#          a .
#       xs: tensor([ 0,  5, 13, 13,  1])
#       ys: tensor([ 5, 13, 13,  1,  0])

# 随机初始化一个 27*27 的参数矩阵
g = torch.Generator().manual_seed(2147483647)
W = torch.randn((27, 27), generator=g, requires_grad=True) # 基于正态分布随机初始化
# 前向传播
import torch.nn.functional as F
xenc = F.one_hot(xs, num_classes=27).float() # 将输入数据xs做成one-hot embedding
logits = xenc @ W # 用于模拟统计模型中的统计数值矩阵,由于 W 是基于正态分布采样,logits 并非直接是计数值,可以认为是 log(counts)
## tensor([[-0.5288, -0.5967, -0.7431,  ...,  0.5990, -1.5881,  1.1731],
##        [-0.3065, -0.1569, -0.8672,  ...,  0.0821,  0.0672, -0.3943],
##        [ 0.4942,  1.5439, -0.2300,  ..., -2.0636, -0.8923, -1.6962],
##        ...,
##        [-0.1936, -0.2342,  0.5450,  ..., -0.0578,  0.7762,  1.9665],
##        [-0.4965, -1.5579,  2.6435,  ...,  0.9274,  0.3591, -0.3198],
##        [ 1.5803, -1.1465, -1.2724,  ...,  0.8207,  0.0131,  0.4530]])
counts = logits.exp() # 将 log(counts) 还原成可以看作是 counts 的矩阵
## tensor([[ 0.5893,  0.5507,  0.4756,  ...,  1.8203,  0.2043,  3.2321],
##        [ 0.7360,  0.8548,  0.4201,  ...,  1.0856,  1.0695,  0.6741],
##        [ 1.6391,  4.6828,  0.7945,  ...,  0.1270,  0.4097,  0.1834],
##        ...,
##        [ 0.8240,  0.7912,  1.7245,  ...,  0.9438,  2.1732,  7.1459],
##        [ 0.6086,  0.2106, 14.0621,  ...,  2.5279,  1.4320,  0.7263],
##        [ 4.8566,  0.3177,  0.2802,  ...,  2.2722,  1.0132,  1.5730]])
probs = counts / counts.sum(1, keepdims=True) # 用于模拟统计模型中的概率矩阵,这其实即是 softmax 的实现
loss = -probs[torch.arange(5), ys].log().mean() # loss = log(P)/n, 这其实即是 cross-entropy 的实现

接下来可以通过loss.backward()来更新参数 W:

for k in range(100):
  
  # forward pass
  xenc = F.one_hot(xs, num_classes=27).float() 
  logits = xenc @ W # predict log-counts
  counts = logits.exp()
  probs = counts / counts.sum(1, keepdims=True) 
  loss = -probs[torch.arange(num), ys].log().mean() + 0.01*(W**2).mean() ## 这里加上了L2正则,防止过拟合
  print(loss.item())
  
  # backward pass
  W.grad = None # 每次反向传播前置为None
  loss.backward()
  
  # update
  W.data += -50 * W.grad  

注意这里 logits = xenc @ W 由于 xenc 是 one-hot 向量,因此这里 logits 相当于是抽出了 W 中的某一行,而结合 bigram 模型中,loss 实际上是在计算实际的 log(P[x_i, y_i]),那么可以认为这里 W 其实是在拟合 bigram 中的计数矩阵 N(不过实际是 logW 在拟合 N)

另外上述神经网络的 loss 最终也是达到差不多 2.47 的最低 loss。这是合理的,因为从上面的分析可知,这个神经网络是完全在拟合 bigram 计数矩阵的,没有使用更复杂的特征提取方法,因此效果最终也会差不多。

这里 loss 中还加了一个 L2 正则,主要目的是压缩 W,使得它向全 0 靠近,这里的效果非常类似于 bigram 中的平滑手段,想象给一个极大的平滑:P = (N+10000).float()`,那么 P 会趋于一个均匀分布,而 W 全为 0 会导致 counts = logits.exp() 全为 1,即也在拟合一个均匀分布。这里前面的参数 0.01 即是用来调整平滑强度的,如果这个给的太大,那么平滑太大了,就会学成一个均匀分布(当然实际不会希望这样,所以不会给很大)


http://www.kler.cn/a/578442.html

相关文章:

  • MATLAB控制函数测试要点剖析
  • P8924 「GMOI R1-T1」Perfect Math Class 题解
  • STM32 内置的通讯协议
  • 在ubuntu20.4中如何创建一个虚拟环境(亲测有效)
  • 代码随想录-基础篇
  • .CSV file input into contact of outlook with gibberish. .csv文件导入outlook, 出现乱码
  • docker本地部署RagFlow
  • 0087.springboot325基于Java的企业OA管理系统的设计与实现+论文
  • Linux内核学习(一)——Vmware虚拟机安装Ubuntu20.4系统及QEMU模拟ARM64 Linux
  • 【星云 Orbit•STM32F4】13. 探索定时器:基本定时器
  • 生命周期总结(uni-app、vue2、vue3生命周期讲解)
  • 蓝桥杯 - 简单 - 工作协调
  • 如何在 Conda 环境中使用 PySide6 将 .ui 文件转换为 .py 文件
  • 【技海登峰】Kafka漫谈系列(六)Java客户端之消费者Consumer核心概念与客户端配置详解
  • 【文心索引】搜索引擎测试报告
  • Synchronized 锁升级过程
  • Python asyncIO 面试题及参考答案 草
  • linux---天气爬虫
  • 从0开始完成基于异步服务器的boost搜索引擎
  • Qt的QGraphics View的使用