Skip to content

[Feature Request] Make objective modules compatible with dictionaries #12

@vmoens

Description

@vmoens

The usage of the TensorDict class simplifies the process of passing data across processes, designing general classes that are oblivious to the keys used in a specific algorithm (e.g. whether or not a action_log_prob / hidden_state key should be expected).
However, introducing new classes can prevent users from copy-paste and re-use modules (see here). We should make sure that TensorDict is used only when absolutely necessary. These cases include situations where all the content of a dictionary will be treated in a similar way:

  • indexing
  • reshaping
  • sending from worker to worker, device to device
  • concatenation / stacking
    In general, TensorDict should be used for high-level classes: Agent, DataCollector, possibly probabilistic operator modules.

Objectives should not require TensorDicts in general.

However, in some cases they may need to check the trajectory length (1st dimension) or the batch size (0th dimension), or even the device. An option in those cases would be to infer those from a specific tensor in the dictionary (e.g. reward?)

Plan

  • Test and fix modules such that they all accept a dictionary as input.
  • Modify typing in this perspective.
  • Currently, modules return a TensorDict but we could perfectly return a regular dict.

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions