diff --git a/diffusion_policy/policy/diffusion_transformer_hybrid_image_policy.py b/diffusion_policy/policy/diffusion_transformer_hybrid_image_policy.py index b5aae29..114d844 100644 --- a/diffusion_policy/policy/diffusion_transformer_hybrid_image_policy.py +++ b/diffusion_policy/policy/diffusion_transformer_hybrid_image_policy.py @@ -256,8 +256,8 @@ class DiffusionTransformerHybridImagePolicy(BaseImagePolicy): # condition through impainting this_nobs = dict_apply(nobs, lambda x: x[:,:To,...].reshape(-1,*x.shape[2:])) nobs_features = self.obs_encoder(this_nobs) - # reshape back to B, T, Do - nobs_features = nobs_features.reshape(B, T, -1) + # reshape back to B, To, Do + nobs_features = nobs_features.reshape(B, To, -1) shape = (B, T, Da+Do) cond_data = torch.zeros(size=shape, device=device, dtype=dtype) cond_mask = torch.zeros_like(cond_data, dtype=torch.bool) diff --git a/diffusion_policy/policy/diffusion_unet_hybrid_image_policy.py b/diffusion_policy/policy/diffusion_unet_hybrid_image_policy.py index 8f8c54f..cb2d848 100644 --- a/diffusion_policy/policy/diffusion_unet_hybrid_image_policy.py +++ b/diffusion_policy/policy/diffusion_unet_hybrid_image_policy.py @@ -247,7 +247,7 @@ class DiffusionUnetHybridImagePolicy(BaseImagePolicy): # condition through impainting this_nobs = dict_apply(nobs, lambda x: x[:,:To,...].reshape(-1,*x.shape[2:])) nobs_features = self.obs_encoder(this_nobs) - # reshape back to B, T, Do + # reshape back to B, To, Do nobs_features = nobs_features.reshape(B, To, -1) cond_data = torch.zeros(size=(B, T, Da+Do), device=device, dtype=dtype) cond_mask = torch.zeros_like(cond_data, dtype=torch.bool)