Add non in-place version of Batch.to_torch
(#1117)
Closes: https://github.com/aai-institute/tianshou/issues/1116 ### API Extensions - Batch received new method: `to_torch_`. #1117 ### Breaking Changes - The method `to_torch` in `data.utils.batch.Batch` is not in-place anymore. Instead, a new method `to_torch_` does the conversion in-place. #1117
This commit is contained in:
parent
ca4f74f40e
commit
6935a111d9
@ -475,12 +475,12 @@ Miscellaneous Notes
|
|||||||
.. raw:: html
|
.. raw:: html
|
||||||
|
|
||||||
<details>
|
<details>
|
||||||
<summary>Batch.to_torch and Batch.to_numpy</summary>
|
<summary>Batch.to_torch_ and Batch.to_numpy_</summary>
|
||||||
|
|
||||||
::
|
::
|
||||||
|
|
||||||
>>> data = Batch(a=np.zeros((3, 4)))
|
>>> data = Batch(a=np.zeros((3, 4)))
|
||||||
>>> data.to_torch(dtype=torch.float32, device='cpu')
|
>>> data.to_torch_(dtype=torch.float32, device='cpu')
|
||||||
>>> print(data.a)
|
>>> print(data.a)
|
||||||
tensor([[0., 0., 0., 0.],
|
tensor([[0., 0., 0., 0.],
|
||||||
[0., 0., 0., 0.],
|
[0., 0., 0., 0.],
|
||||||
|
@ -333,7 +333,7 @@
|
|||||||
"source": [
|
"source": [
|
||||||
"batch_cat.to_numpy_()\n",
|
"batch_cat.to_numpy_()\n",
|
||||||
"print(batch_cat)\n",
|
"print(batch_cat)\n",
|
||||||
"batch_cat.to_torch()\n",
|
"batch_cat.to_torch_()\n",
|
||||||
"print(batch_cat)"
|
"print(batch_cat)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
@ -379,7 +379,7 @@ def test_batch_over_batch_to_torch() -> None:
|
|||||||
b=Batch(c=np.ones((1,), dtype=np.float32), d=torch.ones((1,), dtype=torch.float64)),
|
b=Batch(c=np.ones((1,), dtype=np.float32), d=torch.ones((1,), dtype=torch.float64)),
|
||||||
)
|
)
|
||||||
batch.b.__dict__["e"] = 1 # bypass the check
|
batch.b.__dict__["e"] = 1 # bypass the check
|
||||||
batch.to_torch()
|
batch.to_torch_()
|
||||||
assert isinstance(batch.a, torch.Tensor)
|
assert isinstance(batch.a, torch.Tensor)
|
||||||
assert isinstance(batch.b.c, torch.Tensor)
|
assert isinstance(batch.b.c, torch.Tensor)
|
||||||
assert isinstance(batch.b.d, torch.Tensor)
|
assert isinstance(batch.b.d, torch.Tensor)
|
||||||
@ -391,7 +391,7 @@ def test_batch_over_batch_to_torch() -> None:
|
|||||||
assert batch.b.e.dtype == torch.int32
|
assert batch.b.e.dtype == torch.int32
|
||||||
else:
|
else:
|
||||||
assert batch.b.e.dtype == torch.int64
|
assert batch.b.e.dtype == torch.int64
|
||||||
batch.to_torch(dtype=torch.float32)
|
batch.to_torch_(dtype=torch.float32)
|
||||||
assert batch.a.dtype == torch.float32
|
assert batch.a.dtype == torch.float32
|
||||||
assert batch.b.c.dtype == torch.float32
|
assert batch.b.c.dtype == torch.float32
|
||||||
assert batch.b.d.dtype == torch.float32
|
assert batch.b.d.dtype == torch.float32
|
||||||
@ -477,7 +477,7 @@ def test_batch_from_to_numpy_without_copy() -> None:
|
|||||||
batch = Batch(a=np.ones((1,)), b=Batch(c=np.ones((1,))))
|
batch = Batch(a=np.ones((1,)), b=Batch(c=np.ones((1,))))
|
||||||
a_mem_addr_orig = batch.a.__array_interface__["data"][0]
|
a_mem_addr_orig = batch.a.__array_interface__["data"][0]
|
||||||
c_mem_addr_orig = batch.b.c.__array_interface__["data"][0]
|
c_mem_addr_orig = batch.b.c.__array_interface__["data"][0]
|
||||||
batch.to_torch()
|
batch.to_torch_()
|
||||||
batch.to_numpy_()
|
batch.to_numpy_()
|
||||||
a_mem_addr_new = batch.a.__array_interface__["data"][0]
|
a_mem_addr_new = batch.a.__array_interface__["data"][0]
|
||||||
c_mem_addr_new = batch.b.c.__array_interface__["data"][0]
|
c_mem_addr_new = batch.b.c.__array_interface__["data"][0]
|
||||||
@ -727,6 +727,30 @@ class TestToNumpy:
|
|||||||
assert isinstance(batch.c.d, np.ndarray)
|
assert isinstance(batch.c.d, np.ndarray)
|
||||||
|
|
||||||
|
|
||||||
|
class TestToTorch:
|
||||||
|
"""Tests for `Batch.to_torch()` and its in-place counterpart `Batch.to_torch_()` ."""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def test_to_torch() -> None:
|
||||||
|
batch = Batch(a=1, b=np.arange(5), c={"d": np.array([1, 2, 3])})
|
||||||
|
new_batch: Batch = Batch.to_torch(batch)
|
||||||
|
assert id(batch) != id(new_batch)
|
||||||
|
assert isinstance(batch.b, np.ndarray)
|
||||||
|
assert isinstance(batch.c.d, np.ndarray)
|
||||||
|
|
||||||
|
assert isinstance(new_batch.b, torch.Tensor)
|
||||||
|
assert isinstance(new_batch.c.d, torch.Tensor)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def test_to_torch_() -> None:
|
||||||
|
batch = Batch(a=1, b=np.arange(5), c={"d": np.array([1, 2, 3])})
|
||||||
|
id_batch = id(batch)
|
||||||
|
batch.to_torch_()
|
||||||
|
assert id_batch == id(batch)
|
||||||
|
assert isinstance(batch.b, torch.Tensor)
|
||||||
|
assert isinstance(batch.c.d, torch.Tensor)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_batch()
|
test_batch()
|
||||||
test_batch_over_batch()
|
test_batch_over_batch()
|
||||||
|
@ -281,7 +281,16 @@ class BatchProtocol(Protocol):
|
|||||||
"""Change all torch.Tensor to numpy.ndarray in-place."""
|
"""Change all torch.Tensor to numpy.ndarray in-place."""
|
||||||
...
|
...
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
def to_torch(
|
def to_torch(
|
||||||
|
batch: TBatch,
|
||||||
|
dtype: torch.dtype | None = None,
|
||||||
|
device: str | int | torch.device = "cpu",
|
||||||
|
) -> TBatch:
|
||||||
|
"""Change all numpy.ndarray to torch.Tensor and return a new Batch."""
|
||||||
|
...
|
||||||
|
|
||||||
|
def to_torch_(
|
||||||
self,
|
self,
|
||||||
dtype: torch.dtype | None = None,
|
dtype: torch.dtype | None = None,
|
||||||
device: str | int | torch.device = "cpu",
|
device: str | int | torch.device = "cpu",
|
||||||
@ -641,7 +650,18 @@ class Batch(BatchProtocol):
|
|||||||
elif isinstance(obj, Batch):
|
elif isinstance(obj, Batch):
|
||||||
obj.to_numpy_()
|
obj.to_numpy_()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
def to_torch(
|
def to_torch(
|
||||||
|
batch: TBatch,
|
||||||
|
dtype: torch.dtype | None = None,
|
||||||
|
device: str | int | torch.device = "cpu",
|
||||||
|
) -> TBatch:
|
||||||
|
new_batch = Batch(batch, copy=True)
|
||||||
|
new_batch.to_torch_(dtype=dtype, device=device)
|
||||||
|
|
||||||
|
return new_batch # type: ignore[return-value]
|
||||||
|
|
||||||
|
def to_torch_(
|
||||||
self,
|
self,
|
||||||
dtype: torch.dtype | None = None,
|
dtype: torch.dtype | None = None,
|
||||||
device: str | int | torch.device = "cpu",
|
device: str | int | torch.device = "cpu",
|
||||||
@ -662,7 +682,7 @@ class Batch(BatchProtocol):
|
|||||||
else:
|
else:
|
||||||
self.__dict__[batch_key] = obj.to(device)
|
self.__dict__[batch_key] = obj.to(device)
|
||||||
elif isinstance(obj, Batch):
|
elif isinstance(obj, Batch):
|
||||||
obj.to_torch(dtype, device)
|
obj.to_torch_(dtype, device)
|
||||||
else:
|
else:
|
||||||
# ndarray or scalar
|
# ndarray or scalar
|
||||||
if not isinstance(obj, np.ndarray):
|
if not isinstance(obj, np.ndarray):
|
||||||
|
@ -57,7 +57,7 @@ def to_torch(
|
|||||||
return to_torch(np.asanyarray(x), dtype, device)
|
return to_torch(np.asanyarray(x), dtype, device)
|
||||||
if isinstance(x, dict | Batch):
|
if isinstance(x, dict | Batch):
|
||||||
x = Batch(x, copy=True) if isinstance(x, dict) else deepcopy(x)
|
x = Batch(x, copy=True) if isinstance(x, dict) else deepcopy(x)
|
||||||
x.to_torch(dtype, device)
|
x.to_torch_(dtype, device)
|
||||||
return x
|
return x
|
||||||
if isinstance(x, list | tuple):
|
if isinstance(x, list | tuple):
|
||||||
return to_torch(_parse_value(x), dtype, device)
|
return to_torch(_parse_value(x), dtype, device)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user