Skip to content

Unexpected rendering result wrt to the object's distance to the camera #348

@unlugi

Description

@unlugi

First of all, thank you for the great library!

I'm using Pytorch3D (nightly build 0.2.0 on Ubuntu 16.04) to render some models from the ShapeNet dataset. I noticed that the rendering result changes when I change the distance of the camera to the object. You can find the code, sample results and the .obj file below. (Note: I have implemented a custom shader called SoftFlatShader which is flat_shading followed by softmax blending: you can also find it below.)

`
import os
import torch
import torch.nn as nn
from skimage.io import imread, imsave

from pytorch3d.io import load_objs_as_meshes
from pytorch3d.io import load_obj
from pytorch3d.structures import Meshes
from pytorch3d.renderer import Textures
from pytorch3d.renderer import (
look_at_view_transform,
FoVPerspectiveCameras,
PerspectiveCameras,
PointLights,
DirectionalLights,
Materials,
RasterizationSettings,
MeshRenderer,
MeshRasterizer,
SoftPhongShader,
HardPhongShader,
HardFlatShader,
SoftSilhouetteShader,
SoftGouraudShader,
HardGouraudShader,
BlendParams,
softmax_rgb_blend
)

from pytorch3d.renderer.mesh.shading import flat_shading

class SoftFlatShader(nn.Module):

def __init__(
    self, device="cpu", cameras=None, lights=None, materials=None, blend_params=None):
    super().__init__()
    self.lights = lights if lights is not None else PointLights(device=device)
    self.materials = (
        materials if materials is not None else Materials(device=device)
    )
    self.cameras = cameras
    self.blend_params = blend_params if blend_params is not None else BlendParams()

def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
    cameras = kwargs.get("cameras", self.cameras)
    if cameras is None:
        msg = "Cameras must be specified either at initialization \
            or in the forward pass of HardFlatShader"
        raise ValueError(msg)
    texels = meshes.sample_textures(fragments)
    lights = kwargs.get("lights", self.lights)
    materials = kwargs.get("materials", self.materials)
    blend_params = kwargs.get("blend_params", self.blend_params)
    colors = flat_shading(
        meshes=meshes,
        fragments=fragments,
        texels=texels,
        lights=lights,
        cameras=cameras,
        materials=materials,
    )
    images = softmax_rgb_blend(colors, fragments, blend_params)
    return images

def load_untextured_mesh(mesh_path, device):

verts, faces_idx, _ = load_obj(mesh_path, device=device)
faces = faces_idx.verts_idx
# Initialize each vertex to be white in color
verts_rgb = torch.ones_like(verts)[None]
textures = Textures(verts_rgb=verts_rgb.to(device))
mesh_no_texture = Meshes(verts=[verts.to(device)],
                         faces=[faces.to(device)],
                         textures=textures)
return  mesh_no_texture

def render_mesh(mesh, R, T, device, img_size=512):

cameras = FoVPerspectiveCameras(device=device, R=R, T=T, degrees=False, fov=0.7)
# cameras = PerspectiveCameras(device=device, R=R, T=T, focal_length=0.7)
raster_settings = RasterizationSettings(image_size=img_size,
                                        blur_radius=0.0,
                                        faces_per_pixel=1)

lights = PointLights(device=device, location=cameras.get_camera_center())

renderer = MeshRenderer(rasterizer=MeshRasterizer(cameras=cameras,
                                                  raster_settings=raster_settings),
                        shader=SoftFlatShader(device=device,
                                               cameras=cameras,
                                               lights=lights
                                              )
                        )
images = renderer(mesh)
return images

def render_depth(mesh, R, T, device, img_size=512):

cameras = FoVPerspectiveCameras(device=device, R=R, T=T, degrees=False, fov=0.7)
raster_settings = RasterizationSettings(image_size=img_size,
                                        blur_radius=0.0,
                                        faces_per_pixel=1)

rasterizer = MeshRasterizer(cameras=cameras, raster_settings=raster_settings)
rasterizer_output = rasterizer(mesh)
return rasterizer_output

if name == "main":

if torch.cuda.is_available():
    device = torch.device("cuda:0")
    torch.cuda.set_device(device)
else:
    device = torch.device("cpu")

# Set paths
DATA_DIR = "./data"
obj_file_name = os.path.join(DATA_DIR, "model_normalized.obj")
dist_to_cam = 1 # change between 0.6 and 1 to see the results

with torch.no_grad():
    mesh = load_untextured_mesh(obj_file_name, device)
    R, T = look_at_view_transform(dist=dist_to_cam , elev=0, azim=180)
    rendered_image = render_mesh(mesh, R, T, device)
    rendered_depth = render_depth(mesh, R, T, device)

imsave(os.path.join(DATA_DIR, "renderedCar.png"), rendered_image[0, ..., :3].cpu().numpy())
imsave(os.path.join(DATA_DIR,"depthCar.png"), rendered_depth[1][0, ..., 0].cpu().numpy())

`

When the dist_to_cam parameter is 1, the rendering result is weird(white rectangle on the windshield). I was wondering why this happens?
renderedCar
depthCar

This is the output for dist_to_cam=0.6:
renderedCar-close
depthCar-close

This is the model file:
model_normalized.zip

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions