incorporated change from PR #10
This commit is contained in:
parent
0d00e02b45
commit
dd2cbac9fa
@ -226,6 +226,10 @@ class ConditionalUnet1D(nn.Module):
|
|||||||
for idx, (resnet, resnet2, upsample) in enumerate(self.up_modules):
|
for idx, (resnet, resnet2, upsample) in enumerate(self.up_modules):
|
||||||
x = torch.cat((x, h.pop()), dim=1)
|
x = torch.cat((x, h.pop()), dim=1)
|
||||||
x = resnet(x, global_feature)
|
x = resnet(x, global_feature)
|
||||||
|
# The correct condition should be:
|
||||||
|
# if idx == (len(self.up_modules)-1) and len(h_local) > 0:
|
||||||
|
# However this change will break compatibility with published checkpoints.
|
||||||
|
# Therefore it is left as a comment.
|
||||||
if idx == len(self.up_modules) and len(h_local) > 0:
|
if idx == len(self.up_modules) and len(h_local) > 0:
|
||||||
x = x + h_local[1]
|
x = x + h_local[1]
|
||||||
x = resnet2(x, global_feature)
|
x = resnet2(x, global_feature)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user