From c52bac42eef4d95791440b052959802831e339f9 Mon Sep 17 00:00:00 2001 From: Cheng Chi Date: Sat, 9 Sep 2023 12:51:49 -0400 Subject: [PATCH] fixed T->To based on suggestion from Dominique-Yiu --- .../policy/diffusion_transformer_hybrid_image_policy.py | 4 ++-- diffusion_policy/policy/diffusion_unet_hybrid_image_policy.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/diffusion_policy/policy/diffusion_transformer_hybrid_image_policy.py b/diffusion_policy/policy/diffusion_transformer_hybrid_image_policy.py index b5aae29..114d844 100644 --- a/diffusion_policy/policy/diffusion_transformer_hybrid_image_policy.py +++ b/diffusion_policy/policy/diffusion_transformer_hybrid_image_policy.py @@ -256,8 +256,8 @@ class DiffusionTransformerHybridImagePolicy(BaseImagePolicy): # condition through impainting this_nobs = dict_apply(nobs, lambda x: x[:,:To,...].reshape(-1,*x.shape[2:])) nobs_features = self.obs_encoder(this_nobs) - # reshape back to B, T, Do - nobs_features = nobs_features.reshape(B, T, -1) + # reshape back to B, To, Do + nobs_features = nobs_features.reshape(B, To, -1) shape = (B, T, Da+Do) cond_data = torch.zeros(size=shape, device=device, dtype=dtype) cond_mask = torch.zeros_like(cond_data, dtype=torch.bool) diff --git a/diffusion_policy/policy/diffusion_unet_hybrid_image_policy.py b/diffusion_policy/policy/diffusion_unet_hybrid_image_policy.py index 8f8c54f..cb2d848 100644 --- a/diffusion_policy/policy/diffusion_unet_hybrid_image_policy.py +++ b/diffusion_policy/policy/diffusion_unet_hybrid_image_policy.py @@ -247,7 +247,7 @@ class DiffusionUnetHybridImagePolicy(BaseImagePolicy): # condition through impainting this_nobs = dict_apply(nobs, lambda x: x[:,:To,...].reshape(-1,*x.shape[2:])) nobs_features = self.obs_encoder(this_nobs) - # reshape back to B, T, Do + # reshape back to B, To, Do nobs_features = nobs_features.reshape(B, To, -1) cond_data = torch.zeros(size=(B, T, Da+Do), device=device, dtype=dtype) cond_mask = torch.zeros_like(cond_data, dtype=torch.bool)