Allow two (same/different) Batch objs to be tested for equality (#1098)
Closes: https://github.com/thu-ml/tianshou/issues/1086 ### Api Extensions - Batch received new method: `to_numpy_`. #1098 - `to_dict` in Batch supports also non-recursive conversion. #1098 - Batch `__eq__` now implemented, semantic equality check of batches is now possible. #1098 ### Breaking Changes - The method `to_numpy` in `data.utils.batch.Batch` is not in-place anymore. Instead, a new method `to_numpy_` does the conversion in-place. #1098
This commit is contained in:
parent
049907d9ab
commit
ca4f74f40e
@ -8,6 +8,9 @@
|
|||||||
- Trainers can control whether collectors should be reset prior to training. #1063
|
- Trainers can control whether collectors should be reset prior to training. #1063
|
||||||
- Convenience constructor for `CollectStats` called `with_autogenerated_stats`. #1063
|
- Convenience constructor for `CollectStats` called `with_autogenerated_stats`. #1063
|
||||||
- `SamplingConfig` supports `batch_size=None`. #1077
|
- `SamplingConfig` supports `batch_size=None`. #1077
|
||||||
|
- Batch received new method: `to_numpy_`. #1098
|
||||||
|
- `to_dict` in Batch supports also non-recursive conversion. #1098
|
||||||
|
- Batch __eq__ now implemented, semantic equality check of batches is now possible. #1098
|
||||||
|
|
||||||
### Internal Improvements
|
### Internal Improvements
|
||||||
- `Collector`s rely less on state, the few stateful things are stored explicitly instead of through a `.data` attribute. #1063
|
- `Collector`s rely less on state, the few stateful things are stored explicitly instead of through a `.data` attribute. #1063
|
||||||
@ -34,6 +37,7 @@ expicitly or pass `reset_before_collect=True` . #1063
|
|||||||
- Changed interface of `dist_fn` in `PGPolicy` and all subclasses to take a single argument in both
|
- Changed interface of `dist_fn` in `PGPolicy` and all subclasses to take a single argument in both
|
||||||
continuous and discrete cases. #1032
|
continuous and discrete cases. #1032
|
||||||
- `utils.net.common.Recurrent` now receives and returns a `RecurrentStateBatch` instead of a dict. #1077
|
- `utils.net.common.Recurrent` now receives and returns a `RecurrentStateBatch` instead of a dict. #1077
|
||||||
|
- The method `to_numpy` in `data.utils.batch.Batch` is not in-place anymore. Instead, a new method `to_numpy_` does the conversion in-place. #1098
|
||||||
|
|
||||||
### Tests
|
### Tests
|
||||||
- Fixed env seeding it test_sac_with_il.py so that the test doesn't fail randomly. #1081
|
- Fixed env seeding it test_sac_with_il.py so that the test doesn't fail randomly. #1081
|
||||||
|
@ -485,8 +485,8 @@ Miscellaneous Notes
|
|||||||
tensor([[0., 0., 0., 0.],
|
tensor([[0., 0., 0., 0.],
|
||||||
[0., 0., 0., 0.],
|
[0., 0., 0., 0.],
|
||||||
[0., 0., 0., 0.]])
|
[0., 0., 0., 0.]])
|
||||||
>>> # data.to_numpy is also available
|
>>> # data.to_numpy_ is also available
|
||||||
>>> data.to_numpy()
|
>>> data.to_numpy_()
|
||||||
|
|
||||||
.. raw:: html
|
.. raw:: html
|
||||||
|
|
||||||
|
@ -331,7 +331,7 @@
|
|||||||
},
|
},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"batch_cat.to_numpy()\n",
|
"batch_cat.to_numpy_()\n",
|
||||||
"print(batch_cat)\n",
|
"print(batch_cat)\n",
|
||||||
"batch_cat.to_torch()\n",
|
"batch_cat.to_torch()\n",
|
||||||
"print(batch_cat)"
|
"print(batch_cat)"
|
||||||
|
35
poetry.lock
generated
35
poetry.lock
generated
@ -903,6 +903,24 @@ files = [
|
|||||||
{file = "decorator-5.1.1.tar.gz", hash = "sha256:637996211036b6385ef91435e4fae22989472f9d571faba8927ba8253acbc330"},
|
{file = "decorator-5.1.1.tar.gz", hash = "sha256:637996211036b6385ef91435e4fae22989472f9d571faba8927ba8253acbc330"},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "deepdiff"
|
||||||
|
version = "7.0.1"
|
||||||
|
description = "Deep Difference and Search of any Python object/data. Recreate objects by adding adding deltas to each other."
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.8"
|
||||||
|
files = [
|
||||||
|
{file = "deepdiff-7.0.1-py3-none-any.whl", hash = "sha256:447760081918216aa4fd4ca78a4b6a848b81307b2ea94c810255334b759e1dc3"},
|
||||||
|
{file = "deepdiff-7.0.1.tar.gz", hash = "sha256:260c16f052d4badbf60351b4f77e8390bee03a0b516246f6839bc813fb429ddf"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
ordered-set = ">=4.1.0,<4.2.0"
|
||||||
|
|
||||||
|
[package.extras]
|
||||||
|
cli = ["click (==8.1.7)", "pyyaml (==6.0.1)"]
|
||||||
|
optimize = ["orjson"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "defusedxml"
|
name = "defusedxml"
|
||||||
version = "0.7.1"
|
version = "0.7.1"
|
||||||
@ -3239,7 +3257,6 @@ optional = false
|
|||||||
python-versions = ">=3"
|
python-versions = ">=3"
|
||||||
files = [
|
files = [
|
||||||
{file = "nvidia_nvjitlink_cu12-12.3.101-py3-none-manylinux1_x86_64.whl", hash = "sha256:64335a8088e2b9d196ae8665430bc6a2b7e6ef2eb877a9c735c804bd4ff6467c"},
|
{file = "nvidia_nvjitlink_cu12-12.3.101-py3-none-manylinux1_x86_64.whl", hash = "sha256:64335a8088e2b9d196ae8665430bc6a2b7e6ef2eb877a9c735c804bd4ff6467c"},
|
||||||
{file = "nvidia_nvjitlink_cu12-12.3.101-py3-none-manylinux2014_aarch64.whl", hash = "sha256:211a63e7b30a9d62f1a853e19928fbb1a750e3f17a13a3d1f98ff0ced19478dd"},
|
|
||||||
{file = "nvidia_nvjitlink_cu12-12.3.101-py3-none-win_amd64.whl", hash = "sha256:1b2e317e437433753530792f13eece58f0aec21a2b05903be7bffe58a606cbd1"},
|
{file = "nvidia_nvjitlink_cu12-12.3.101-py3-none-win_amd64.whl", hash = "sha256:1b2e317e437433753530792f13eece58f0aec21a2b05903be7bffe58a606cbd1"},
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -3349,6 +3366,20 @@ numpy = ["numpy"]
|
|||||||
test = ["pytest", "pytest-cov", "pytest-xdist"]
|
test = ["pytest", "pytest-cov", "pytest-xdist"]
|
||||||
torch = ["torch"]
|
torch = ["torch"]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "ordered-set"
|
||||||
|
version = "4.1.0"
|
||||||
|
description = "An OrderedSet is a custom MutableSet that remembers its order, so that every"
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.7"
|
||||||
|
files = [
|
||||||
|
{file = "ordered-set-4.1.0.tar.gz", hash = "sha256:694a8e44c87657c59292ede72891eb91d34131f6531463aab3009191c77364a8"},
|
||||||
|
{file = "ordered_set-4.1.0-py3-none-any.whl", hash = "sha256:046e1132c71fcf3330438a539928932caf51ddbc582496833e23de611de14562"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.extras]
|
||||||
|
dev = ["black", "mypy", "pytest"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "overrides"
|
name = "overrides"
|
||||||
version = "7.4.0"
|
version = "7.4.0"
|
||||||
@ -6223,4 +6254,4 @@ vizdoom = ["vizdoom"]
|
|||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.0"
|
lock-version = "2.0"
|
||||||
python-versions = "^3.11"
|
python-versions = "^3.11"
|
||||||
content-hash = "06b9166b2e752fbab564cbc0dbce226844c26dd2b59f9f7e95104570e377c43b"
|
content-hash = "a7aa80de549e7af1147d14f9bdd48659b7018732af34022cc734565af1f742e9"
|
||||||
|
@ -25,6 +25,7 @@ exclude = ["test/*", "examples/*", "docs/*"]
|
|||||||
|
|
||||||
[tool.poetry.dependencies]
|
[tool.poetry.dependencies]
|
||||||
python = "^3.11"
|
python = "^3.11"
|
||||||
|
deepdiff = "^7.0.1"
|
||||||
gymnasium = "^0.28.0"
|
gymnasium = "^0.28.0"
|
||||||
h5py = "^3.9.0"
|
h5py = "^3.9.0"
|
||||||
numba = "^0.57.1"
|
numba = "^0.57.1"
|
||||||
|
@ -2,12 +2,13 @@ import copy
|
|||||||
import pickle
|
import pickle
|
||||||
import sys
|
import sys
|
||||||
from itertools import starmap
|
from itertools import starmap
|
||||||
from typing import cast
|
from typing import Any, cast
|
||||||
|
|
||||||
import networkx as nx
|
import networkx as nx
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
from deepdiff import DeepDiff
|
||||||
|
|
||||||
from tianshou.data import Batch, to_numpy, to_torch
|
from tianshou.data import Batch, to_numpy, to_torch
|
||||||
|
|
||||||
@ -477,7 +478,7 @@ def test_batch_from_to_numpy_without_copy() -> None:
|
|||||||
a_mem_addr_orig = batch.a.__array_interface__["data"][0]
|
a_mem_addr_orig = batch.a.__array_interface__["data"][0]
|
||||||
c_mem_addr_orig = batch.b.c.__array_interface__["data"][0]
|
c_mem_addr_orig = batch.b.c.__array_interface__["data"][0]
|
||||||
batch.to_torch()
|
batch.to_torch()
|
||||||
batch.to_numpy()
|
batch.to_numpy_()
|
||||||
a_mem_addr_new = batch.a.__array_interface__["data"][0]
|
a_mem_addr_new = batch.a.__array_interface__["data"][0]
|
||||||
c_mem_addr_new = batch.b.c.__array_interface__["data"][0]
|
c_mem_addr_new = batch.b.c.__array_interface__["data"][0]
|
||||||
assert a_mem_addr_new == a_mem_addr_orig
|
assert a_mem_addr_new == a_mem_addr_orig
|
||||||
@ -565,6 +566,167 @@ def test_batch_standard_compatibility() -> None:
|
|||||||
Batch()[0]
|
Batch()[0]
|
||||||
|
|
||||||
|
|
||||||
|
class TestBatchEquality:
|
||||||
|
@staticmethod
|
||||||
|
def test_keys_different() -> None:
|
||||||
|
batch1 = Batch(a=[1, 2], b=[100, 50])
|
||||||
|
batch2 = Batch(b=[1, 2], c=[100, 50])
|
||||||
|
assert batch1 != batch2
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def test_keys_missing() -> None:
|
||||||
|
batch1 = Batch(a=[1, 2], b=[2, 3, 4])
|
||||||
|
batch2 = Batch(a=[1, 2], b=[2, 3, 4])
|
||||||
|
batch2.pop("b")
|
||||||
|
assert batch1 != batch2
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def test_types_keys_different() -> None:
|
||||||
|
batch1 = Batch(a=[1, 2, 3], b=[4, 5])
|
||||||
|
batch2 = Batch(a=[1, 2, 3], b=Batch(a=[4, 5]))
|
||||||
|
assert batch1 != batch2
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def test_array_types_different() -> None:
|
||||||
|
batch1 = Batch(a=[1, 2, 3], b=np.array([4, 5]))
|
||||||
|
batch2 = Batch(a=[1, 2, 3], b=torch.Tensor([4, 5]))
|
||||||
|
assert batch1 != batch2
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def test_nested_values_different() -> None:
|
||||||
|
batch1 = Batch(a=Batch(a=[1, 2, 3]), b=[4, 5])
|
||||||
|
batch2 = Batch(a=Batch(a=[1, 2, 4]), b=[4, 5])
|
||||||
|
assert batch1 != batch2
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def test_nested_shapes_different() -> None:
|
||||||
|
batch1 = Batch(a=Batch(a=[1, 2, 3]), b=[4, 5])
|
||||||
|
batch2 = Batch(a=Batch(a=[1, 4]), b=[4, 5])
|
||||||
|
assert batch1 != batch2
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def test_slice_equal() -> None:
|
||||||
|
batch1 = Batch(a=[1, 2, 3])
|
||||||
|
assert batch1[:2] == batch1[:2]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def test_slice_ellipsis_equal() -> None:
|
||||||
|
batch1 = Batch(a=Batch(a=[1, 2, 3]), b=[4, 5], c=[100, 1001, 2000])
|
||||||
|
assert batch1[..., 1:] == batch1[..., 1:]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def test_empty_batches() -> None:
|
||||||
|
assert Batch() == Batch()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def test_different_order_keys() -> None:
|
||||||
|
assert Batch(a=1, b=2) == Batch(b=2, a=1)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def test_tuple_and_list_types() -> None:
|
||||||
|
assert Batch(a=(1, 2)) == Batch(a=[1, 2])
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def test_subbatch_dict_and_batch_types() -> None:
|
||||||
|
assert Batch(a={"x": 1}) == Batch(a=Batch(x=1))
|
||||||
|
|
||||||
|
|
||||||
|
class TestBatchToDict:
|
||||||
|
@staticmethod
|
||||||
|
def test_to_dict_empty_batch_no_recurse() -> None:
|
||||||
|
batch = Batch()
|
||||||
|
expected: dict[Any, Any] = {}
|
||||||
|
assert batch.to_dict() == expected
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def test_to_dict_with_simple_values_recurse() -> None:
|
||||||
|
batch = Batch(a=1, b="two", c=np.array([3, 4]))
|
||||||
|
expected = {"a": np.asanyarray(1), "b": "two", "c": np.array([3, 4])}
|
||||||
|
assert not DeepDiff(batch.to_dict(recursive=True), expected)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def test_to_dict_simple() -> None:
|
||||||
|
batch = Batch(a=1, b="two")
|
||||||
|
expected = {"a": np.asanyarray(1), "b": "two"}
|
||||||
|
assert batch.to_dict() == expected
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def test_to_dict_nested_batch_no_recurse() -> None:
|
||||||
|
nested_batch = Batch(c=3)
|
||||||
|
batch = Batch(a=1, b=nested_batch)
|
||||||
|
expected = {"a": np.asanyarray(1), "b": nested_batch}
|
||||||
|
assert not DeepDiff(batch.to_dict(recursive=False), expected)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def test_to_dict_nested_batch_recurse() -> None:
|
||||||
|
nested_batch = Batch(c=3)
|
||||||
|
batch = Batch(a=1, b=nested_batch)
|
||||||
|
expected = {"a": np.asanyarray(1), "b": {"c": np.asanyarray(3)}}
|
||||||
|
assert not DeepDiff(batch.to_dict(recursive=True), expected)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def test_to_dict_multiple_nested_batch_recurse() -> None:
|
||||||
|
nested_batch = Batch(c=Batch(e=3), d=[100, 200, 300])
|
||||||
|
batch = Batch(a=1, b=nested_batch)
|
||||||
|
expected = {
|
||||||
|
"a": np.asanyarray(1),
|
||||||
|
"b": {"c": {"e": np.asanyarray(3)}, "d": np.array([100, 200, 300])},
|
||||||
|
}
|
||||||
|
assert not DeepDiff(batch.to_dict(recursive=True), expected)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def test_to_dict_array() -> None:
|
||||||
|
batch = Batch(a=np.array([1, 2, 3]))
|
||||||
|
expected = {"a": np.array([1, 2, 3])}
|
||||||
|
assert not DeepDiff(batch.to_dict(), expected)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def test_to_dict_nested_batch_with_array() -> None:
|
||||||
|
nested_batch = Batch(c=np.array([4, 5]))
|
||||||
|
batch = Batch(a=1, b=nested_batch)
|
||||||
|
expected = {"a": np.asanyarray(1), "b": {"c": np.array([4, 5])}}
|
||||||
|
assert not DeepDiff(batch.to_dict(recursive=True), expected)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def test_to_dict_torch_tensor() -> None:
|
||||||
|
t1 = torch.tensor([1.0, 2.0]).detach().cpu().numpy()
|
||||||
|
batch = Batch(a=t1)
|
||||||
|
t2 = torch.tensor([1.0, 2.0]).detach().cpu().numpy()
|
||||||
|
expected = {"a": t2}
|
||||||
|
assert not DeepDiff(batch.to_dict(), expected)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def test_to_dict_nested_batch_with_torch_tensor() -> None:
|
||||||
|
nested_batch = Batch(c=torch.tensor([4, 5]).detach().cpu().numpy())
|
||||||
|
batch = Batch(a=1, b=nested_batch)
|
||||||
|
expected = {"a": np.asanyarray(1), "b": {"c": torch.tensor([4, 5]).detach().cpu().numpy()}}
|
||||||
|
assert not DeepDiff(batch.to_dict(recursive=True), expected)
|
||||||
|
|
||||||
|
|
||||||
|
class TestToNumpy:
|
||||||
|
"""Tests for `Batch.to_numpy()` and its in-place counterpart `Batch.to_numpy_()` ."""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def test_to_numpy() -> None:
|
||||||
|
batch = Batch(a=1, b=torch.arange(5), c={"d": torch.tensor([1, 2, 3])})
|
||||||
|
new_batch: Batch = Batch.to_numpy(batch)
|
||||||
|
assert id(batch) != id(new_batch)
|
||||||
|
assert isinstance(batch.b, torch.Tensor)
|
||||||
|
assert isinstance(batch.c.d, torch.Tensor)
|
||||||
|
|
||||||
|
assert isinstance(new_batch.b, np.ndarray)
|
||||||
|
assert isinstance(new_batch.c.d, np.ndarray)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def test_to_numpy_() -> None:
|
||||||
|
batch = Batch(a=1, b=torch.arange(5), c={"d": torch.tensor([1, 2, 3])})
|
||||||
|
id_batch = id(batch)
|
||||||
|
batch.to_numpy_()
|
||||||
|
assert id_batch == id(batch)
|
||||||
|
assert isinstance(batch.b, np.ndarray)
|
||||||
|
assert isinstance(batch.c.d, np.ndarray)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_batch()
|
test_batch()
|
||||||
test_batch_over_batch()
|
test_batch_over_batch()
|
||||||
|
@ -17,6 +17,7 @@ from typing import (
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
from deepdiff import DeepDiff
|
||||||
|
|
||||||
_SingleIndexType = slice | int | EllipsisType
|
_SingleIndexType = slice | int | EllipsisType
|
||||||
IndexType = np.ndarray | _SingleIndexType | list[_SingleIndexType] | tuple[_SingleIndexType, ...]
|
IndexType = np.ndarray | _SingleIndexType | list[_SingleIndexType] | tuple[_SingleIndexType, ...]
|
||||||
@ -268,7 +269,15 @@ class BatchProtocol(Protocol):
|
|||||||
def __iter__(self) -> Iterator[Self]:
|
def __iter__(self) -> Iterator[Self]:
|
||||||
...
|
...
|
||||||
|
|
||||||
def to_numpy(self) -> None:
|
def __eq__(self, other: Any) -> bool:
|
||||||
|
...
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def to_numpy(batch: TBatch) -> TBatch:
|
||||||
|
"""Change all torch.Tensor to numpy.ndarray and return a new Batch."""
|
||||||
|
...
|
||||||
|
|
||||||
|
def to_numpy_(self) -> None:
|
||||||
"""Change all torch.Tensor to numpy.ndarray in-place."""
|
"""Change all torch.Tensor to numpy.ndarray in-place."""
|
||||||
...
|
...
|
||||||
|
|
||||||
@ -396,7 +405,7 @@ class BatchProtocol(Protocol):
|
|||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
def to_dict(self) -> dict[str, Any]:
|
def to_dict(self, recurse: bool = True) -> dict[str, Any]:
|
||||||
...
|
...
|
||||||
|
|
||||||
def to_list_of_dicts(self) -> list[dict[str, Any]]:
|
def to_list_of_dicts(self) -> list[dict[str, Any]]:
|
||||||
@ -433,11 +442,11 @@ class Batch(BatchProtocol):
|
|||||||
# Feels like kwargs could be just merged into batch_dict in the beginning
|
# Feels like kwargs could be just merged into batch_dict in the beginning
|
||||||
self.__init__(kwargs, copy=copy) # type: ignore
|
self.__init__(kwargs, copy=copy) # type: ignore
|
||||||
|
|
||||||
def to_dict(self) -> dict[str, Any]:
|
def to_dict(self, recursive: bool = True) -> dict[str, Any]:
|
||||||
result = {}
|
result = {}
|
||||||
for k, v in self.__dict__.items():
|
for k, v in self.__dict__.items():
|
||||||
if isinstance(v, Batch):
|
if recursive and isinstance(v, Batch):
|
||||||
v = v.to_dict()
|
v = v.to_dict(recursive=recursive)
|
||||||
result[k] = v
|
result[k] = v
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@ -503,6 +512,17 @@ class Batch(BatchProtocol):
|
|||||||
return new_batch
|
return new_batch
|
||||||
raise IndexError("Cannot access item from empty Batch object.")
|
raise IndexError("Cannot access item from empty Batch object.")
|
||||||
|
|
||||||
|
def __eq__(self, other: Any) -> bool:
|
||||||
|
if not isinstance(other, self.__class__):
|
||||||
|
return False
|
||||||
|
|
||||||
|
this_batch_no_torch_tensor: Batch = Batch.to_numpy(self)
|
||||||
|
other_batch_no_torch_tensor: Batch = Batch.to_numpy(other)
|
||||||
|
this_dict = this_batch_no_torch_tensor.to_dict(recursive=True)
|
||||||
|
other_dict = other_batch_no_torch_tensor.to_dict(recursive=True)
|
||||||
|
|
||||||
|
return not DeepDiff(this_dict, other_dict)
|
||||||
|
|
||||||
def __iter__(self) -> Iterator[Self]:
|
def __iter__(self) -> Iterator[Self]:
|
||||||
# TODO: empty batch raises an error on len and needs separate treatment, that's probably not a good idea
|
# TODO: empty batch raises an error on len and needs separate treatment, that's probably not a good idea
|
||||||
if len(self.__dict__) == 0:
|
if len(self.__dict__) == 0:
|
||||||
@ -602,12 +622,24 @@ class Batch(BatchProtocol):
|
|||||||
self_str = self.__class__.__name__ + "()"
|
self_str = self.__class__.__name__ + "()"
|
||||||
return self_str
|
return self_str
|
||||||
|
|
||||||
def to_numpy(self) -> None:
|
@staticmethod
|
||||||
|
def to_numpy(batch: TBatch) -> TBatch:
|
||||||
|
batch_dict = deepcopy(batch)
|
||||||
|
for batch_key, obj in batch_dict.items():
|
||||||
|
if isinstance(obj, torch.Tensor):
|
||||||
|
batch_dict.__dict__[batch_key] = obj.detach().cpu().numpy()
|
||||||
|
elif isinstance(obj, Batch):
|
||||||
|
obj = Batch.to_numpy(obj)
|
||||||
|
batch_dict.__dict__[batch_key] = obj
|
||||||
|
|
||||||
|
return batch_dict
|
||||||
|
|
||||||
|
def to_numpy_(self) -> None:
|
||||||
for batch_key, obj in self.items():
|
for batch_key, obj in self.items():
|
||||||
if isinstance(obj, torch.Tensor):
|
if isinstance(obj, torch.Tensor):
|
||||||
self.__dict__[batch_key] = obj.detach().cpu().numpy()
|
self.__dict__[batch_key] = obj.detach().cpu().numpy()
|
||||||
elif isinstance(obj, Batch):
|
elif isinstance(obj, Batch):
|
||||||
obj.to_numpy()
|
obj.to_numpy_()
|
||||||
|
|
||||||
def to_torch(
|
def to_torch(
|
||||||
self,
|
self,
|
||||||
|
@ -26,7 +26,7 @@ def to_numpy(x: Any) -> Batch | np.ndarray:
|
|||||||
return np.array(None, dtype=object)
|
return np.array(None, dtype=object)
|
||||||
if isinstance(x, dict | Batch):
|
if isinstance(x, dict | Batch):
|
||||||
x = Batch(x) if isinstance(x, dict) else deepcopy(x)
|
x = Batch(x) if isinstance(x, dict) else deepcopy(x)
|
||||||
x.to_numpy()
|
x.to_numpy_()
|
||||||
return x
|
return x
|
||||||
if isinstance(x, list | tuple):
|
if isinstance(x, list | tuple):
|
||||||
return to_numpy(_parse_value(x))
|
return to_numpy(_parse_value(x))
|
||||||
|
Loading…
x
Reference in New Issue
Block a user