diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index cde6e6f..f498a4d 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -95,7 +95,7 @@ class Batch: k__ = '_' + k + '@' + k_ self.__dict__[k__] = v_ else: - self.__dict__[k] = kwargs[k] + self.__dict__[k] = v def __getstate__(self): """Pickling interface. Only the actual data are serialized @@ -180,7 +180,7 @@ class Batch: """ for k, v in self.__dict__.items(): if isinstance(v, torch.Tensor): - self.__dict__[k] = v.cpu().numpy() + self.__dict__[k] = v.detach().cpu().numpy() elif isinstance(v, Batch): v.to_numpy()