-
Notifications
You must be signed in to change notification settings - Fork 1.4k
Description
First of all, thank you for the great library!
I'm using Pytorch3D (nightly build 0.2.0 on Ubuntu 16.04) to render some models from the ShapeNet dataset. I noticed that the rendering result changes when I change the distance of the camera to the object. You can find the code, sample results and the .obj file below. (Note: I have implemented a custom shader called SoftFlatShader which is flat_shading followed by softmax blending: you can also find it below.)
`
import os
import torch
import torch.nn as nn
from skimage.io import imread, imsave
from pytorch3d.io import load_objs_as_meshes
from pytorch3d.io import load_obj
from pytorch3d.structures import Meshes
from pytorch3d.renderer import Textures
from pytorch3d.renderer import (
look_at_view_transform,
FoVPerspectiveCameras,
PerspectiveCameras,
PointLights,
DirectionalLights,
Materials,
RasterizationSettings,
MeshRenderer,
MeshRasterizer,
SoftPhongShader,
HardPhongShader,
HardFlatShader,
SoftSilhouetteShader,
SoftGouraudShader,
HardGouraudShader,
BlendParams,
softmax_rgb_blend
)
from pytorch3d.renderer.mesh.shading import flat_shading
class SoftFlatShader(nn.Module):
def __init__(
self, device="cpu", cameras=None, lights=None, materials=None, blend_params=None):
super().__init__()
self.lights = lights if lights is not None else PointLights(device=device)
self.materials = (
materials if materials is not None else Materials(device=device)
)
self.cameras = cameras
self.blend_params = blend_params if blend_params is not None else BlendParams()
def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
cameras = kwargs.get("cameras", self.cameras)
if cameras is None:
msg = "Cameras must be specified either at initialization \
or in the forward pass of HardFlatShader"
raise ValueError(msg)
texels = meshes.sample_textures(fragments)
lights = kwargs.get("lights", self.lights)
materials = kwargs.get("materials", self.materials)
blend_params = kwargs.get("blend_params", self.blend_params)
colors = flat_shading(
meshes=meshes,
fragments=fragments,
texels=texels,
lights=lights,
cameras=cameras,
materials=materials,
)
images = softmax_rgb_blend(colors, fragments, blend_params)
return images
def load_untextured_mesh(mesh_path, device):
verts, faces_idx, _ = load_obj(mesh_path, device=device)
faces = faces_idx.verts_idx
# Initialize each vertex to be white in color
verts_rgb = torch.ones_like(verts)[None]
textures = Textures(verts_rgb=verts_rgb.to(device))
mesh_no_texture = Meshes(verts=[verts.to(device)],
faces=[faces.to(device)],
textures=textures)
return mesh_no_texture
def render_mesh(mesh, R, T, device, img_size=512):
cameras = FoVPerspectiveCameras(device=device, R=R, T=T, degrees=False, fov=0.7)
# cameras = PerspectiveCameras(device=device, R=R, T=T, focal_length=0.7)
raster_settings = RasterizationSettings(image_size=img_size,
blur_radius=0.0,
faces_per_pixel=1)
lights = PointLights(device=device, location=cameras.get_camera_center())
renderer = MeshRenderer(rasterizer=MeshRasterizer(cameras=cameras,
raster_settings=raster_settings),
shader=SoftFlatShader(device=device,
cameras=cameras,
lights=lights
)
)
images = renderer(mesh)
return images
def render_depth(mesh, R, T, device, img_size=512):
cameras = FoVPerspectiveCameras(device=device, R=R, T=T, degrees=False, fov=0.7)
raster_settings = RasterizationSettings(image_size=img_size,
blur_radius=0.0,
faces_per_pixel=1)
rasterizer = MeshRasterizer(cameras=cameras, raster_settings=raster_settings)
rasterizer_output = rasterizer(mesh)
return rasterizer_output
if name == "main":
if torch.cuda.is_available():
device = torch.device("cuda:0")
torch.cuda.set_device(device)
else:
device = torch.device("cpu")
# Set paths
DATA_DIR = "./data"
obj_file_name = os.path.join(DATA_DIR, "model_normalized.obj")
dist_to_cam = 1 # change between 0.6 and 1 to see the results
with torch.no_grad():
mesh = load_untextured_mesh(obj_file_name, device)
R, T = look_at_view_transform(dist=dist_to_cam , elev=0, azim=180)
rendered_image = render_mesh(mesh, R, T, device)
rendered_depth = render_depth(mesh, R, T, device)
imsave(os.path.join(DATA_DIR, "renderedCar.png"), rendered_image[0, ..., :3].cpu().numpy())
imsave(os.path.join(DATA_DIR,"depthCar.png"), rendered_depth[1][0, ..., 0].cpu().numpy())
`
When the dist_to_cam parameter is 1, the rendering result is weird(white rectangle on the windshield). I was wondering why this happens?
This is the output for dist_to_cam=0.6:
This is the model file:
model_normalized.zip