[论文解读]Street Gaussians: Modeling Dynamic Urban Scenes with Gaussian Splatting
Street Gaussians是年初的一篇动态场景重建论文, 在当时是做到了SOTA,至今为止很多自动驾驶或者动态场景重建的文章都会将Street Gaussians作为实验的比较对象,这也表明了这篇文章的重要性,今天就一起来看看这篇文章;
论文地址:[2401.01339] Street Gaussians: Modeling Dynamic Urban Scenes with Gaussian Splatting
源码链接:zju3dv/street_gaussians: [ECCV 2024] Street Gaussians: Modeling Dynamic Urban Scenes with Gaussian Splatting
项目主页:Street Gaussians
论文解读
这篇论文《Street Gaussians: Modeling Dynamic Urban Scenes with Gaussian Splatting》提出了一种名为Street Gaussians的新方法,用于解决自动驾驶场景中动态城市街道建模问题,该方法利用点云构建动态场景,显著提高了训练和渲染效率,实现了高效的实时渲染和高质量的视图合成。
- 研究背景与目的
- 背景:从图像建模动态3D街道在城市模拟、自动驾驶和游戏等领域有重要应用,但现有方法如基于NeRF的方法在处理大规模场景时存在训练时间长、无法处理动态车辆或实时渲染效果差等问题。
- 目的:提出一种新的场景表示方法Street Gaussians,能够高效重建和实时渲染高保真的动态城市街道场景。
- 方法原理
- Street Gaussians表示
- 背景模型:用世界坐标系中的点集表示,每个点带有3D高斯参数(协方差矩阵、位置向量)、透明度、球谐系数和语义logit,协方差矩阵由缩放矩阵和旋转矩阵恢复。
- 对象模型:每个前景车辆对象由一组可优化的跟踪车辆姿态和点云表示,点云的高斯属性与背景相似但坐标系不同,位置和旋转在对象局部坐标系,通过跟踪姿态转换到世界坐标系;外观模型使用4D球谐模型,用傅里叶变换系数编码时间信息以降低存储成本;语义表示为可学习的一维标量。
- 初始化:背景模型使用聚合LiDAR点云初始化,颜色通过投影到图像平面获取,对象模型收集3D边界框内的点并转换到局部坐标系,点数不足则随机采样,背景模型还进行体素下采样并结合SfM点云补偿LiDAR覆盖不足。
- 渲染过程:通过将所有点云投影到2D图像空间来渲染,先计算球谐函数,根据跟踪姿态转换对象点云到世界坐标系,连接背景和对象点云,计算2D高斯参数,用α - 混合计算像素颜色,还可渲染深度、语义等信号,天空区域使用高分辨率立方体贴图映射天空颜色并与渲染颜色混合。
- 训练方法
- 跟踪姿态优化:将跟踪姿态视为可学习参数,在转换矩阵中添加可学习的变换,直接获取梯度,避免额外计算。
- 损失函数:联合优化场景表示、天空立方体贴图和跟踪姿态,损失函数包括颜色重建损失、深度损失、天空监督损失、语义损失和正则化项,用于去除浮动和增强分解效果。
- Street Gaussians表示
- 实验结果
- 实验设置
- 数据集:在Waymo Open和KITTI数据集上进行实验,Waymo数据集选择8个序列,每4帧选1帧为测试帧,输入图像降为1066×1600,KITTI和Virtual KITTI 2数据集遵循MARS的设置。
- 基线方法:与NSG、MARS、3D Gaussians、EmerNeRF比较,前两者用真实对象轨迹训练和评估,后两者用官方代码和特定设置运行。
- 对比结果:在渲染质量和速度上与基线方法比较,采用PSNR、SSIM和LPIPS评估,对移动对象计算PSNR*,模型在所有指标上表现最佳,渲染速度比基于NeRF的方法快两个数量级,定性结果显示其他方法存在模糊、失真或伪影,而该方法能生成高质量视图。
- 消融实验:验证算法设计选择,优化跟踪姿态可提高质量,4D球谐模型能细化渲染质量,尤其是在对象与环境光照交互时,结合LiDAR点云可增强结果,恢复更准确的场景几何形状。
- 应用展示:可应用于场景编辑(如车辆旋转、平移、交换)、对象分解(生成高质量分解结果)和语义分割(语义图性能优于Video - K - Net)。
- 实验设置
- 结论与展望
- 研究成果总结:提出Street Gaussians场景表示,将背景和前景车辆分别建模为点云,可实现场景编辑和实时渲染,性能与使用精确真实姿态相当,通过实验验证了方法的有效性。
- 研究局限与未来方向:方法限于重建刚性动态场景,依赖现成跟踪器的召回率,仍需逐场景优化;未来可考虑采用更复杂动态场景建模方法处理非刚性对象,通过2D跟踪获取连续轨迹改善跟踪问题,探索前馈方式预测通用3D高斯。
代码解读
初始化
- 场景初始化从点云开始,
- 需要对每个移动障碍物进行track,得到每个障碍物在每个相机下的轨迹。
模型设计
整个模型的结构可以参考下图或者**StreetGaussianModel
**的setup_functions
函数,主要包括背景,前景和天空三类;
def setup_functions(self):
obj_tracklets = self.metadata['obj_tracklets']
obj_info = self.metadata['obj_meta']
tracklet_timestamps = self.metadata['tracklet_timestamps']
camera_timestamps = self.metadata['camera_timestamps']
self.model_name_id = bidict()
self.obj_list = []
self.models_num = 0
self.obj_info = obj_info
# Build background model
if self.include_background:
self.background = GaussianModelBkgd(
model_name='background',
scene_center=self.metadata['scene_center'],
scene_radius=self.metadata['scene_radius'],
sphere_center=self.metadata['sphere_center'],
sphere_radius=self.metadata['sphere_radius'],
)
self.model_name_id['background'] = 0
self.models_num += 1
# Build object model
if self.include_obj:
for track_id, obj_meta in self.obj_info.items():
model_name = f'obj_{track_id:03d}'
setattr(self, model_name, GaussianModelActor(model_name=model_name, obj_meta=obj_meta))
self.model_name_id[model_name] = self.models_num
self.obj_list.append(model_name)
self.models_num += 1
# Build sky model
if self.include_sky:
self.sky_cubemap = SkyCubeMap()
else:
self.sky_cubemap = None
# Build actor model 动态物体的pose优化
if self.include_obj:
self.actor_pose = ActorPose(obj_tracklets, tracklet_timestamps, camera_timestamps, obj_info)
else:
self.actor_pose = None
# Build color correction 未启用,类似于曝光补偿
if self.use_color_correction:
self.color_correction = ColorCorrection(self.metadata)
else:
self.color_correction = None
# Build pose correction, 未启用,位姿优化
if self.use_pose_correction:
self.pose_correction = PoseCorrection(self.metadata)
else:
self.pose_correction = None
静态背景
静态背景的表示还是基础的3DGS来进行建模,在基础的3DGS属性上面增加了一个语义(semantic)属性,为了后面的3D语义特征构建(本文暂未使用,也许是长线计划);
模型代码在**GaussianModelBkgd
类中,继承自GaussianModel
**类,和作者论文中一致,这是一个基础的高斯类,这里不再详细展开
动态前景
对每一个动态物体都构建一个高斯模型,模型代码在**GaussianModelActor
**
- 傅里叶变换
def IDFT(time, dim):
if isinstance(time, float):
time = torch.tensor(time)
t = time.view(-1, 1).float()
idft = torch.zeros(t.shape[0], dim)
indices = torch.arange(dim)
even_indices = indices[::2]
odd_indices = indices[1::2]
idft[:, even_indices] = torch.cos(torch.pi * t * even_indices)
idft[:, odd_indices] = torch.sin(torch.pi * t * (odd_indices + 1))
return idft
# 这里在获取颜色的时候调用
def get_features_fourier(self, frame=0):
normalized_frame = (frame - self.start_frame) / (self.end_frame - self.start_frame)
time = self.fourier_scale * normalized_frame
idft_base = IDFT(time, self.fourier_dim)[0].cuda()
features_dc = self._features_dc # [N, C, 3]
features_dc = torch.sum(features_dc * idft_base[..., None], dim=1, keepdim=True) # [N, 1, 3]
features_rest = self._features_rest # [N, sh, 3]
features = torch.cat([features_dc, features_rest], dim=1) # [N, (sh + 1) * C, 3]
return features
- 动态物体的位姿优化,开启
opt_track
的选项后在**ActorPose
**模块中会开始动态物体的位姿优化,在训练过程中优化动态物体的平移和旋转两个参数;在OmniRe的论文中也有,这也是目前动态3DGS重建的主流方向,不过动态物体的位姿处理也有很多方法,可以放在数据前处理阶段用传统的slam等方式,但是在训练中优化的方法都认为可以获得更好的重建效果;
self.opt_track = cfg.model.nsg.opt_track
if self.opt_track:
self.opt_trans = nn.Parameter(torch.zeros_like(self.input_trans)).requires_grad_(True)
# [num_frames, max_obj, [dx, dy, dz]]
self.opt_rots = nn.Parameter(torch.zeros_like(self.input_rots[..., :1])).requires_grad_(True)
# [num_frames, max_obj, [dtheta]
- 由于object的位姿和背景是不同的,而且opt track的作用也会优化位姿,所以论文对obj的gaussian做了一些刚体变换,也就是公式2;对应
StreetGaussianModel
的get_xyz
和get_rotation
两个函数;这里只贴上核心处理部分
def get_xyz(self):
......
if len(self.graph_obj_list) > 0:
xyzs_local = []
for i, obj_name in enumerate(self.graph_obj_list):
obj_model: GaussianModelActor = getattr(self, obj_name)
xyz_local = obj_model.get_xyz
xyzs_local.append(xyz_local)
xyzs_local = torch.cat(xyzs_local, dim=0)
xyzs_local = xyzs_local.clone()
# 这里的filp是借鉴的AutoSplat论文中的对称建模的方法,行人会关闭
xyzs_local[self.flip_mask, self.flip_axis] *= -1
obj_rots = quaternion_to_matrix(self.obj_rots)
xyzs_obj = torch.einsum('bij, bj -> bi', obj_rots, xyzs_local) + self.obj_trans
xyzs.append(xyzs_obj)
xyzs = torch.cat(xyzs, dim=0)
return xyzs
def get_rotation(self):
......
if len(self.graph_obj_list) > 0:
rotations_local = []
for i, obj_name in enumerate(self.graph_obj_list):
obj_model: GaussianModelActor = getattr(self, obj_name)
rotation_local = obj_model.get_rotation
rotations_local.append(rotation_local)
rotations_local = torch.cat(rotations_local, dim=0)
rotations_local = rotations_local.clone()
rotations_local[self.flip_mask] = quaternion_raw_multiply(self.flip_matrix, rotations_local[self.flip_mask])
rotations_obj = quaternion_raw_multiply(self.obj_rots, rotations_local)
rotations_obj = torch.nn.functional.normalize(rotations_obj)
rotations.append(rotations_obj)
rotations = torch.cat(rotations, dim=0)
return rotations
loss设计
这部分代码在train.py
的training
函数,写的很清楚,就不再贴代码了;这几个loss项和regulation也是目前三维重建中常见的设计;
这里对代码里面加入的几个loss但是论文中没有提及的,做个讨论;
- 颜色修正:颜色的正则项是在对相机的仿射变换矩阵进行修正,感觉有点像appearance embedding的操作,进行一定的曝光补偿;
# color correction loss,
if optim_args.lambda_color_correction > 0 and gaussians.use_color_correction:
color_correction_reg_loss = gaussians.color_correction.regularization_loss(viewpoint_cam)
scalar_dict['color_correction_reg_loss'] = color_correction_reg_loss.item()
loss += optim_args.lambda_color_correction * color_correction_reg_loss
- 位置loss:这部分是需要开启位置校正的情况下,在训练中修正每个相机到world系的位姿,这里的正则是将位姿的修正值最小化;
# pose correction loss
if optim_args.lambda_pose_correction > 0 and gaussians.use_pose_correction:
pose_correction_reg_loss = gaussians.pose_correction.regularization_loss()
scalar_dict['pose_correction_reg_loss'] = pose_correction_reg_loss.item()
loss += optim_args.lambda_pose_correction * pose_correction_reg_loss
- scale扁平化,主轴的平均绝对值不要逼近0,同时保证s2和s3两个比例接近,避免场景扭曲;
def scale_flatten_loss(self):
scales = self.get_scaling
sorted_scales = torch.sort(scales, dim=1, descending=False).values
s1, s2, s3 = sorted_scales[:, 0], sorted_scales[:, 1], sorted_scales[:, 2]
s1 = torch.clamp(s1, 0, 30)
s2 = torch.clamp(s2, 1e-5, 30)
s3 = torch.clamp(s3, 1e-5, 30)
scale_flatten_loss = torch.abs(s1).mean()
scale_flatten_loss += torch.abs(s2 / s3 + s3 / s2 - 2.).mean()
return scale_flatten_loss
# scale flatten loss
if optim_args.lambda_scale_flatten > 0:
scale_flatten_loss = gaussians.background.scale_flatten_loss()
scalar_dict['scale_flatten_loss'] = scale_flatten_loss.item()
loss += optim_args.lambda_scale_flatten * scale_flatten_loss
- 不透明度loss,利用二维交叉熵损失鼓励不透明度趋近于0或者1,避免中间状态
# opacity sparse loss
if optim_args.lambda_opacity_sparse > 0:
opacity = gaussians.get_opacity
opacity = opacity.clamp(1e-6, 1-1e-6)
log_opacity = opacity * torch.log(opacity)
log_one_minus_opacity = (1-opacity) * torch.log(1 - opacity)
sparse_loss = -1 * (log_opacity + log_one_minus_opacity)[visibility_filter].mean()
scalar_dict['opacity_sparse_loss'] = sparse_loss.item()
loss += optim_args.lambda_opacity_sparse * sparse_loss
- 法线监督:和GaussianPro文章中提到的法线监督很像,估计也是借鉴了一下,这种法线针对于路面或者深度确定的平面应该效果很好,这个真值比较重要,需要置信度比较高的数据,这个其实和文章中的lidar depth应该关系比较紧密,因为lidar的深度数据是比较可信的,如果能够用这种数据计算法线进行监督应该比较好的方法;
# normal loss
if optim_args.lambda_normal_mono > 0 and 'mono_normal' in viewpoint_cam.meta and 'normals' in render_pkg:
if sky_mask is None:
normal_mask = mask
else:
normal_mask = torch.logical_and(mask, ~sky_mask)
normal_mask = normal_mask.squeeze(0)
normal_mask[:50] = False
normal_gt = viewpoint_cam.meta['mono_normal'].permute(1, 2, 0).cuda() # [H, W, 3]
R_c2w = viewpoint_cam.world_view_transform[:3, :3]
normal_gt = torch.matmul(normal_gt, R_c2w.T) # to world space
normal_pred = render_pkg['normals'].permute(1, 2, 0) # [H, W, 3]
normal_l1_loss = torch.abs(normal_pred[normal_mask] - normal_gt[normal_mask]).mean()
normal_cos_loss = (1. - torch.sum(normal_pred[normal_mask] * normal_gt[normal_mask], dim=-1)).mean()
scalar_dict['normal_l1_loss'] = normal_l1_loss.item()
scalar_dict['normal_cos_loss'] = normal_cos_loss.item()
normal_loss = normal_l1_loss + normal_cos_loss
loss += optim_args.lambda_normal_mono * normal_loss
另外有LightWheel实验室对整个算法进行了重构,采用nerfstudio的框架,链接放在此处:https://github.com/LightwheelAI/street-gaussians-ns,感兴趣的同学可以看看;