Improve to_torch/to_numpy converters (#147)
* Enable converting list/tuple back and forth from/to numpy/torch. * Add fallbacks. * Fix PEP8 * Update unit tests. * Type annotation. Robust dtype check. * List of object are converted individually, as a single tensor otherwise. * Improve robustness of _to_array_with_correct_type * Add unit tests. * Do not catch exception at _to_array_with_correct_type level. * Use _parse_value * Fix PEP8 * Fix _parse_value list output type fallback. * Catch torch exception. * Do not convert torch tensor during fallback. * Improve unit tests. * Add unit tests. * FIx missing import * Remove support of numpy arrays of tensors for Batch value parser. * Forbid numpy arrays of tensors. * Fix PEP8. * Fix comment. * Reduce _parse_value branch number. * Fix None value. * Forward error message for debugging purpose. * Fix _is_scalar. * More specific try/catch blocks. * Fix exception chaining. * Fix PEP8. * Fix _is_scalar. * Fix missing corner case. * Fix PEP8. * Allow Batch empty key. * Fix multi-dim array datatype check. Co-authored-by: Alexis Duburcq <alexis.duburcq@wandercraft.eu>
This commit is contained in:
parent
8c32d99c65
commit
865ef6c693
@ -3,8 +3,9 @@ import copy
|
||||
import pickle
|
||||
import pytest
|
||||
import numpy as np
|
||||
from itertools import starmap
|
||||
|
||||
from tianshou.data import Batch, to_torch
|
||||
from tianshou.data import Batch, to_torch, to_numpy
|
||||
|
||||
|
||||
def test_batch():
|
||||
@ -28,8 +29,19 @@ def test_batch():
|
||||
assert b.a == 3
|
||||
with pytest.raises(AssertionError):
|
||||
Batch({1: 2})
|
||||
with pytest.raises(TypeError):
|
||||
Batch(a=[np.zeros((2, 3)), np.zeros((3, 3))])
|
||||
with pytest.raises(TypeError):
|
||||
Batch(a=[np.zeros((3, 2)), np.zeros((3, 3))])
|
||||
with pytest.raises(TypeError):
|
||||
Batch(a=[torch.zeros((2, 3)), torch.zeros((3, 3))])
|
||||
with pytest.raises(TypeError):
|
||||
Batch(a=[torch.zeros((3, 3)), np.zeros((3, 3))])
|
||||
with pytest.raises(TypeError):
|
||||
Batch(a=[1, np.zeros((3, 3)), torch.zeros((3, 3))])
|
||||
batch = Batch(a=[torch.ones(3), torch.ones(3)])
|
||||
assert torch.allclose(batch.a, torch.ones(2, 3))
|
||||
Batch(a=[])
|
||||
batch = Batch(obs=[0], np=np.zeros([3, 4]))
|
||||
assert batch.obs == batch["obs"]
|
||||
batch.obs = [1]
|
||||
@ -307,7 +319,7 @@ def test_batch_over_batch_to_torch():
|
||||
assert batch.b.d.dtype == torch.float32
|
||||
|
||||
|
||||
def test_utils_to_torch():
|
||||
def test_utils_to_torch_numpy():
|
||||
batch = Batch(
|
||||
a=np.float64(1.0),
|
||||
b=Batch(
|
||||
@ -323,8 +335,37 @@ def test_utils_to_torch():
|
||||
assert batch_torch_float.a.dtype == torch.float32
|
||||
assert batch_torch_float.b.c.dtype == torch.float32
|
||||
assert batch_torch_float.b.d.dtype == torch.float32
|
||||
array_list = [float('nan'), 1.0]
|
||||
assert to_torch(array_list).dtype == torch.float64
|
||||
data_list = [float('nan'), 1]
|
||||
data_list_torch = to_torch(data_list)
|
||||
assert data_list_torch.dtype == torch.float64
|
||||
data_list_2 = [np.random.rand(3, 3), np.random.rand(3, 3)]
|
||||
data_list_2_torch = to_torch(data_list_2)
|
||||
assert data_list_2_torch.shape == (2, 3, 3)
|
||||
assert np.allclose(to_numpy(to_torch(data_list_2)), data_list_2)
|
||||
data_list_3 = [np.zeros((3, 2)), np.zeros((3, 3))]
|
||||
data_list_3_torch = to_torch(data_list_3)
|
||||
assert isinstance(data_list_3_torch, list)
|
||||
assert all(isinstance(e, torch.Tensor) for e in data_list_3_torch)
|
||||
assert all(starmap(np.allclose,
|
||||
zip(to_numpy(to_torch(data_list_3)), data_list_3)))
|
||||
data_list_4 = [np.zeros((2, 3)), np.zeros((3, 3))]
|
||||
data_list_4_torch = to_torch(data_list_4)
|
||||
assert isinstance(data_list_4_torch, list)
|
||||
assert all(isinstance(e, torch.Tensor) for e in data_list_4_torch)
|
||||
assert all(starmap(np.allclose,
|
||||
zip(to_numpy(to_torch(data_list_4)), data_list_4)))
|
||||
data_list_5 = [np.zeros(2), np.zeros((3, 3))]
|
||||
data_list_5_torch = to_torch(data_list_5)
|
||||
assert isinstance(data_list_5_torch, list)
|
||||
assert all(isinstance(e, torch.Tensor) for e in data_list_5_torch)
|
||||
data_array = np.random.rand(3, 2, 2)
|
||||
data_empty_tensor = to_torch(data_array[[]])
|
||||
assert isinstance(data_empty_tensor, torch.Tensor)
|
||||
assert data_empty_tensor.shape == (0, 2, 2)
|
||||
data_empty_array = to_numpy(data_empty_tensor)
|
||||
assert isinstance(data_empty_array, np.ndarray)
|
||||
assert data_empty_array.shape == (0, 2, 2)
|
||||
assert np.allclose(to_numpy(to_torch(data_array)), data_array)
|
||||
|
||||
|
||||
def test_batch_pickle():
|
||||
@ -432,7 +473,7 @@ if __name__ == '__main__':
|
||||
test_batch()
|
||||
test_batch_over_batch()
|
||||
test_batch_over_batch_to_torch()
|
||||
test_utils_to_torch()
|
||||
test_utils_to_torch_numpy()
|
||||
test_batch_pickle()
|
||||
test_batch_from_to_numpy_without_copy()
|
||||
test_batch_standard_compatibility()
|
||||
|
@ -4,6 +4,7 @@ import warnings
|
||||
import numpy as np
|
||||
from copy import deepcopy
|
||||
from numbers import Number
|
||||
from collections.abc import Collection
|
||||
from typing import Any, List, Tuple, Union, Iterator, Optional
|
||||
|
||||
# Disable pickle warning related to torch, since it has been removed
|
||||
@ -36,8 +37,11 @@ def _is_scalar(value: Any) -> bool:
|
||||
# 3. python object rather than dict / Batch / tensor
|
||||
# the check of dict / Batch is omitted because this only checks a value.
|
||||
# a dict / Batch will eventually check their values
|
||||
value = np.asanyarray(value)
|
||||
return value.size == 1 and not value.shape
|
||||
if isinstance(value, torch.Tensor):
|
||||
return value.numel() == 1 and not value.shape
|
||||
else:
|
||||
value = np.asanyarray(value)
|
||||
return value.size == 1 and not value.shape
|
||||
|
||||
|
||||
def _is_number(value: Any) -> bool:
|
||||
@ -53,16 +57,21 @@ def _is_number(value: Any) -> bool:
|
||||
def _to_array_with_correct_type(v: Any) -> np.ndarray:
|
||||
# convert the value to np.ndarray
|
||||
# convert to np.object data type if neither bool nor number
|
||||
# raises an exception if array's elements are tensors themself
|
||||
v = np.asanyarray(v)
|
||||
if not issubclass(v.dtype.type, (np.bool_, np.number)):
|
||||
v = v.astype(np.object)
|
||||
if v.dtype == np.object and not v.shape:
|
||||
if v.dtype == np.object:
|
||||
# scalar ndarray with np.object data type is very annoying
|
||||
# a=np.array([np.array({}, dtype=object), np.array({}, dtype=object)])
|
||||
# a is not array([{}, {}], dtype=object), and a[0]={} results in
|
||||
# something very strange:
|
||||
# array([{}, array({}, dtype=object)], dtype=object)
|
||||
v = v.item(0)
|
||||
if not v.shape:
|
||||
v = v.item(0)
|
||||
elif any(isinstance(e, (np.ndarray, torch.Tensor))
|
||||
for e in v.reshape(-1)):
|
||||
raise ValueError("Numpy arrays of tensors are not supported yet.")
|
||||
return v
|
||||
|
||||
|
||||
@ -113,25 +122,29 @@ def _assert_type_keys(keys):
|
||||
|
||||
|
||||
def _parse_value(v: Any):
|
||||
if isinstance(v, (list, tuple, np.ndarray)):
|
||||
if not isinstance(v, np.ndarray) and \
|
||||
all(isinstance(e, torch.Tensor) for e in v):
|
||||
v = torch.stack(v)
|
||||
return v
|
||||
v_ = _to_array_with_correct_type(v)
|
||||
if v_.dtype == np.object and _is_batch_set(v):
|
||||
v = Batch(v) # list of dict / Batch
|
||||
else:
|
||||
# normal data list (main case)
|
||||
# or actually a data list with objects
|
||||
v = v_
|
||||
elif isinstance(v, dict):
|
||||
if isinstance(v, dict):
|
||||
v = Batch(v)
|
||||
elif isinstance(v, (Batch, torch.Tensor)):
|
||||
pass
|
||||
else:
|
||||
# scalar case, convert to ndarray
|
||||
v = _to_array_with_correct_type(v)
|
||||
if not isinstance(v, np.ndarray) and isinstance(v, Collection) and \
|
||||
len(v) > 0 and all(isinstance(e, torch.Tensor) for e in v):
|
||||
try:
|
||||
return torch.stack(v)
|
||||
except RuntimeError as e:
|
||||
raise TypeError("Batch does not support non-stackable iterable"
|
||||
" of torch.Tensor as unique value yet.") from e
|
||||
try:
|
||||
v_ = _to_array_with_correct_type(v)
|
||||
except ValueError as e:
|
||||
raise TypeError("Batch does not support heterogeneous list/tuple"
|
||||
" of tensors as unique value yet.") from e
|
||||
if _is_batch_set(v):
|
||||
v = Batch(v) # list of dict / Batch
|
||||
else:
|
||||
# None, scalar, normal data list (main case)
|
||||
# or an actual list of objects
|
||||
v = v_
|
||||
return v
|
||||
|
||||
|
||||
|
@ -3,12 +3,12 @@ import numpy as np
|
||||
from numbers import Number
|
||||
from typing import Union, Optional
|
||||
|
||||
from tianshou.data import Batch
|
||||
from tianshou.data.batch import _parse_value, Batch
|
||||
|
||||
|
||||
def to_numpy(x: Union[
|
||||
torch.Tensor, dict, Batch, np.ndarray]) -> Union[
|
||||
dict, Batch, np.ndarray]:
|
||||
Batch, dict, list, tuple, np.ndarray, torch.Tensor]) -> Union[
|
||||
Batch, dict, list, tuple, np.ndarray, torch.Tensor]:
|
||||
"""Return an object without torch.Tensor."""
|
||||
if isinstance(x, torch.Tensor):
|
||||
x = x.detach().cpu().numpy()
|
||||
@ -17,13 +17,20 @@ def to_numpy(x: Union[
|
||||
x[k] = to_numpy(v)
|
||||
elif isinstance(x, Batch):
|
||||
x.to_numpy()
|
||||
elif isinstance(x, (list, tuple)):
|
||||
try:
|
||||
x = to_numpy(_parse_value(x))
|
||||
except TypeError:
|
||||
x = [to_numpy(e) for e in x]
|
||||
else: # fallback
|
||||
x = np.asanyarray(x)
|
||||
return x
|
||||
|
||||
|
||||
def to_torch(x: Union[torch.Tensor, dict, Batch, np.ndarray],
|
||||
def to_torch(x: Union[Batch, dict, list, tuple, np.ndarray, torch.Tensor],
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
device: Union[str, int, torch.device] = 'cpu'
|
||||
) -> Union[dict, Batch, torch.Tensor]:
|
||||
) -> Union[Batch, dict, list, tuple, np.ndarray, torch.Tensor]:
|
||||
"""Return an object without np.ndarray."""
|
||||
if isinstance(x, torch.Tensor):
|
||||
if dtype is not None:
|
||||
@ -36,14 +43,19 @@ def to_torch(x: Union[torch.Tensor, dict, Batch, np.ndarray],
|
||||
x.to_torch(dtype, device)
|
||||
elif isinstance(x, (np.number, np.bool_, Number)):
|
||||
x = to_torch(np.asanyarray(x), dtype, device)
|
||||
elif isinstance(x, list) and len(x) > 0 and \
|
||||
all(isinstance(e, (np.number, np.bool_, Number)) for e in x):
|
||||
x = to_torch(np.asanyarray(x), dtype, device)
|
||||
elif isinstance(x, np.ndarray) and \
|
||||
isinstance(x.item(0), (np.number, np.bool_, Number)):
|
||||
x = torch.from_numpy(x).to(device)
|
||||
if dtype is not None:
|
||||
x = x.type(dtype)
|
||||
elif isinstance(x, (list, tuple)):
|
||||
try:
|
||||
x = to_torch(_parse_value(x), dtype, device)
|
||||
except TypeError:
|
||||
x = [to_torch(e, dtype, device) for e in x]
|
||||
else: # fallback
|
||||
x = np.asanyarray(x)
|
||||
if issubclass(x.dtype.type, (np.bool_, np.number)):
|
||||
x = torch.from_numpy(x).to(device)
|
||||
if dtype is not None:
|
||||
x = x.type(dtype)
|
||||
else:
|
||||
raise TypeError(f"object {x} cannot be converted to torch.")
|
||||
return x
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user