-
Notifications
You must be signed in to change notification settings - Fork 335
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[QUESTION] How to reset only certain nested parts of a key with TensorDictPrimer? #2053
Comments
For further clarity on the "end up losing the other nested keys" part when I only specify the key I would like, here I print the observation spec before and after adding my TensorDictPrimer transform:
Observation spec before:
Observation spec after:
The hidden_spec I am passing to TensorDictPrimer:
|
For the record, I am able to work around this issue by simply specifying a different key for the hidden states, e.g. ("agents_hs", "hidden_state") which avoids overwriting the original obs_spec or zeroing out other fields at the same nesting level. I would just like to know if this dilemma is avoidable. |
On it sorry for the delay. |
Hi, I have an observation spec for a multi-agent environment which looks like this:
Here, the key ("agents", "edge_index") is a special field that I populate once upon creating the env and never want to change.
My problem is that I would like to add a recurrent policy, which requires tracking the hidden state for each agent. I read the Recurrent DQN tutorial, but the LSTMModule's make_tensordict_primer() does not quite work for me as it is designed for the single-agent case.
Thus I have tried to write a custom TensorDictPrimer transform, like so:
However I notice that on environment resets, this TensorDictPrimer now overwrites all the fields in this spec with 0s. I have attempted to specify the TensorDictPrimer's input keys as solely the ("agents", "hidden_state") key I want to zero-out, but when I do so, I end up losing the other nested keys under "agents" on reset.
Am I misunderstanding the usage of TensorDictPrimer? Any help would be appreciated.
The text was updated successfully, but these errors were encountered: