Add non in-place version of Batch.to_torch
(#1117)
Closes: https://github.com/aai-institute/tianshou/issues/1116 ### API Extensions - Batch received new method: `to_torch_`. #1117 ### Breaking Changes - The method `to_torch` in `data.utils.batch.Batch` is not in-place anymore. Instead, a new method `to_torch_` does the conversion in-place. #1117
This commit is contained in:
parent
ca4f74f40e
commit
6935a111d9
@ -475,12 +475,12 @@ Miscellaneous Notes
|
||||
.. raw:: html
|
||||
|
||||
<details>
|
||||
<summary>Batch.to_torch and Batch.to_numpy</summary>
|
||||
<summary>Batch.to_torch_ and Batch.to_numpy_</summary>
|
||||
|
||||
::
|
||||
|
||||
>>> data = Batch(a=np.zeros((3, 4)))
|
||||
>>> data.to_torch(dtype=torch.float32, device='cpu')
|
||||
>>> data.to_torch_(dtype=torch.float32, device='cpu')
|
||||
>>> print(data.a)
|
||||
tensor([[0., 0., 0., 0.],
|
||||
[0., 0., 0., 0.],
|
||||
|
@ -333,7 +333,7 @@
|
||||
"source": [
|
||||
"batch_cat.to_numpy_()\n",
|
||||
"print(batch_cat)\n",
|
||||
"batch_cat.to_torch()\n",
|
||||
"batch_cat.to_torch_()\n",
|
||||
"print(batch_cat)"
|
||||
]
|
||||
},
|
||||
|
@ -379,7 +379,7 @@ def test_batch_over_batch_to_torch() -> None:
|
||||
b=Batch(c=np.ones((1,), dtype=np.float32), d=torch.ones((1,), dtype=torch.float64)),
|
||||
)
|
||||
batch.b.__dict__["e"] = 1 # bypass the check
|
||||
batch.to_torch()
|
||||
batch.to_torch_()
|
||||
assert isinstance(batch.a, torch.Tensor)
|
||||
assert isinstance(batch.b.c, torch.Tensor)
|
||||
assert isinstance(batch.b.d, torch.Tensor)
|
||||
@ -391,7 +391,7 @@ def test_batch_over_batch_to_torch() -> None:
|
||||
assert batch.b.e.dtype == torch.int32
|
||||
else:
|
||||
assert batch.b.e.dtype == torch.int64
|
||||
batch.to_torch(dtype=torch.float32)
|
||||
batch.to_torch_(dtype=torch.float32)
|
||||
assert batch.a.dtype == torch.float32
|
||||
assert batch.b.c.dtype == torch.float32
|
||||
assert batch.b.d.dtype == torch.float32
|
||||
@ -477,7 +477,7 @@ def test_batch_from_to_numpy_without_copy() -> None:
|
||||
batch = Batch(a=np.ones((1,)), b=Batch(c=np.ones((1,))))
|
||||
a_mem_addr_orig = batch.a.__array_interface__["data"][0]
|
||||
c_mem_addr_orig = batch.b.c.__array_interface__["data"][0]
|
||||
batch.to_torch()
|
||||
batch.to_torch_()
|
||||
batch.to_numpy_()
|
||||
a_mem_addr_new = batch.a.__array_interface__["data"][0]
|
||||
c_mem_addr_new = batch.b.c.__array_interface__["data"][0]
|
||||
@ -727,6 +727,30 @@ class TestToNumpy:
|
||||
assert isinstance(batch.c.d, np.ndarray)
|
||||
|
||||
|
||||
class TestToTorch:
|
||||
"""Tests for `Batch.to_torch()` and its in-place counterpart `Batch.to_torch_()` ."""
|
||||
|
||||
@staticmethod
|
||||
def test_to_torch() -> None:
|
||||
batch = Batch(a=1, b=np.arange(5), c={"d": np.array([1, 2, 3])})
|
||||
new_batch: Batch = Batch.to_torch(batch)
|
||||
assert id(batch) != id(new_batch)
|
||||
assert isinstance(batch.b, np.ndarray)
|
||||
assert isinstance(batch.c.d, np.ndarray)
|
||||
|
||||
assert isinstance(new_batch.b, torch.Tensor)
|
||||
assert isinstance(new_batch.c.d, torch.Tensor)
|
||||
|
||||
@staticmethod
|
||||
def test_to_torch_() -> None:
|
||||
batch = Batch(a=1, b=np.arange(5), c={"d": np.array([1, 2, 3])})
|
||||
id_batch = id(batch)
|
||||
batch.to_torch_()
|
||||
assert id_batch == id(batch)
|
||||
assert isinstance(batch.b, torch.Tensor)
|
||||
assert isinstance(batch.c.d, torch.Tensor)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_batch()
|
||||
test_batch_over_batch()
|
||||
|
@ -281,7 +281,16 @@ class BatchProtocol(Protocol):
|
||||
"""Change all torch.Tensor to numpy.ndarray in-place."""
|
||||
...
|
||||
|
||||
@staticmethod
|
||||
def to_torch(
|
||||
batch: TBatch,
|
||||
dtype: torch.dtype | None = None,
|
||||
device: str | int | torch.device = "cpu",
|
||||
) -> TBatch:
|
||||
"""Change all numpy.ndarray to torch.Tensor and return a new Batch."""
|
||||
...
|
||||
|
||||
def to_torch_(
|
||||
self,
|
||||
dtype: torch.dtype | None = None,
|
||||
device: str | int | torch.device = "cpu",
|
||||
@ -641,7 +650,18 @@ class Batch(BatchProtocol):
|
||||
elif isinstance(obj, Batch):
|
||||
obj.to_numpy_()
|
||||
|
||||
@staticmethod
|
||||
def to_torch(
|
||||
batch: TBatch,
|
||||
dtype: torch.dtype | None = None,
|
||||
device: str | int | torch.device = "cpu",
|
||||
) -> TBatch:
|
||||
new_batch = Batch(batch, copy=True)
|
||||
new_batch.to_torch_(dtype=dtype, device=device)
|
||||
|
||||
return new_batch # type: ignore[return-value]
|
||||
|
||||
def to_torch_(
|
||||
self,
|
||||
dtype: torch.dtype | None = None,
|
||||
device: str | int | torch.device = "cpu",
|
||||
@ -662,7 +682,7 @@ class Batch(BatchProtocol):
|
||||
else:
|
||||
self.__dict__[batch_key] = obj.to(device)
|
||||
elif isinstance(obj, Batch):
|
||||
obj.to_torch(dtype, device)
|
||||
obj.to_torch_(dtype, device)
|
||||
else:
|
||||
# ndarray or scalar
|
||||
if not isinstance(obj, np.ndarray):
|
||||
|
@ -57,7 +57,7 @@ def to_torch(
|
||||
return to_torch(np.asanyarray(x), dtype, device)
|
||||
if isinstance(x, dict | Batch):
|
||||
x = Batch(x, copy=True) if isinstance(x, dict) else deepcopy(x)
|
||||
x.to_torch(dtype, device)
|
||||
x.to_torch_(dtype, device)
|
||||
return x
|
||||
if isinstance(x, list | tuple):
|
||||
return to_torch(_parse_value(x), dtype, device)
|
||||
|
Loading…
x
Reference in New Issue
Block a user