-
Notifications
You must be signed in to change notification settings - Fork 396
Description
The usage of the TensorDict
class simplifies the process of passing data across processes, designing general classes that are oblivious to the keys used in a specific algorithm (e.g. whether or not a action_log_prob
/ hidden_state
key should be expected).
However, introducing new classes can prevent users from copy-paste and re-use modules (see here). We should make sure that TensorDict
is used only when absolutely necessary. These cases include situations where all the content of a dictionary will be treated in a similar way:
- indexing
- reshaping
- sending from worker to worker, device to device
- concatenation / stacking
In general, TensorDict should be used for high-level classes:Agent
,DataCollector
, possibly probabilistic operator modules.
Objectives should not require TensorDicts in general.
However, in some cases they may need to check the trajectory length (1st dimension) or the batch size (0th dimension), or even the device. An option in those cases would be to infer those from a specific tensor in the dictionary (e.g. reward?)
Plan
- Test and fix modules such that they all accept a dictionary as input.
- Modify typing in this perspective.
- Currently, modules return a
TensorDict
but we could perfectly return a regulardict
.