-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcreate_cond_ds.py
More file actions
264 lines (227 loc) · 11 KB
/
create_cond_ds.py
File metadata and controls
264 lines (227 loc) · 11 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
# Based on demo.ipynb from https://github.com/PKU-EPIC/GAPartNet commit c8d4ad2
# Orignial code licensed under CC BY-NC 4.0 https://creativecommons.org/licenses/by-nc/4.0/
# Modifications Copyright (c) 2025 University of Augsburg (Author: Jens Kreber), licensed under CC BY-NC 4.0 https://creativecommons.org/licenses/by-nc/4.0/
import argparse
from collections import defaultdict
import glob
import sys, os, os.path as osp
from matplotlib import pyplot as plt
from matplotlib.colors import Normalize
import json
import random
import numpy as np
from PIL import Image
sys.path.append('GAPartNet')
from GAPartNet.structure.gapartnet import ObjIns
from GAPartNet.structure.gapartnet import save_point_cloud_to_ply
from GAPartNet.structure.utils import read_pcs_from_ply
from GAPartNet.dataset.render_tools.utils.config_utils import BACKGROUND_RGB
from GAPartNet.dataset.render_tools.utils.read_utils import read_joints_from_urdf_file
from GAPartNet.dataset.render_tools.utils.render_utils import get_cam_pos_fix, set_all_scene, render_rgb_image, render_depth_map, add_background_color_for_image, get_camera_pos_mat
PARTNET_DATASET_PATH = "data/partnet-mobility-v0/dataset"
N_POINTS = 1000
FIXED_CAM_POSE = (60, 180, 4)
def get_pc(model_id, category, rng: np.random.Generator, height, width, fixed_cam=False, base_path=None, add_name='', use_gt_pc=False):
model_path = os.path.join(PARTNET_DATASET_PATH, str(model_id))
cam_range = {
'theta_min': 30.0,
'theta_max': 80.0,
'phi_min': 120.0,
'phi_max': 240.0,
'distance_min': 1, # smallest value from GAPartNet: 3
'distance_max': 6, # biggest value in GAPartNet: 5.5
}
if fixed_cam:
cam_pose = FIXED_CAM_POSE
else:
cam_pose = rng.uniform([cam_range['theta_min'], cam_range['phi_min'], cam_range['distance_min']], [cam_range['theta_max'], cam_range['phi_max'], cam_range['distance_max']])
camera_pos = get_cam_pos_fix( # theta, phi, distance
cam_pose[0], cam_pose[1], cam_pose[2],
)
joints_dict = read_joints_from_urdf_file(model_path, 'mobility.urdf')
joint_qpos = {}
for joint_name in joints_dict:
joint_type = joints_dict[joint_name]['type']
if joint_type == 'prismatic' or joint_type == 'revolute':
joint_limit = joints_dict[joint_name]['limit']
joint_qpos[joint_name] = np.zeros_like(joint_limit[0])
elif joint_type == 'fixed':
joint_qpos[joint_name] = 0.0 # ! the qpos of fixed joint must be 0.0
elif joint_type == 'continuous':
joint_qpos[joint_name] = 0
else:
raise ValueError(f'Unknown joint type {joint_type}')
scene, camera, engine, robot = set_all_scene(data_path=model_path, urdf_file='mobility.urdf', cam_pos=camera_pos, width=width, height=height, use_raytracing=False, joint_qpos_dict=joint_qpos)
rgb_image = render_rgb_image(camera=camera)
depth_map = render_depth_map(camera=camera)
available_link_names = []
vis_id_to_link_name = {}
for articulation in scene.get_all_articulations():
for link in articulation.get_links():
link_name = link.get_name()
available_link_names.append(link_name)
for visual in link.get_visual_bodies():
visual_id = visual.get_visual_id()
vis_id_to_link_name[visual_id] = link_name
seg_labels = camera.get_uint32_texture("Segmentation")
seg_labels_by_visual_id = seg_labels[..., 0].astype(np.uint16) # H x W, save each pixel's visual id
height, width = seg_labels_by_visual_id.shape
ins_seg_map = np.ones((height, width), dtype=np.int32) * (-2)
valid_linkName_to_instId_mapping = {}
part_ins_cnt = 0
for link_name in available_link_names:
mask = np.zeros((height, width), dtype=np.int32)
for vis_id in vis_id_to_link_name.keys():
if vis_id_to_link_name[vis_id] == link_name:
mask += (seg_labels_by_visual_id == vis_id).astype(np.int32)
area = int(sum(sum(mask > 0)))
if area == 0:
continue
ins_seg_map[mask > 0] = part_ins_cnt
valid_linkName_to_instId_mapping[link_name] = part_ins_cnt
part_ins_cnt += 1
eps = 1e-6
empty_mask = abs(depth_map) < eps
ins_seg_map[empty_mask] = -1 # do not use -2 as bg
n_other = (ins_seg_map == -2).sum()
if n_other > 0:
print('Warning: found others in segmentation!', n_other)
ins_seg_map[ins_seg_map == -2] = -1
cmap = plt.get_cmap('tab10')
colors = np.array([cmap(i) for i in range(part_ins_cnt)] + [[0,0,0,0.]]) # -1 is transparent
seg_image = (colors[ins_seg_map] * 255).astype(np.uint8)
Image.fromarray(seg_image).save(osp.join(base_path, 'seg', add_name + '.png'))
# depth
dcmap = plt.get_cmap('plasma')
norm = Normalize(3., 5.)
depth_normed = norm(depth_map)
dcolors = dcmap(depth_normed)
dcolors[empty_mask] = 0
Image.fromarray((dcolors * 255).astype(np.uint8)).save(osp.join(base_path, 'depth', add_name + '.png'))
camera_intrinsic, world2camera_rotation, camera2world_translation = get_camera_pos_mat(camera)
rgb_image = add_background_color_for_image(rgb_image, depth_map, BACKGROUND_RGB)
Image.fromarray(rgb_image).save(osp.join(base_path, 'rgb', add_name + '.png'))
# filtering
arg_where_content = np.argwhere(~ empty_mask)
min_height, min_width = np.min(arg_where_content, axis=0) # the idx axis
max_height, max_width = np.max(arg_where_content, axis=0) # the idx axis
frac_height, frac_width = (max_height - min_height) / (height-1), (max_width - min_width) / (width-1)
if min_height == 0 or max_height == (height-1) or min_width == 0 or max_width == (width-1):
print('rejecting sample for model id', model_id, 'since boundary edges are hit!', min_height, max_height, min_width, max_width)
return False, None
elif frac_height < 0.5 and frac_width < 0.5:
print('rejecting sample for model id', model_id, 'because not enough of the image was covered!', frac_height, frac_width)
return False, None
obj = ObjIns(
name = '_unused_',
cate = category,
image = rgb_image,
depth = depth_map,
K = camera_intrinsic,
world2camera_rotation = world2camera_rotation,
camera2world_translation = camera2world_translation,
image_reso = (rgb_image.shape[0], rgb_image.shape[1]),
)
if use_gt_pc:
points, obj.pcs_all_rgb = read_pcs_from_ply(osp.join(model_path, 'point_sample', 'ply-10000.ply'))
points[:, 2] = -points[:,2]
points[:, 1] = -points[:,1]
obj.pcs_all_xyz = points
else:
obj.get_pc()
obj.get_downsampled_pc(N_POINTS)
pc_cam = obj.pcs_xyz
pc_world = (world2camera_rotation[None,:,:] @ pc_cam[:,:,None])[:,:,0] + camera2world_translation
pc_world = np.stack([-pc_world[:,1], pc_world[:,2], -pc_world[:,0]], axis=1) # transform
os.makedirs(osp.join(base_path, 'pc'), exist_ok=True)
save_point_cloud_to_ply(pc_world, obj.pcs_rgb * 255, osp.join(base_path, 'pc', add_name + '.ply'))
return True, pc_world
SPLIT_FILE = 'resource/partnet_m_split.json'
def run():
parser = argparse.ArgumentParser()
parser.add_argument('--limit_n_models', type=int, default=None, help='limit number of used models')
parser.add_argument('--n_per_model', type=int, default=5, help='num cam poses per object')
parser.add_argument('--seed', type=int, default=42)
parser.add_argument('--dir', type=str, required=True)
parser.add_argument('--use_gt_pc', action=argparse.BooleanOptionalAction, default=False)
parser.add_argument('--split', type=str, default='val')
parser.add_argument('--cats', type=str, default=None)
conf = parser.parse_args()
rng = np.random.default_rng(conf.seed)
random.seed(conf.seed) # for pcl subsampling
if os.path.exists(conf.dir):
print('Warning, path already exists!')
os.makedirs(conf.dir, exist_ok=True)
os.makedirs(osp.join(conf.dir, 'rgb'), exist_ok=True)
os.makedirs(osp.join(conf.dir, 'seg'), exist_ok=True)
os.makedirs(osp.join(conf.dir, 'depth'), exist_ok=True)
if conf.cats is not None:
filtered_cats = conf.cats.strip().split(',')
split_names = ['train', 'val', 'test']
splits = {split: set() for split in split_names}
with open(SPLIT_FILE, 'r') as f:
split_data = json.load(f)
model_to_cat = {}
orig_cat_counts = defaultdict(int)
for cat, cat_splits in split_data.items():
if conf.cats is not None and cat not in filtered_cats:
continue
for split in split_names:
this_model_ids = [int(q) for q in cat_splits[split]]
splits[split].update(this_model_ids)
for model_id in this_model_ids:
model_to_cat[model_id] = cat
orig_cat_counts[cat] += 1
the_split = splits[conf.split]
# model ids that are actually used in the GT sets
gt_model_paths = glob.glob(f'log/GT{conf.split}/PCL/*.npz')
gt_model_ids = set(int(os.path.basename(p)[:-4]) for p in gt_model_paths)
print('the gt dir contains', len(gt_model_ids), 'model ids')
if not len(gt_model_ids - the_split) == 0:
raise ValueError()
the_split = gt_model_ids
model_ids = np.array(list(the_split))
model_ids = rng.permutation(model_ids)
n_models = len(model_ids) if conf.limit_n_models is None else conf.limit_n_models
print('using', n_models, 'out of', len(model_ids), 'models available')
pcs = []
cats = []
used_model_ids = []
i_total = 0
i_model = 0
n_target_total = n_models * conf.n_per_model
while i_total < n_target_total:
model_id = model_ids[i_model]
cat = model_to_cat[model_id]
i_cam = 0
i_tries_model = 0
pcs_this_model = []
while i_cam < conf.n_per_model:
succ, pc = get_pc(model_id, cat, rng, height=800, width=800, fixed_cam=False, base_path=conf.dir, add_name=f"{i_total + i_cam:04}", use_gt_pc=conf.use_gt_pc)
if not succ:
i_tries_model += 1
if i_tries_model >= 20:
if not conf.limit_n_models:
raise RuntimeError(f"Failed sampling for model id {model_id} after {i_tries_model} tries!")
else:
print(f"Failed sampling for model id {model_id} after {i_tries_model} tries!")
break
continue
pcs_this_model.append(pc.astype(np.float32))
i_cam += 1
if len(pcs_this_model) == conf.n_per_model:
pcs.extend(pcs_this_model)
cats.extend([cat] * conf.n_per_model)
used_model_ids.extend([model_id] * conf.n_per_model)
i_total += conf.n_per_model
i_model += 1
pcs = np.stack(pcs)
np.save(osp.join(conf.dir, 'pcs.npy'), pcs)
with open(osp.join(conf.dir, 'categories.txt'), 'w') as f:
for cat in cats:
f.write(cat + '\n')
with open(osp.join(conf.dir, 'model_ids.txt'), 'w') as f:
for model_id in used_model_ids:
f.write(f"{model_id}\n")
if __name__ == '__main__':
run()