Fix 'to_tensor' dtype/device forwarding for Batch over Batch. (#68)
* Fix Batch to_torch method not updating dtype/device of already converted data. * Fix dtype/device to forwarded by to_tensor for Batch over Batch. * Add Unit test to check to_torch dtype/device recursive forwarding. * Batch UT check accessing data using both dict and class style. * Fix utils to_tensor dtype/device forwarding. Add Unit tests. * Fix UT. Co-authored-by: Alexis Duburcq <alexis.duburcq@wandercraft.eu> Co-authored-by: n+e <463003665@qq.com>
This commit is contained in:
parent
529a4cf44c
commit
1fce527c77
@ -3,11 +3,12 @@ import pickle
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
from tianshou.data import Batch
|
||||
from tianshou.data import Batch, to_torch
|
||||
|
||||
|
||||
def test_batch():
|
||||
batch = Batch(obs=[0], np=np.zeros([3, 4]))
|
||||
assert batch.obs == batch["obs"]
|
||||
batch.obs = [1]
|
||||
assert batch.obs == [1]
|
||||
batch.append(batch)
|
||||
@ -31,6 +32,45 @@ def test_batch_over_batch():
|
||||
assert batch2[-1].b.b == 0
|
||||
|
||||
|
||||
def test_batch_over_batch_to_torch():
|
||||
batch = Batch(
|
||||
a=np.ones((1,), dtype=np.float64),
|
||||
b=Batch(
|
||||
c=np.ones((1,), dtype=np.float64),
|
||||
d=torch.ones((1,), dtype=torch.float64)
|
||||
)
|
||||
)
|
||||
batch.to_torch()
|
||||
assert isinstance(batch.a, torch.Tensor)
|
||||
assert isinstance(batch.b.c, torch.Tensor)
|
||||
assert isinstance(batch.b.d, torch.Tensor)
|
||||
assert batch.a.dtype == torch.float64
|
||||
assert batch.b.c.dtype == torch.float64
|
||||
assert batch.b.d.dtype == torch.float64
|
||||
batch.to_torch(dtype=torch.float32)
|
||||
assert batch.a.dtype == torch.float32
|
||||
assert batch.b.c.dtype == torch.float32
|
||||
assert batch.b.d.dtype == torch.float32
|
||||
|
||||
|
||||
def test_utils_to_torch():
|
||||
batch = Batch(
|
||||
a=np.ones((1,), dtype=np.float64),
|
||||
b=Batch(
|
||||
c=np.ones((1,), dtype=np.float64),
|
||||
d=torch.ones((1,), dtype=torch.float64)
|
||||
)
|
||||
)
|
||||
a_torch_float = to_torch(batch.a, dtype=torch.float32)
|
||||
assert a_torch_float.dtype == torch.float32
|
||||
a_torch_double = to_torch(batch.a, dtype=torch.float64)
|
||||
assert a_torch_double.dtype == torch.float64
|
||||
batch_torch_float = to_torch(batch, dtype=torch.float32)
|
||||
assert batch_torch_float.a.dtype == torch.float32
|
||||
assert batch_torch_float.b.c.dtype == torch.float32
|
||||
assert batch_torch_float.b.d.dtype == torch.float32
|
||||
|
||||
|
||||
def test_batch_pickle():
|
||||
batch = Batch(obs=Batch(a=0.0, c=torch.Tensor([1.0, 2.0])),
|
||||
np=np.zeros([3, 4]))
|
||||
@ -42,14 +82,12 @@ def test_batch_pickle():
|
||||
|
||||
def test_batch_from_to_numpy_without_copy():
|
||||
batch = Batch(a=np.ones((1,)), b=Batch(c=np.ones((1,))))
|
||||
a_mem_addr_orig = batch["a"].__array_interface__['data'][0]
|
||||
c_mem_addr_orig = batch["b"]["c"].__array_interface__['data'][0]
|
||||
a_mem_addr_orig = batch.a.__array_interface__['data'][0]
|
||||
c_mem_addr_orig = batch.b.c.__array_interface__['data'][0]
|
||||
batch.to_torch()
|
||||
assert isinstance(batch["a"], torch.Tensor)
|
||||
assert isinstance(batch["b"]["c"], torch.Tensor)
|
||||
batch.to_numpy()
|
||||
a_mem_addr_new = batch["a"].__array_interface__['data'][0]
|
||||
c_mem_addr_new = batch["b"]["c"].__array_interface__['data'][0]
|
||||
a_mem_addr_new = batch.a.__array_interface__['data'][0]
|
||||
c_mem_addr_new = batch.b.c.__array_interface__['data'][0]
|
||||
assert a_mem_addr_new == a_mem_addr_orig
|
||||
assert c_mem_addr_new == c_mem_addr_orig
|
||||
|
||||
|
@ -183,19 +183,36 @@ class Batch:
|
||||
|
||||
def to_torch(self,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
device: Union[str, int] = 'cpu'
|
||||
device: Union[str, int, torch.device] = 'cpu'
|
||||
) -> None:
|
||||
"""Change all numpy.ndarray to torch.Tensor. This is an inplace
|
||||
operation.
|
||||
"""
|
||||
if not isinstance(device, torch.device):
|
||||
device = torch.device(device)
|
||||
|
||||
for k, v in self.__dict__.items():
|
||||
if isinstance(v, np.ndarray):
|
||||
v = torch.from_numpy(v).to(device)
|
||||
if dtype is not None:
|
||||
v = v.type(dtype)
|
||||
self.__dict__[k] = v
|
||||
if isinstance(v, torch.Tensor):
|
||||
if dtype is not None and v.dtype != dtype:
|
||||
must_update_tensor = True
|
||||
elif v.device.type != device.type:
|
||||
must_update_tensor = True
|
||||
elif device.index is not None and \
|
||||
device.index != v.device.index:
|
||||
must_update_tensor = True
|
||||
else:
|
||||
must_update_tensor = False
|
||||
if must_update_tensor:
|
||||
if dtype is not None:
|
||||
v = v.type(dtype)
|
||||
self.__dict__[k] = v.to(device)
|
||||
elif isinstance(v, Batch):
|
||||
v.to_torch()
|
||||
v.to_torch(dtype, device)
|
||||
|
||||
def append(self, batch: 'Batch') -> None:
|
||||
"""Append a :class:`~tianshou.data.Batch` object to current batch."""
|
||||
|
@ -28,9 +28,13 @@ def to_torch(x: Union[torch.Tensor, dict, Batch, np.ndarray],
|
||||
x = torch.from_numpy(x).to(device)
|
||||
if dtype is not None:
|
||||
x = x.type(dtype)
|
||||
if isinstance(x, torch.Tensor):
|
||||
if dtype is not None:
|
||||
x = x.type(dtype)
|
||||
x = x.to(device)
|
||||
elif isinstance(x, dict):
|
||||
for k, v in x.items():
|
||||
x[k] = to_torch(v, dtype, device)
|
||||
elif isinstance(x, Batch):
|
||||
x.to_torch()
|
||||
x.to_torch(dtype, device)
|
||||
return x
|
||||
|
Loading…
x
Reference in New Issue
Block a user