当前位置: 首页 > article >正文

cuda从零开始手搓PB神经网络

cuda实现PB神经网络


基于上一篇的矩阵点乘,实现了矩阵的加减乘除、函数调用等。并且复用之前元编程里面写的梯度下降、Adam、NAdam优化方法。实现PB神经网络如下:

#ifndef __BP_NETWORK_HPP__
#define __BP_NETWORK_HPP__
#include "matrix.hpp"
#include "mat.hpp"
#include "update_methods.hpp"

template<typename activate_type, typename val_type_, template<typename> class update_type_tpl, typename init_type, int input_num_, int output_num_, int ... remain_layer>
struct bp_network
{
    constexpr static int input_num = input_num_;
    constexpr static int output_num = output_num_;
    using val_type = val_type_;

    using input_type = mat<input_num, 1, val_type>;
    using input_t_type = mat<1, input_num, val_type>;
    using output_type = mat<output_num, 1, val_type>;
    using weight_type = mat<output_num, input_num, val_type>;

    using forward_func = typename func_pair<activate_type>::forward_func;
    using backward_func = typename func_pair<activate_type>::backward_func;

    using next_node_type = typename bp_network<activate_type, val_type, update_type_tpl, init_type, output_num, remain_layer...>;
    using term_output_type = typename next_node_type::term_output_type;

    weight_type weight;
    update_type_tpl<weight_type> weight_update_method;
    output_type bias;
    update_type_tpl<output_type> bias_update_method;

    input_type pre_input;
    output_type pre_func_input;
    next_node_type next_node;

    bp_network():weight_update_method(), bias_update_method()
    {
        weight.template reset<init_type>();
        bias.template reset<init_type>();
        next_node = bp_network<activate_type, val_type, update_type_tpl, init_type, output_num, remain_layer...>();
    }

    auto forward(input_type& input)
    {
        output_type curr_output;
        pre_input = input;
        auto temp = weight.dot(input);
        pre_func_input = temp + bias;
        curr_output = pre_func_input.template activate<forward_func>();
        return next_node.forward(curr_output);
    }

    auto backward(term_output_type& delta, val_type lr)
    {
        output_type curr_delta = next_node.backward(delta, lr);
        curr_delta = pre_func_input.template activate<backward_func>() * curr_delta;
        auto ret = weight.t_dot(curr_delta);
        // 更新参数
        weight_type delta_weight = curr_delta.dot(pre_input.t());
        weight = weight_update_method.update(weight, delta_weight);
        bias = bias_update_method.update(bias, curr_delta);
        return ret;
    }   

    // 更新惯性量
    void update_inert()
    {
        weight_update_method.update_inert();
        bias_update_method.update_inert();
        next_node.update_inert();
    }

    void print()
    {
        weight.print();
        printf("-----------------\n");
        bias.print();
        printf("=================\n");
        next_node.print();
    }
};

template<typename activate_type, typename val_type_, template<typename> class update_type_tpl, typename init_type, int input_num_, int output_num_>
struct bp_network<activate_type, val_type_, update_type_tpl, init_type, input_num_, output_num_>
{
    constexpr static int input_num = input_num_;
    constexpr static int output_num = output_num_;
    using val_type = val_type_;

    using input_type = mat<input_num, 1, val_type>;
    using input_t_type = mat<1, input_num, val_type>;
    using output_type = mat<output_num, 1, val_type>;
    using weight_type = mat<output_num, input_num, val_type>;

    using forward_func = typename func_pair<activate_type>::forward_func;
    using backward_func = typename func_pair<activate_type>::backward_func;
    using term_output_type = typename output_type;
    using weight_update_type = typename update_type_tpl<weight_type>;
    using bias_update_type = typename update_type_tpl<output_type>;

    weight_type weight;
    weight_update_type weight_update;
    output_type bias;
    bias_update_type bias_update;

    output_type pre_func_input;
    input_type pre_input;

    bp_network():weight_update(), bias_update()
    {
        weight.template reset<init_type>();
        bias.template reset<init_type>();
    }

    auto forward(input_type& input)
    {
        pre_input = input;
        auto temp = weight.dot(input);
        pre_func_input = temp + bias;
        return pre_func_input.template activate<forward_func>();
    }

    auto backward(output_type& delta, val_type lr)
    {
        output_type curr_delta = pre_func_input.template activate<backward_func>() * delta;
        auto ret = weight.t_dot(curr_delta);
        // 更新参数
        weight_type delta_weight = curr_delta.dot(pre_input.t());
        weight = weight_update.update(weight, delta_weight);
        bias = bias_update.update(bias, curr_delta);
        return ret;
    }

    void update_inert()
    {
        weight_update.update_inert();
        bias_update.update_inert();
    }

    void print()
    {
        weight.print();
        printf("-----------------\n");
        bias.print();
        printf("*****************\n");
    }
};

#endif

下面实验一下我们的bp神经网络。

#include <chrono>
#include <thread>
#include "matrix.hpp"
#include "bp_network.hpp"
int main()
{
    constexpr int row_num = 32;
    constexpr int adj_num = 32;
    constexpr int col_num = 32;
    /*
    matrix_device_proxy<row_num, adj_num, double> A;
    eyes(A(), 2.0);
    matrix_device_proxy<adj_num, col_num, double> B;
    eyes(B(), 1.0);
    matrix_device_proxy<row_num, col_num, double> C;
    mat_dot<sigmoid>(A(), B(), C());
    print(type_cast(C()));

    auto A = mat<row_num, adj_num, double>::eyes(2.0);
    auto B = mat<adj_num, col_num, double>::eyes(1.0);
    auto C = A.dot(B);
    C = C + 1.0;
    C = sqrtl(C);
    C = C - 2.0;
    C = C * 3.0;
    C = C / 4.0;
    C.print();

    std::cout << "---------- D ----------" << std::endl;
    auto D = mat<row_num, col_num, double>::xavier_gaussian();
    D.print();
    std::cout << "---------- E ----------" << std::endl;
    auto E = mat<row_num, col_num, double>::xavier_mean();
    E.print();
    std::cout << "---------- F ----------" << std::endl;
    auto F = mat<row_num, col_num, double>::he_gaussian();
    F.print();
    std::cout << "---------- G ----------" << std::endl;
    auto G = mat<row_num, col_num, double>::he_mean();
    G.print();
    */
    bp_network<sigmoid, double, nadam, xavier_gaussian_type, row_num, adj_num, col_num> node;
    auto input = mat<row_num, 1, double>::ones(0.2);
    auto expect = mat<col_num, 1, double>::ones(0.4);

    int times = 8000;
    int update_inert_times = 100;
    int step = times / update_inert_times;
    // 计时开始
    auto start = std::chrono::high_resolution_clock::now();
    for (int i = 0; i < times; ++i)
    {
        auto output = node.forward(input);
        auto delta = (output - expect);
        node.backward(delta, 0.001);
        if (i == times - 1)
        {
            output.t().print();
        }

        if (i % step == 0 && i != 0)
        {
            node.update_inert();
        }

    }
    // 计时结束
    // 获取结束时间点
    auto end = std::chrono::high_resolution_clock::now();

    // 计算持续时间
    std::chrono::duration<double> duration = end - start;

    // 输出执行时间
    std::cout << "Execution time: " << duration.count() << " seconds" << std::endl;
    //node.print();
    cudaDeviceReset();
    return 0;
}

以上代码有个学习率lr没有地方设置哈,将来优化,见谅。执行结果如下:
在这里插入图片描述
可以看出,经过8000次的训练,这个使用sigmoid激活函数、NAdam优化、Xavier-Gaussian初始化的323232的PB能够将误差缩减到0.0001这个量级,而训练时间仅为8.54秒。还是相当给力的。
虽然这对于我的工作没有任何关系,但是我还是想搞一下。毕竟“越是没用的知识就越有用,越是有用的东西就越没用”。


http://www.kler.cn/a/512855.html

相关文章:

  • 【winRAR】windows11右键直接打开winRAR
  • C语言内存之旅:从静态到动态的跨越
  • 【LC】2239. 找到最接近 0 的数字
  • 基础入门-传输加密数据格式编码算法密文存储代码混淆逆向保护安全影响
  • 从零开始:Spring Boot核心概念与架构解析
  • 【玩转全栈】----Django模板的继承
  • QT:多窗口设计(主窗口点击按钮打开子窗口)
  • 开源的Text-to-SQL工具WrenAI
  • SQL Server2022版详细安装教程(Windows)
  • 有线通信方式(I2C、UART、RS485)
  • 【Red Hat8】:搭建FTP服务器
  • springboot接入deepseek深度求索 java
  • vue3使用音频audio标签
  • 可视化平台建设技术方案,商业BI系统解决方案,大屏建设功能需求分析(word原件)
  • Datawhale组队学习笔记task2——leetcode面试题
  • 前〈和平精英〉技术策划进军AI游戏领域,获新投资
  • 【数据结构】搜索二叉树
  • 【有啥问啥】什么是端到端(End-to-End)?
  • 【AI大模型Agent探索】深入探索实践 Qwen-Agent 的 Function Calling
  • 【Linux】Linux入门(4)其他常用指令
  • 基于Docker的Kafka分布式集群
  • leetcode——和为K的子数组(java)
  • 【配置环境】VS Code中JavaScript环境搭建
  • Ubuntu22.04系统切换内核版本
  • 【论文投稿】探秘嵌入式硬件设计:从原理到代码实战
  • 计算机视觉模型的未来:视觉语言模型