Merge pull request #19 from columbia-ai-robotics/cchi/fix_transformer_impainting
fixed T->To based on suggestion from Dominique-Yiu
This commit is contained in:
commit
68eef44d3e
@ -256,8 +256,8 @@ class DiffusionTransformerHybridImagePolicy(BaseImagePolicy):
|
|||||||
# condition through impainting
|
# condition through impainting
|
||||||
this_nobs = dict_apply(nobs, lambda x: x[:,:To,...].reshape(-1,*x.shape[2:]))
|
this_nobs = dict_apply(nobs, lambda x: x[:,:To,...].reshape(-1,*x.shape[2:]))
|
||||||
nobs_features = self.obs_encoder(this_nobs)
|
nobs_features = self.obs_encoder(this_nobs)
|
||||||
# reshape back to B, T, Do
|
# reshape back to B, To, Do
|
||||||
nobs_features = nobs_features.reshape(B, T, -1)
|
nobs_features = nobs_features.reshape(B, To, -1)
|
||||||
shape = (B, T, Da+Do)
|
shape = (B, T, Da+Do)
|
||||||
cond_data = torch.zeros(size=shape, device=device, dtype=dtype)
|
cond_data = torch.zeros(size=shape, device=device, dtype=dtype)
|
||||||
cond_mask = torch.zeros_like(cond_data, dtype=torch.bool)
|
cond_mask = torch.zeros_like(cond_data, dtype=torch.bool)
|
||||||
|
@ -247,7 +247,7 @@ class DiffusionUnetHybridImagePolicy(BaseImagePolicy):
|
|||||||
# condition through impainting
|
# condition through impainting
|
||||||
this_nobs = dict_apply(nobs, lambda x: x[:,:To,...].reshape(-1,*x.shape[2:]))
|
this_nobs = dict_apply(nobs, lambda x: x[:,:To,...].reshape(-1,*x.shape[2:]))
|
||||||
nobs_features = self.obs_encoder(this_nobs)
|
nobs_features = self.obs_encoder(this_nobs)
|
||||||
# reshape back to B, T, Do
|
# reshape back to B, To, Do
|
||||||
nobs_features = nobs_features.reshape(B, To, -1)
|
nobs_features = nobs_features.reshape(B, To, -1)
|
||||||
cond_data = torch.zeros(size=(B, T, Da+Do), device=device, dtype=dtype)
|
cond_data = torch.zeros(size=(B, T, Da+Do), device=device, dtype=dtype)
|
||||||
cond_mask = torch.zeros_like(cond_data, dtype=torch.bool)
|
cond_mask = torch.zeros_like(cond_data, dtype=torch.bool)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user