手写svm primal form形式
svm.py
import numpy as np
class SVM:
def __init__(self,C=1.0,lr=0.01,batch_size=32,epochs=100):
self.C=C
self.lr=lr
self.batch_size=batch_size
self.epochs=epochs
self.w=None
self.b=0.0
self.epoch=0
#计算最高得分和对应w,b
def fit(self,X,y,X_val=None,y_val=None):
sample,feature=X.shape
self.w=np.zeros(feature)
self.b=0.0
best_score=-np.inf
#best_w=self.w 错误
best_w=self.w.copy()
best_b=self.b
best_epoch=0
for epoch in range(self.epochs):
#打乱顺序
shu_index=np.random.permutation(sample)
shu_X=X[shu_index]
shu_y=y[shu_index]
for i in range(0,sample,self.batch_size):
end=i+self.batch_size
#第x个批量
x_batch=shu_X[i:end]
y_batch=shu_y[i:end]
dw,db=self.com_gradient(x_batch,y_batch)
self.w-=self.lr*dw
self.b-=self.lr*db
if X_val is not None and y_val is not None:
y_pred=self.predict(X_val)
#np.mean(x,y)错误
score=np.mean(y_pred==y_val)
if score>best_score:
best_score=score
best_w=self.w.copy()
best_b=self.b
#best_epoch=self.epoch 错误
best_epoch=epoch
print(f"第{epoch+1}轮训练,准确率为:{score:.4f}")
if X_val is not None and y_val is not None:
self.w=best_w
self.b=best_b
self.epoch=best_epoch
def com_gradient(self,X_batch,y_batch):
n=X_batch.shape[0]
dw_hinge=np.zeros_like(self.w)
db_hinge=0.0
for i in range(n):
xi=X_batch[i]
yi=y_batch[i]
#margin=yi*np.dot(xi,self.w)+self.b 注意是xi
margin=yi*np.dot(xi,self.w)+self.b
if margin<1:
dw_hinge+=-yi*xi
db_hinge+=-yi
#注意 是计算完n个样本的dw_hinge才算dw
dw=self.w+(self.C/n)*dw_hinge
db=(self.C/n)*db_hinge
return dw,db
def predict(self,X):
linear=np.dot(X,self.w)+self.b
return np.sign(linear)
def evaluate(self,X,y):
y_true=y
y_pre=self.predict(X)
#注意是标签是-1和1,而非0,1
tp=np.sum((y_pre==1)&(y_true==1))
fp=np.sum((y_pre==1)&(y_true==-1))
tn=np.sum((y_pre==-1)&(y_true==-1))
fn=np.sum((y_pre==-1)&(y_true==1))
accuracy=(tp+tn)/(tp+tn+fp+fn)
precision=tp/(tp+fp) if tp+fp!=0 else 0
recall=tp/(tp+fn) if tp+fn!=0 else 0
f1=(2*precision*recall)/(precision+recall) if precision+recall!=0 else 0
#注意字典的键值对xx:xx
return{
'accuracy':accuracy,
'precision':precision,
'recall':recall,
'f1':f1
}
def save_weight(self,filename):
#注意w和b要保存进文件
np.savez(filename,w=self.w,b=self.b,epoch=self.epoch,C=self.C,lr=self.lr,batch_size=self.batch_size,epochs=self.epochs)
@classmethod
def load_weight(cls,filename):
data=np.load(filename)
svm=cls(C=data['C'],lr=data['lr'],batch_size=data['batch_size'],epochs=data['epochs'])
svm.w=data['w']
svm.b=data['b']
svm.epoch=data['epoch']
return svm
train.py
import numpy as np
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from joblib import dump
from svm import SVM
data=datasets.load_breast_cancer()
X=data.data
y=data.target
y=np.where(y==0,-1,1)
X_train,X_test,y_train,y_test=train_test_split(X,y,test_size=0.2,random_state=42)
X_train,X_val,y_train,y_val=train_test_split(X_train,y_train,test_size=0.25,random_state=42)
scaler=StandardScaler()
X_train=scaler.fit_transform(X_train)
X_val=scaler.transform(X_val)
X_test=scaler.transform(X_test)
dump(scaler,'scaler.joblib')
#最佳准确率以及最佳模型
best_accu=-np.inf
best_model=None
C_values=[0.1,1,10,100]
for C in C_values:
print(f"开始C:{C}")
model=SVM(C=C,lr=0.01,batch_size=32,epochs=100)
model.fit(X_train,y_train,X_val,y_val)
#注意要评估X_val,y_val的得分,传参
m_metrics=model.evaluate(X_val,y_val)
if m_metrics['accuracy']>best_accu:
#注意m_metrics['accuracy']传参
best_accu=m_metrics['accuracy']
best_model=model
best_model.save_weight("best_weight.npz")
print(f"最优C:{best_model.C}")
print(f"最优C对应的epoch:{best_model.epoch+1}")
test.py
import numpy as np
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from svm import SVM
from joblib import load
data=datasets.load_breast_cancer()
X=data.data
y=data.target
y=np.where(y==0,-1,1)
_,X_test,_,y_test=train_test_split(X,y,test_size=0.2,random_state=42)
scaler=load('scaler.joblib')
X_test=scaler.transform(X_test)
model=SVM.load_weight('best_weight.npz')
print(f"C:{model.C}")
print(f"最优C的epoch:{model.epoch+1}")
t_metrics=model.evaluate(X_test,y_test)
print(f"Accuracy:{t_metrics['accuracy']:.4f},Precision:{t_metrics['precision']:.4f},Recall:{t_metrics['recall']:.4f},f1分数:{t_metrics['f1']:.4f}")