当前位置: 首页 > article >正文

计算机视觉实战|NeRF 实战教程:基于 nerf_recon_dataset 的三维重建

一、引言

神经辐射场(NeRF) 是一种利用神经网络从多视角图像重建 3D 场景的技术,通过隐式表示场景的几何和外观,实现高质量渲染。本教程将引导您使用 PyTorch 实现 NeRF 模型,基于 ModelScope 提供的 nerf_recon_dataset 数据集完成三维重建实战,包括环境配置、数据准备、模型构建、训练和渲染。
在这里插入图片描述

二、环境搭建

1、依赖安装

确保您的系统已安装 Python(推荐 3.8 或以上),然后安装以下依赖:

pip install torch torchvision numpy matplotlib opencv-python imageio gitpython
  • PyTorch:深度学习框架,用于模型构建和训练。
  • torchvision:图像处理工具。
  • NumPy:数组操作。
  • Matplotlib:可视化结果。
  • OpenCV:图像处理。
  • imageio:保存渲染视频。
  • gitpython:从 Git 下载数据集。

2、硬件要求

  • GPU(推荐):加速训练和推理,NVIDIA GPU(如 GTX 1080 或以上)配合 CUDA。
  • CPU(可选):可运行,但速度较慢。
  • 内存:至少 16GB,复杂场景可能需更多。

三、数据准备

1、数据集下载

我们使用 ModelScope 提供的 nerf_recon_dataset 数据集,包含多视角图像和相机参数。下载方法:

import git
import os

dataset_url = "https://www.modelscope.cn/datasets/damo/nerf_recon_dataset.git"
dataset_dir = "nerf_recon_dataset"

if not os.path.exists(dataset_dir):
    git.Repo.clone_from(dataset_url, dataset_dir)
else:
    print("Dataset already downloaded.")

下载后,检查目录结构(假设包含图像和相机参数文件,可能类似以下形式,具体以实际数据集为准):

nerf_recon_dataset
|   ├── nerf_synthetic
|   |   └── lego
|   |     └── train
|   |     └── val
|   |     └── test
|   |     transforms_train.json
|   |     transforms_val.json
|   |     transforms_test.json
|   |   └── ship    
|   |   └── ...
|   ├── custom
|   |   └── sceneName
|   |     └── images
|   |   └── scene.mp4
        

2、数据预处理

假设数据集提供图像和相机位姿(JSON 或其他格式),我们需加载图像并生成光线。以下为示例代码(根据实际文件格式调整):

import json
import cv2
import numpy as np
import torch

def load_data(data_dir, split='train'):
    # 假设相机参数在 transforms.json 中
    with open(f'{data_dir}/transforms_{split}.json', 'r') as f:
        meta = json.load(f)
    
    images = []
    poses = []
    for frame in meta['frames']:
        img_path = f"{data_dir}/{frame['file_path']}"
        img = cv2.imread(img_path)[..., ::-1]  # BGR to RGB
        img = img / 255.0  # 归一化到 [0, 1]
        images.append(img)
        poses.append(np.array(frame['transform_matrix']))
    
    H, W = images[0].shape[:2]
    focal = 0.5 * W / np.tan(0.5 * meta['camera_angle_x'])
    return np.array(images), np.array(poses), H, W, focal

data_dir = 'nerf_recon_dataset'
images, poses, H, W, focal = load_data(data_dir)

注意:如果数据集格式不同(如相机参数为单独文件或无 JSON),需根据 README 或文档调整加载逻辑。

四、NeRF 模型构建

1、网络架构

NeRF 使用 多层感知机(MLP),输入 5D 向量(3D 坐标 ( x , y , z ) (x, y, z) (x,y,z) 和观察方向 ( θ , ϕ ) (\theta, \phi) (θ,ϕ)),输出颜色 ( r , g , b ) (r, g, b) (r,g,b) 和体密度 σ \sigma σ

import torch
import torch.nn as nn

class NeRF(nn.Module):
    def __init__(self, D=8, W=256, input_ch=3, input_ch_views=2, output_ch=4):
        super(NeRF, self).__init__()
        self.input_ch = input_ch
        self.input_ch_views = input_ch_views
        
        # 主干网络
        layers = [nn.Linear(input_ch, W)]
        for i in range(D-1):
            layers += [nn.ReLU(), nn.Linear(W, W)]
        self.main = nn.Sequential(*layers)
        
        # 跳跃连接
        self.skip = nn.Linear(input_ch + W, W)
        
        # 输出颜色和密度
        self.sigma = nn.Sequential(nn.Linear(W, 1), nn.ReLU())
        self.rgb = nn.Sequential(
            nn.Linear(W + input_ch_views, W//2),
            nn.ReLU(),
            nn.Linear(W//2, 3),
            nn.Sigmoid()
        )
    
    def forward(self, x, d):
        input_pts, input_views = x, d
        h = self.main(input_pts)
        h = torch.cat([input_pts, h], -1)
        h = self.skip(h)
        sigma = self.sigma(h)
        h = torch.cat([h, input_views], -1)
        rgb = self.rgb(h)
        return torch.cat([rgb, sigma], -1)

model = NeRF().cuda()

2、位置编码

为捕捉高频细节,对输入进行位置编码:

γ ( p ) = [ sin ⁡ ( 2 0 π p ) , cos ⁡ ( 2 0 π p ) , … , sin ⁡ ( 2 L − 1 π p ) , cos ⁡ ( 2 L − 1 π p ) ] \gamma(p) = [\sin(2^0\pi p), \cos(2^0\pi p), \ldots, \sin(2^{L-1}\pi p), \cos(2^{L-1}\pi p)] γ(p)=[sin(20πp),cos(20πp),,sin(2L1πp),cos(2L1πp)]

def positional_encoding(x, L):
    out = [x]
    for i in range(L):
        out.append(torch.sin(2**i * np.pi * x))
        out.append(torch.cos(2**i * np.pi * x))
    return torch.cat(out, dim=-1)

pos_enc_L = 10  # 位置编码层数
view_enc_L = 4  # 方向编码层数

五、训练过程

1、光线采样与体渲染

生成光线并进行体渲染:

def get_rays(H, W, focal, pose):
    i, j = torch.meshgrid(
        torch.linspace(0, W-1, W),
        torch.linspace(0, H-1, H)
    )
    i, j = i.cuda(), j.cuda()
    dirs = torch.stack([(i-W*.5)/focal, -(j-H*.5)/focal, -torch.ones_like(i)], -1)
    rays_d = torch.sum(dirs[..., None, :] * pose[:3, :3], -1)
    rays_o = pose[:3, -1].expand_as(rays_d)
    return rays_o, rays_d

def render_rays(model, rays_o, rays_d, near=2., far=6., N_samples=64):
    t_vals = torch.linspace(near, far, N_samples).cuda()
    z_vals = near + (far - near) * torch.rand(rays_o.shape[:-1] + (N_samples,)).cuda()
    pts = rays_o[..., None, :] + rays_d[..., None, :] * z_vals[..., :, None]
    
    # 位置编码
    pts_flat = pts.reshape(-1, 3)
    dirs_flat = rays_d.reshape(-1, 3)
    encoded_pts = positional_encoding(pts_flat, pos_enc_L)
    encoded_dirs = positional_encoding(dirs_flat, view_enc_L)
    
    # 前向传播
    outputs = model(encoded_pts, encoded_dirs)
    outputs = outputs.reshape(*pts.shape[:-1], 4)
    rgb, sigma = outputs[..., :3], outputs[..., 3]
    
    # 体渲染
    deltas = z_vals[..., 1:] - z_vals[..., :-1]
    delta_inf = 1e10 * torch.ones_like(deltas[..., :1])
    deltas = torch.cat([deltas, delta_inf], -1)
    alpha = 1. - torch.exp(-sigma * deltas)
    weights = alpha * torch.cumprod(torch.cat([torch.ones_like(alpha[..., :1]), 1.-alpha + 1e-10], -1), -1)[..., :-1]
    rgb_map = torch.sum(weights[..., None] * rgb, -2)
    return rgb_map

rays_o, rays_d = get_rays(H, W, focal, torch.tensor(poses[0]).cuda())

2、训练循环

使用 MSE 损失函数优化模型:

import torch.optim as optim

optimizer = optim.Adam(model.parameters(), lr=5e-4)
N_iter = 20000  # 迭代次数
batch_size = 4096  # 每批光线数

for i in range(N_iter):
    img_idx = np.random.randint(len(images))
    target = torch.tensor(images[img_idx]).cuda()
    rays_o, rays_d = get_rays(H, W, focal, torch.tensor(poses[img_idx]).cuda())
    
    ray_idx = torch.randperm(H*W)[:batch_size]
    rays_o_batch = rays_o.reshape(-1, 3)[ray_idx]
    rays_d_batch = rays_d.reshape(-1, 3)[ray_idx]
    target_batch = target.reshape(-1, 3)[ray_idx]
    
    rgb_map = render_rays(model, rays_o_batch, rays_d_batch)
    loss = ((rgb_map - target_batch) ** 2).mean()
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    if i % 1000 == 0:
        print(f"Iter {i}, Loss: {loss.item():.4f}")

注意:根据数据集场景范围调整 nearfar 参数。

六、推理与渲染

1、生成新视角图像

渲染新视角图像:

def render_image(model, H, W, focal, pose):
    rays_o, rays_d = get_rays(H, W, focal, torch.tensor(pose).cuda())
    rgb_map = render_rays(model, rays_o.reshape(-1, 3), rays_d.reshape(-1, 3))
    return rgb_map.reshape(H, W, 3).cpu().numpy()

test_pose = poses[len(images)//2]  # 示例:中间视角
img = render_image(model, H, W, focal, test_pose)

import matplotlib.pyplot as plt
plt.imshow(img)
plt.axis('off')
plt.savefig('render_result.png')
plt.show()

2、生成视频

插值相机路径生成视频:

import imageio

frames = []
for t in np.linspace(0, 1, 120):
    pose = poses[0] * (1-t) + poses[-1] * t  # 线性插值
    img = render_image(model, H, W, focal, pose)
    frames.append((img * 255).astype(np.uint8))

imageio.mimsave('nerf_video.mp4', frames, fps=30)

在这里插入图片描述在这里插入图片描述

在这里插入图片描述
在这里插入图片描述

七、优化与扩展

  1. 层次采样:使用粗细网络优化采样效率。
  2. 正则化:添加 L2 正则防止过拟合。
  3. 动态场景:若数据集包含动态元素,可引入时间维度或变形场。
  4. 加速:尝试 Instant-NGP 等高效变体。

八、总结

本教程基于 nerf_recon_dataset 实现了 NeRF 三维重建,涵盖环境搭建、数据处理、模型构建和训练推理。完成基础实现后,您可根据数据集特性调整参数,或探索更高级技术。欢迎查阅 ModelScope 文档以优化代码。


延伸阅读

  • AI Agent 系列文章


  • 计算机视觉系列文章


  • 机器学习核心算法系列文章


  • 深度学习系列文章



http://www.kler.cn/a/582438.html

相关文章:

  • MySQL库和表的操作详解:从创建库到表的管理全面指南
  • 45.HarmonyOS NEXT Layout布局组件系统详解(十二):高级应用案例与性能优化
  • 无标记点动作捕捉系统,无需穿戴设备,摄像头智能采集人体运动姿态
  • Webpack 优化深度解析:从构建性能到输出优化的全面指南
  • TDengine SQL 函数
  • JVM和运行时数据区
  • 国产化信创操作系统的电脑,能运行windows程序吗
  • 关于回归中R2指标的理解
  • docker搭建elk
  • 【学写LibreCAD】 4.1 RS_Undoable文件
  • 【Linux内核系列】:文件系统
  • 一文说清docker及docker compose的应用和部署
  • UI显示不出来问题(有的能显示出来一个方法,有的数据显示不出来另一个方法),多次尝试无果
  • CSPM-3级国标认证,项目管理如何成为组织变革的核心引擎?
  • 裂变营销策略在“开源链动2+1模式AI智能名片S2B2C商城小程序”中的应用探索
  • JavaScript性能优化实战:让你的Web应用飞起来
  • AI+API引爆数据分析:BI已成过去?
  • 【漫话机器学习系列】133.决定系数(R²:Coefficient of Determination)
  • 微电网管理 实现分布式能源的智能调度和管理
  • ROS——节点、工作空间、功能包