diff --git a/diffusion_policy/model/diffusion/conditional_unet1d.py b/diffusion_policy/model/diffusion/conditional_unet1d.py index 0e17bbf..fea22f5 100644 --- a/diffusion_policy/model/diffusion/conditional_unet1d.py +++ b/diffusion_policy/model/diffusion/conditional_unet1d.py @@ -226,6 +226,10 @@ class ConditionalUnet1D(nn.Module): for idx, (resnet, resnet2, upsample) in enumerate(self.up_modules): x = torch.cat((x, h.pop()), dim=1) x = resnet(x, global_feature) + # The correct condition should be: + # if idx == (len(self.up_modules)-1) and len(h_local) > 0: + # However this change will break compatibility with published checkpoints. + # Therefore it is left as a comment. if idx == len(self.up_modules) and len(h_local) > 0: x = x + h_local[1] x = resnet2(x, global_feature)