Fix to_torch converters (#111)
* Fix to_torch converters. * to_torch now convert any object Torch Tensor-compatible. * Fix linter. * Fix Batch to_torch to convert any Torch Tensor-compatible data. Co-authored-by: Alexis Duburcq <alexis.duburcq@wandercraft.eu>
This commit is contained in:
parent
8913bf36b1
commit
69caf89908
@ -149,9 +149,9 @@ def test_batch_cat_and_stack():
|
|||||||
|
|
||||||
def test_batch_over_batch_to_torch():
|
def test_batch_over_batch_to_torch():
|
||||||
batch = Batch(
|
batch = Batch(
|
||||||
a=np.ones((1,), dtype=np.float64),
|
a=np.float64(1.0),
|
||||||
b=Batch(
|
b=Batch(
|
||||||
c=np.ones((1,), dtype=np.float64),
|
c=np.ones((1,), dtype=np.float32),
|
||||||
d=torch.ones((1,), dtype=torch.float64)
|
d=torch.ones((1,), dtype=torch.float64)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@ -160,7 +160,7 @@ def test_batch_over_batch_to_torch():
|
|||||||
assert isinstance(batch.b.c, torch.Tensor)
|
assert isinstance(batch.b.c, torch.Tensor)
|
||||||
assert isinstance(batch.b.d, torch.Tensor)
|
assert isinstance(batch.b.d, torch.Tensor)
|
||||||
assert batch.a.dtype == torch.float64
|
assert batch.a.dtype == torch.float64
|
||||||
assert batch.b.c.dtype == torch.float64
|
assert batch.b.c.dtype == torch.float32
|
||||||
assert batch.b.d.dtype == torch.float64
|
assert batch.b.d.dtype == torch.float64
|
||||||
batch.to_torch(dtype=torch.float32)
|
batch.to_torch(dtype=torch.float32)
|
||||||
assert batch.a.dtype == torch.float32
|
assert batch.a.dtype == torch.float32
|
||||||
@ -170,9 +170,9 @@ def test_batch_over_batch_to_torch():
|
|||||||
|
|
||||||
def test_utils_to_torch():
|
def test_utils_to_torch():
|
||||||
batch = Batch(
|
batch = Batch(
|
||||||
a=np.ones((1,), dtype=np.float64),
|
a=np.float64(1.0),
|
||||||
b=Batch(
|
b=Batch(
|
||||||
c=np.ones((1,), dtype=np.float64),
|
c=np.ones((1,), dtype=np.float32),
|
||||||
d=torch.ones((1,), dtype=torch.float64)
|
d=torch.ones((1,), dtype=torch.float64)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@ -184,6 +184,8 @@ def test_utils_to_torch():
|
|||||||
assert batch_torch_float.a.dtype == torch.float32
|
assert batch_torch_float.a.dtype == torch.float32
|
||||||
assert batch_torch_float.b.c.dtype == torch.float32
|
assert batch_torch_float.b.c.dtype == torch.float32
|
||||||
assert batch_torch_float.b.d.dtype == torch.float32
|
assert batch_torch_float.b.d.dtype == torch.float32
|
||||||
|
array_list = [float('nan'), 1.0]
|
||||||
|
assert to_torch(array_list).dtype == torch.float64
|
||||||
|
|
||||||
|
|
||||||
def test_batch_pickle():
|
def test_batch_pickle():
|
||||||
|
@ -446,12 +446,14 @@ class Batch:
|
|||||||
device = torch.device(device)
|
device = torch.device(device)
|
||||||
|
|
||||||
for k, v in self.items():
|
for k, v in self.items():
|
||||||
if isinstance(v, (np.generic, np.ndarray)):
|
if isinstance(v, (np.number, np.bool_, Number, np.ndarray)):
|
||||||
|
if isinstance(v, (np.number, np.bool_, Number)):
|
||||||
|
v = np.asanyarray(v)
|
||||||
v = torch.from_numpy(v).to(device)
|
v = torch.from_numpy(v).to(device)
|
||||||
if dtype is not None:
|
if dtype is not None:
|
||||||
v = v.type(dtype)
|
v = v.type(dtype)
|
||||||
self.__dict__[k] = v
|
self.__dict__[k] = v
|
||||||
if isinstance(v, torch.Tensor):
|
elif isinstance(v, torch.Tensor):
|
||||||
if dtype is not None and v.dtype != dtype:
|
if dtype is not None and v.dtype != dtype:
|
||||||
must_update_tensor = True
|
must_update_tensor = True
|
||||||
elif v.device.type != device.type:
|
elif v.device.type != device.type:
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from numbers import Number
|
||||||
from typing import Union, Optional
|
from typing import Union, Optional
|
||||||
|
|
||||||
from tianshou.data import Batch
|
from tianshou.data import Batch
|
||||||
@ -24,10 +25,6 @@ def to_torch(x: Union[torch.Tensor, dict, Batch, np.ndarray],
|
|||||||
device: Union[str, int, torch.device] = 'cpu'
|
device: Union[str, int, torch.device] = 'cpu'
|
||||||
) -> Union[dict, Batch, torch.Tensor]:
|
) -> Union[dict, Batch, torch.Tensor]:
|
||||||
"""Return an object without np.ndarray."""
|
"""Return an object without np.ndarray."""
|
||||||
if isinstance(x, np.ndarray):
|
|
||||||
x = torch.from_numpy(x).to(device)
|
|
||||||
if dtype is not None:
|
|
||||||
x = x.type(dtype)
|
|
||||||
if isinstance(x, torch.Tensor):
|
if isinstance(x, torch.Tensor):
|
||||||
if dtype is not None:
|
if dtype is not None:
|
||||||
x = x.type(dtype)
|
x = x.type(dtype)
|
||||||
@ -37,6 +34,16 @@ def to_torch(x: Union[torch.Tensor, dict, Batch, np.ndarray],
|
|||||||
x[k] = to_torch(v, dtype, device)
|
x[k] = to_torch(v, dtype, device)
|
||||||
elif isinstance(x, Batch):
|
elif isinstance(x, Batch):
|
||||||
x.to_torch(dtype, device)
|
x.to_torch(dtype, device)
|
||||||
|
elif isinstance(x, (np.number, np.bool_, Number)):
|
||||||
|
x = to_torch(np.asanyarray(x), dtype, device)
|
||||||
|
elif isinstance(x, list) and len(x) > 0 and \
|
||||||
|
isinstance(x[0], (np.number, np.bool_, Number)):
|
||||||
|
x = to_torch(np.asanyarray(x), dtype, device)
|
||||||
|
elif isinstance(x, np.ndarray) and \
|
||||||
|
isinstance(x.item(0), (np.number, np.bool_, Number)):
|
||||||
|
x = torch.from_numpy(x).to(device)
|
||||||
|
if dtype is not None:
|
||||||
|
x = x.type(dtype)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user