Fix 'to_tensor' dtype/device forwarding for Batch over Batch. (#68)
* Fix Batch to_torch method not updating dtype/device of already converted data. * Fix dtype/device to forwarded by to_tensor for Batch over Batch. * Add Unit test to check to_torch dtype/device recursive forwarding. * Batch UT check accessing data using both dict and class style. * Fix utils to_tensor dtype/device forwarding. Add Unit tests. * Fix UT. Co-authored-by: Alexis Duburcq <alexis.duburcq@wandercraft.eu> Co-authored-by: n+e <463003665@qq.com>
This commit is contained in:
parent
529a4cf44c
commit
1fce527c77
@ -3,11 +3,12 @@ import pickle
|
|||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from tianshou.data import Batch
|
from tianshou.data import Batch, to_torch
|
||||||
|
|
||||||
|
|
||||||
def test_batch():
|
def test_batch():
|
||||||
batch = Batch(obs=[0], np=np.zeros([3, 4]))
|
batch = Batch(obs=[0], np=np.zeros([3, 4]))
|
||||||
|
assert batch.obs == batch["obs"]
|
||||||
batch.obs = [1]
|
batch.obs = [1]
|
||||||
assert batch.obs == [1]
|
assert batch.obs == [1]
|
||||||
batch.append(batch)
|
batch.append(batch)
|
||||||
@ -31,6 +32,45 @@ def test_batch_over_batch():
|
|||||||
assert batch2[-1].b.b == 0
|
assert batch2[-1].b.b == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_batch_over_batch_to_torch():
|
||||||
|
batch = Batch(
|
||||||
|
a=np.ones((1,), dtype=np.float64),
|
||||||
|
b=Batch(
|
||||||
|
c=np.ones((1,), dtype=np.float64),
|
||||||
|
d=torch.ones((1,), dtype=torch.float64)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
batch.to_torch()
|
||||||
|
assert isinstance(batch.a, torch.Tensor)
|
||||||
|
assert isinstance(batch.b.c, torch.Tensor)
|
||||||
|
assert isinstance(batch.b.d, torch.Tensor)
|
||||||
|
assert batch.a.dtype == torch.float64
|
||||||
|
assert batch.b.c.dtype == torch.float64
|
||||||
|
assert batch.b.d.dtype == torch.float64
|
||||||
|
batch.to_torch(dtype=torch.float32)
|
||||||
|
assert batch.a.dtype == torch.float32
|
||||||
|
assert batch.b.c.dtype == torch.float32
|
||||||
|
assert batch.b.d.dtype == torch.float32
|
||||||
|
|
||||||
|
|
||||||
|
def test_utils_to_torch():
|
||||||
|
batch = Batch(
|
||||||
|
a=np.ones((1,), dtype=np.float64),
|
||||||
|
b=Batch(
|
||||||
|
c=np.ones((1,), dtype=np.float64),
|
||||||
|
d=torch.ones((1,), dtype=torch.float64)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
a_torch_float = to_torch(batch.a, dtype=torch.float32)
|
||||||
|
assert a_torch_float.dtype == torch.float32
|
||||||
|
a_torch_double = to_torch(batch.a, dtype=torch.float64)
|
||||||
|
assert a_torch_double.dtype == torch.float64
|
||||||
|
batch_torch_float = to_torch(batch, dtype=torch.float32)
|
||||||
|
assert batch_torch_float.a.dtype == torch.float32
|
||||||
|
assert batch_torch_float.b.c.dtype == torch.float32
|
||||||
|
assert batch_torch_float.b.d.dtype == torch.float32
|
||||||
|
|
||||||
|
|
||||||
def test_batch_pickle():
|
def test_batch_pickle():
|
||||||
batch = Batch(obs=Batch(a=0.0, c=torch.Tensor([1.0, 2.0])),
|
batch = Batch(obs=Batch(a=0.0, c=torch.Tensor([1.0, 2.0])),
|
||||||
np=np.zeros([3, 4]))
|
np=np.zeros([3, 4]))
|
||||||
@ -42,14 +82,12 @@ def test_batch_pickle():
|
|||||||
|
|
||||||
def test_batch_from_to_numpy_without_copy():
|
def test_batch_from_to_numpy_without_copy():
|
||||||
batch = Batch(a=np.ones((1,)), b=Batch(c=np.ones((1,))))
|
batch = Batch(a=np.ones((1,)), b=Batch(c=np.ones((1,))))
|
||||||
a_mem_addr_orig = batch["a"].__array_interface__['data'][0]
|
a_mem_addr_orig = batch.a.__array_interface__['data'][0]
|
||||||
c_mem_addr_orig = batch["b"]["c"].__array_interface__['data'][0]
|
c_mem_addr_orig = batch.b.c.__array_interface__['data'][0]
|
||||||
batch.to_torch()
|
batch.to_torch()
|
||||||
assert isinstance(batch["a"], torch.Tensor)
|
|
||||||
assert isinstance(batch["b"]["c"], torch.Tensor)
|
|
||||||
batch.to_numpy()
|
batch.to_numpy()
|
||||||
a_mem_addr_new = batch["a"].__array_interface__['data'][0]
|
a_mem_addr_new = batch.a.__array_interface__['data'][0]
|
||||||
c_mem_addr_new = batch["b"]["c"].__array_interface__['data'][0]
|
c_mem_addr_new = batch.b.c.__array_interface__['data'][0]
|
||||||
assert a_mem_addr_new == a_mem_addr_orig
|
assert a_mem_addr_new == a_mem_addr_orig
|
||||||
assert c_mem_addr_new == c_mem_addr_orig
|
assert c_mem_addr_new == c_mem_addr_orig
|
||||||
|
|
||||||
|
@ -183,19 +183,36 @@ class Batch:
|
|||||||
|
|
||||||
def to_torch(self,
|
def to_torch(self,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
device: Union[str, int] = 'cpu'
|
device: Union[str, int, torch.device] = 'cpu'
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Change all numpy.ndarray to torch.Tensor. This is an inplace
|
"""Change all numpy.ndarray to torch.Tensor. This is an inplace
|
||||||
operation.
|
operation.
|
||||||
"""
|
"""
|
||||||
|
if not isinstance(device, torch.device):
|
||||||
|
device = torch.device(device)
|
||||||
|
|
||||||
for k, v in self.__dict__.items():
|
for k, v in self.__dict__.items():
|
||||||
if isinstance(v, np.ndarray):
|
if isinstance(v, np.ndarray):
|
||||||
v = torch.from_numpy(v).to(device)
|
v = torch.from_numpy(v).to(device)
|
||||||
if dtype is not None:
|
if dtype is not None:
|
||||||
v = v.type(dtype)
|
v = v.type(dtype)
|
||||||
self.__dict__[k] = v
|
self.__dict__[k] = v
|
||||||
|
if isinstance(v, torch.Tensor):
|
||||||
|
if dtype is not None and v.dtype != dtype:
|
||||||
|
must_update_tensor = True
|
||||||
|
elif v.device.type != device.type:
|
||||||
|
must_update_tensor = True
|
||||||
|
elif device.index is not None and \
|
||||||
|
device.index != v.device.index:
|
||||||
|
must_update_tensor = True
|
||||||
|
else:
|
||||||
|
must_update_tensor = False
|
||||||
|
if must_update_tensor:
|
||||||
|
if dtype is not None:
|
||||||
|
v = v.type(dtype)
|
||||||
|
self.__dict__[k] = v.to(device)
|
||||||
elif isinstance(v, Batch):
|
elif isinstance(v, Batch):
|
||||||
v.to_torch()
|
v.to_torch(dtype, device)
|
||||||
|
|
||||||
def append(self, batch: 'Batch') -> None:
|
def append(self, batch: 'Batch') -> None:
|
||||||
"""Append a :class:`~tianshou.data.Batch` object to current batch."""
|
"""Append a :class:`~tianshou.data.Batch` object to current batch."""
|
||||||
|
@ -28,9 +28,13 @@ def to_torch(x: Union[torch.Tensor, dict, Batch, np.ndarray],
|
|||||||
x = torch.from_numpy(x).to(device)
|
x = torch.from_numpy(x).to(device)
|
||||||
if dtype is not None:
|
if dtype is not None:
|
||||||
x = x.type(dtype)
|
x = x.type(dtype)
|
||||||
|
if isinstance(x, torch.Tensor):
|
||||||
|
if dtype is not None:
|
||||||
|
x = x.type(dtype)
|
||||||
|
x = x.to(device)
|
||||||
elif isinstance(x, dict):
|
elif isinstance(x, dict):
|
||||||
for k, v in x.items():
|
for k, v in x.items():
|
||||||
x[k] = to_torch(v, dtype, device)
|
x[k] = to_torch(v, dtype, device)
|
||||||
elif isinstance(x, Batch):
|
elif isinstance(x, Batch):
|
||||||
x.to_torch()
|
x.to_torch(dtype, device)
|
||||||
return x
|
return x
|
||||||
|
Loading…
x
Reference in New Issue
Block a user