Alexis DUBURCQ 1fce527c77
Fix 'to_tensor' dtype/device forwarding for Batch over Batch. (#68)
* Fix Batch to_torch method not updating dtype/device of already converted data.

* Fix dtype/device to forwarded by to_tensor for Batch over Batch.

* Add Unit test to check to_torch dtype/device recursive forwarding.

* Batch UT check accessing data using both dict and class style.

* Fix utils to_tensor dtype/device forwarding. Add Unit tests.

* Fix UT.

Co-authored-by: Alexis Duburcq <alexis.duburcq@wandercraft.eu>
Co-authored-by: n+e <463003665@qq.com>
2020-05-30 21:40:31 +08:00

41 lines
1.2 KiB
Python

import torch
import numpy as np
from typing import Union, Optional
from tianshou.data import Batch
def to_numpy(x: Union[
torch.Tensor, dict, Batch, np.ndarray]) -> Union[
dict, Batch, np.ndarray]:
"""Return an object without torch.Tensor."""
if isinstance(x, torch.Tensor):
x = x.detach().cpu().numpy()
elif isinstance(x, dict):
for k, v in x.items():
x[k] = to_numpy(v)
elif isinstance(x, Batch):
x.to_numpy()
return x
def to_torch(x: Union[torch.Tensor, dict, Batch, np.ndarray],
dtype: Optional[torch.dtype] = None,
device: Union[str, int] = 'cpu'
) -> Union[dict, Batch, torch.Tensor]:
"""Return an object without np.ndarray."""
if isinstance(x, np.ndarray):
x = torch.from_numpy(x).to(device)
if dtype is not None:
x = x.type(dtype)
if isinstance(x, torch.Tensor):
if dtype is not None:
x = x.type(dtype)
x = x.to(device)
elif isinstance(x, dict):
for k, v in x.items():
x[k] = to_torch(v, dtype, device)
elif isinstance(x, Batch):
x.to_torch(dtype, device)
return x