Fix to_torch converters (#111)
* Fix to_torch converters. * to_torch now convert any object Torch Tensor-compatible. * Fix linter. * Fix Batch to_torch to convert any Torch Tensor-compatible data. Co-authored-by: Alexis Duburcq <alexis.duburcq@wandercraft.eu>
This commit is contained in:
parent
8913bf36b1
commit
69caf89908
@ -149,9 +149,9 @@ def test_batch_cat_and_stack():
|
||||
|
||||
def test_batch_over_batch_to_torch():
|
||||
batch = Batch(
|
||||
a=np.ones((1,), dtype=np.float64),
|
||||
a=np.float64(1.0),
|
||||
b=Batch(
|
||||
c=np.ones((1,), dtype=np.float64),
|
||||
c=np.ones((1,), dtype=np.float32),
|
||||
d=torch.ones((1,), dtype=torch.float64)
|
||||
)
|
||||
)
|
||||
@ -160,7 +160,7 @@ def test_batch_over_batch_to_torch():
|
||||
assert isinstance(batch.b.c, torch.Tensor)
|
||||
assert isinstance(batch.b.d, torch.Tensor)
|
||||
assert batch.a.dtype == torch.float64
|
||||
assert batch.b.c.dtype == torch.float64
|
||||
assert batch.b.c.dtype == torch.float32
|
||||
assert batch.b.d.dtype == torch.float64
|
||||
batch.to_torch(dtype=torch.float32)
|
||||
assert batch.a.dtype == torch.float32
|
||||
@ -170,9 +170,9 @@ def test_batch_over_batch_to_torch():
|
||||
|
||||
def test_utils_to_torch():
|
||||
batch = Batch(
|
||||
a=np.ones((1,), dtype=np.float64),
|
||||
a=np.float64(1.0),
|
||||
b=Batch(
|
||||
c=np.ones((1,), dtype=np.float64),
|
||||
c=np.ones((1,), dtype=np.float32),
|
||||
d=torch.ones((1,), dtype=torch.float64)
|
||||
)
|
||||
)
|
||||
@ -184,6 +184,8 @@ def test_utils_to_torch():
|
||||
assert batch_torch_float.a.dtype == torch.float32
|
||||
assert batch_torch_float.b.c.dtype == torch.float32
|
||||
assert batch_torch_float.b.d.dtype == torch.float32
|
||||
array_list = [float('nan'), 1.0]
|
||||
assert to_torch(array_list).dtype == torch.float64
|
||||
|
||||
|
||||
def test_batch_pickle():
|
||||
|
@ -446,12 +446,14 @@ class Batch:
|
||||
device = torch.device(device)
|
||||
|
||||
for k, v in self.items():
|
||||
if isinstance(v, (np.generic, np.ndarray)):
|
||||
if isinstance(v, (np.number, np.bool_, Number, np.ndarray)):
|
||||
if isinstance(v, (np.number, np.bool_, Number)):
|
||||
v = np.asanyarray(v)
|
||||
v = torch.from_numpy(v).to(device)
|
||||
if dtype is not None:
|
||||
v = v.type(dtype)
|
||||
self.__dict__[k] = v
|
||||
if isinstance(v, torch.Tensor):
|
||||
elif isinstance(v, torch.Tensor):
|
||||
if dtype is not None and v.dtype != dtype:
|
||||
must_update_tensor = True
|
||||
elif v.device.type != device.type:
|
||||
|
@ -1,5 +1,6 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from numbers import Number
|
||||
from typing import Union, Optional
|
||||
|
||||
from tianshou.data import Batch
|
||||
@ -24,10 +25,6 @@ def to_torch(x: Union[torch.Tensor, dict, Batch, np.ndarray],
|
||||
device: Union[str, int, torch.device] = 'cpu'
|
||||
) -> Union[dict, Batch, torch.Tensor]:
|
||||
"""Return an object without np.ndarray."""
|
||||
if isinstance(x, np.ndarray):
|
||||
x = torch.from_numpy(x).to(device)
|
||||
if dtype is not None:
|
||||
x = x.type(dtype)
|
||||
if isinstance(x, torch.Tensor):
|
||||
if dtype is not None:
|
||||
x = x.type(dtype)
|
||||
@ -37,6 +34,16 @@ def to_torch(x: Union[torch.Tensor, dict, Batch, np.ndarray],
|
||||
x[k] = to_torch(v, dtype, device)
|
||||
elif isinstance(x, Batch):
|
||||
x.to_torch(dtype, device)
|
||||
elif isinstance(x, (np.number, np.bool_, Number)):
|
||||
x = to_torch(np.asanyarray(x), dtype, device)
|
||||
elif isinstance(x, list) and len(x) > 0 and \
|
||||
isinstance(x[0], (np.number, np.bool_, Number)):
|
||||
x = to_torch(np.asanyarray(x), dtype, device)
|
||||
elif isinstance(x, np.ndarray) and \
|
||||
isinstance(x.item(0), (np.number, np.bool_, Number)):
|
||||
x = torch.from_numpy(x).to(device)
|
||||
if dtype is not None:
|
||||
x = x.type(dtype)
|
||||
return x
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user