Fix 'to_tensor' dtype/device forwarding for Batch over Batch. (#68)

* Fix Batch to_torch method not updating dtype/device of already converted data.

* Fix dtype/device to forwarded by to_tensor for Batch over Batch.

* Add Unit test to check to_torch dtype/device recursive forwarding.

* Batch UT check accessing data using both dict and class style.

* Fix utils to_tensor dtype/device forwarding. Add Unit tests.

* Fix UT.

Co-authored-by: Alexis Duburcq <alexis.duburcq@wandercraft.eu>
Co-authored-by: n+e <463003665@qq.com>
This commit is contained in:
Alexis DUBURCQ 2020-05-30 15:40:31 +02:00 committed by GitHub
parent 529a4cf44c
commit 1fce527c77
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 69 additions and 10 deletions

View File

@ -3,11 +3,12 @@ import pickle
import torch import torch
import numpy as np import numpy as np
from tianshou.data import Batch from tianshou.data import Batch, to_torch
def test_batch(): def test_batch():
batch = Batch(obs=[0], np=np.zeros([3, 4])) batch = Batch(obs=[0], np=np.zeros([3, 4]))
assert batch.obs == batch["obs"]
batch.obs = [1] batch.obs = [1]
assert batch.obs == [1] assert batch.obs == [1]
batch.append(batch) batch.append(batch)
@ -31,6 +32,45 @@ def test_batch_over_batch():
assert batch2[-1].b.b == 0 assert batch2[-1].b.b == 0
def test_batch_over_batch_to_torch():
batch = Batch(
a=np.ones((1,), dtype=np.float64),
b=Batch(
c=np.ones((1,), dtype=np.float64),
d=torch.ones((1,), dtype=torch.float64)
)
)
batch.to_torch()
assert isinstance(batch.a, torch.Tensor)
assert isinstance(batch.b.c, torch.Tensor)
assert isinstance(batch.b.d, torch.Tensor)
assert batch.a.dtype == torch.float64
assert batch.b.c.dtype == torch.float64
assert batch.b.d.dtype == torch.float64
batch.to_torch(dtype=torch.float32)
assert batch.a.dtype == torch.float32
assert batch.b.c.dtype == torch.float32
assert batch.b.d.dtype == torch.float32
def test_utils_to_torch():
batch = Batch(
a=np.ones((1,), dtype=np.float64),
b=Batch(
c=np.ones((1,), dtype=np.float64),
d=torch.ones((1,), dtype=torch.float64)
)
)
a_torch_float = to_torch(batch.a, dtype=torch.float32)
assert a_torch_float.dtype == torch.float32
a_torch_double = to_torch(batch.a, dtype=torch.float64)
assert a_torch_double.dtype == torch.float64
batch_torch_float = to_torch(batch, dtype=torch.float32)
assert batch_torch_float.a.dtype == torch.float32
assert batch_torch_float.b.c.dtype == torch.float32
assert batch_torch_float.b.d.dtype == torch.float32
def test_batch_pickle(): def test_batch_pickle():
batch = Batch(obs=Batch(a=0.0, c=torch.Tensor([1.0, 2.0])), batch = Batch(obs=Batch(a=0.0, c=torch.Tensor([1.0, 2.0])),
np=np.zeros([3, 4])) np=np.zeros([3, 4]))
@ -42,14 +82,12 @@ def test_batch_pickle():
def test_batch_from_to_numpy_without_copy(): def test_batch_from_to_numpy_without_copy():
batch = Batch(a=np.ones((1,)), b=Batch(c=np.ones((1,)))) batch = Batch(a=np.ones((1,)), b=Batch(c=np.ones((1,))))
a_mem_addr_orig = batch["a"].__array_interface__['data'][0] a_mem_addr_orig = batch.a.__array_interface__['data'][0]
c_mem_addr_orig = batch["b"]["c"].__array_interface__['data'][0] c_mem_addr_orig = batch.b.c.__array_interface__['data'][0]
batch.to_torch() batch.to_torch()
assert isinstance(batch["a"], torch.Tensor)
assert isinstance(batch["b"]["c"], torch.Tensor)
batch.to_numpy() batch.to_numpy()
a_mem_addr_new = batch["a"].__array_interface__['data'][0] a_mem_addr_new = batch.a.__array_interface__['data'][0]
c_mem_addr_new = batch["b"]["c"].__array_interface__['data'][0] c_mem_addr_new = batch.b.c.__array_interface__['data'][0]
assert a_mem_addr_new == a_mem_addr_orig assert a_mem_addr_new == a_mem_addr_orig
assert c_mem_addr_new == c_mem_addr_orig assert c_mem_addr_new == c_mem_addr_orig

View File

@ -183,19 +183,36 @@ class Batch:
def to_torch(self, def to_torch(self,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
device: Union[str, int] = 'cpu' device: Union[str, int, torch.device] = 'cpu'
) -> None: ) -> None:
"""Change all numpy.ndarray to torch.Tensor. This is an inplace """Change all numpy.ndarray to torch.Tensor. This is an inplace
operation. operation.
""" """
if not isinstance(device, torch.device):
device = torch.device(device)
for k, v in self.__dict__.items(): for k, v in self.__dict__.items():
if isinstance(v, np.ndarray): if isinstance(v, np.ndarray):
v = torch.from_numpy(v).to(device) v = torch.from_numpy(v).to(device)
if dtype is not None: if dtype is not None:
v = v.type(dtype) v = v.type(dtype)
self.__dict__[k] = v self.__dict__[k] = v
if isinstance(v, torch.Tensor):
if dtype is not None and v.dtype != dtype:
must_update_tensor = True
elif v.device.type != device.type:
must_update_tensor = True
elif device.index is not None and \
device.index != v.device.index:
must_update_tensor = True
else:
must_update_tensor = False
if must_update_tensor:
if dtype is not None:
v = v.type(dtype)
self.__dict__[k] = v.to(device)
elif isinstance(v, Batch): elif isinstance(v, Batch):
v.to_torch() v.to_torch(dtype, device)
def append(self, batch: 'Batch') -> None: def append(self, batch: 'Batch') -> None:
"""Append a :class:`~tianshou.data.Batch` object to current batch.""" """Append a :class:`~tianshou.data.Batch` object to current batch."""

View File

@ -28,9 +28,13 @@ def to_torch(x: Union[torch.Tensor, dict, Batch, np.ndarray],
x = torch.from_numpy(x).to(device) x = torch.from_numpy(x).to(device)
if dtype is not None: if dtype is not None:
x = x.type(dtype) x = x.type(dtype)
if isinstance(x, torch.Tensor):
if dtype is not None:
x = x.type(dtype)
x = x.to(device)
elif isinstance(x, dict): elif isinstance(x, dict):
for k, v in x.items(): for k, v in x.items():
x[k] = to_torch(v, dtype, device) x[k] = to_torch(v, dtype, device)
elif isinstance(x, Batch): elif isinstance(x, Batch):
x.to_torch() x.to_torch(dtype, device)
return x return x