Skip to content

VAE mesh decoder training data #236

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
lan-creator opened this issue Apr 11, 2025 · 0 comments
Open

VAE mesh decoder training data #236

lan-creator opened this issue Apr 11, 2025 · 0 comments

Comments

@lan-creator
Copy link

Hello, I noticed that you use normal map to calculate color loss as follows:

a normal map N_m directly derived from the mesh,

However, I try to find the corresponding code here, and I find you didn't use the rendered mesh gt normal map in structured_latent_vae_mesh_dec.py:line 213, but instead use normal map from dataloader, where it doesn't provide any normal map data. So, I wonder if you indeed not use normal map recon loss in color loss. Thank you for your reply!

def geometry_losses(
self,
reps: List[MeshExtractResult],
mesh: List[Dict],
normal_map: torch.Tensor,
extrinsics: torch.Tensor,
intrinsics: torch.Tensor,
):
with torch.no_grad():
gt_meshes = []
for i in range(len(reps)):
gt_mesh = MeshExtractResult(mesh[i]['vertices'].to(self.device), mesh[i]['faces'].to(self.device))
gt_meshes.append(gt_mesh)
target = self._render_batch(gt_meshes, extrinsics, intrinsics, return_types=['mask', 'depth', 'normal'])
target['normal'] = self._flip_normal(target['normal'], extrinsics, intrinsics)
terms = edict(geo_loss = 0.0)
if self.lambda_tsdf > 0:
tsdf_loss = self._calc_tsdf_loss(reps, target['depth'], extrinsics, intrinsics)
terms['tsdf_loss'] = tsdf_loss
terms['geo_loss'] += tsdf_loss * self.lambda_tsdf
return_types = ['mask', 'depth', 'normal', 'normal_map'] if self.use_color else ['mask', 'depth', 'normal']
buffer = self._render_batch(reps, extrinsics, intrinsics, return_types=return_types)
success_mask = torch.tensor([rep.success for rep in reps], device=self.device)
if success_mask.sum() != 0:
for k, v in buffer.items():
buffer[k] = v[success_mask]
for k, v in target.items():
target[k] = v[success_mask]
terms['mask_loss'] = l1_loss(buffer['mask'], target['mask'])
if self.depth_loss_type == 'l1':
terms['depth_loss'] = l1_loss(buffer['depth'] * target['mask'], target['depth'] * target['mask'])
elif self.depth_loss_type == 'smooth_l1':
terms['depth_loss'] = smooth_l1_loss(buffer['depth'] * target['mask'], target['depth'] * target['mask'], beta=1.0 / (2 * reps[0].res))
else:
raise ValueError(f"Unsupported depth loss type: {self.depth_loss_type}")
terms.update(self._perceptual_loss(buffer['normal'] * target['mask'], target['normal'] * target['mask'], 'normal'))
terms['geo_loss'] = terms['geo_loss'] + terms['mask_loss'] + terms['depth_loss'] * self.lambda_depth + terms['normal_loss_perceptual']
if self.use_color and normal_map is not None:
terms.update(self._perceptual_loss(normal_map[success_mask], buffer['normal_map'], 'normal_map'))
terms['geo_loss'] = terms['geo_loss'] + terms['normal_map_loss_perceptual'] * self.lambda_color
return terms

def collate_fn(batch):
pack = {}
coords = []
for i, b in enumerate(batch):
coords.append(torch.cat([torch.full((b['coords'].shape[0], 1), i, dtype=torch.int32), b['coords']], dim=-1))
coords = torch.cat(coords)
feats = torch.cat([b['feats'] for b in batch])
pack['latents'] = SparseTensor(
coords=coords,
feats=feats,
)
# collate other data
keys = [k for k in batch[0].keys() if k not in ['coords', 'feats']]
for k in keys:
if isinstance(batch[0][k], torch.Tensor):
pack[k] = torch.stack([b[k] for b in batch])
elif isinstance(batch[0][k], list):
pack[k] = sum([b[k] for b in batch], [])
else:
pack[k] = [b[k] for b in batch]
return pack

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant