-
Notifications
You must be signed in to change notification settings - Fork 1.4k
Description
🚀 Feature
GridRaySampler
creates the grid in constructor and calls register_buffer
. However, we should pass persistent=False
argument so that the canonical grid is not saved in state_dict.
self.register_buffer("_xy_grid", _xy_grid) |
Motivation
We want the grid to be generated in code every time instead of loading from the state_dict. This allows us to change the image render size at the construction time. Currently, the state_dict will mismatch if the image size is changed. This is a needed feature since it is common to have training and evaluation to be at different resolutions.
Pitch
Change
self.register_buffer("_xy_grid", _xy_grid)
to
self.register_buffer("_xy_grid", _xy_grid, False)
BC Breaking!
Note: This will break existing state_dicts as the new module no longer has _xy_grid
, but previously saved state_dicts will have the extra key. A simple workaround is to use load_state_dict(state_dict, strict=False)
What are your thoughts on BC breaking?