Fix to_numpy. (#73)

Co-authored-by: Alexis Duburcq <alexis.duburcq@wandercraft.eu>
This commit is contained in:
Alexis DUBURCQ 2020-06-04 16:32:05 +02:00 committed by GitHub
parent 7bf202f195
commit 66be5641b6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -95,7 +95,7 @@ class Batch:
k__ = '_' + k + '@' + k_
self.__dict__[k__] = v_
else:
self.__dict__[k] = kwargs[k]
self.__dict__[k] = v
def __getstate__(self):
"""Pickling interface. Only the actual data are serialized
@ -180,7 +180,7 @@ class Batch:
"""
for k, v in self.__dict__.items():
if isinstance(v, torch.Tensor):
self.__dict__[k] = v.cpu().numpy()
self.__dict__[k] = v.detach().cpu().numpy()
elif isinstance(v, Batch):
v.to_numpy()