Add non in-place version of Batch.to_torch (#1117)

Closes: https://github.com/aai-institute/tianshou/issues/1116

### API Extensions

- Batch received new method: `to_torch_`. #1117

### Breaking Changes

- The method `to_torch` in `data.utils.batch.Batch` is not in-place
anymore. Instead, a new method `to_torch_` does the conversion in-place.
#1117
This commit is contained in:
Daniel Plop 2024-04-17 22:07:24 +02:00 committed by GitHub
parent ca4f74f40e
commit 6935a111d9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 52 additions and 8 deletions

View File

@ -475,12 +475,12 @@ Miscellaneous Notes
.. raw:: html .. raw:: html
<details> <details>
<summary>Batch.to_torch and Batch.to_numpy</summary> <summary>Batch.to_torch_ and Batch.to_numpy_</summary>
:: ::
>>> data = Batch(a=np.zeros((3, 4))) >>> data = Batch(a=np.zeros((3, 4)))
>>> data.to_torch(dtype=torch.float32, device='cpu') >>> data.to_torch_(dtype=torch.float32, device='cpu')
>>> print(data.a) >>> print(data.a)
tensor([[0., 0., 0., 0.], tensor([[0., 0., 0., 0.],
[0., 0., 0., 0.], [0., 0., 0., 0.],

View File

@ -333,7 +333,7 @@
"source": [ "source": [
"batch_cat.to_numpy_()\n", "batch_cat.to_numpy_()\n",
"print(batch_cat)\n", "print(batch_cat)\n",
"batch_cat.to_torch()\n", "batch_cat.to_torch_()\n",
"print(batch_cat)" "print(batch_cat)"
] ]
}, },

View File

@ -379,7 +379,7 @@ def test_batch_over_batch_to_torch() -> None:
b=Batch(c=np.ones((1,), dtype=np.float32), d=torch.ones((1,), dtype=torch.float64)), b=Batch(c=np.ones((1,), dtype=np.float32), d=torch.ones((1,), dtype=torch.float64)),
) )
batch.b.__dict__["e"] = 1 # bypass the check batch.b.__dict__["e"] = 1 # bypass the check
batch.to_torch() batch.to_torch_()
assert isinstance(batch.a, torch.Tensor) assert isinstance(batch.a, torch.Tensor)
assert isinstance(batch.b.c, torch.Tensor) assert isinstance(batch.b.c, torch.Tensor)
assert isinstance(batch.b.d, torch.Tensor) assert isinstance(batch.b.d, torch.Tensor)
@ -391,7 +391,7 @@ def test_batch_over_batch_to_torch() -> None:
assert batch.b.e.dtype == torch.int32 assert batch.b.e.dtype == torch.int32
else: else:
assert batch.b.e.dtype == torch.int64 assert batch.b.e.dtype == torch.int64
batch.to_torch(dtype=torch.float32) batch.to_torch_(dtype=torch.float32)
assert batch.a.dtype == torch.float32 assert batch.a.dtype == torch.float32
assert batch.b.c.dtype == torch.float32 assert batch.b.c.dtype == torch.float32
assert batch.b.d.dtype == torch.float32 assert batch.b.d.dtype == torch.float32
@ -477,7 +477,7 @@ def test_batch_from_to_numpy_without_copy() -> None:
batch = Batch(a=np.ones((1,)), b=Batch(c=np.ones((1,)))) batch = Batch(a=np.ones((1,)), b=Batch(c=np.ones((1,))))
a_mem_addr_orig = batch.a.__array_interface__["data"][0] a_mem_addr_orig = batch.a.__array_interface__["data"][0]
c_mem_addr_orig = batch.b.c.__array_interface__["data"][0] c_mem_addr_orig = batch.b.c.__array_interface__["data"][0]
batch.to_torch() batch.to_torch_()
batch.to_numpy_() batch.to_numpy_()
a_mem_addr_new = batch.a.__array_interface__["data"][0] a_mem_addr_new = batch.a.__array_interface__["data"][0]
c_mem_addr_new = batch.b.c.__array_interface__["data"][0] c_mem_addr_new = batch.b.c.__array_interface__["data"][0]
@ -727,6 +727,30 @@ class TestToNumpy:
assert isinstance(batch.c.d, np.ndarray) assert isinstance(batch.c.d, np.ndarray)
class TestToTorch:
"""Tests for `Batch.to_torch()` and its in-place counterpart `Batch.to_torch_()` ."""
@staticmethod
def test_to_torch() -> None:
batch = Batch(a=1, b=np.arange(5), c={"d": np.array([1, 2, 3])})
new_batch: Batch = Batch.to_torch(batch)
assert id(batch) != id(new_batch)
assert isinstance(batch.b, np.ndarray)
assert isinstance(batch.c.d, np.ndarray)
assert isinstance(new_batch.b, torch.Tensor)
assert isinstance(new_batch.c.d, torch.Tensor)
@staticmethod
def test_to_torch_() -> None:
batch = Batch(a=1, b=np.arange(5), c={"d": np.array([1, 2, 3])})
id_batch = id(batch)
batch.to_torch_()
assert id_batch == id(batch)
assert isinstance(batch.b, torch.Tensor)
assert isinstance(batch.c.d, torch.Tensor)
if __name__ == "__main__": if __name__ == "__main__":
test_batch() test_batch()
test_batch_over_batch() test_batch_over_batch()

View File

@ -281,7 +281,16 @@ class BatchProtocol(Protocol):
"""Change all torch.Tensor to numpy.ndarray in-place.""" """Change all torch.Tensor to numpy.ndarray in-place."""
... ...
@staticmethod
def to_torch( def to_torch(
batch: TBatch,
dtype: torch.dtype | None = None,
device: str | int | torch.device = "cpu",
) -> TBatch:
"""Change all numpy.ndarray to torch.Tensor and return a new Batch."""
...
def to_torch_(
self, self,
dtype: torch.dtype | None = None, dtype: torch.dtype | None = None,
device: str | int | torch.device = "cpu", device: str | int | torch.device = "cpu",
@ -641,7 +650,18 @@ class Batch(BatchProtocol):
elif isinstance(obj, Batch): elif isinstance(obj, Batch):
obj.to_numpy_() obj.to_numpy_()
@staticmethod
def to_torch( def to_torch(
batch: TBatch,
dtype: torch.dtype | None = None,
device: str | int | torch.device = "cpu",
) -> TBatch:
new_batch = Batch(batch, copy=True)
new_batch.to_torch_(dtype=dtype, device=device)
return new_batch # type: ignore[return-value]
def to_torch_(
self, self,
dtype: torch.dtype | None = None, dtype: torch.dtype | None = None,
device: str | int | torch.device = "cpu", device: str | int | torch.device = "cpu",
@ -662,7 +682,7 @@ class Batch(BatchProtocol):
else: else:
self.__dict__[batch_key] = obj.to(device) self.__dict__[batch_key] = obj.to(device)
elif isinstance(obj, Batch): elif isinstance(obj, Batch):
obj.to_torch(dtype, device) obj.to_torch_(dtype, device)
else: else:
# ndarray or scalar # ndarray or scalar
if not isinstance(obj, np.ndarray): if not isinstance(obj, np.ndarray):

View File

@ -57,7 +57,7 @@ def to_torch(
return to_torch(np.asanyarray(x), dtype, device) return to_torch(np.asanyarray(x), dtype, device)
if isinstance(x, dict | Batch): if isinstance(x, dict | Batch):
x = Batch(x, copy=True) if isinstance(x, dict) else deepcopy(x) x = Batch(x, copy=True) if isinstance(x, dict) else deepcopy(x)
x.to_torch(dtype, device) x.to_torch_(dtype, device)
return x return x
if isinstance(x, list | tuple): if isinstance(x, list | tuple):
return to_torch(_parse_value(x), dtype, device) return to_torch(_parse_value(x), dtype, device)