RNN实现阿尔茨海默症的诊断识别
本文为为🔗365天深度学习训练营内部文章
原作者:K同学啊
一 导入数据
import torch.nn as nn
import torch.nn.functional as F
import torchvision,torch
from sklearn.preprocessing import StandardScaler
from torch.utils.data import TensorDataset,DataLoader
import numpy as np
import pandas as pd
import seaborn as sns
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
plt.rcParams['font.sans-serif'] = ['SimHei']
import warnings
warnings.filterwarnings('ignore')
# 设置硬件设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
df = pd.read_excel('dia.xls')
df
二 数据处理分析
# 删除第一列和最后一列
df = df.iloc[:,1:-1]
print(df)
Age Gender Ethnicity EducationLevel BMI Smoking 0 73 0 0 2 22.927749 0 \ 1 89 0 0 0 26.827681 0 2 73 0 3 1 17.795882 0 3 74 1 0 1 33.800817 1 4 89 0 0 0 20.716974 0 ... ... ... ... ... ... ... 2144 61 0 0 1 39.121757 0 2145 75 0 0 2 17.857903 0 2146 77 0 0 1 15.476479 0 2147 78 1 3 1 15.299911 0 2148 72 0 0 2 33.289738 0 AlcoholConsumption PhysicalActivity DietQuality SleepQuality ... 0 13.297218 6.327112 1.347214 9.025679 ... \ 1 4.542524 7.619885 0.518767 7.151293 ... 2 19.555085 7.844988 1.826335 9.673574 ... 3 12.209266 8.428001 7.435604 8.392554 ... 4 18.454356 6.310461 0.795498 5.597238 ... ... ... ... ... ... ... 2144 1.561126 4.049964 6.555306 7.535540 ... 2145 18.767261 1.360667 2.904662 8.555256 ... 2146 4.594670 9.886002 8.120025 5.769464 ... 2147 8.674505 6.354282 1.263427 8.322874 ... 2148 7.890703 6.570993 7.941404 9.878711 ... FunctionalAssessment MemoryComplaints BehavioralProblems ADL 0 6.518877 0 0 1.725883 \ 1 7.118696 0 0 2.592424 2 5.895077 0 0 7.119548 3 8.965106 0 1 6.481226 4 6.045039 0 0 0.014691 ... ... ... ... ... 2144 0.238667 0 0 4.492838 2145 8.687480 0 1 9.204952 2146 1.972137 0 0 5.036334 2147 5.173891 0 0 3.785399 2148 6.307543 0 1 8.327563 Confusion Disorientation PersonalityChanges 0 0 0 0 \ 1 0 0 0 2 0 1 0 3 0 0 0 4 0 0 1 ... ... ... ... 2144 1 0 0 2145 0 0 0 2146 0 0 0 2147 0 0 0 2148 0 1 0 DifficultyCompletingTasks Forgetfulness Diagnosis 0 1 0 0 1 0 1 0 2 1 0 0 3 0 0 0 4 1 0 0 ... ... ... ... 2144 0 0 1 2145 0 0 1 2146 0 0 1 2147 0 1 1 2148 0 1 0 [2149 rows x 33 columns]
三 探索性数据分析
1.得病分布
res = df.groupby('Diabetes')['Age'].count()
print(res)
plt.figure(figsize=(8, 6))
plt.pie(res.values, labels=res.index, autopct='%1.1f%%', startangle=90,
colors=['#ff9999','#66b3ff','#99ff99'], explode=(0.1, 0),
wedgeprops={'edgecolor': 'black', 'linewidth': 1, 'linestyle': 'solid'})
plt.title('是否得阿尔茨海默症', fontsize=16, fontweight='bold')
plt.show()
2.BMI分布直方图
# BMI分布直方图
sns.displot(df['BMI'], kde=True, color='skyblue', bins=30, height=6, aspect=1.2)
plt.title('BMI Distribution', fontsize=18, fontweight='bold', color='darkblue')
plt.xlabel('BMI', fontsize=14, color='darkgreen')
plt.ylabel('Frequency', fontsize=14, color='darkgreen')
plt.grid(True, linestyle='--', alpha=0.7)
plt.show()
3.年龄分布直方图
# Age分布直方图
sns.displot(df['Age'], kde=True, color='skyblue', bins=30, height=6, aspect=1.2)
plt.title('Age Distribution', fontsize=18, fontweight='bold', color='darkblue')
plt.xlabel('Age', fontsize=14, color='darkgreen')
plt.ylabel('Frequency', fontsize=14, color='darkgreen')
plt.grid(True, linestyle='--', alpha=0.7)
plt.show()
四 构建划分数据集
X = df.iloc[:,:-1]
y = df.iloc[:,-1]
sc = StandardScaler()
X = sc.fit_transform(X)
# 划分数据集
X = torch.tensor(np.array(X),dtype=torch.float32)
y = torch.tensor(np.array(y),dtype=torch.int64)
X_train,X_test,y_train,y_test = train_test_split(X,y,test_size=0.1,random_state=1)
# 构建数据加载器
train_dl = DataLoader(TensorDataset(X_train,y_train),batch_size=64,shuffle=False)
test_dl = DataLoader(TensorDataset(X_test,y_test),batch_size=64,shuffle=False)
五 训练模型
1.构建模型
# 构建模型
class model_rnn(nn.Module):
def __init__(self):
super(model_rnn, self).__init__()
self.rnn0 = nn.RNN(input_size=32,hidden_size=200,num_layers=1,batch_first=True)
self.fc0 = nn.Linear(200,50)
self.fc1 = nn.Linear(50,2)
def forward(self,x):
out , hidden1 = self.rnn0(x)
out = self.fc0(out)
out = self.fc1(out)
return out
model = model_rnn().to(device)
print(model)
model_rnn( (rnn0): RNN(32, 200, batch_first=True) (fc0): Linear(in_features=200, out_features=50, bias=True) (fc1): Linear(in_features=50, out_features=2, bias=True) )
2.训练函数
'''
训练模型
'''
# 训练循环
def train(dataloader,model,loss_fn,optimizer):
size = len(dataloader.dataset) # 训练集的大小
num_batches = len(dataloader) # 批次数目,(size/batchsize,向上取整)
train_acc,train_loss = 0,0 # 初始化训练损失和正确率
for x,y in dataloader: # 获取数据
X,y = x.to(device),y.to(device)
# 计算预测误差
pred = model(X) # 网络输出
loss = loss_fn(pred,y) # 计算误差
# 反向传播
optimizer.zero_grad() # grad属性归零
loss.backward() # 反向传播
optimizer.step() # 每一步自动更新
# 记录acc与loss
train_acc += (pred.argmax(1) == y).type(torch.float).sum().item()
train_loss += loss.item()
train_acc /= size
train_loss /= num_batches
return train_acc,train_loss
3.测试函数
# 测试循环
def valid(dataloader,model,loss_fn):
size = len(dataloader.dataset) # 训练集的大小
num_batches = len(dataloader) # 批次数目,(size/batchsize,向上取整)
test_loss, test_acc = 0, 0 # 初始化训练损失和正确率
# 当不进行训练时,停止梯度更新,节省计算内存消耗
with torch.no_grad():
for imgs,target in dataloader:
imgs,target = imgs.to(device),target.to(device)
# 计算loss
target_pred = model(imgs)
loss = loss_fn(target_pred,target)
test_loss += loss.item()
test_acc += (target_pred.argmax(1) == target).type(torch.float).sum().item()
test_acc /= size
test_loss /= num_batches
return test_acc,test_loss
4.正式训练
loss_fn = nn.CrossEntropyLoss() # 创建损失函数
learn_rate = 1e-4 # 学习率
opt = torch.optim.Adam(model.parameters(),lr=learn_rate)
epochs = 30
train_loss = []
train_acc = []
test_loss = []
test_acc = []
for epoch in range(epochs):
model.train()
epoch_train_acc,epoch_train_loss = train(train_dl,model,loss_fn,opt)
model.eval()
epoch_test_acc,epoch_test_loss = valid(test_dl,model,loss_fn)
train_acc.append(epoch_train_acc)
train_loss.append(epoch_train_loss)
test_acc.append(epoch_test_acc)
test_loss.append(epoch_test_loss)
# 获取当前的学习率
lr = opt.state_dict()['param_groups'][0]['lr']
template = ('Epoch:{:2d},Train_acc:{:.1f}%,Train_loss:{:.3f},Test_acc:{:.1f}%,Test_loss:{:.3f},lr:{:.2E}')
print(template.format(epoch+1,epoch_train_acc*100,epoch_train_loss,epoch_test_acc*100,epoch_test_loss,lr))
print("="*20,'Done',"="*20)
Epoch: 1,Train_acc:52.9%,Train_loss:0.688,Test_acc:67.0%,Test_loss:0.658,lr:1.00E-04 Epoch: 2,Train_acc:68.7%,Train_loss:0.612,Test_acc:67.4%,Test_loss:0.600,lr:1.00E-04 Epoch: 3,Train_acc:68.7%,Train_loss:0.566,Test_acc:70.7%,Test_loss:0.567,lr:1.00E-04 Epoch: 4,Train_acc:74.4%,Train_loss:0.526,Test_acc:72.6%,Test_loss:0.533,lr:1.00E-04 Epoch: 5,Train_acc:77.9%,Train_loss:0.487,Test_acc:78.1%,Test_loss:0.501,lr:1.00E-04 Epoch: 6,Train_acc:81.1%,Train_loss:0.451,Test_acc:79.5%,Test_loss:0.473,lr:1.00E-04 Epoch: 7,Train_acc:82.3%,Train_loss:0.421,Test_acc:80.0%,Test_loss:0.451,lr:1.00E-04 Epoch: 8,Train_acc:83.4%,Train_loss:0.397,Test_acc:78.6%,Test_loss:0.434,lr:1.00E-04 Epoch: 9,Train_acc:84.7%,Train_loss:0.378,Test_acc:80.0%,Test_loss:0.422,lr:1.00E-04 Epoch:10,Train_acc:85.2%,Train_loss:0.365,Test_acc:80.0%,Test_loss:0.414,lr:1.00E-04 Epoch:11,Train_acc:85.6%,Train_loss:0.354,Test_acc:80.0%,Test_loss:0.408,lr:1.00E-04 Epoch:12,Train_acc:85.9%,Train_loss:0.347,Test_acc:80.0%,Test_loss:0.405,lr:1.00E-04 Epoch:13,Train_acc:86.3%,Train_loss:0.341,Test_acc:78.6%,Test_loss:0.403,lr:1.00E-04 Epoch:14,Train_acc:87.0%,Train_loss:0.335,Test_acc:78.1%,Test_loss:0.403,lr:1.00E-04 Epoch:15,Train_acc:87.1%,Train_loss:0.331,Test_acc:78.6%,Test_loss:0.404,lr:1.00E-04 Epoch:16,Train_acc:87.1%,Train_loss:0.327,Test_acc:78.1%,Test_loss:0.405,lr:1.00E-04 Epoch:17,Train_acc:87.1%,Train_loss:0.324,Test_acc:78.6%,Test_loss:0.407,lr:1.00E-04 Epoch:18,Train_acc:87.3%,Train_loss:0.321,Test_acc:78.6%,Test_loss:0.409,lr:1.00E-04 Epoch:19,Train_acc:87.4%,Train_loss:0.318,Test_acc:77.7%,Test_loss:0.412,lr:1.00E-04 Epoch:20,Train_acc:87.7%,Train_loss:0.315,Test_acc:78.1%,Test_loss:0.415,lr:1.00E-04 Epoch:21,Train_acc:87.8%,Train_loss:0.312,Test_acc:77.7%,Test_loss:0.418,lr:1.00E-04 Epoch:22,Train_acc:88.1%,Train_loss:0.309,Test_acc:78.1%,Test_loss:0.422,lr:1.00E-04 Epoch:23,Train_acc:88.6%,Train_loss:0.306,Test_acc:78.1%,Test_loss:0.425,lr:1.00E-04 Epoch:24,Train_acc:88.6%,Train_loss:0.303,Test_acc:79.1%,Test_loss:0.429,lr:1.00E-04 Epoch:25,Train_acc:88.6%,Train_loss:0.301,Test_acc:79.5%,Test_loss:0.433,lr:1.00E-04 Epoch:26,Train_acc:88.6%,Train_loss:0.298,Test_acc:79.5%,Test_loss:0.437,lr:1.00E-04 Epoch:27,Train_acc:88.8%,Train_loss:0.295,Test_acc:80.0%,Test_loss:0.440,lr:1.00E-04 Epoch:28,Train_acc:89.1%,Train_loss:0.292,Test_acc:79.5%,Test_loss:0.444,lr:1.00E-04 Epoch:29,Train_acc:89.1%,Train_loss:0.290,Test_acc:79.1%,Test_loss:0.449,lr:1.00E-04 Epoch:30,Train_acc:89.2%,Train_loss:0.287,Test_acc:79.1%,Test_loss:0.453,lr:1.00E-04 ==================== Done ====================
六 结果可视化
1.Loss和Acc图
epochs_range = range(30)
plt.figure(figsize=(14,4))
plt.subplot(1,2,1)
plt.plot(epochs_range,train_acc,label='training accuracy')
plt.plot(epochs_range,test_acc,label='validation accuracy')
plt.legend(loc='lower right')
plt.title('training and validation accuracy')
plt.subplot(1,2,2)
plt.plot(epochs_range,train_loss,label='training loss')
plt.plot(epochs_range,test_loss,label='validation loss')
plt.legend(loc='upper right')
plt.title('training and validation loss')
plt.show()
2.调用模型预测
test_X = X_test[0].reshape(1,-1)
pred = model(test_X.to(device)).argmax(1).item()
print('模型预测结果:',pred)
print('=='*20)
print('0:未患病')
print('1:已患病')
模型预测结果: 0 ======================================== 0:未患病 1:已患病
3.绘制混淆矩阵
'''
绘制混淆矩阵
'''
print('=============输入数据shape为==============')
print('X_test.shape:',X_test.shape)
print('y_test.shape:',y_test.shape)
pred = model(X_test.to(device)).argmax(1).cpu().numpy()
print('\n==========输出数据shape为==============')
print('pred.shape:',pred.shape)
from sklearn.metrics import confusion_matrix,ConfusionMatrixDisplay
# 计算混淆矩阵
cm = confusion_matrix(y_test,pred)
plt.figure(figsize=(6,5))
plt.suptitle('')
sns.heatmap(cm,annot=True,fmt='d',cmap='Blues')
plt.xticks(fontsize=10)
plt.yticks(fontsize=10)
plt.title('Confusion Matrix',fontsize=12)
plt.xlabel('Pred Label',fontsize=10)
plt.ylabel('True Label',fontsize=10)
plt.tight_layout()
plt.show()
=============输入数据shape为============== X_test.shape: torch.Size([215, 32]) y_test.shape: torch.Size([215]) ==========输出数据shape为============== pred.shape: (215,)