Fix to_torch converters (#111)

* Fix to_torch converters.

* to_torch now convert any object Torch Tensor-compatible.

* Fix linter.

* Fix Batch to_torch to convert any Torch Tensor-compatible data.

Co-authored-by: Alexis Duburcq <alexis.duburcq@wandercraft.eu>
This commit is contained in:
Alexis DUBURCQ 2020-07-07 12:40:55 +02:00 committed by GitHub
parent 8913bf36b1
commit 69caf89908
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 22 additions and 11 deletions

View File

@ -149,9 +149,9 @@ def test_batch_cat_and_stack():
def test_batch_over_batch_to_torch(): def test_batch_over_batch_to_torch():
batch = Batch( batch = Batch(
a=np.ones((1,), dtype=np.float64), a=np.float64(1.0),
b=Batch( b=Batch(
c=np.ones((1,), dtype=np.float64), c=np.ones((1,), dtype=np.float32),
d=torch.ones((1,), dtype=torch.float64) d=torch.ones((1,), dtype=torch.float64)
) )
) )
@ -160,7 +160,7 @@ def test_batch_over_batch_to_torch():
assert isinstance(batch.b.c, torch.Tensor) assert isinstance(batch.b.c, torch.Tensor)
assert isinstance(batch.b.d, torch.Tensor) assert isinstance(batch.b.d, torch.Tensor)
assert batch.a.dtype == torch.float64 assert batch.a.dtype == torch.float64
assert batch.b.c.dtype == torch.float64 assert batch.b.c.dtype == torch.float32
assert batch.b.d.dtype == torch.float64 assert batch.b.d.dtype == torch.float64
batch.to_torch(dtype=torch.float32) batch.to_torch(dtype=torch.float32)
assert batch.a.dtype == torch.float32 assert batch.a.dtype == torch.float32
@ -170,9 +170,9 @@ def test_batch_over_batch_to_torch():
def test_utils_to_torch(): def test_utils_to_torch():
batch = Batch( batch = Batch(
a=np.ones((1,), dtype=np.float64), a=np.float64(1.0),
b=Batch( b=Batch(
c=np.ones((1,), dtype=np.float64), c=np.ones((1,), dtype=np.float32),
d=torch.ones((1,), dtype=torch.float64) d=torch.ones((1,), dtype=torch.float64)
) )
) )
@ -184,6 +184,8 @@ def test_utils_to_torch():
assert batch_torch_float.a.dtype == torch.float32 assert batch_torch_float.a.dtype == torch.float32
assert batch_torch_float.b.c.dtype == torch.float32 assert batch_torch_float.b.c.dtype == torch.float32
assert batch_torch_float.b.d.dtype == torch.float32 assert batch_torch_float.b.d.dtype == torch.float32
array_list = [float('nan'), 1.0]
assert to_torch(array_list).dtype == torch.float64
def test_batch_pickle(): def test_batch_pickle():

View File

@ -446,12 +446,14 @@ class Batch:
device = torch.device(device) device = torch.device(device)
for k, v in self.items(): for k, v in self.items():
if isinstance(v, (np.generic, np.ndarray)): if isinstance(v, (np.number, np.bool_, Number, np.ndarray)):
if isinstance(v, (np.number, np.bool_, Number)):
v = np.asanyarray(v)
v = torch.from_numpy(v).to(device) v = torch.from_numpy(v).to(device)
if dtype is not None: if dtype is not None:
v = v.type(dtype) v = v.type(dtype)
self.__dict__[k] = v self.__dict__[k] = v
if isinstance(v, torch.Tensor): elif isinstance(v, torch.Tensor):
if dtype is not None and v.dtype != dtype: if dtype is not None and v.dtype != dtype:
must_update_tensor = True must_update_tensor = True
elif v.device.type != device.type: elif v.device.type != device.type:

View File

@ -1,5 +1,6 @@
import torch import torch
import numpy as np import numpy as np
from numbers import Number
from typing import Union, Optional from typing import Union, Optional
from tianshou.data import Batch from tianshou.data import Batch
@ -24,10 +25,6 @@ def to_torch(x: Union[torch.Tensor, dict, Batch, np.ndarray],
device: Union[str, int, torch.device] = 'cpu' device: Union[str, int, torch.device] = 'cpu'
) -> Union[dict, Batch, torch.Tensor]: ) -> Union[dict, Batch, torch.Tensor]:
"""Return an object without np.ndarray.""" """Return an object without np.ndarray."""
if isinstance(x, np.ndarray):
x = torch.from_numpy(x).to(device)
if dtype is not None:
x = x.type(dtype)
if isinstance(x, torch.Tensor): if isinstance(x, torch.Tensor):
if dtype is not None: if dtype is not None:
x = x.type(dtype) x = x.type(dtype)
@ -37,6 +34,16 @@ def to_torch(x: Union[torch.Tensor, dict, Batch, np.ndarray],
x[k] = to_torch(v, dtype, device) x[k] = to_torch(v, dtype, device)
elif isinstance(x, Batch): elif isinstance(x, Batch):
x.to_torch(dtype, device) x.to_torch(dtype, device)
elif isinstance(x, (np.number, np.bool_, Number)):
x = to_torch(np.asanyarray(x), dtype, device)
elif isinstance(x, list) and len(x) > 0 and \
isinstance(x[0], (np.number, np.bool_, Number)):
x = to_torch(np.asanyarray(x), dtype, device)
elif isinstance(x, np.ndarray) and \
isinstance(x.item(0), (np.number, np.bool_, Number)):
x = torch.from_numpy(x).to(device)
if dtype is not None:
x = x.type(dtype)
return x return x