From dd2cbac9fa83d264487194919da34d5302f76cf1 Mon Sep 17 00:00:00 2001 From: Cheng Chi Date: Thu, 7 Sep 2023 19:02:13 -0400 Subject: [PATCH] incorporated change from PR #10 --- diffusion_policy/model/diffusion/conditional_unet1d.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/diffusion_policy/model/diffusion/conditional_unet1d.py b/diffusion_policy/model/diffusion/conditional_unet1d.py index 0e17bbf..fea22f5 100644 --- a/diffusion_policy/model/diffusion/conditional_unet1d.py +++ b/diffusion_policy/model/diffusion/conditional_unet1d.py @@ -226,6 +226,10 @@ class ConditionalUnet1D(nn.Module): for idx, (resnet, resnet2, upsample) in enumerate(self.up_modules): x = torch.cat((x, h.pop()), dim=1) x = resnet(x, global_feature) + # The correct condition should be: + # if idx == (len(self.up_modules)-1) and len(h_local) > 0: + # However this change will break compatibility with published checkpoints. + # Therefore it is left as a comment. if idx == len(self.up_modules) and len(h_local) > 0: x = x + h_local[1] x = resnet2(x, global_feature)