Skip to content

Commit

Permalink
typo in VAE encoder's joint projection
Browse files Browse the repository at this point in the history
  • Loading branch information
tonyzhaozh committed Jun 23, 2023
1 parent f74f0db commit dfe6c7f
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion detr/models/detr_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def forward(self, qpos, image, env_state, actions=None, is_pad=None):
if is_training:
# project action sequence to embedding dim, and concat with a CLS token
action_embed = self.encoder_action_proj(actions) # (bs, seq, hidden_dim)
qpos_embed = self.encoder_action_proj(qpos) # (bs, hidden_dim)
qpos_embed = self.encoder_joint_proj(qpos) # (bs, hidden_dim)
qpos_embed = torch.unsqueeze(qpos_embed, axis=1) # (bs, 1, hidden_dim)
cls_embed = self.cls_embed.weight # (1, hidden_dim)
cls_embed = torch.unsqueeze(cls_embed, axis=0).repeat(bs, 1, 1) # (bs, 1, hidden_dim)
Expand Down

0 comments on commit dfe6c7f

Please sign in to comment.