Improve to_torch/to_numpy converters (#147)
* Enable converting list/tuple back and forth from/to numpy/torch. * Add fallbacks. * Fix PEP8 * Update unit tests. * Type annotation. Robust dtype check. * List of object are converted individually, as a single tensor otherwise. * Improve robustness of _to_array_with_correct_type * Add unit tests. * Do not catch exception at _to_array_with_correct_type level. * Use _parse_value * Fix PEP8 * Fix _parse_value list output type fallback. * Catch torch exception. * Do not convert torch tensor during fallback. * Improve unit tests. * Add unit tests. * FIx missing import * Remove support of numpy arrays of tensors for Batch value parser. * Forbid numpy arrays of tensors. * Fix PEP8. * Fix comment. * Reduce _parse_value branch number. * Fix None value. * Forward error message for debugging purpose. * Fix _is_scalar. * More specific try/catch blocks. * Fix exception chaining. * Fix PEP8. * Fix _is_scalar. * Fix missing corner case. * Fix PEP8. * Allow Batch empty key. * Fix multi-dim array datatype check. Co-authored-by: Alexis Duburcq <alexis.duburcq@wandercraft.eu>
This commit is contained in:
parent
8c32d99c65
commit
865ef6c693
@ -3,8 +3,9 @@ import copy
|
|||||||
import pickle
|
import pickle
|
||||||
import pytest
|
import pytest
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from itertools import starmap
|
||||||
|
|
||||||
from tianshou.data import Batch, to_torch
|
from tianshou.data import Batch, to_torch, to_numpy
|
||||||
|
|
||||||
|
|
||||||
def test_batch():
|
def test_batch():
|
||||||
@ -28,8 +29,19 @@ def test_batch():
|
|||||||
assert b.a == 3
|
assert b.a == 3
|
||||||
with pytest.raises(AssertionError):
|
with pytest.raises(AssertionError):
|
||||||
Batch({1: 2})
|
Batch({1: 2})
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
Batch(a=[np.zeros((2, 3)), np.zeros((3, 3))])
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
Batch(a=[np.zeros((3, 2)), np.zeros((3, 3))])
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
Batch(a=[torch.zeros((2, 3)), torch.zeros((3, 3))])
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
Batch(a=[torch.zeros((3, 3)), np.zeros((3, 3))])
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
Batch(a=[1, np.zeros((3, 3)), torch.zeros((3, 3))])
|
||||||
batch = Batch(a=[torch.ones(3), torch.ones(3)])
|
batch = Batch(a=[torch.ones(3), torch.ones(3)])
|
||||||
assert torch.allclose(batch.a, torch.ones(2, 3))
|
assert torch.allclose(batch.a, torch.ones(2, 3))
|
||||||
|
Batch(a=[])
|
||||||
batch = Batch(obs=[0], np=np.zeros([3, 4]))
|
batch = Batch(obs=[0], np=np.zeros([3, 4]))
|
||||||
assert batch.obs == batch["obs"]
|
assert batch.obs == batch["obs"]
|
||||||
batch.obs = [1]
|
batch.obs = [1]
|
||||||
@ -307,7 +319,7 @@ def test_batch_over_batch_to_torch():
|
|||||||
assert batch.b.d.dtype == torch.float32
|
assert batch.b.d.dtype == torch.float32
|
||||||
|
|
||||||
|
|
||||||
def test_utils_to_torch():
|
def test_utils_to_torch_numpy():
|
||||||
batch = Batch(
|
batch = Batch(
|
||||||
a=np.float64(1.0),
|
a=np.float64(1.0),
|
||||||
b=Batch(
|
b=Batch(
|
||||||
@ -323,8 +335,37 @@ def test_utils_to_torch():
|
|||||||
assert batch_torch_float.a.dtype == torch.float32
|
assert batch_torch_float.a.dtype == torch.float32
|
||||||
assert batch_torch_float.b.c.dtype == torch.float32
|
assert batch_torch_float.b.c.dtype == torch.float32
|
||||||
assert batch_torch_float.b.d.dtype == torch.float32
|
assert batch_torch_float.b.d.dtype == torch.float32
|
||||||
array_list = [float('nan'), 1.0]
|
data_list = [float('nan'), 1]
|
||||||
assert to_torch(array_list).dtype == torch.float64
|
data_list_torch = to_torch(data_list)
|
||||||
|
assert data_list_torch.dtype == torch.float64
|
||||||
|
data_list_2 = [np.random.rand(3, 3), np.random.rand(3, 3)]
|
||||||
|
data_list_2_torch = to_torch(data_list_2)
|
||||||
|
assert data_list_2_torch.shape == (2, 3, 3)
|
||||||
|
assert np.allclose(to_numpy(to_torch(data_list_2)), data_list_2)
|
||||||
|
data_list_3 = [np.zeros((3, 2)), np.zeros((3, 3))]
|
||||||
|
data_list_3_torch = to_torch(data_list_3)
|
||||||
|
assert isinstance(data_list_3_torch, list)
|
||||||
|
assert all(isinstance(e, torch.Tensor) for e in data_list_3_torch)
|
||||||
|
assert all(starmap(np.allclose,
|
||||||
|
zip(to_numpy(to_torch(data_list_3)), data_list_3)))
|
||||||
|
data_list_4 = [np.zeros((2, 3)), np.zeros((3, 3))]
|
||||||
|
data_list_4_torch = to_torch(data_list_4)
|
||||||
|
assert isinstance(data_list_4_torch, list)
|
||||||
|
assert all(isinstance(e, torch.Tensor) for e in data_list_4_torch)
|
||||||
|
assert all(starmap(np.allclose,
|
||||||
|
zip(to_numpy(to_torch(data_list_4)), data_list_4)))
|
||||||
|
data_list_5 = [np.zeros(2), np.zeros((3, 3))]
|
||||||
|
data_list_5_torch = to_torch(data_list_5)
|
||||||
|
assert isinstance(data_list_5_torch, list)
|
||||||
|
assert all(isinstance(e, torch.Tensor) for e in data_list_5_torch)
|
||||||
|
data_array = np.random.rand(3, 2, 2)
|
||||||
|
data_empty_tensor = to_torch(data_array[[]])
|
||||||
|
assert isinstance(data_empty_tensor, torch.Tensor)
|
||||||
|
assert data_empty_tensor.shape == (0, 2, 2)
|
||||||
|
data_empty_array = to_numpy(data_empty_tensor)
|
||||||
|
assert isinstance(data_empty_array, np.ndarray)
|
||||||
|
assert data_empty_array.shape == (0, 2, 2)
|
||||||
|
assert np.allclose(to_numpy(to_torch(data_array)), data_array)
|
||||||
|
|
||||||
|
|
||||||
def test_batch_pickle():
|
def test_batch_pickle():
|
||||||
@ -432,7 +473,7 @@ if __name__ == '__main__':
|
|||||||
test_batch()
|
test_batch()
|
||||||
test_batch_over_batch()
|
test_batch_over_batch()
|
||||||
test_batch_over_batch_to_torch()
|
test_batch_over_batch_to_torch()
|
||||||
test_utils_to_torch()
|
test_utils_to_torch_numpy()
|
||||||
test_batch_pickle()
|
test_batch_pickle()
|
||||||
test_batch_from_to_numpy_without_copy()
|
test_batch_from_to_numpy_without_copy()
|
||||||
test_batch_standard_compatibility()
|
test_batch_standard_compatibility()
|
||||||
|
@ -4,6 +4,7 @@ import warnings
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from numbers import Number
|
from numbers import Number
|
||||||
|
from collections.abc import Collection
|
||||||
from typing import Any, List, Tuple, Union, Iterator, Optional
|
from typing import Any, List, Tuple, Union, Iterator, Optional
|
||||||
|
|
||||||
# Disable pickle warning related to torch, since it has been removed
|
# Disable pickle warning related to torch, since it has been removed
|
||||||
@ -36,8 +37,11 @@ def _is_scalar(value: Any) -> bool:
|
|||||||
# 3. python object rather than dict / Batch / tensor
|
# 3. python object rather than dict / Batch / tensor
|
||||||
# the check of dict / Batch is omitted because this only checks a value.
|
# the check of dict / Batch is omitted because this only checks a value.
|
||||||
# a dict / Batch will eventually check their values
|
# a dict / Batch will eventually check their values
|
||||||
value = np.asanyarray(value)
|
if isinstance(value, torch.Tensor):
|
||||||
return value.size == 1 and not value.shape
|
return value.numel() == 1 and not value.shape
|
||||||
|
else:
|
||||||
|
value = np.asanyarray(value)
|
||||||
|
return value.size == 1 and not value.shape
|
||||||
|
|
||||||
|
|
||||||
def _is_number(value: Any) -> bool:
|
def _is_number(value: Any) -> bool:
|
||||||
@ -53,16 +57,21 @@ def _is_number(value: Any) -> bool:
|
|||||||
def _to_array_with_correct_type(v: Any) -> np.ndarray:
|
def _to_array_with_correct_type(v: Any) -> np.ndarray:
|
||||||
# convert the value to np.ndarray
|
# convert the value to np.ndarray
|
||||||
# convert to np.object data type if neither bool nor number
|
# convert to np.object data type if neither bool nor number
|
||||||
|
# raises an exception if array's elements are tensors themself
|
||||||
v = np.asanyarray(v)
|
v = np.asanyarray(v)
|
||||||
if not issubclass(v.dtype.type, (np.bool_, np.number)):
|
if not issubclass(v.dtype.type, (np.bool_, np.number)):
|
||||||
v = v.astype(np.object)
|
v = v.astype(np.object)
|
||||||
if v.dtype == np.object and not v.shape:
|
if v.dtype == np.object:
|
||||||
# scalar ndarray with np.object data type is very annoying
|
# scalar ndarray with np.object data type is very annoying
|
||||||
# a=np.array([np.array({}, dtype=object), np.array({}, dtype=object)])
|
# a=np.array([np.array({}, dtype=object), np.array({}, dtype=object)])
|
||||||
# a is not array([{}, {}], dtype=object), and a[0]={} results in
|
# a is not array([{}, {}], dtype=object), and a[0]={} results in
|
||||||
# something very strange:
|
# something very strange:
|
||||||
# array([{}, array({}, dtype=object)], dtype=object)
|
# array([{}, array({}, dtype=object)], dtype=object)
|
||||||
v = v.item(0)
|
if not v.shape:
|
||||||
|
v = v.item(0)
|
||||||
|
elif any(isinstance(e, (np.ndarray, torch.Tensor))
|
||||||
|
for e in v.reshape(-1)):
|
||||||
|
raise ValueError("Numpy arrays of tensors are not supported yet.")
|
||||||
return v
|
return v
|
||||||
|
|
||||||
|
|
||||||
@ -113,25 +122,29 @@ def _assert_type_keys(keys):
|
|||||||
|
|
||||||
|
|
||||||
def _parse_value(v: Any):
|
def _parse_value(v: Any):
|
||||||
if isinstance(v, (list, tuple, np.ndarray)):
|
if isinstance(v, dict):
|
||||||
if not isinstance(v, np.ndarray) and \
|
|
||||||
all(isinstance(e, torch.Tensor) for e in v):
|
|
||||||
v = torch.stack(v)
|
|
||||||
return v
|
|
||||||
v_ = _to_array_with_correct_type(v)
|
|
||||||
if v_.dtype == np.object and _is_batch_set(v):
|
|
||||||
v = Batch(v) # list of dict / Batch
|
|
||||||
else:
|
|
||||||
# normal data list (main case)
|
|
||||||
# or actually a data list with objects
|
|
||||||
v = v_
|
|
||||||
elif isinstance(v, dict):
|
|
||||||
v = Batch(v)
|
v = Batch(v)
|
||||||
elif isinstance(v, (Batch, torch.Tensor)):
|
elif isinstance(v, (Batch, torch.Tensor)):
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
# scalar case, convert to ndarray
|
if not isinstance(v, np.ndarray) and isinstance(v, Collection) and \
|
||||||
v = _to_array_with_correct_type(v)
|
len(v) > 0 and all(isinstance(e, torch.Tensor) for e in v):
|
||||||
|
try:
|
||||||
|
return torch.stack(v)
|
||||||
|
except RuntimeError as e:
|
||||||
|
raise TypeError("Batch does not support non-stackable iterable"
|
||||||
|
" of torch.Tensor as unique value yet.") from e
|
||||||
|
try:
|
||||||
|
v_ = _to_array_with_correct_type(v)
|
||||||
|
except ValueError as e:
|
||||||
|
raise TypeError("Batch does not support heterogeneous list/tuple"
|
||||||
|
" of tensors as unique value yet.") from e
|
||||||
|
if _is_batch_set(v):
|
||||||
|
v = Batch(v) # list of dict / Batch
|
||||||
|
else:
|
||||||
|
# None, scalar, normal data list (main case)
|
||||||
|
# or an actual list of objects
|
||||||
|
v = v_
|
||||||
return v
|
return v
|
||||||
|
|
||||||
|
|
||||||
|
@ -3,12 +3,12 @@ import numpy as np
|
|||||||
from numbers import Number
|
from numbers import Number
|
||||||
from typing import Union, Optional
|
from typing import Union, Optional
|
||||||
|
|
||||||
from tianshou.data import Batch
|
from tianshou.data.batch import _parse_value, Batch
|
||||||
|
|
||||||
|
|
||||||
def to_numpy(x: Union[
|
def to_numpy(x: Union[
|
||||||
torch.Tensor, dict, Batch, np.ndarray]) -> Union[
|
Batch, dict, list, tuple, np.ndarray, torch.Tensor]) -> Union[
|
||||||
dict, Batch, np.ndarray]:
|
Batch, dict, list, tuple, np.ndarray, torch.Tensor]:
|
||||||
"""Return an object without torch.Tensor."""
|
"""Return an object without torch.Tensor."""
|
||||||
if isinstance(x, torch.Tensor):
|
if isinstance(x, torch.Tensor):
|
||||||
x = x.detach().cpu().numpy()
|
x = x.detach().cpu().numpy()
|
||||||
@ -17,13 +17,20 @@ def to_numpy(x: Union[
|
|||||||
x[k] = to_numpy(v)
|
x[k] = to_numpy(v)
|
||||||
elif isinstance(x, Batch):
|
elif isinstance(x, Batch):
|
||||||
x.to_numpy()
|
x.to_numpy()
|
||||||
|
elif isinstance(x, (list, tuple)):
|
||||||
|
try:
|
||||||
|
x = to_numpy(_parse_value(x))
|
||||||
|
except TypeError:
|
||||||
|
x = [to_numpy(e) for e in x]
|
||||||
|
else: # fallback
|
||||||
|
x = np.asanyarray(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
def to_torch(x: Union[torch.Tensor, dict, Batch, np.ndarray],
|
def to_torch(x: Union[Batch, dict, list, tuple, np.ndarray, torch.Tensor],
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
device: Union[str, int, torch.device] = 'cpu'
|
device: Union[str, int, torch.device] = 'cpu'
|
||||||
) -> Union[dict, Batch, torch.Tensor]:
|
) -> Union[Batch, dict, list, tuple, np.ndarray, torch.Tensor]:
|
||||||
"""Return an object without np.ndarray."""
|
"""Return an object without np.ndarray."""
|
||||||
if isinstance(x, torch.Tensor):
|
if isinstance(x, torch.Tensor):
|
||||||
if dtype is not None:
|
if dtype is not None:
|
||||||
@ -36,14 +43,19 @@ def to_torch(x: Union[torch.Tensor, dict, Batch, np.ndarray],
|
|||||||
x.to_torch(dtype, device)
|
x.to_torch(dtype, device)
|
||||||
elif isinstance(x, (np.number, np.bool_, Number)):
|
elif isinstance(x, (np.number, np.bool_, Number)):
|
||||||
x = to_torch(np.asanyarray(x), dtype, device)
|
x = to_torch(np.asanyarray(x), dtype, device)
|
||||||
elif isinstance(x, list) and len(x) > 0 and \
|
elif isinstance(x, (list, tuple)):
|
||||||
all(isinstance(e, (np.number, np.bool_, Number)) for e in x):
|
try:
|
||||||
x = to_torch(np.asanyarray(x), dtype, device)
|
x = to_torch(_parse_value(x), dtype, device)
|
||||||
elif isinstance(x, np.ndarray) and \
|
except TypeError:
|
||||||
isinstance(x.item(0), (np.number, np.bool_, Number)):
|
x = [to_torch(e, dtype, device) for e in x]
|
||||||
x = torch.from_numpy(x).to(device)
|
else: # fallback
|
||||||
if dtype is not None:
|
x = np.asanyarray(x)
|
||||||
x = x.type(dtype)
|
if issubclass(x.dtype.type, (np.bool_, np.number)):
|
||||||
|
x = torch.from_numpy(x).to(device)
|
||||||
|
if dtype is not None:
|
||||||
|
x = x.type(dtype)
|
||||||
|
else:
|
||||||
|
raise TypeError(f"object {x} cannot be converted to torch.")
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user