Merge pull request #18 from columbia-ai-robotics/cchi/bug_fix_unet1d

incorporated change from PR #10
This commit is contained in:
Cheng Chi 2023-09-07 16:03:03 -07:00 committed by GitHub
commit 749db2ce9c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -226,6 +226,10 @@ class ConditionalUnet1D(nn.Module):
for idx, (resnet, resnet2, upsample) in enumerate(self.up_modules): for idx, (resnet, resnet2, upsample) in enumerate(self.up_modules):
x = torch.cat((x, h.pop()), dim=1) x = torch.cat((x, h.pop()), dim=1)
x = resnet(x, global_feature) x = resnet(x, global_feature)
# The correct condition should be:
# if idx == (len(self.up_modules)-1) and len(h_local) > 0:
# However this change will break compatibility with published checkpoints.
# Therefore it is left as a comment.
if idx == len(self.up_modules) and len(h_local) > 0: if idx == len(self.up_modules) and len(h_local) > 0:
x = x + h_local[1] x = x + h_local[1]
x = resnet2(x, global_feature) x = resnet2(x, global_feature)