incorporated change from PR #10

This commit is contained in:
Cheng Chi 2023-09-07 19:02:13 -04:00
parent 0d00e02b45
commit dd2cbac9fa

View File

@ -226,6 +226,10 @@ class ConditionalUnet1D(nn.Module):
for idx, (resnet, resnet2, upsample) in enumerate(self.up_modules):
x = torch.cat((x, h.pop()), dim=1)
x = resnet(x, global_feature)
# The correct condition should be:
# if idx == (len(self.up_modules)-1) and len(h_local) > 0:
# However this change will break compatibility with published checkpoints.
# Therefore it is left as a comment.
if idx == len(self.up_modules) and len(h_local) > 0:
x = x + h_local[1]
x = resnet2(x, global_feature)