当前位置: 首页 > article >正文

训练数据重复采样,让正负样本比例1:1

详细解释

  1. resample 函数

    • resample 函数来自 sklearn.utils,用于从数据集中重新抽样。
    • replace=True 表示允许重复抽样,即同一个样本可以被多次选中。
    • n_samples 指定抽样的数量。
  2. 确保训练集数量相同

    • 通过 resample 函数,你可以确保正训练集和负训练集的数量相同,即使其中一个集的数量小于另一个集的数量。
    • 如果 n_train_num 小于 max_train_numresample 会从 n_train 中随机选择 max_train_num 个样本,允许重复选择。

示例代码

假设你有一个包含正样本和负样本的列表,并且需要确保训练集中的正样本和负样本数量相同。以下是一个完整的示例代码:

import random
from sklearn.utils import resample

# 假设 positive_ori 和 negative_ori 是包含正样本和负样本的列表
positive_ori = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
negative_ori = [11, 12, 13, 14, 15, 16, 17, 18, 19, 20]

# 指定测试样本数量
p_test_num = 3
n_test_num = 3

# 抽取测试集
p_test = random.sample(positive_ori, p_test_num)
n_test = random.sample(negative_ori, n_test_num)

# 生成训练集
p_train = [item for item in positive_ori if item not in p_test]
n_train = [item for item in negative_ori if item not in n_test]

# 计算训练集的最大数量
max_train_num = max(len(p_train), len(n_train))

# 确保训练集数量相同
if len(p_train) < max_train_num:
    p_train = resample(p_train, replace=True, n_samples=max_train_num)
if len(n_train) < max_train_num:
    n_train = resample(n_train, replace=True, n_samples=max_train_num)

# 打印结果
print("正测试集:", p_test)
print("正训练集:", p_train)
print("负测试集:", n_test)
print("负训练集:", n_train)

示例输出

假设 random.sample 抽取的元素如下:

  • p_test = [2, 5, 9]
  • n_test = [12, 15, 18]

则输出可能如下:

正测试集: [2, 5, 9]
正训练集: [1, 3, 4, 6, 7, 8, 10]
负测试集: [12, 15, 18]
负训练集: [11, 13, 14, 16, 17, 19, 20, 11, 13]

解释

  1. 抽取测试集

    • p_test 从 positive_ori 中随机抽取了 3 个元素 [2, 5, 9]
    • n_test 从 negative_ori 中随机抽取了 3 个元素 [12, 15, 18]
  2. 生成训练集

    • p_train 从 positive_ori 中移除了 p_test 中的元素,生成了 [1, 3, 4, 6, 7, 8, 10]
    • n_train 从 negative_ori 中移除了 n_test 中的元素,生成了 [11, 13, 14, 16, 17, 19, 20]
  3. 确保训练集数量相同

    • max_train_num 计算为 7(p_train 和 n_train 的长度都是 7)。
    • 由于 p_train 和 n_train 的长度已经相等,不需要重新抽样。
    • 如果 n_train 的长度小于 7,resample 会从 n_train 中随机选择 7 个样本,允许重复选择。

重复抽样的示例

假设 n_train 的长度小于 max_train_num,例如 n_train 只有 5 个元素:

import random
from sklearn.utils import resample

# 假设 positive_ori 和 negative_ori 是包含正样本和负样本的列表
positive_ori = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
negative_ori = [11, 12, 13, 14, 15, 16, 17, 18, 19, 20]

# 指定测试样本数量
p_test_num = 3
n_test_num = 5

# 抽取测试集
p_test = random.sample(positive_ori, p_test_num)
n_test = random.sample(negative_ori, n_test_num)

# 生成训练集
p_train = [item for item in positive_ori if item not in p_test]
n_train = [item for item in negative_ori if item not in n_test]

# 计算训练集的最大数量
max_train_num = max(len(p_train), len(n_train))

# 确保训练集数量相同
if len(p_train) < max_train_num:
    p_train = resample(p_train, replace=True, n_samples=max_train_num)
if len(n_train) < max_train_num:
    n_train = resample(n_train, replace=True, n_samples=max_train_num)

# 打印结果
print("正测试集:", p_test)
print("正训练集:", p_train)
print("负测试集:", n_test)
print("负训练集:", n_train)

输出示例

假设 random.sample 抽取的元素如下:

  • p_test = [2, 5, 9]
  • n_test = [12, 15, 18, 19, 20]

则输出可能如下:

正测试集: [2, 5, 9]
正训练集: [1, 3, 4, 6, 7, 8, 10]
负测试集: [12, 15, 18, 19, 20]
负训练集: [11, 13, 14, 16, 17, 11, 13]

解释

  1. 抽取测试集

    • p_test 从 positive_ori 中随机抽取了 3 个元素 [2, 5, 9]
    • n_test 从 negative_ori 中随机抽取了 5 个元素 [12, 15, 18, 19, 20]
  2. 生成训练集

    • p_train 从 positive_ori 中移除了 p_test 中的元素,生成了 [1, 3, 4, 6, 7, 8, 10]
    • n_train 从 negative_ori 中移除了 n_test 中的元素,生成了 [11, 13, 14, 16, 17]
  3. 确保训练集数量相同

    • max_train_num 计算为 7(p_train 的长度是 7,n_train 的长度是 5)。
    • 由于 n_train 的长度小于 7,resample 会从 n_train 中随机选择 7 个样本,允许重复选择。因此,n_train 可能包含重复的元素,例如 [11, 13, 14, 16, 17, 11, 13]

总结

使用 resample 函数并设置 replace=True 可以确保在训练集数量不一致时,通过允许重复抽样来平衡训练集的数量。这在数据集不平衡的情况下非常有用,可以确保模型在训练时看到相同数量的正样本和负样本。


http://www.kler.cn/a/584949.html

相关文章:

  • 【开源项目-爬虫】Firecrawl
  • MySQL行列转化
  • 开VR大空间体验馆,如何最低成本获取最大收入?
  • 深度学习环境配置指令大全
  • go-文件缓存与锁
  • C#中除了Dictionary,List,HashSet,HashTable 还有哪些可以保存列表的数据类型?
  • 批量将 Excel 文档中的图片提取到文件夹
  • 如何学习VBA_3.2.20:DTP与Datepicker实现日期的输入
  • 罗德与施瓦茨RTO1044,数字示波器
  • 大数据面试之路 (一) 数据倾斜
  • C++程序设计语言笔记——基本功能:异常处理
  • 如何接入DeepSeek布局企业AI系统开发技术
  • JVM内存结构笔记01-运行时数据区域
  • 记录致远OA服务器硬盘升级过程
  • Qt常用控件之水平布局QHBoxLayout
  • node基础
  • 【YOLOv8】YOLOv8改进系列(6)----替换主干网络之VanillaNet
  • Python 机器学习小项目:手写数字识别(MNIST 数据集)
  • 蓝桥杯备赛-基础练习 day1
  • linux 构建网站环境