Fix to_torch converters (#111)

* Fix to_torch converters.

* to_torch now convert any object Torch Tensor-compatible.

* Fix linter.

* Fix Batch to_torch to convert any Torch Tensor-compatible data.

Co-authored-by: Alexis Duburcq <alexis.duburcq@wandercraft.eu>
This commit is contained in:
Alexis DUBURCQ 2020-07-07 12:40:55 +02:00 committed by GitHub
parent 8913bf36b1
commit 69caf89908
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 22 additions and 11 deletions

View File

@ -149,9 +149,9 @@ def test_batch_cat_and_stack():
def test_batch_over_batch_to_torch():
batch = Batch(
a=np.ones((1,), dtype=np.float64),
a=np.float64(1.0),
b=Batch(
c=np.ones((1,), dtype=np.float64),
c=np.ones((1,), dtype=np.float32),
d=torch.ones((1,), dtype=torch.float64)
)
)
@ -160,7 +160,7 @@ def test_batch_over_batch_to_torch():
assert isinstance(batch.b.c, torch.Tensor)
assert isinstance(batch.b.d, torch.Tensor)
assert batch.a.dtype == torch.float64
assert batch.b.c.dtype == torch.float64
assert batch.b.c.dtype == torch.float32
assert batch.b.d.dtype == torch.float64
batch.to_torch(dtype=torch.float32)
assert batch.a.dtype == torch.float32
@ -170,9 +170,9 @@ def test_batch_over_batch_to_torch():
def test_utils_to_torch():
batch = Batch(
a=np.ones((1,), dtype=np.float64),
a=np.float64(1.0),
b=Batch(
c=np.ones((1,), dtype=np.float64),
c=np.ones((1,), dtype=np.float32),
d=torch.ones((1,), dtype=torch.float64)
)
)
@ -184,6 +184,8 @@ def test_utils_to_torch():
assert batch_torch_float.a.dtype == torch.float32
assert batch_torch_float.b.c.dtype == torch.float32
assert batch_torch_float.b.d.dtype == torch.float32
array_list = [float('nan'), 1.0]
assert to_torch(array_list).dtype == torch.float64
def test_batch_pickle():

View File

@ -446,12 +446,14 @@ class Batch:
device = torch.device(device)
for k, v in self.items():
if isinstance(v, (np.generic, np.ndarray)):
if isinstance(v, (np.number, np.bool_, Number, np.ndarray)):
if isinstance(v, (np.number, np.bool_, Number)):
v = np.asanyarray(v)
v = torch.from_numpy(v).to(device)
if dtype is not None:
v = v.type(dtype)
self.__dict__[k] = v
if isinstance(v, torch.Tensor):
elif isinstance(v, torch.Tensor):
if dtype is not None and v.dtype != dtype:
must_update_tensor = True
elif v.device.type != device.type:

View File

@ -1,5 +1,6 @@
import torch
import numpy as np
from numbers import Number
from typing import Union, Optional
from tianshou.data import Batch
@ -24,10 +25,6 @@ def to_torch(x: Union[torch.Tensor, dict, Batch, np.ndarray],
device: Union[str, int, torch.device] = 'cpu'
) -> Union[dict, Batch, torch.Tensor]:
"""Return an object without np.ndarray."""
if isinstance(x, np.ndarray):
x = torch.from_numpy(x).to(device)
if dtype is not None:
x = x.type(dtype)
if isinstance(x, torch.Tensor):
if dtype is not None:
x = x.type(dtype)
@ -37,6 +34,16 @@ def to_torch(x: Union[torch.Tensor, dict, Batch, np.ndarray],
x[k] = to_torch(v, dtype, device)
elif isinstance(x, Batch):
x.to_torch(dtype, device)
elif isinstance(x, (np.number, np.bool_, Number)):
x = to_torch(np.asanyarray(x), dtype, device)
elif isinstance(x, list) and len(x) > 0 and \
isinstance(x[0], (np.number, np.bool_, Number)):
x = to_torch(np.asanyarray(x), dtype, device)
elif isinstance(x, np.ndarray) and \
isinstance(x.item(0), (np.number, np.bool_, Number)):
x = torch.from_numpy(x).to(device)
if dtype is not None:
x = x.type(dtype)
return x