当前位置: 首页 > article >正文

[论文解析]PVG:Periodic Vibration Gaussian

PVG:Periodic Vibration Gaussian

在这里插入图片描述

项目地址:https://fudan-zvg.github.io/PVG/

论文地址:https://arxiv.org/abs/2311.18561

源码地址:https://github.com/fudan-zvg/PVG


论文总结

周期振动高斯PVG也是3DGS论文中经常提到的一篇论文,也是经常被引用和实验对比的工作,在实验指标上也是比较优秀,在发布的时间能够做到PSNR≥30(当然在如今的三维重建工作中PSNR≥30是常见了);

PVG在vanilla 3DGS的基础上引入了周期和时间属性,对于动态场景表现会比较好,可以处理一些随着时序状态改变的场景建模;但是之前我们提到的DeformableGS也是引入了时间属性,这两篇工作的主要区别在于,DeformableGS是通过时间编码小型网络来影响Gaussian的属性的;PVG则是单刀直入,设计了Gaussian属性对时间的函数,通过自动求导的优化方法完成的;我认为,不严谨的说,DeformableGS是偏学习方法,PVG是偏传统优化方案;

源码精读

模型初始化

还是主要看Gaussian构建这部分,多数方案都是在这里加入新特性,构造自己的Gaussian属性

def create_from_pcd(self, pcd: BasicPointCloud, spatial_lr_scale: float):
		#......省略......
    # 时间戳处理,没有时间的时候随机初始化
    if pcd.time is None or pcd.time.shape[0] != fused_point_cloud.shape[0]:
        if pcd.time is None:
            time = (np.random.rand(pcd.points.shape[0], 1) * 1.2 - 0.1) * (
                    self.time_duration[1] - self.time_duration[0]) + self.time_duration[0]
        else:
            time = pcd.time

        if self.t_init < 1:
            random_times = (torch.rand(fused_point_cloud.shape[0]-pcd.points.shape[0], 1, device="cuda") * 1.2 - 0.1) * (
                    self.time_duration[1] - self.time_duration[0]) + self.time_duration[0]
            pts_times = torch.from_numpy(time.copy()).float().cuda()
            fused_times = torch.cat([pts_times, random_times], dim=0)
        else:
            fused_times = torch.full_like(fused_point_cloud[..., :1],
                                            0.5 * (self.time_duration[1] + self.time_duration[0]))
    else:
		    # SH初始化为时间的一半
        fused_times = torch.from_numpy(np.asarray(pcd.time.copy())).cuda().float()
        fused_times_sh = torch.full_like(pts_sph[..., :1], 0.5 * (self.time_duration[1] + self.time_duration[0]))
        fused_times = torch.cat([fused_times, fused_times_sh], dim=0)

    print("Number of points at initialization : ", fused_point_cloud.shape[0])

    dist2 = torch.clamp_min(distCUDA2(fused_point_cloud), 0.0000001)
    scales = self.scaling_inverse_activation(torch.sqrt(dist2))[..., None].repeat(1, 3)

    rots = torch.zeros((fused_point_cloud.shape[0], 4), device="cuda")
    rots[:, 0] = 1
    
    dist_t = torch.full_like(fused_times, (self.time_duration[1] - self.time_duration[0])*self.t_init)
    # 尺度初始化
    scales_t = self.scaling_t_inverse_activation(torch.sqrt(dist_t))
    # 速度初始化
    velocity = torch.full((fused_point_cloud.shape[0], 3), 0., device="cuda")
    
    opacities = inverse_sigmoid(0.01 * torch.ones((fused_point_cloud.shape[0], 1), dtype=torch.float, device="cuda"))

    self._xyz = nn.Parameter(fused_point_cloud.requires_grad_(True))
    self._features_dc = nn.Parameter(features[:, :, 0:1].transpose(1, 2).contiguous().requires_grad_(True))
    self._features_rest = nn.Parameter(features[:, :, 1:].transpose(1, 2).contiguous().requires_grad_(True))
    self._scaling = nn.Parameter(scales.requires_grad_(True))
    self._rotation = nn.Parameter(rots.requires_grad_(True))
    self._opacity = nn.Parameter(opacities.requires_grad_(True))
    self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda")
    # 新增属性
    # 生命峰值\tau
    self._t = nn.Parameter(fused_times.requires_grad_(True))
    # 根据时间变化的scale
    self._scaling_t = nn.Parameter(scales_t.requires_grad_(True))
    # 速度v
    self._velocity = nn.Parameter(velocity.requires_grad_(True))

模型训练

  • 时间偏移的随机性,这里是文章提到的以0.5的概率对1.5帧进行时间平滑
# training函数
if np.random.random() < args.lambda_self_supervision:
    time_shift = 3*(np.random.random() - 0.5) * scene.time_interval
else:
    time_shift = None
  • 周期和时间属性对mean和scale的影响,在光栅化执行前,将周期属性作用到means和opacity上,
# render函数
if time_shift is not None:
    means3D = pc.get_xyz_SHM(viewpoint_camera.timestamp-time_shift)
    means3D = means3D + pc.get_inst_velocity * time_shift
    marginal_t = pc.get_marginal_t(viewpoint_camera.timestamp-time_shift)
else:
    means3D = pc.get_xyz_SHM(viewpoint_camera.timestamp)
    marginal_t = pc.get_marginal_t(viewpoint_camera.timestamp)
# 对应公式7
opacity = opacity * marginal_t
  • 周期振动函数的实现,在gaussian_model文件中
# 对应论文公式6
def get_xyz_SHM(self, t):
    a = 1/self.T * np.pi * 2
    return self._xyz + self._velocity*torch.sin((t-self._t)*a)/a

# 对应论文公式?
@property
def get_inst_velocity(self):
    return self._velocity*torch.exp(-self.get_scaling_t/self.T/2*self.velocity_decay)

# 对应论文公式7中exp部分
def get_marginal_t(self, timestamp):
    return torch.exp(-0.5 * (self.get_t - timestamp) ** 2 / self.get_scaling_t ** 2)
    
  • 高斯元的密度控制

在vanllia 3DGS基础上,增加了一个周期\tau的平均梯度阈值来控制高斯元的致密化;

高斯元的prune和原版基本一致,此处不再赘述;

# clone函数
if time_clone:
		# 周期阈值限制,梯度的2范数和阈值比较
    selected_time_mask = torch.where(torch.norm(grads_t, dim=-1) >= grad_t_threshold, True, False)
    # 周期的scaling限制,周期的克隆还是遵循尺度小于阈值
    extend_thresh = self.percent_dense
    selected_time_mask = torch.logical_and(selected_time_mask,
                                           torch.max(self.get_scaling_t, dim=1).values <= extend_thresh)
    selected_pts_mask = torch.logical_or(selected_pts_mask, selected_time_mask)

# split函数
if time_split:
    # 梯度直接和阈值比较
    padded_grad_t = torch.zeros((n_init_points), device="cuda")
    padded_grad_t[:grads_t.shape[0]] = grads_t.squeeze()
    selected_time_mask = torch.where(padded_grad_t >= grad_t_threshold, True, False)
    extend_thresh = self.percent_dense
    # 分裂需要尺度大于阈值,此处和原版一致;
    selected_time_mask = torch.logical_and(selected_time_mask,
                                           torch.max(self.get_scaling_t, dim=1).values > extend_thresh)
    if joint_sample:
        selected_pts_mask = torch.logical_or(selected_pts_mask, selected_time_mask)

损失函数的设置

training函数中依据论文设计了多个损失函数

# lidar深度损失
if args.lambda_lidar > 0:
    assert viewpoint_cam.pts_depth is not None
    pts_depth = viewpoint_cam.pts_depth.cuda()

    mask = pts_depth > 0
    loss_lidar =  torch.abs(1 / (pts_depth[mask] + 1e-5) - 1 / (depth[mask] + 1e-5)).mean()
    if args.lidar_decay > 0:
        iter_decay = np.exp(-iteration / 8000 * args.lidar_decay)
    else:
        iter_decay = 1
    log_dict['loss_lidar'] = loss_lidar.item()
    loss += iter_decay * args.lambda_lidar * loss_lidar

# 周期正则化,来自渲染结果中feature,确保同一帧中周期的一致性
if args.lambda_t_reg > 0:
    loss_t_reg = -torch.abs(t_map).mean()
    log_dict['loss_t_reg'] = loss_t_reg.item()
    loss += args.lambda_t_reg * loss_t_reg

# 速度正则化,来自渲染结果中feature,确保速度的平滑性
if args.lambda_v_reg > 0:
    loss_v_reg = torch.abs(v_map).mean()
    log_dict['loss_v_reg'] = loss_v_reg.item()
    loss += args.lambda_v_reg * loss_v_reg

# 逆深度平滑
if args.lambda_inv_depth > 0:
    inverse_depth = 1 / (depth + 1e-5)
    loss_inv_depth = kornia.losses.inverse_depth_smoothness_loss(inverse_depth[None], gt_image[None])
    log_dict['loss_inv_depth'] = loss_inv_depth.item()
    loss = loss + args.lambda_inv_depth * loss_inv_depth

# 速度平滑,避免在物体边缘处出现突然的速度变化
if args.lambda_v_smooth > 0:
    loss_v_smooth = kornia.losses.inverse_depth_smoothness_loss(v_map[None], gt_image[None])
    log_dict['loss_v_smooth'] = loss_v_smooth.item()
    loss = loss + args.lambda_v_smooth * loss_v_smooth

# 天空不透明度,保证天空保持一定的不透明度
if args.lambda_sky_opa > 0:
    o = alpha.clamp(1e-6, 1-1e-6)
    sky = sky_mask.float()
    loss_sky_opa = (-sky * torch.log(1 - o)).mean()
    log_dict['loss_sky_opa'] = loss_sky_opa.item()
    loss = loss + args.lambda_sky_opa * loss_sky_opa

# 不透明度的熵损失,保证不透明度的均匀分布
if args.lambda_opacity_entropy > 0:
    o = alpha.clamp(1e-6, 1 - 1e-6)
    loss_opacity_entropy =  -(o*torch.log(o)).mean()
    log_dict['loss_opacity_entropy'] = loss_opacity_entropy.item()
    loss = loss + args.lambda_opacity_entropy * loss_opacity_entropy

如果你喜欢我的文章欢迎点赞、关注,同时非常欢迎一起探讨交流相关技术,谢谢支持;


http://www.kler.cn/a/578564.html

相关文章:

  • 6. 机器人实现远程遥控(具身智能机器人套件)
  • 深度解析 slabtop:实时监控内核缓存的利器
  • Eclipse Kura:开源的物联网网关框架
  • vue3+antV G6节点与文本框实现
  • JavaScript时间计算函数封装
  • 将长上下文大语言模型研究从输入转向输出
  • wordpress鼠标特效笔记+npm问题解决
  • fps项目二次总结
  • PDFMathTranslate安装使用
  • electron + vue3 + vite 主进程到渲染进程的单向通信
  • MyBatis Mapper 接口的作用,以及如何将 Mapper 接口与 SQL 映射文件关联起来
  • 祝福语【算法赛】
  • Conda常用命令汇总
  • 2020CVPR速读:SiamBAN,用于视觉跟踪的Siamese框自适应网络
  • Ubuntu工控卫士在制造企业中的应用案例
  • Python深度学习算法介绍
  • 聊一聊 Android 的消息机制
  • Aws batch task 无法拉取ECR 镜像unable to pull secrets or registry auth 问题排查
  • 开源项目介绍:Native-LLM-for-Android
  • 大模型量化技术原理总结 [吃果冻不吐果冻皮]