当前位置: 首页 > article >正文

java中使用BP网络进行回归

在神经网络中,激活函数扮演着至关重要的角色,它们为网络引入了非线性因素,使得网络能够学习和模拟复杂的非线性关系。下面列出的各种激活函数及其特点如下:

  1. LINEAR ("Linear"):
    • 线性激活函数实际上就是恒等函数,即 f(x)=x。
    • 它不会给网络带来任何非线性特性,因此在实际应用中很少使用(除了在某些特殊情况下,如回归问题或某些层作为输出层时)。
  2. RAMP ("Ramp"):
    • Ramp函数是一种分段线性函数,通常定义为当 x<0 时 f(x)=0,当 x≥0 时 f(x)=x。
    • 它与ReLU(Rectified Linear Unit)类似,但Ramp函数在 x<0 时保持输出为0,而不是像ReLU那样保持不变。Ramp函数有时被认为是一个更平滑的ReLU版本。
  3. STEP ("Step"):
    • Step函数是一种简单的二值激活函数,通常定义为当 x<0 时 f(x)=0,当 x≥0 时 f(x)=1。
    • 它将任意输入值映射到两个输出值之一,这种非连续性使得它在大多数现代神经网络中不太实用。
  4. SIGMOID ("Sigmoid"):
    • Sigmoid函数是一个S形曲线,定义为 f(x)=1+e−x1​。
    • 它将任意实值压缩到区间(0, 1)内,非常适合用于二分类问题的输出层。然而,由于梯度消失问题和计算量较大,它在深度神经网络中的使用已经减少。
  5. TANH ("Tanh"):
    • Tanh函数是双曲正切函数,定义为 f(x)=ex+e−xex−e−x​。
    • 它将输入值压缩到区间(-1, 1)内,并且是零中心的(即,输出的均值为0)。这有助于加快收敛速度,因此它比Sigmoid函数更受欢迎。然而,梯度消失问题仍然存在。
  6. GAUSSIAN ("Gaussian"):
    • 高斯(或称为高斯)激活函数使用高斯(或正态)分布函数,形式可能因具体实现而异,但通常与标准正态分布相关。
    • 它将输入值映射到一个钟形曲线,但这种激活函数在神经网络中不太常见,因为其他激活函数通常能提供更好的性能。
  7. TRAPEZOID ("Trapezoid"):
    • 梯形激活函数是一种非线性函数,其形状类似于梯形。
    • 它结合了线性和非线性的部分,但在实际应用中不太常见,可能是因为它缺乏足够的理论依据或优势。
  8. SGN ("Sgn"):
    • Sgn(符号)函数根据输入值的正负输出+1、-1或0(取决于具体定义)。
    • 它是一种简单的二值激活函数,但在神经网络中并不常用,因为它的梯度在大部分输入上都是0(除了0点附近的微小区域),这会导致梯度消失问题。
  9. SIN ("Sin"):
    • Sin函数是正弦函数,将输入值映射到[-1, 1]区间的正弦波上。
    • 它是一个周期性的非线性函数,但在神经网络中不常用,因为神经网络通常希望激活函数具有单调递增或递减的特性,以便进行有效的梯度下降。
  10. LOG ("Log"):
    • Log函数(可能指的是自然对数或其他对数)将输入值映射到对数尺度上。
    • 它在某些特定领域(如机器学习中的某些损失函数)中很有用,但在作为神经网络的激活函数时并不常见。
  11. RECTIFIED ("RectifiedLinear", ReLU):
    • ReLU(Rectified Linear Unit)函数是当前深度学习中使用最广泛的激活函数之一。
    • 它定义为 f(x)=max(0,x),即将所有负值置为0,保持正值不变。
    • ReLU函数简单、计算效率高,且有助于缓解梯度消失问题(在正值区域内)。然而,它也存在死亡ReLU问题(即某些神经元在训练过程中永远不会被激活)。

   为了分析1个y和3个自变量(x,z,  k) (0为训练集,1为测试集) 的回归问题使用java的BP网络激活函数使用RECTIFIED如下使用依赖如下:

  <properties>
<neuroph-core.version>2.96</neuroph-core.version>
<visrec-api.version>1.0.3</visrec-api.version>     
  </properties>

   <dependency>
            <groupId>org.neuroph</groupId>
            <artifactId>neuroph-core</artifactId>
            <version>${neuroph-core.version}</version>
        </dependency>

        <!-- https://mvnrepository.com/artifact/javax.visrec/visrec-api -->
        <dependency>
            <groupId>javax.visrec</groupId>
            <artifactId>visrec-api</artifactId>
            <version>${visrec-api.version}</version>
        </dependency>

进行回归的方法如下:

private static void poly3UseBpn(Double[] x0, Double[] z0, Double[] k0, Double[] y, Double[] x1, Double[] z1, Double[] k1, Double[] y0, Double[] y1) {
    // 归一化数据
    double minX = Arrays.stream(x0).min(Double::compareTo).get();
    double maxX = Arrays.stream(x0).max(Double::compareTo).get();
    double rangeX = maxX - minX;
    double minZ = Arrays.stream(z0).min(Double::compareTo).get();
    double maxZ = Arrays.stream(z0).max(Double::compareTo).get();
    double rangeZ = maxZ - minZ;
    double minK = Arrays.stream(k0).min(Double::compareTo).get();
    double maxK = Arrays.stream(k0).max(Double::compareTo).get();
    double rangeK = maxK - minK;
    // 创建数据集 (3 个输入,1 个输出)
    DataSet trainingSet = new DataSet(3, 1);
    for (int i = 0; i < x0.length; i++) {
        double normalizedInputX = (x0[i] - minX) / rangeX;
        double normalizedInputZ = (z0[i] - minZ) / rangeZ;
        double normalizedInputK = (k0[i] - minK) / rangeZ;
        trainingSet.add(new DataSetRow(new double[]{normalizedInputX, normalizedInputZ, normalizedInputK}, new double[]{y[i]}));
    }
    // 创建一个具有 3 个输入层神经元,10 个隐藏层神经元,1 个输出层神经元的神经网络
    MultiLayerPerceptron neuralNet = new MultiLayerPerceptron(TransferFunctionType.RECTIFIED, 3, 10, 1);
    //使用固定的权重
    neuralNet.randomizeWeights(new Random(1));
    // 设置学习率和迭代次数
    double currentLearningRate = 0.01;
    int iterations = 10000;
    double errorThreshold = 0.001;
    double minLearningRate = 1e-5;
    double decay = 0.9;
    // 训练网络
    BackPropagation backPropagation = new BackPropagation();
    backPropagation.setLearningRate(currentLearningRate);
    neuralNet.setLearningRule(backPropagation);
    for (int i = 0; i < iterations; i++) {
        neuralNet.learn(trainingSet);
        // 获取当前误差
        double error = backPropagation.getTotalNetworkError();
        if (error < errorThreshold) {
            System.out.println("error:" + error + "---->" + (error < errorThreshold) + "====>" + i);
            break;
        }
        // 更新学习率(指数衰减)
        currentLearningRate = Math.max(minLearningRate, currentLearningRate * decay);
        backPropagation.setLearningRate(currentLearningRate);
    }

    // 在 x0 和 z0 上评估拟合值,并将结果存储在 y0 中
    for (int i = 0; i < x0.length; i++) {
        double normalizedInputX = (x0[i] - minX) / rangeX;
        double normalizedInputZ = (z0[i] - minZ) / rangeZ;
        double normalizedInputK = (k0[i] - minK) / rangeZ;
        neuralNet.setInput(normalizedInputX, normalizedInputZ, normalizedInputK);
        neuralNet.calculate();
        y0[i] = neuralNet.getOutput()[0];
    }
    // 在 x1 和 z1 上评估拟合值,并将结果存储在 y1 中
    for (int i = 0; i < x1.length; i++) {
        double normalizedInputX = (x1[i] - minX) / rangeX;
        double normalizedInputZ = (z1[i] - minZ) / rangeZ;
        double normalizedInputK = (k1[i] - minK) / rangeZ;
        neuralNet.setInput(normalizedInputX, normalizedInputZ, normalizedInputK);
        neuralNet.calculate();
        y1[i] = neuralNet.getOutput()[0];
    }

由于这边数据量比较小大概在70条左右,由于neuroph是单线程的,在执行到neuralNet.learn(trainingSet);这里时会卡住,调整学习率和减少隐藏层单元后仍旧无法训练故只能换个方法如下使用deeplearning4j

  <dl4j.version>1.0.0-M1</dl4j.version>
        <nd4j.version>1.0.0-M1</nd4j.version>
  
<dependency>
            <groupId>org.deeplearning4j</groupId>
            <artifactId>deeplearning4j-core</artifactId>
            <version>${dl4j.version}</version>
        </dependency>
        <dependency>
            <groupId>org.deeplearning4j</groupId>
            <artifactId>deeplearning4j-nn</artifactId>
            <version>${dl4j.version}</version>
        </dependency>
        <dependency>
            <groupId>org.nd4j</groupId>
            <artifactId>nd4j-native-platform</artifactId>
            <version>${nd4j.version}</version>
        </dependency>

修改后的训练方法如下:

public void poly3UseBpn(double[] x0, double[] z0, double[] k0, double[] y,
                            double[] x1, double[] z1, double[] k1, double[] y0, double[] y1) {
        // 创建输入数据集 (3个输入,1个输出)
        List<org.nd4j.linalg.dataset.DataSet> dataSetList = new ArrayList<>();
        for (int i = 0; i < x0.length; i++) {
            INDArray input = Nd4j.create(new double[][]{{x0[i], z0[i], k0[i]}});
            INDArray output = Nd4j.create(new double[][]{{y[i]}});
            org.nd4j.linalg.dataset.DataSet DataSet = new org.nd4j.linalg.dataset.DataSet(input, output);
            dataSetList.add(DataSet);
        }
        ListDataSetIterator<org.nd4j.linalg.dataset.DataSet> trainingData = new ListDataSetIterator<>(dataSetList);
        // 构建网络配置
        MultiLayerNetwork model = new MultiLayerNetwork(new NeuralNetConfiguration.Builder()
                .seed(42)//随机种子
                .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) //随机梯度下降
                .updater(new org.nd4j.linalg.learning.config.Adam(0.01)) //初始化学习率
                .weightInit(WeightInit.RELU)//权重初始化
                .list()
                .layer(0, new DenseLayer.Builder()
                        .nIn(3)
                        .nOut(10)
                        .activation(Activation.RELU) //激活函数
                        .build()) // 输入层 3个输入
                .layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.MSE) //MSE损失函数 MSE表示均方误差
                        .nIn(10)
                        .nOut(1)
                        .activation(Activation.IDENTITY)
                        .build()) // 输出层 1个输出
                .build());
        //网络初始化
        model.init();
        //用于输出误差信息 100次输出1个
        model.setListeners(new ScoreIterationListener(100));
        // 设置训练参数
        double errorThreshold = 0.001;
        int maxIterations = 1000;
        int patience = 10;
        double bestError = Double.MAX_VALUE;
        int noImprovementCount = 0;
        for (int i = 0; i < maxIterations; i++) {
            model.fit(trainingData);
            double currentError = model.score();
            System.out.println("error:" + currentError + "====>" + i);

            if (currentError < bestError) {
                bestError = currentError;
                noImprovementCount = 0;
            } else {
                noImprovementCount++;
            }

            if (currentError < errorThreshold || noImprovementCount >= patience) {
                break;
            }
        }

        // 评估在 x0 和 z0 上的拟合值,并将结果存储在 y0 中
        for (int i = 0; i < x0.length; i++) {
            INDArray input = Nd4j.create(new double[][]{{x0[i], z0[i], k0[i]}});
            INDArray output = model.output(input);
            y0[i] = output.getDouble(0);
        }

        // 评估在 x1 和 z1 上的拟合值,并将结果存储在 y1 中
        for (int i = 0; i < x1.length; i++) {
            INDArray input = Nd4j.create(new double[][]{{x1[i], z1[i], k1[i]}});
            INDArray output = model.output(input);
            y1[i] = output.getDouble(0);
        }
    }

修改完后重新运行,已经可以顺利训练出数据了,自此记录一下!


http://www.kler.cn/news/317097.html

相关文章:

  • 【ComfyUI】控制光照节点——ComfyUI-IC-Light-Native
  • 爵士编曲:爵士鼓编写 爵士鼓笔记 底鼓和军鼓 闭镲和开镲 嗵鼓
  • 9.23作业
  • 无人机之激光避障篇
  • 3.4 爬虫实战-爬去智联招聘职位信息
  • 什么是反射,反射用途,spring哪些地方用到了反射,我们项目中哪些地方用到了反射
  • 【python】requests 库 源码解读、参数解读
  • Maven笔记(一):基础使用【记录】
  • Spring Boot 中的拦截器 Interceptors
  • 【已解决】用JAVA代码实现递归算法-从自然数中取3个数进行组合之递归算法-用递归算法找出 n(n>=3) 个自然数中取 3 个数的组合。
  • 在云渲染中3D工程文件安全性怎么样?
  • 【HarmonyOS】深入理解@Observed装饰器和@ObjectLink装饰器:嵌套类对象属性变化
  • Unity-Screen屏幕相关
  • 【设计模式】万字详解:深入掌握五大基础行为模式
  • 鸿蒙 OS 开发零基础快速入门教程
  • ER论文阅读-Incomplete Multimodality-Diffused Emotion Recognition
  • 【LLM学习之路】9月22日 第九天 自然语言处理
  • 计算一个矩阵的逆矩阵的方法
  • 2024ICPC网络赛第一场C. Permutation Counting 4(线性代数)
  • nginx的反向代理和负载均衡
  • 16.3 k8s容器cpu内存告警指标与资源request和limit
  • 【数据结构-栈】力扣682. 棒球比赛
  • 0-1开发自己的obsidian plugin DAY 1
  • 鸿蒙操作系统(HarmonyOS)生态与机遇
  • YOLOv10改进,YOLOv10替换主干网络为PP-HGNetV1(百度飞桨视觉团队自研,全网首发,助力涨点)
  • watch和computed的使用及区别
  • Correcting Chinese Spelling Errors with Phonetic Pre-training(ACL2021)
  • Python Web 面试题
  • Spring Boot自定义配置项
  • [leetcode刷题]面试经典150题之6轮转数字(简单)