diff --git a/CHANGELOG.md b/CHANGELOG.md index 126f81a..24c72ed 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,9 @@ - Trainers can control whether collectors should be reset prior to training. #1063 - Convenience constructor for `CollectStats` called `with_autogenerated_stats`. #1063 - `SamplingConfig` supports `batch_size=None`. #1077 +- Batch received new method: `to_numpy_`. #1098 +- `to_dict` in Batch supports also non-recursive conversion. #1098 +- Batch __eq__ now implemented, semantic equality check of batches is now possible. #1098 ### Internal Improvements - `Collector`s rely less on state, the few stateful things are stored explicitly instead of through a `.data` attribute. #1063 @@ -34,6 +37,7 @@ expicitly or pass `reset_before_collect=True` . #1063 - Changed interface of `dist_fn` in `PGPolicy` and all subclasses to take a single argument in both continuous and discrete cases. #1032 - `utils.net.common.Recurrent` now receives and returns a `RecurrentStateBatch` instead of a dict. #1077 +- The method `to_numpy` in `data.utils.batch.Batch` is not in-place anymore. Instead, a new method `to_numpy_` does the conversion in-place. #1098 ### Tests - Fixed env seeding it test_sac_with_il.py so that the test doesn't fail randomly. #1081 diff --git a/docs/01_tutorials/03_batch.rst b/docs/01_tutorials/03_batch.rst index 71f82f8..46fa86b 100644 --- a/docs/01_tutorials/03_batch.rst +++ b/docs/01_tutorials/03_batch.rst @@ -485,8 +485,8 @@ Miscellaneous Notes tensor([[0., 0., 0., 0.], [0., 0., 0., 0.], [0., 0., 0., 0.]]) - >>> # data.to_numpy is also available - >>> data.to_numpy() + >>> # data.to_numpy_ is also available + >>> data.to_numpy_() .. raw:: html diff --git a/docs/02_notebooks/L1_Batch.ipynb b/docs/02_notebooks/L1_Batch.ipynb index 9c80349..54008ee 100644 --- a/docs/02_notebooks/L1_Batch.ipynb +++ b/docs/02_notebooks/L1_Batch.ipynb @@ -331,7 +331,7 @@ }, "outputs": [], "source": [ - "batch_cat.to_numpy()\n", + "batch_cat.to_numpy_()\n", "print(batch_cat)\n", "batch_cat.to_torch()\n", "print(batch_cat)" diff --git a/poetry.lock b/poetry.lock index a6fbf32..56230ab 100644 --- a/poetry.lock +++ b/poetry.lock @@ -903,6 +903,24 @@ files = [ {file = "decorator-5.1.1.tar.gz", hash = "sha256:637996211036b6385ef91435e4fae22989472f9d571faba8927ba8253acbc330"}, ] +[[package]] +name = "deepdiff" +version = "7.0.1" +description = "Deep Difference and Search of any Python object/data. Recreate objects by adding adding deltas to each other." +optional = false +python-versions = ">=3.8" +files = [ + {file = "deepdiff-7.0.1-py3-none-any.whl", hash = "sha256:447760081918216aa4fd4ca78a4b6a848b81307b2ea94c810255334b759e1dc3"}, + {file = "deepdiff-7.0.1.tar.gz", hash = "sha256:260c16f052d4badbf60351b4f77e8390bee03a0b516246f6839bc813fb429ddf"}, +] + +[package.dependencies] +ordered-set = ">=4.1.0,<4.2.0" + +[package.extras] +cli = ["click (==8.1.7)", "pyyaml (==6.0.1)"] +optimize = ["orjson"] + [[package]] name = "defusedxml" version = "0.7.1" @@ -3239,7 +3257,6 @@ optional = false python-versions = ">=3" files = [ {file = "nvidia_nvjitlink_cu12-12.3.101-py3-none-manylinux1_x86_64.whl", hash = "sha256:64335a8088e2b9d196ae8665430bc6a2b7e6ef2eb877a9c735c804bd4ff6467c"}, - {file = "nvidia_nvjitlink_cu12-12.3.101-py3-none-manylinux2014_aarch64.whl", hash = "sha256:211a63e7b30a9d62f1a853e19928fbb1a750e3f17a13a3d1f98ff0ced19478dd"}, {file = "nvidia_nvjitlink_cu12-12.3.101-py3-none-win_amd64.whl", hash = "sha256:1b2e317e437433753530792f13eece58f0aec21a2b05903be7bffe58a606cbd1"}, ] @@ -3349,6 +3366,20 @@ numpy = ["numpy"] test = ["pytest", "pytest-cov", "pytest-xdist"] torch = ["torch"] +[[package]] +name = "ordered-set" +version = "4.1.0" +description = "An OrderedSet is a custom MutableSet that remembers its order, so that every" +optional = false +python-versions = ">=3.7" +files = [ + {file = "ordered-set-4.1.0.tar.gz", hash = "sha256:694a8e44c87657c59292ede72891eb91d34131f6531463aab3009191c77364a8"}, + {file = "ordered_set-4.1.0-py3-none-any.whl", hash = "sha256:046e1132c71fcf3330438a539928932caf51ddbc582496833e23de611de14562"}, +] + +[package.extras] +dev = ["black", "mypy", "pytest"] + [[package]] name = "overrides" version = "7.4.0" @@ -6223,4 +6254,4 @@ vizdoom = ["vizdoom"] [metadata] lock-version = "2.0" python-versions = "^3.11" -content-hash = "06b9166b2e752fbab564cbc0dbce226844c26dd2b59f9f7e95104570e377c43b" +content-hash = "a7aa80de549e7af1147d14f9bdd48659b7018732af34022cc734565af1f742e9" diff --git a/pyproject.toml b/pyproject.toml index 813cbd3..d9ea48b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,6 +25,7 @@ exclude = ["test/*", "examples/*", "docs/*"] [tool.poetry.dependencies] python = "^3.11" +deepdiff = "^7.0.1" gymnasium = "^0.28.0" h5py = "^3.9.0" numba = "^0.57.1" diff --git a/test/base/test_batch.py b/test/base/test_batch.py index 801ab44..82ff4a3 100644 --- a/test/base/test_batch.py +++ b/test/base/test_batch.py @@ -2,12 +2,13 @@ import copy import pickle import sys from itertools import starmap -from typing import cast +from typing import Any, cast import networkx as nx import numpy as np import pytest import torch +from deepdiff import DeepDiff from tianshou.data import Batch, to_numpy, to_torch @@ -477,7 +478,7 @@ def test_batch_from_to_numpy_without_copy() -> None: a_mem_addr_orig = batch.a.__array_interface__["data"][0] c_mem_addr_orig = batch.b.c.__array_interface__["data"][0] batch.to_torch() - batch.to_numpy() + batch.to_numpy_() a_mem_addr_new = batch.a.__array_interface__["data"][0] c_mem_addr_new = batch.b.c.__array_interface__["data"][0] assert a_mem_addr_new == a_mem_addr_orig @@ -565,6 +566,167 @@ def test_batch_standard_compatibility() -> None: Batch()[0] +class TestBatchEquality: + @staticmethod + def test_keys_different() -> None: + batch1 = Batch(a=[1, 2], b=[100, 50]) + batch2 = Batch(b=[1, 2], c=[100, 50]) + assert batch1 != batch2 + + @staticmethod + def test_keys_missing() -> None: + batch1 = Batch(a=[1, 2], b=[2, 3, 4]) + batch2 = Batch(a=[1, 2], b=[2, 3, 4]) + batch2.pop("b") + assert batch1 != batch2 + + @staticmethod + def test_types_keys_different() -> None: + batch1 = Batch(a=[1, 2, 3], b=[4, 5]) + batch2 = Batch(a=[1, 2, 3], b=Batch(a=[4, 5])) + assert batch1 != batch2 + + @staticmethod + def test_array_types_different() -> None: + batch1 = Batch(a=[1, 2, 3], b=np.array([4, 5])) + batch2 = Batch(a=[1, 2, 3], b=torch.Tensor([4, 5])) + assert batch1 != batch2 + + @staticmethod + def test_nested_values_different() -> None: + batch1 = Batch(a=Batch(a=[1, 2, 3]), b=[4, 5]) + batch2 = Batch(a=Batch(a=[1, 2, 4]), b=[4, 5]) + assert batch1 != batch2 + + @staticmethod + def test_nested_shapes_different() -> None: + batch1 = Batch(a=Batch(a=[1, 2, 3]), b=[4, 5]) + batch2 = Batch(a=Batch(a=[1, 4]), b=[4, 5]) + assert batch1 != batch2 + + @staticmethod + def test_slice_equal() -> None: + batch1 = Batch(a=[1, 2, 3]) + assert batch1[:2] == batch1[:2] + + @staticmethod + def test_slice_ellipsis_equal() -> None: + batch1 = Batch(a=Batch(a=[1, 2, 3]), b=[4, 5], c=[100, 1001, 2000]) + assert batch1[..., 1:] == batch1[..., 1:] + + @staticmethod + def test_empty_batches() -> None: + assert Batch() == Batch() + + @staticmethod + def test_different_order_keys() -> None: + assert Batch(a=1, b=2) == Batch(b=2, a=1) + + @staticmethod + def test_tuple_and_list_types() -> None: + assert Batch(a=(1, 2)) == Batch(a=[1, 2]) + + @staticmethod + def test_subbatch_dict_and_batch_types() -> None: + assert Batch(a={"x": 1}) == Batch(a=Batch(x=1)) + + +class TestBatchToDict: + @staticmethod + def test_to_dict_empty_batch_no_recurse() -> None: + batch = Batch() + expected: dict[Any, Any] = {} + assert batch.to_dict() == expected + + @staticmethod + def test_to_dict_with_simple_values_recurse() -> None: + batch = Batch(a=1, b="two", c=np.array([3, 4])) + expected = {"a": np.asanyarray(1), "b": "two", "c": np.array([3, 4])} + assert not DeepDiff(batch.to_dict(recursive=True), expected) + + @staticmethod + def test_to_dict_simple() -> None: + batch = Batch(a=1, b="two") + expected = {"a": np.asanyarray(1), "b": "two"} + assert batch.to_dict() == expected + + @staticmethod + def test_to_dict_nested_batch_no_recurse() -> None: + nested_batch = Batch(c=3) + batch = Batch(a=1, b=nested_batch) + expected = {"a": np.asanyarray(1), "b": nested_batch} + assert not DeepDiff(batch.to_dict(recursive=False), expected) + + @staticmethod + def test_to_dict_nested_batch_recurse() -> None: + nested_batch = Batch(c=3) + batch = Batch(a=1, b=nested_batch) + expected = {"a": np.asanyarray(1), "b": {"c": np.asanyarray(3)}} + assert not DeepDiff(batch.to_dict(recursive=True), expected) + + @staticmethod + def test_to_dict_multiple_nested_batch_recurse() -> None: + nested_batch = Batch(c=Batch(e=3), d=[100, 200, 300]) + batch = Batch(a=1, b=nested_batch) + expected = { + "a": np.asanyarray(1), + "b": {"c": {"e": np.asanyarray(3)}, "d": np.array([100, 200, 300])}, + } + assert not DeepDiff(batch.to_dict(recursive=True), expected) + + @staticmethod + def test_to_dict_array() -> None: + batch = Batch(a=np.array([1, 2, 3])) + expected = {"a": np.array([1, 2, 3])} + assert not DeepDiff(batch.to_dict(), expected) + + @staticmethod + def test_to_dict_nested_batch_with_array() -> None: + nested_batch = Batch(c=np.array([4, 5])) + batch = Batch(a=1, b=nested_batch) + expected = {"a": np.asanyarray(1), "b": {"c": np.array([4, 5])}} + assert not DeepDiff(batch.to_dict(recursive=True), expected) + + @staticmethod + def test_to_dict_torch_tensor() -> None: + t1 = torch.tensor([1.0, 2.0]).detach().cpu().numpy() + batch = Batch(a=t1) + t2 = torch.tensor([1.0, 2.0]).detach().cpu().numpy() + expected = {"a": t2} + assert not DeepDiff(batch.to_dict(), expected) + + @staticmethod + def test_to_dict_nested_batch_with_torch_tensor() -> None: + nested_batch = Batch(c=torch.tensor([4, 5]).detach().cpu().numpy()) + batch = Batch(a=1, b=nested_batch) + expected = {"a": np.asanyarray(1), "b": {"c": torch.tensor([4, 5]).detach().cpu().numpy()}} + assert not DeepDiff(batch.to_dict(recursive=True), expected) + + +class TestToNumpy: + """Tests for `Batch.to_numpy()` and its in-place counterpart `Batch.to_numpy_()` .""" + + @staticmethod + def test_to_numpy() -> None: + batch = Batch(a=1, b=torch.arange(5), c={"d": torch.tensor([1, 2, 3])}) + new_batch: Batch = Batch.to_numpy(batch) + assert id(batch) != id(new_batch) + assert isinstance(batch.b, torch.Tensor) + assert isinstance(batch.c.d, torch.Tensor) + + assert isinstance(new_batch.b, np.ndarray) + assert isinstance(new_batch.c.d, np.ndarray) + + @staticmethod + def test_to_numpy_() -> None: + batch = Batch(a=1, b=torch.arange(5), c={"d": torch.tensor([1, 2, 3])}) + id_batch = id(batch) + batch.to_numpy_() + assert id_batch == id(batch) + assert isinstance(batch.b, np.ndarray) + assert isinstance(batch.c.d, np.ndarray) + + if __name__ == "__main__": test_batch() test_batch_over_batch() diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index 1136beb..d911788 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -17,6 +17,7 @@ from typing import ( import numpy as np import torch +from deepdiff import DeepDiff _SingleIndexType = slice | int | EllipsisType IndexType = np.ndarray | _SingleIndexType | list[_SingleIndexType] | tuple[_SingleIndexType, ...] @@ -268,7 +269,15 @@ class BatchProtocol(Protocol): def __iter__(self) -> Iterator[Self]: ... - def to_numpy(self) -> None: + def __eq__(self, other: Any) -> bool: + ... + + @staticmethod + def to_numpy(batch: TBatch) -> TBatch: + """Change all torch.Tensor to numpy.ndarray and return a new Batch.""" + ... + + def to_numpy_(self) -> None: """Change all torch.Tensor to numpy.ndarray in-place.""" ... @@ -396,7 +405,7 @@ class BatchProtocol(Protocol): """ ... - def to_dict(self) -> dict[str, Any]: + def to_dict(self, recurse: bool = True) -> dict[str, Any]: ... def to_list_of_dicts(self) -> list[dict[str, Any]]: @@ -433,11 +442,11 @@ class Batch(BatchProtocol): # Feels like kwargs could be just merged into batch_dict in the beginning self.__init__(kwargs, copy=copy) # type: ignore - def to_dict(self) -> dict[str, Any]: + def to_dict(self, recursive: bool = True) -> dict[str, Any]: result = {} for k, v in self.__dict__.items(): - if isinstance(v, Batch): - v = v.to_dict() + if recursive and isinstance(v, Batch): + v = v.to_dict(recursive=recursive) result[k] = v return result @@ -503,6 +512,17 @@ class Batch(BatchProtocol): return new_batch raise IndexError("Cannot access item from empty Batch object.") + def __eq__(self, other: Any) -> bool: + if not isinstance(other, self.__class__): + return False + + this_batch_no_torch_tensor: Batch = Batch.to_numpy(self) + other_batch_no_torch_tensor: Batch = Batch.to_numpy(other) + this_dict = this_batch_no_torch_tensor.to_dict(recursive=True) + other_dict = other_batch_no_torch_tensor.to_dict(recursive=True) + + return not DeepDiff(this_dict, other_dict) + def __iter__(self) -> Iterator[Self]: # TODO: empty batch raises an error on len and needs separate treatment, that's probably not a good idea if len(self.__dict__) == 0: @@ -602,12 +622,24 @@ class Batch(BatchProtocol): self_str = self.__class__.__name__ + "()" return self_str - def to_numpy(self) -> None: + @staticmethod + def to_numpy(batch: TBatch) -> TBatch: + batch_dict = deepcopy(batch) + for batch_key, obj in batch_dict.items(): + if isinstance(obj, torch.Tensor): + batch_dict.__dict__[batch_key] = obj.detach().cpu().numpy() + elif isinstance(obj, Batch): + obj = Batch.to_numpy(obj) + batch_dict.__dict__[batch_key] = obj + + return batch_dict + + def to_numpy_(self) -> None: for batch_key, obj in self.items(): if isinstance(obj, torch.Tensor): self.__dict__[batch_key] = obj.detach().cpu().numpy() elif isinstance(obj, Batch): - obj.to_numpy() + obj.to_numpy_() def to_torch( self, diff --git a/tianshou/data/utils/converter.py b/tianshou/data/utils/converter.py index 2df462d..7edf3ff 100644 --- a/tianshou/data/utils/converter.py +++ b/tianshou/data/utils/converter.py @@ -26,7 +26,7 @@ def to_numpy(x: Any) -> Batch | np.ndarray: return np.array(None, dtype=object) if isinstance(x, dict | Batch): x = Batch(x) if isinstance(x, dict) else deepcopy(x) - x.to_numpy() + x.to_numpy_() return x if isinstance(x, list | tuple): return to_numpy(_parse_value(x))