Alexis DUBURCQ 865ef6c693
Improve to_torch/to_numpy converters (#147)
* Enable converting list/tuple back and forth from/to numpy/torch.

* Add fallbacks.

* Fix PEP8

* Update unit tests.

* Type annotation. Robust dtype check.

* List of object are converted individually, as a single tensor otherwise.

* Improve robustness of _to_array_with_correct_type

* Add unit tests.

* Do not catch exception at _to_array_with_correct_type level.

* Use _parse_value

* Fix PEP8

* Fix _parse_value list output type fallback.

* Catch torch exception.

* Do not convert torch tensor during fallback.

* Improve unit tests.

* Add unit tests.

* FIx missing import

* Remove support of numpy arrays of tensors for Batch value parser.

* Forbid numpy arrays of tensors.

* Fix PEP8.

* Fix comment.

* Reduce _parse_value branch number.

* Fix None value.

* Forward error message for debugging purpose.

* Fix _is_scalar.

* More specific try/catch blocks.

* Fix exception chaining.

* Fix PEP8.

* Fix _is_scalar.

* Fix missing corner case.

* Fix PEP8.

* Allow Batch empty key.

* Fix multi-dim array datatype check.

Co-authored-by: Alexis Duburcq <alexis.duburcq@wandercraft.eu>
2020-07-21 16:47:56 +08:00

70 lines
2.3 KiB
Python

import torch
import numpy as np
from numbers import Number
from typing import Union, Optional
from tianshou.data.batch import _parse_value, Batch
def to_numpy(x: Union[
Batch, dict, list, tuple, np.ndarray, torch.Tensor]) -> Union[
Batch, dict, list, tuple, np.ndarray, torch.Tensor]:
"""Return an object without torch.Tensor."""
if isinstance(x, torch.Tensor):
x = x.detach().cpu().numpy()
elif isinstance(x, dict):
for k, v in x.items():
x[k] = to_numpy(v)
elif isinstance(x, Batch):
x.to_numpy()
elif isinstance(x, (list, tuple)):
try:
x = to_numpy(_parse_value(x))
except TypeError:
x = [to_numpy(e) for e in x]
else: # fallback
x = np.asanyarray(x)
return x
def to_torch(x: Union[Batch, dict, list, tuple, np.ndarray, torch.Tensor],
dtype: Optional[torch.dtype] = None,
device: Union[str, int, torch.device] = 'cpu'
) -> Union[Batch, dict, list, tuple, np.ndarray, torch.Tensor]:
"""Return an object without np.ndarray."""
if isinstance(x, torch.Tensor):
if dtype is not None:
x = x.type(dtype)
x = x.to(device)
elif isinstance(x, dict):
for k, v in x.items():
x[k] = to_torch(v, dtype, device)
elif isinstance(x, Batch):
x.to_torch(dtype, device)
elif isinstance(x, (np.number, np.bool_, Number)):
x = to_torch(np.asanyarray(x), dtype, device)
elif isinstance(x, (list, tuple)):
try:
x = to_torch(_parse_value(x), dtype, device)
except TypeError:
x = [to_torch(e, dtype, device) for e in x]
else: # fallback
x = np.asanyarray(x)
if issubclass(x.dtype.type, (np.bool_, np.number)):
x = torch.from_numpy(x).to(device)
if dtype is not None:
x = x.type(dtype)
else:
raise TypeError(f"object {x} cannot be converted to torch.")
return x
def to_torch_as(x: Union[torch.Tensor, dict, Batch, np.ndarray],
y: torch.Tensor
) -> Union[dict, Batch, torch.Tensor]:
"""Return an object without np.ndarray. Same as
``to_torch(x, dtype=y.dtype, device=y.device)``.
"""
assert isinstance(y, torch.Tensor)
return to_torch(x, dtype=y.dtype, device=y.device)