Skip to content

render image with some constant colors that described in OBJ-file #1109

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Nan2018 opened this issue Mar 9, 2022 · 9 comments
Closed

render image with some constant colors that described in OBJ-file #1109

Nan2018 opened this issue Mar 9, 2022 · 9 comments
Assignees
Labels

Comments

@Nan2018
Copy link

Nan2018 commented Mar 9, 2022

❓ Questions on how to use PyTorch3D

basically the same question asked in #280, but asking for a concrete example

I was able to create the Materials object from aux.material_colors. But passing it to shader causes a RuntimeError when rendering. (Same error with renderer(materials=materials)). My feeling is faces.materials_idx is missing in the pipeline. As in this comment, faces.materials_idx indexes into the material properties for each face, but I couldn't find a place to pass it to the renderer.

here is how I create materials

ambient_color = []
diffuse_color = []
specular_color = []
shininess = []
for colors in aux.material_colors.values():
    ambient_color.append(colors['ambient_color'])
    diffuse_color.append(colors['diffuse_color'])
    specular_color.append(colors['specular_color'])
    shininess.append(colors['shininess'])
                     
ambient_color = torch.stack(ambient_color)
diffuse_color = torch.stack(diffuse_color)
specular_color = torch.stack(specular_color)
shininess = torch.stack(shininess)
print(ambient_color.shape, diffuse_color.shape, specular_color.shape, shininess.shape)
mat = Materials(ambient_color, diffuse_color, specular_color, shininess.squeeze(), device=device)

here is how I render an image

renderer = MeshRenderer(
    rasterizer=MeshRasterizer(
        cameras=cameras, 
        raster_settings=raster_settings
    ),
    shader=SoftPhongShader(
        device=device, 
        cameras=cameras,
        lights=lights,
        materials=mat
    )
)
images = renderer(mesh, lights=lights, cameras=cameras)

and here is the RuntimeError

RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.
@bottler
Copy link
Contributor

bottler commented Mar 9, 2022

Can you give the whole stack trace?

@Nan2018
Copy link
Author

Nan2018 commented Mar 10, 2022

full trace back

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
/tmp/ipykernel_2518/3614903963.py in <module>
     14 
     15 # Re render the mesh, passing in keyword arguments for the modified components.
---> 16 images = renderer(mesh, lights=lights, cameras=cameras)
     17 
     18 plt.figure(figsize=(10, 10))

/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1100         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1101                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102             return forward_call(*input, **kwargs)
   1103         # Do not call functions when jit is used
   1104         full_backward_hooks, non_full_backward_hooks = [], []

/opt/conda/lib/python3.7/site-packages/pytorch3d/renderer/mesh/renderer.py in forward(self, meshes_world, **kwargs)
     58         """
     59         fragments = self.rasterizer(meshes_world, **kwargs)
---> 60         images = self.shader(fragments, meshes_world, **kwargs)
     61 
     62         return images

/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1100         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1101                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102             return forward_call(*input, **kwargs)
   1103         # Do not call functions when jit is used
   1104         full_backward_hooks, non_full_backward_hooks = [], []

/opt/conda/lib/python3.7/site-packages/pytorch3d/renderer/mesh/shader.py in forward(self, fragments, meshes, **kwargs)
    153             lights=lights,
    154             cameras=cameras,
--> 155             materials=materials,
    156         )
    157         znear = kwargs.get("znear", getattr(cameras, "znear", 1.0))

/opt/conda/lib/python3.7/site-packages/pytorch3d/renderer/mesh/shading.py in phong_shading(meshes, fragments, lights, cameras, materials, texels)
     89     )
     90     ambient, diffuse, specular = _apply_lighting(
---> 91         pixel_coords, pixel_normals, lights, cameras, materials
     92     )
     93     colors = (ambient + diffuse) * texels + specular

/opt/conda/lib/python3.7/site-packages/pytorch3d/renderer/mesh/shading.py in _apply_lighting(points, normals, lights, cameras, materials)
     35         points=points,
     36         camera_position=cameras.get_camera_center(),
---> 37         shininess=materials.shininess,
     38     )
     39     ambient_color = materials.ambient_color * lights.ambient_color

/opt/conda/lib/python3.7/site-packages/pytorch3d/renderer/lighting.py in specular(self, normals, points, camera_position, shininess)
    284             direction=direction,
    285             camera_position=camera_position,
--> 286             shininess=shininess,
    287         )
    288 

/opt/conda/lib/python3.7/site-packages/pytorch3d/renderer/lighting.py in specular(points, normals, direction, color, camera_position, shininess)
    130     expand_dims = (-1,) + (1,) * len(points_dims)
    131     if direction.shape != normals.shape:
--> 132         direction = direction.view(expand_dims + (3,))
    133     if color.shape != normals.shape:
    134         color = color.view(expand_dims + (3,))

RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.

@bottler
Copy link
Contributor

bottler commented Mar 14, 2022

Something weird is happening here. What is lights? (Actually, the more of your code you give, the easier it might be for us to work out what is wrong.)

@bottler bottler self-assigned this Mar 14, 2022
@Nan2018
Copy link
Author

Nan2018 commented Mar 14, 2022

Here are the codes

verts, faces_idx, aux = load_obj(obj_filename, create_texture_atlas=True, device=device)
faces = faces_idx.verts_idx

# this will cause "ValueError: Meshes does not have textures"
# mesh = Meshes(
#     verts=[verts.to(device)],   
#     faces=[faces.to(device)],
# )

mesh = load_objs_as_meshes([obj_filename], device=device)

ambient_color = []
diffuse_color = []
specular_color = []
shininess = []
for colors in aux.material_colors.values():
    ambient_color.append(colors['ambient_color'])
    diffuse_color.append(colors['diffuse_color'])
    specular_color.append(colors['specular_color'])
    shininess.append(colors['shininess'])
ambient_color = torch.stack(ambient_color)
diffuse_color = torch.stack(diffuse_color)
specular_color = torch.stack(specular_color)
shininess = torch.stack(shininess)
mat = Materials(ambient_color, diffuse_color, specular_color, shininess.squeeze(), device=device)

raster_settings = RasterizationSettings(
    image_size=512, 
    blur_radius=0.0, 
    faces_per_pixel=1, 
)

lights = PointLights(device=device, location=[[0.0, 0.0, 5.0]])
R, T = look_at_view_transform(3, 30, 0) 
cameras = FoVPerspectiveCameras(device=device, R=R, T=T)
renderer = MeshRenderer(
    rasterizer=MeshRasterizer(
        cameras=cameras, 
        raster_settings=raster_settings
    ),
    shader=SoftPhongShader(
        device=device, 
        cameras=cameras,
        lights=lights,
        materials=mat
    )
)

images = renderer(mesh)

@bottler
Copy link
Contributor

bottler commented Mar 15, 2022

Whatever has gone wrong here is quite concerning and it might need some work to debug. If you can help that would be great, otherwise supplying us with everything we need to reproduce the error is needed (the data, the full code, the versions of pytorch, pytorch3d, cuda). There's an assumption in the lighting code that certain tensors will be contiguous, which allows it to save memory by knowing that reshaping tensors won't need to copy data, and one of these assumptions has failed the case here.

It would be useful to work backwards from the error and look at the strides and shapes of the tensors which go in to it. This could be done in a debugger or by adding print statements. Specifically, direction must not be contiguous. It came from the line direction = location - points. The location comes from the location in the lights and the points come from following line in phong_shading in shading.py.

    pixel_coords = interpolate_face_attributes(
        fragments.pix_to_face, fragments.bary_coords, faces_verts
    )

At least one of these two things is probably not contiguous, and the next question would be why.

@Nan2018
Copy link
Author

Nan2018 commented Mar 17, 2022

data: test_obj.tar.gz
versions:

python: 3.7.12
torch: 1.10.0
pytorch3d: 0.6.1
nvcc: Cuda compilation tools, release 11.0, V11.0.221, Build cuda_11.0_bu.TC445_37.28845127_0

full code in the comment obove
screenshot in blender:
Screen Shot 2022-03-16 at 10 39 55 PM

@bottler bottler added bug Something isn't working and removed bug Something isn't working labels Mar 17, 2022
@bottler
Copy link
Contributor

bottler commented Mar 17, 2022

There's a mismatch in batch sizes. The materials (ambient_color, diffuse_color, specular_color, shininess) each contain 23 items / a batch size of 23. len(mat) is 23. But there is only one mesh.

@github-actions
Copy link

This issue is stale because it has been open 30 days with no activity. Remove stale label or comment or this will be closed in 5 days.

@github-actions github-actions bot added the Stale label Apr 17, 2022
@github-actions
Copy link

This issue was closed because it has been stalled for 5 days with no activity.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants