【计算机图形学】3DIT的训练数据总结
3D Implicit Transporter用的是PartNet-Mobility数据集,我想用Shape2Motion数据集,但是3DIT是所有类别的数据扔一起训练的,为了避免到时候弄3DIT数据忘记了,我在这里记录一些点云数据训练的东西。方便之后用3DIT来训练BaseLine
1. 训练文件
./core/datasets/splits/train_partnet.txt
可以看出,所有数据是混合训练的,具体每个数据的意思说明如下:./core/datasets/train_dataset.py
中:obj_name, articulate_type, instance_name, name_index = self.filenames[index].split(' ')
,可以看出每一行的意思是:物体类别名,铰接类型(R为旋转T为平移),实例名称,第x帧点云
2.点云数据分析
- 读取点云文件后,pc的shape为
[31,]
,可以理解为对一个物体采31帧 - 对于
pc[i]
,其shape为[n,4]
,n
是采点的数量,前3维是物体的(x,y,z)
坐标,最后一维尚不明确
3. train_dataset.py,训练文件
- 注意对于合成数据,看的类是
class Articulated_Obj_Syn(BaseDataset)
- 对于
BaseDataset
,主要注意的就是定义了:
self.on_occupancy_num = config_public_params["on_occupancy_num"] # 表面点
self.off_occupancy_num = config_public_params["off_occupancy_num"] # 离面点
__getitem__
操作第一步是get_input
,首先基于一个中间索引(middle)
生成前后随机两帧的索引start, end
,加载了该物体点云后,通过这两帧随机索引,获取pc[start]
和pc[end]
,实际上就是起始帧的点云和结束帧的点云
# 返回这个点云的pre、middle、last三个点云的点坐标
# 物体名称、实例名称、铰接物体类型
return data[name_index_1][:, :3].astype(np.float), \
data[name_index][:, :3].astype(np.float), \
data[name_index_2][:, :3].astype(np.float), \
None, None, \
obj_name, instance_name, articulate_type
- 第二步是
prepare_input_data()
,主要就是根据config中的指定数量来采样点云
# sample
# print(f"self.max_down_sample:{self.max_down_sample}") # 1
pc_start = self.prepare_input_data(pc_start_ori, self.max_down_sample)
pc_middle = self.prepare_input_data(pc_middle_ori, 1)
pc_end = self.prepare_input_data(pc_end_ori, self.max_down_sample)
- 第三步是数据增强,主要是做旋转、扰动,能处理旋转是因为在Pipeline里使用PointNet&PointNet++做了处理,即使旋转了也没关系
if self.config_aug["do_aug"]: # 默认为True
# rotate
z_angle = np.random.uniform() * self.config_aug["rotate_angle"] / 180.0 * (np.pi)
angles_2d = [0, 0, z_angle]
# pre,middle,final在z轴随机旋转
pc_start = atomic_rotate(pc_start, angles_2d)
pc_middle = atomic_rotate(pc_middle, angles_2d) # 被裁后的点云
pc_middle_ori = atomic_rotate(pc_middle_ori, angles_2d) # 原始点云
pc_end = atomic_rotate(pc_end, angles_2d)
# jitter (Gaussian noise) -> 高斯噪声随机扰动
sigma, clip = self.config_aug["sigma"], self.config_aug["clip"]
jitter1 = np.clip(sigma * np.random.uniform(pc_start.shape[0], 3), -1 * clip, clip)
jitter2 = np.clip(sigma * np.random.uniform(pc_middle.shape[0], 3), -1 * clip, clip)
jitter3 = np.clip(sigma * np.random.uniform(pc_end.shape[0], 3), -1 * clip, clip)
pc_start += jitter1
pc_middle += jitter2
pc_end += jitter3
- 第四步对点云数据归一化、缩放
# normalize
bound_max = np.maximum(pc_start.max(0), pc_middle.max(0), pc_end.max(0)) # 最大点云
bound_min = np.minimum(pc_start.min(0), pc_middle.min(0), pc_end.min(0)) # 最小点云
center = (bound_min + bound_max) / 2 # 求中间值?
scale = (bound_max - bound_min).max() # / (1 + self.padding)
# 缩放
pc_start = (pc_start - center) / scale
pc_middle = (pc_middle - center) / scale
pc_middle_ori = (pc_middle_ori - center) / scale
pc_end = (pc_end - center) / scale
- 第五步是生成查询点云的occupancy label,就是表面点取1,然后随机采样config里指定的里面点为0
occup_coords, occup_labels = self.prepare_occupancy_data(pc_middle_ori, self.sampling_mode)
↓↓↓
def prepare_occupancy_data(self, pcd_data, sampling_mode='random'):
# print(f"sampling_mode:{sampling_mode}") # random
can_repeat = True if self.on_occupancy_num > pcd_data.shape[0] else False
# 采一些表面点
rand_idcs_on = np.random.choice(pcd_data.shape[0],
size=self.on_occupancy_num,
replace=can_repeat)
on_surface_coords = pcd_data[rand_idcs_on]
# 表面点的occupancy label置True
on_surface_labels = np.ones(self.on_occupancy_num)
if sampling_mode == 'random':
# 在bound范围内采离心点~
# self.bound默认是0.5,相当于是在边长为1的立方体内做采样
off_surface_x = np.random.uniform(-self.bound, self.bound,
size=(self.off_occupancy_num, 1))
off_surface_y = np.random.uniform(-self.bound, self.bound,
size=(self.off_occupancy_num, 1))
off_surface_z = np.random.uniform(-self.bound, self.bound,
size=(self.off_occupancy_num, 1))
off_surface_coords = np.concatenate((off_surface_x, off_surface_y, off_surface_z),
axis=1)
# 非表面点置false
off_surface_labels = np.zeros(self.off_occupancy_num)
else:
grid = make_3d_grid([-0.5, 0.5, 15], [-0.5, 0.5, 15], [-0.5, 0.5, 15])
rand_idcs_on = np.random.choice(grid.shape[0],
size=self.off_occupancy_num)
off_surface_coords = grid[rand_idcs_on]
off_surface_labels = np.zeros(self.off_occupancy_num)
coords = np.concatenate((on_surface_coords, off_surface_coords), axis=0)
labels = np.concatenate((on_surface_labels, off_surface_labels), axis=0)
# 打乱输入点云
rix = np.random.permutation(coords.shape[0])
coords = coords[rix]
labels = labels[rix]
return coords, labels
10.第六步返回结果
4. 采样点云的数据范围
点云xyz应该是在[-1, +1]之间