计算机视觉实战|NeRF 实战教程:基于 nerf_recon_dataset 的三维重建
一、引言
神经辐射场(NeRF) 是一种利用神经网络从多视角图像重建 3D 场景的技术,通过隐式表示场景的几何和外观,实现高质量渲染。本教程将引导您使用 PyTorch 实现 NeRF 模型,基于 ModelScope 提供的 nerf_recon_dataset
数据集完成三维重建实战,包括环境配置、数据准备、模型构建、训练和渲染。
二、环境搭建
1、依赖安装
确保您的系统已安装 Python(推荐 3.8 或以上),然后安装以下依赖:
pip install torch torchvision numpy matplotlib opencv-python imageio gitpython
- PyTorch:深度学习框架,用于模型构建和训练。
- torchvision:图像处理工具。
- NumPy:数组操作。
- Matplotlib:可视化结果。
- OpenCV:图像处理。
- imageio:保存渲染视频。
- gitpython:从 Git 下载数据集。
2、硬件要求
- GPU(推荐):加速训练和推理,NVIDIA GPU(如 GTX 1080 或以上)配合 CUDA。
- CPU(可选):可运行,但速度较慢。
- 内存:至少 16GB,复杂场景可能需更多。
三、数据准备
1、数据集下载
我们使用 ModelScope 提供的 nerf_recon_dataset 数据集,包含多视角图像和相机参数。下载方法:
import git
import os
dataset_url = "https://www.modelscope.cn/datasets/damo/nerf_recon_dataset.git"
dataset_dir = "nerf_recon_dataset"
if not os.path.exists(dataset_dir):
git.Repo.clone_from(dataset_url, dataset_dir)
else:
print("Dataset already downloaded.")
下载后,检查目录结构(假设包含图像和相机参数文件,可能类似以下形式,具体以实际数据集为准):
nerf_recon_dataset
| ├── nerf_synthetic
| | └── lego
| | └── train
| | └── val
| | └── test
| | transforms_train.json
| | transforms_val.json
| | transforms_test.json
| | └── ship
| | └── ...
| ├── custom
| | └── sceneName
| | └── images
| | └── scene.mp4
2、数据预处理
假设数据集提供图像和相机位姿(JSON 或其他格式),我们需加载图像并生成光线。以下为示例代码(根据实际文件格式调整):
import json
import cv2
import numpy as np
import torch
def load_data(data_dir, split='train'):
# 假设相机参数在 transforms.json 中
with open(f'{data_dir}/transforms_{split}.json', 'r') as f:
meta = json.load(f)
images = []
poses = []
for frame in meta['frames']:
img_path = f"{data_dir}/{frame['file_path']}"
img = cv2.imread(img_path)[..., ::-1] # BGR to RGB
img = img / 255.0 # 归一化到 [0, 1]
images.append(img)
poses.append(np.array(frame['transform_matrix']))
H, W = images[0].shape[:2]
focal = 0.5 * W / np.tan(0.5 * meta['camera_angle_x'])
return np.array(images), np.array(poses), H, W, focal
data_dir = 'nerf_recon_dataset'
images, poses, H, W, focal = load_data(data_dir)
注意:如果数据集格式不同(如相机参数为单独文件或无 JSON),需根据 README 或文档调整加载逻辑。
四、NeRF 模型构建
1、网络架构
NeRF 使用 多层感知机(MLP),输入 5D 向量(3D 坐标 ( x , y , z ) (x, y, z) (x,y,z) 和观察方向 ( θ , ϕ ) (\theta, \phi) (θ,ϕ)),输出颜色 ( r , g , b ) (r, g, b) (r,g,b) 和体密度 σ \sigma σ。
import torch
import torch.nn as nn
class NeRF(nn.Module):
def __init__(self, D=8, W=256, input_ch=3, input_ch_views=2, output_ch=4):
super(NeRF, self).__init__()
self.input_ch = input_ch
self.input_ch_views = input_ch_views
# 主干网络
layers = [nn.Linear(input_ch, W)]
for i in range(D-1):
layers += [nn.ReLU(), nn.Linear(W, W)]
self.main = nn.Sequential(*layers)
# 跳跃连接
self.skip = nn.Linear(input_ch + W, W)
# 输出颜色和密度
self.sigma = nn.Sequential(nn.Linear(W, 1), nn.ReLU())
self.rgb = nn.Sequential(
nn.Linear(W + input_ch_views, W//2),
nn.ReLU(),
nn.Linear(W//2, 3),
nn.Sigmoid()
)
def forward(self, x, d):
input_pts, input_views = x, d
h = self.main(input_pts)
h = torch.cat([input_pts, h], -1)
h = self.skip(h)
sigma = self.sigma(h)
h = torch.cat([h, input_views], -1)
rgb = self.rgb(h)
return torch.cat([rgb, sigma], -1)
model = NeRF().cuda()
2、位置编码
为捕捉高频细节,对输入进行位置编码:
γ ( p ) = [ sin ( 2 0 π p ) , cos ( 2 0 π p ) , … , sin ( 2 L − 1 π p ) , cos ( 2 L − 1 π p ) ] \gamma(p) = [\sin(2^0\pi p), \cos(2^0\pi p), \ldots, \sin(2^{L-1}\pi p), \cos(2^{L-1}\pi p)] γ(p)=[sin(20πp),cos(20πp),…,sin(2L−1πp),cos(2L−1πp)]
def positional_encoding(x, L):
out = [x]
for i in range(L):
out.append(torch.sin(2**i * np.pi * x))
out.append(torch.cos(2**i * np.pi * x))
return torch.cat(out, dim=-1)
pos_enc_L = 10 # 位置编码层数
view_enc_L = 4 # 方向编码层数
五、训练过程
1、光线采样与体渲染
生成光线并进行体渲染:
def get_rays(H, W, focal, pose):
i, j = torch.meshgrid(
torch.linspace(0, W-1, W),
torch.linspace(0, H-1, H)
)
i, j = i.cuda(), j.cuda()
dirs = torch.stack([(i-W*.5)/focal, -(j-H*.5)/focal, -torch.ones_like(i)], -1)
rays_d = torch.sum(dirs[..., None, :] * pose[:3, :3], -1)
rays_o = pose[:3, -1].expand_as(rays_d)
return rays_o, rays_d
def render_rays(model, rays_o, rays_d, near=2., far=6., N_samples=64):
t_vals = torch.linspace(near, far, N_samples).cuda()
z_vals = near + (far - near) * torch.rand(rays_o.shape[:-1] + (N_samples,)).cuda()
pts = rays_o[..., None, :] + rays_d[..., None, :] * z_vals[..., :, None]
# 位置编码
pts_flat = pts.reshape(-1, 3)
dirs_flat = rays_d.reshape(-1, 3)
encoded_pts = positional_encoding(pts_flat, pos_enc_L)
encoded_dirs = positional_encoding(dirs_flat, view_enc_L)
# 前向传播
outputs = model(encoded_pts, encoded_dirs)
outputs = outputs.reshape(*pts.shape[:-1], 4)
rgb, sigma = outputs[..., :3], outputs[..., 3]
# 体渲染
deltas = z_vals[..., 1:] - z_vals[..., :-1]
delta_inf = 1e10 * torch.ones_like(deltas[..., :1])
deltas = torch.cat([deltas, delta_inf], -1)
alpha = 1. - torch.exp(-sigma * deltas)
weights = alpha * torch.cumprod(torch.cat([torch.ones_like(alpha[..., :1]), 1.-alpha + 1e-10], -1), -1)[..., :-1]
rgb_map = torch.sum(weights[..., None] * rgb, -2)
return rgb_map
rays_o, rays_d = get_rays(H, W, focal, torch.tensor(poses[0]).cuda())
2、训练循环
使用 MSE 损失函数优化模型:
import torch.optim as optim
optimizer = optim.Adam(model.parameters(), lr=5e-4)
N_iter = 20000 # 迭代次数
batch_size = 4096 # 每批光线数
for i in range(N_iter):
img_idx = np.random.randint(len(images))
target = torch.tensor(images[img_idx]).cuda()
rays_o, rays_d = get_rays(H, W, focal, torch.tensor(poses[img_idx]).cuda())
ray_idx = torch.randperm(H*W)[:batch_size]
rays_o_batch = rays_o.reshape(-1, 3)[ray_idx]
rays_d_batch = rays_d.reshape(-1, 3)[ray_idx]
target_batch = target.reshape(-1, 3)[ray_idx]
rgb_map = render_rays(model, rays_o_batch, rays_d_batch)
loss = ((rgb_map - target_batch) ** 2).mean()
optimizer.zero_grad()
loss.backward()
optimizer.step()
if i % 1000 == 0:
print(f"Iter {i}, Loss: {loss.item():.4f}")
注意:根据数据集场景范围调整 near
和 far
参数。
六、推理与渲染
1、生成新视角图像
渲染新视角图像:
def render_image(model, H, W, focal, pose):
rays_o, rays_d = get_rays(H, W, focal, torch.tensor(pose).cuda())
rgb_map = render_rays(model, rays_o.reshape(-1, 3), rays_d.reshape(-1, 3))
return rgb_map.reshape(H, W, 3).cpu().numpy()
test_pose = poses[len(images)//2] # 示例:中间视角
img = render_image(model, H, W, focal, test_pose)
import matplotlib.pyplot as plt
plt.imshow(img)
plt.axis('off')
plt.savefig('render_result.png')
plt.show()
2、生成视频
插值相机路径生成视频:
import imageio
frames = []
for t in np.linspace(0, 1, 120):
pose = poses[0] * (1-t) + poses[-1] * t # 线性插值
img = render_image(model, H, W, focal, pose)
frames.append((img * 255).astype(np.uint8))
imageio.mimsave('nerf_video.mp4', frames, fps=30)
七、优化与扩展
- 层次采样:使用粗细网络优化采样效率。
- 正则化:添加 L2 正则防止过拟合。
- 动态场景:若数据集包含动态元素,可引入时间维度或变形场。
- 加速:尝试 Instant-NGP 等高效变体。
八、总结
本教程基于 nerf_recon_dataset
实现了 NeRF 三维重建,涵盖环境搭建、数据处理、模型构建和训练推理。完成基础实现后,您可根据数据集特性调整参数,或探索更高级技术。欢迎查阅 ModelScope 文档以优化代码。
延伸阅读
-
AI Agent 系列文章
-
计算机视觉系列文章
-
机器学习核心算法系列文章
-
深度学习系列文章