当前位置: 首页 > article >正文

scikit-learn基本功能和示例代码

scikit-learn(简称sklearn)是一个广泛使用的Python机器学习库,提供了丰富的工具和算法,涵盖了数据预处理、模型训练、评估和优化等多个方面。
scikit-learn是一个功能强大的机器学习库,涵盖了数据预处理、分类、回归、聚类、降维、模型选择与评估等多个方面。通过上述代码示例,您可以快速上手并使用scikit-learn进行机器学习任务。以下是对scikit-learn主要功能的详细论述,并附上相关Python代码示例。

1. 数据预处理

数据预处理是机器学习流程中的重要步骤,scikit-learn提供了多种工具来处理数据。

1.1 数据清洗与缺失值处理

SimpleImputer 用于处理缺失值,可以用均值、中位数、众数等填充缺失值。

from sklearn.impute import SimpleImputer
import numpy as np

# 示例数据
X = [[1, 2], [np.nan, 3], [7, 6]]

# 使用均值填充缺失值
imputer = SimpleImputer(strategy="mean")
X_imputed = imputer.fit_transform(X)
print(X_imputed)
1.2 数据标准化

StandardScaler 用于标准化数据,使其均值为0,方差为1。MinMaxScaler 用于归一化数据,使其值在0到1之间。

from sklearn.preprocessing import StandardScaler, MinMaxScaler

# 示例数据
X = [[1, 2], [3, 4], [5, 6]]

# 标准化
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)
print(X_scaled)

# 归一化
minmax_scaler = MinMaxScaler()
X_normalized = minmax_scaler.fit_transform(X)
print(X_normalized)
1.3 编码分类特征

OneHotEncoder 用于将分类特征转换为独热编码。

from sklearn.preprocessing import OneHotEncoder

# 示例数据
X = [['cat'], ['dog'], ['cat'], ['bird']]

# 独热编码
encoder = OneHotEncoder()
X_encoded = encoder.fit_transform(X).toarray()
print(X_encoded)
1.4 生成多项式特征

PolynomialFeatures 用于生成多项式特征,适用于非线性模型。

from sklearn.preprocessing import PolynomialFeatures

# 示例数据
X = [[1, 2], [3, 4]]

# 生成二次多项式特征
poly = PolynomialFeatures(degree=2)
X_poly = poly.fit_transform(X)
print(X_poly)

2. 分类算法

scikit-learn提供了多种分类算法,适用于不同的分类任务。

2.1 支持向量机(SVM)
from sklearn.svm import SVC
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

# 加载数据集
X, y = load_iris(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

# 训练SVM模型
model = SVC(kernel='linear')
model.fit(X_train, y_train)

# 预测
y_pred = model.predict(X_test)
print("Accuracy:", accuracy_score(y_test, y_pred))
2.2 决策树
from sklearn.tree import DecisionTreeClassifier

# 训练决策树模型
model = DecisionTreeClassifier()
model.fit(X_train, y_train)

# 预测
y_pred = model.predict(X_test)
print("Accuracy:", accuracy_score(y_test, y_pred))
2.3 随机森林
from sklearn.ensemble import RandomForestClassifier

# 训练随机森林模型
model = RandomForestClassifier(n_estimators=100)
model.fit(X_train, y_train)

# 预测
y_pred = model.predict(X_test)
print("Accuracy:", accuracy_score(y_test, y_pred))
2.4 K近邻(KNN)
from sklearn.neighbors import KNeighborsClassifier

# 训练KNN模型
model = KNeighborsClassifier(n_neighbors=3)
model.fit(X_train, y_train)

# 预测
y_pred = model.predict(X_test)
print("Accuracy:", accuracy_score(y_test, y_pred))

3. 回归算法

scikit-learn提供了多种回归算法,用于预测连续值。

3.1 线性回归
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error

# 示例数据
X = [[1], [2], [3], [4]]
y = [1, 2, 3, 4]

# 训练线性回归模型
model = LinearRegression()
model.fit(X, y)

# 预测
y_pred = model.predict(X)
print("MSE:", mean_squared_error(y, y_pred))
3.2 多项式回归
from sklearn.preprocessing import PolynomialFeatures
from sklearn.pipeline import make_pipeline

# 生成多项式特征并训练模型
model = make_pipeline(PolynomialFeatures(degree=2), LinearRegression())
model.fit(X, y)

# 预测
y_pred = model.predict(X)
print("MSE:", mean_squared_error(y, y_pred))

4. 聚类算法

scikit-learn提供了多种聚类算法,适用于无监督学习任务。

4.1 K-means
from sklearn.cluster import KMeans

# 示例数据
X = [[1, 2], [1, 4], [1, 0], [4, 2], [4, 4], [4, 0]]

# 训练K-means模型
kmeans = KMeans(n_clusters=2, random_state=0).fit(X)
print("Cluster labels:", kmeans.labels_)
4.2 DBSCAN
from sklearn.cluster import DBSCAN

# 训练DBSCAN模型
dbscan = DBSCAN(eps=1, min_samples=2).fit(X)
print("Cluster labels:", dbscan.labels_)

5. 降维算法

降维算法可以帮助减少数据维度,提高可视化效果和模型性能。

5.1 主成分分析(PCA)
from sklearn.decomposition import PCA

# 示例数据
X = [[1, 2], [3, 4], [5, 6]]

# 降维
pca = PCA(n_components=1)
X_reduced = pca.fit_transform(X)
print(X_reduced)
5.2 t-SNE
from sklearn.manifold import TSNE

# 降维
tsne = TSNE(n_components=2)
X_reduced = tsne.fit_transform(X)
print(X_reduced)

6. 模型选择与评估

scikit-learn提供了多种工具来选择和评估模型。

6.1 交叉验证
from sklearn.model_selection import cross_val_score

# 交叉验证
scores = cross_val_score(model, X, y, cv=5)
print("Cross-validation scores:", scores)
6.2 网格搜索
from sklearn.model_selection import GridSearchCV

# 参数网格
param_grid = {'n_neighbors': [3, 5, 7]}

# 网格搜索
grid_search = GridSearchCV(KNeighborsClassifier(), param_grid, cv=5)
grid_search.fit(X, y)
print("Best parameters:", grid_search.best_params_)
6.3 评估指标
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

# 计算评估指标
print("Accuracy:", accuracy_score(y_test, y_pred))
print("Precision:", precision_score(y_test, y_pred, average='macro'))
print("Recall:", recall_score(y_test, y_pred, average='macro'))
print("F1 Score:", f1_score(y_test, y_pred, average='macro'))

7. 高级功能

scikit-learn还提供了一些高级功能,如管道和集成方法。

7.1 管道(Pipeline)
from sklearn.pipeline import Pipeline

# 创建管道
pipeline = Pipeline([
    ('scaler', StandardScaler()),
    ('svm', SVC())
])

# 训练模型
pipeline.fit(X_train, y_train)

# 预测
y_pred = pipeline.predict(X_test)
print("Accuracy:", accuracy_score(y_test, y_pred))
7.2 集成方法
from sklearn.ensemble import BaggingClassifier, AdaBoostClassifier

# Bagging
bagging = BaggingClassifier(base_estimator=SVC(), n_estimators=10)
bagging.fit(X_train, y_train)

# Boosting
boosting = AdaBoostClassifier(base_estimator=DecisionTreeClassifier(), n_estimators=10)
boosting.fit(X_train, y_train)

8. 经典数据集

scikit-learn提供了一些经典数据集,方便初学者快速上手。

from sklearn.datasets import load_iris, load_breast_cancer

# 加载鸢尾花数据集
iris = load_iris()
X, y = iris.data, iris.target

# 加载乳腺癌数据集
cancer = load_breast_cancer()
X, y = cancer.data, cancer.target

9.综合应用

1. 基础概念
  • 题目: 解释 scikit-learn 中的 fit()transform()fit_transform() 方法的区别。
  • 答案:
    • fit(): 用于训练模型或计算数据集的统计信息(如均值、标准差等)。
    • transform(): 使用 fit() 计算的结果对数据进行转换(如标准化、降维等)。
    • fit_transform(): 结合 fit()transform(),先拟合数据,再对其进行转换。
2. 数据预处理
  • 题目: 如何使用 scikit-learn 对数据进行标准化处理?
  • 答案:
    from sklearn.preprocessing import StandardScaler
    
    scaler = StandardScaler()
    X_train_scaled = scaler.fit_transform(X_train)
    X_test_scaled = scaler.transform(X_test)
    
3. 模型训练与评估
  • 题目: 如何使用 scikit-learn 训练一个简单的线性回归模型,并计算其均方误差(MSE)?
  • 答案:
    from sklearn.linear_model import LinearRegression
    from sklearn.metrics import mean_squared_error
    
    model = LinearRegression()
    model.fit(X_train, y_train)
    y_pred = model.predict(X_test)
    mse = mean_squared_error(y_test, y_pred)
    print("Mean Squared Error:", mse)
    
4. 交叉验证
  • 题目: 如何使用 scikit-learn 进行 K 折交叉验证?
  • 答案:
    from sklearn.model_selection import cross_val_score
    from sklearn.ensemble import RandomForestClassifier
    
    model = RandomForestClassifier()
    scores = cross_val_score(model, X, y, cv=5)
    print("Cross-Validation Scores:", scores)
    
5. 超参数调优
  • 题目: 如何使用 GridSearchCV 进行超参数调优?
  • 答案:
    from sklearn.model_selection import GridSearchCV
    from sklearn.svm import SVC
    
    param_grid = {'C': [0.1, 1, 10], 'kernel': ['linear', 'rbf']}
    grid_search = GridSearchCV(SVC(), param_grid, cv=5)
    grid_search.fit(X_train, y_train)
    print("Best Parameters:", grid_search.best_params_)
    
6. 特征选择
  • 题目: 如何使用 scikit-learn 进行特征选择?
  • 答案:
    from sklearn.feature_selection import SelectKBest, f_classif
    
    selector = SelectKBest(f_classif, k=10)
    X_new = selector.fit_transform(X, y)
    
7. 模型持久化
  • 题目: 如何保存和加载 scikit-learn 模型?
  • 答案:
    import joblib
    
    # 保存模型
    joblib.dump(model, 'model.pkl')
    
    # 加载模型
    model = joblib.load('model.pkl')
    
8. Pipeline
  • 题目: 如何使用 Pipeline 将数据预处理和模型训练结合在一起?
  • 答案:
    from sklearn.pipeline import Pipeline
    from sklearn.preprocessing import StandardScaler
    from sklearn.linear_model import LogisticRegression
    
    pipeline = Pipeline([
        ('scaler', StandardScaler()),
        ('classifier', LogisticRegression())
    ])
    pipeline.fit(X_train, y_train)
    
9. 分类报告与混淆矩阵
  • 题目: 如何生成分类报告和混淆矩阵?
  • 答案:
    from sklearn.metrics import classification_report, confusion_matrix
    
    y_pred = model.predict(X_test)
    print("Classification Report:\n", classification_report(y_test, y_pred))
    print("Confusion Matrix:\n", confusion_matrix(y_test, y_pred))
    
10. 聚类算法
  • 题目: 如何使用 KMeans 进行聚类?
  • 答案:
    from sklearn.cluster import KMeans
    
    kmeans = KMeans(n_clusters=3)
    kmeans.fit(X)
    labels = kmeans.labels_
    
11. 降维
  • 题目: 如何使用 PCA 进行降维?
  • 答案:
    from sklearn.decomposition import PCA
    
    pca = PCA(n_components=2)
    X_reduced = pca.fit_transform(X)
    
12. 模型评估
  • 题目: 如何计算模型的准确率、精确率、召回率和 F1 分数?
  • 答案:
    from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
    
    y_pred = model.predict(X_test)
    print("Accuracy:", accuracy_score(y_test, y_pred))
    print("Precision:", precision_score(y_test, y_pred))
    print("Recall:", recall_score(y_test, y_pred))
    print("F1 Score:", f1_score(y_test, y_pred))
    

13. 异常检测

  • 题目: 如何使用 IsolationForest 进行异常检测?
  • 答案:
    from sklearn.ensemble import IsolationForest
    
    iso_forest = IsolationForest(contamination=0.1)
    iso_forest.fit(X)
    outliers = iso_forest.predict(X)
    

14. 集成学习

  • 题目: 如何使用 RandomForestClassifier 进行分类?
  • 答案:
    from sklearn.ensemble import RandomForestClassifier
    
    model = RandomForestClassifier(n_estimators=100)
    model.fit(X_train, y_train)
    y_pred = model.predict(X_test)
    

15. 时间序列处理

  • 题目: 如何使用 TimeSeriesSplit 进行时间序列交叉验证?
  • 答案:
    from sklearn.model_selection import TimeSeriesSplit
    
    tscv = TimeSeriesSplit(n_splits=5)
    for train_index, test_index in tscv.split(X):
        X_train, X_test = X[train_index], X[test_index]
        y_train, y_test = y[train_index], y[test_index]
    

http://www.kler.cn/a/522330.html

相关文章:

  • 团体程序设计天梯赛-练习集——L1-022 奇偶分家
  • Rust语言进阶之zip用法实例(九十五)
  • windows lm studio 0.3.8无法下载模型,更换镜像
  • OpenCV:开运算
  • 【信息系统项目管理师-选择真题】2010上半年综合知识答案和详解
  • Tensor 基本操作2 理解 tensor.max 操作,沿着给定的 dim 是什么意思 | PyTorch 深度学习实战
  • postgresql 9.4.1 普通表,子表,父表的创建与测试
  • 系统设计的
  • JavaScript系列(46)-- WebGL图形编程详解
  • 专为课堂打造:宏碁推出三款全新耐用型 Chromebook
  • 【实用技能】如何借助Excel处理控件Aspose.Cells,使用 C# 锁定 Excel 中的单元格
  • 获取加工视图下所有元素
  • java后端之事务管理
  • 【C++探索之路】STL---string
  • Day27-【13003】短文,单链表应用代码举例
  • 解决MySQL删除/var/lib/mysql下的所有文件后无法启动的问题
  • 未来五年高速线缆市场有望翻3倍!AEC凭借传输距离优势占比将更高
  • CentOS7非root用户离线安装Docker及常见问题总结、各种操作系统docker桌面程序下载地址
  • 非注意力模型崛起:LLM架构新突破
  • 【JavaEE】Spring(5):Mybatis(上)
  • 【单链表算法实战】解锁数据结构核心谜题——环形链表
  • 基于PostgreSQL的自然语义解析电子病历编程实践与探索(下)
  • vim多文件操作如何同屏开多个文件
  • 软件测试丨Airtest 游戏自动化测试框架
  • 电梯系统的UML文档12
  • LangChain:使用表达式语言优化提示词链