Skip to content

GridRaySampler should not persist the grid. #802

@PeterL1n

Description

@PeterL1n

🚀 Feature

GridRaySampler creates the grid in constructor and calls register_buffer. However, we should pass persistent=False argument so that the canonical grid is not saved in state_dict.

self.register_buffer("_xy_grid", _xy_grid)

Motivation

We want the grid to be generated in code every time instead of loading from the state_dict. This allows us to change the image render size at the construction time. Currently, the state_dict will mismatch if the image size is changed. This is a needed feature since it is common to have training and evaluation to be at different resolutions.

Pitch

Change

self.register_buffer("_xy_grid", _xy_grid)

to

self.register_buffer("_xy_grid", _xy_grid, False)

BC Breaking!

Note: This will break existing state_dicts as the new module no longer has _xy_grid, but previously saved state_dicts will have the extra key. A simple workaround is to use load_state_dict(state_dict, strict=False) What are your thoughts on BC breaking?

Metadata

Metadata

Assignees

Labels

enhancementNew feature or request

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions