Merge pull request #18 from columbia-ai-robotics/cchi/bug_fix_unet1d
incorporated change from PR #10
This commit is contained in:
commit
749db2ce9c
@ -226,6 +226,10 @@ class ConditionalUnet1D(nn.Module):
|
||||
for idx, (resnet, resnet2, upsample) in enumerate(self.up_modules):
|
||||
x = torch.cat((x, h.pop()), dim=1)
|
||||
x = resnet(x, global_feature)
|
||||
# The correct condition should be:
|
||||
# if idx == (len(self.up_modules)-1) and len(h_local) > 0:
|
||||
# However this change will break compatibility with published checkpoints.
|
||||
# Therefore it is left as a comment.
|
||||
if idx == len(self.up_modules) and len(h_local) > 0:
|
||||
x = x + h_local[1]
|
||||
x = resnet2(x, global_feature)
|
||||
|
Loading…
x
Reference in New Issue
Block a user