diff --git a/test/base/test_batch.py b/test/base/test_batch.py index eafd7b3..b2da847 100644 --- a/test/base/test_batch.py +++ b/test/base/test_batch.py @@ -3,11 +3,12 @@ import pickle import torch import numpy as np -from tianshou.data import Batch +from tianshou.data import Batch, to_torch def test_batch(): batch = Batch(obs=[0], np=np.zeros([3, 4])) + assert batch.obs == batch["obs"] batch.obs = [1] assert batch.obs == [1] batch.append(batch) @@ -31,6 +32,45 @@ def test_batch_over_batch(): assert batch2[-1].b.b == 0 +def test_batch_over_batch_to_torch(): + batch = Batch( + a=np.ones((1,), dtype=np.float64), + b=Batch( + c=np.ones((1,), dtype=np.float64), + d=torch.ones((1,), dtype=torch.float64) + ) + ) + batch.to_torch() + assert isinstance(batch.a, torch.Tensor) + assert isinstance(batch.b.c, torch.Tensor) + assert isinstance(batch.b.d, torch.Tensor) + assert batch.a.dtype == torch.float64 + assert batch.b.c.dtype == torch.float64 + assert batch.b.d.dtype == torch.float64 + batch.to_torch(dtype=torch.float32) + assert batch.a.dtype == torch.float32 + assert batch.b.c.dtype == torch.float32 + assert batch.b.d.dtype == torch.float32 + + +def test_utils_to_torch(): + batch = Batch( + a=np.ones((1,), dtype=np.float64), + b=Batch( + c=np.ones((1,), dtype=np.float64), + d=torch.ones((1,), dtype=torch.float64) + ) + ) + a_torch_float = to_torch(batch.a, dtype=torch.float32) + assert a_torch_float.dtype == torch.float32 + a_torch_double = to_torch(batch.a, dtype=torch.float64) + assert a_torch_double.dtype == torch.float64 + batch_torch_float = to_torch(batch, dtype=torch.float32) + assert batch_torch_float.a.dtype == torch.float32 + assert batch_torch_float.b.c.dtype == torch.float32 + assert batch_torch_float.b.d.dtype == torch.float32 + + def test_batch_pickle(): batch = Batch(obs=Batch(a=0.0, c=torch.Tensor([1.0, 2.0])), np=np.zeros([3, 4])) @@ -42,14 +82,12 @@ def test_batch_pickle(): def test_batch_from_to_numpy_without_copy(): batch = Batch(a=np.ones((1,)), b=Batch(c=np.ones((1,)))) - a_mem_addr_orig = batch["a"].__array_interface__['data'][0] - c_mem_addr_orig = batch["b"]["c"].__array_interface__['data'][0] + a_mem_addr_orig = batch.a.__array_interface__['data'][0] + c_mem_addr_orig = batch.b.c.__array_interface__['data'][0] batch.to_torch() - assert isinstance(batch["a"], torch.Tensor) - assert isinstance(batch["b"]["c"], torch.Tensor) batch.to_numpy() - a_mem_addr_new = batch["a"].__array_interface__['data'][0] - c_mem_addr_new = batch["b"]["c"].__array_interface__['data'][0] + a_mem_addr_new = batch.a.__array_interface__['data'][0] + c_mem_addr_new = batch.b.c.__array_interface__['data'][0] assert a_mem_addr_new == a_mem_addr_orig assert c_mem_addr_new == c_mem_addr_orig diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index e0129f7..38905c4 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -183,19 +183,36 @@ class Batch: def to_torch(self, dtype: Optional[torch.dtype] = None, - device: Union[str, int] = 'cpu' + device: Union[str, int, torch.device] = 'cpu' ) -> None: """Change all numpy.ndarray to torch.Tensor. This is an inplace operation. """ + if not isinstance(device, torch.device): + device = torch.device(device) + for k, v in self.__dict__.items(): if isinstance(v, np.ndarray): v = torch.from_numpy(v).to(device) if dtype is not None: v = v.type(dtype) self.__dict__[k] = v + if isinstance(v, torch.Tensor): + if dtype is not None and v.dtype != dtype: + must_update_tensor = True + elif v.device.type != device.type: + must_update_tensor = True + elif device.index is not None and \ + device.index != v.device.index: + must_update_tensor = True + else: + must_update_tensor = False + if must_update_tensor: + if dtype is not None: + v = v.type(dtype) + self.__dict__[k] = v.to(device) elif isinstance(v, Batch): - v.to_torch() + v.to_torch(dtype, device) def append(self, batch: 'Batch') -> None: """Append a :class:`~tianshou.data.Batch` object to current batch.""" diff --git a/tianshou/data/utils.py b/tianshou/data/utils.py index 8494074..62dfb46 100644 --- a/tianshou/data/utils.py +++ b/tianshou/data/utils.py @@ -28,9 +28,13 @@ def to_torch(x: Union[torch.Tensor, dict, Batch, np.ndarray], x = torch.from_numpy(x).to(device) if dtype is not None: x = x.type(dtype) + if isinstance(x, torch.Tensor): + if dtype is not None: + x = x.type(dtype) + x = x.to(device) elif isinstance(x, dict): for k, v in x.items(): x[k] = to_torch(v, dtype, device) elif isinstance(x, Batch): - x.to_torch() + x.to_torch(dtype, device) return x