Allow two (same/different) Batch objs to be tested for equality (#1098)

Closes: https://github.com/thu-ml/tianshou/issues/1086

### Api Extensions

- Batch received new method: `to_numpy_`. #1098
- `to_dict` in Batch supports also non-recursive conversion. #1098
- Batch `__eq__` now implemented, semantic equality check of batches is
now possible. #1098

### Breaking Changes

- The method `to_numpy` in `data.utils.batch.Batch` is not in-place
anymore. Instead, a new method `to_numpy_` does the conversion in-place.
#1098
This commit is contained in:
Daniel Plop 2024-04-16 18:12:48 +02:00 committed by GitHub
parent 049907d9ab
commit ca4f74f40e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 245 additions and 15 deletions

View File

@ -8,6 +8,9 @@
- Trainers can control whether collectors should be reset prior to training. #1063
- Convenience constructor for `CollectStats` called `with_autogenerated_stats`. #1063
- `SamplingConfig` supports `batch_size=None`. #1077
- Batch received new method: `to_numpy_`. #1098
- `to_dict` in Batch supports also non-recursive conversion. #1098
- Batch __eq__ now implemented, semantic equality check of batches is now possible. #1098
### Internal Improvements
- `Collector`s rely less on state, the few stateful things are stored explicitly instead of through a `.data` attribute. #1063
@ -34,6 +37,7 @@ expicitly or pass `reset_before_collect=True` . #1063
- Changed interface of `dist_fn` in `PGPolicy` and all subclasses to take a single argument in both
continuous and discrete cases. #1032
- `utils.net.common.Recurrent` now receives and returns a `RecurrentStateBatch` instead of a dict. #1077
- The method `to_numpy` in `data.utils.batch.Batch` is not in-place anymore. Instead, a new method `to_numpy_` does the conversion in-place. #1098
### Tests
- Fixed env seeding it test_sac_with_il.py so that the test doesn't fail randomly. #1081

View File

@ -485,8 +485,8 @@ Miscellaneous Notes
tensor([[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.]])
>>> # data.to_numpy is also available
>>> data.to_numpy()
>>> # data.to_numpy_ is also available
>>> data.to_numpy_()
.. raw:: html

View File

@ -331,7 +331,7 @@
},
"outputs": [],
"source": [
"batch_cat.to_numpy()\n",
"batch_cat.to_numpy_()\n",
"print(batch_cat)\n",
"batch_cat.to_torch()\n",
"print(batch_cat)"

35
poetry.lock generated
View File

@ -903,6 +903,24 @@ files = [
{file = "decorator-5.1.1.tar.gz", hash = "sha256:637996211036b6385ef91435e4fae22989472f9d571faba8927ba8253acbc330"},
]
[[package]]
name = "deepdiff"
version = "7.0.1"
description = "Deep Difference and Search of any Python object/data. Recreate objects by adding adding deltas to each other."
optional = false
python-versions = ">=3.8"
files = [
{file = "deepdiff-7.0.1-py3-none-any.whl", hash = "sha256:447760081918216aa4fd4ca78a4b6a848b81307b2ea94c810255334b759e1dc3"},
{file = "deepdiff-7.0.1.tar.gz", hash = "sha256:260c16f052d4badbf60351b4f77e8390bee03a0b516246f6839bc813fb429ddf"},
]
[package.dependencies]
ordered-set = ">=4.1.0,<4.2.0"
[package.extras]
cli = ["click (==8.1.7)", "pyyaml (==6.0.1)"]
optimize = ["orjson"]
[[package]]
name = "defusedxml"
version = "0.7.1"
@ -3239,7 +3257,6 @@ optional = false
python-versions = ">=3"
files = [
{file = "nvidia_nvjitlink_cu12-12.3.101-py3-none-manylinux1_x86_64.whl", hash = "sha256:64335a8088e2b9d196ae8665430bc6a2b7e6ef2eb877a9c735c804bd4ff6467c"},
{file = "nvidia_nvjitlink_cu12-12.3.101-py3-none-manylinux2014_aarch64.whl", hash = "sha256:211a63e7b30a9d62f1a853e19928fbb1a750e3f17a13a3d1f98ff0ced19478dd"},
{file = "nvidia_nvjitlink_cu12-12.3.101-py3-none-win_amd64.whl", hash = "sha256:1b2e317e437433753530792f13eece58f0aec21a2b05903be7bffe58a606cbd1"},
]
@ -3349,6 +3366,20 @@ numpy = ["numpy"]
test = ["pytest", "pytest-cov", "pytest-xdist"]
torch = ["torch"]
[[package]]
name = "ordered-set"
version = "4.1.0"
description = "An OrderedSet is a custom MutableSet that remembers its order, so that every"
optional = false
python-versions = ">=3.7"
files = [
{file = "ordered-set-4.1.0.tar.gz", hash = "sha256:694a8e44c87657c59292ede72891eb91d34131f6531463aab3009191c77364a8"},
{file = "ordered_set-4.1.0-py3-none-any.whl", hash = "sha256:046e1132c71fcf3330438a539928932caf51ddbc582496833e23de611de14562"},
]
[package.extras]
dev = ["black", "mypy", "pytest"]
[[package]]
name = "overrides"
version = "7.4.0"
@ -6223,4 +6254,4 @@ vizdoom = ["vizdoom"]
[metadata]
lock-version = "2.0"
python-versions = "^3.11"
content-hash = "06b9166b2e752fbab564cbc0dbce226844c26dd2b59f9f7e95104570e377c43b"
content-hash = "a7aa80de549e7af1147d14f9bdd48659b7018732af34022cc734565af1f742e9"

View File

@ -25,6 +25,7 @@ exclude = ["test/*", "examples/*", "docs/*"]
[tool.poetry.dependencies]
python = "^3.11"
deepdiff = "^7.0.1"
gymnasium = "^0.28.0"
h5py = "^3.9.0"
numba = "^0.57.1"

View File

@ -2,12 +2,13 @@ import copy
import pickle
import sys
from itertools import starmap
from typing import cast
from typing import Any, cast
import networkx as nx
import numpy as np
import pytest
import torch
from deepdiff import DeepDiff
from tianshou.data import Batch, to_numpy, to_torch
@ -477,7 +478,7 @@ def test_batch_from_to_numpy_without_copy() -> None:
a_mem_addr_orig = batch.a.__array_interface__["data"][0]
c_mem_addr_orig = batch.b.c.__array_interface__["data"][0]
batch.to_torch()
batch.to_numpy()
batch.to_numpy_()
a_mem_addr_new = batch.a.__array_interface__["data"][0]
c_mem_addr_new = batch.b.c.__array_interface__["data"][0]
assert a_mem_addr_new == a_mem_addr_orig
@ -565,6 +566,167 @@ def test_batch_standard_compatibility() -> None:
Batch()[0]
class TestBatchEquality:
@staticmethod
def test_keys_different() -> None:
batch1 = Batch(a=[1, 2], b=[100, 50])
batch2 = Batch(b=[1, 2], c=[100, 50])
assert batch1 != batch2
@staticmethod
def test_keys_missing() -> None:
batch1 = Batch(a=[1, 2], b=[2, 3, 4])
batch2 = Batch(a=[1, 2], b=[2, 3, 4])
batch2.pop("b")
assert batch1 != batch2
@staticmethod
def test_types_keys_different() -> None:
batch1 = Batch(a=[1, 2, 3], b=[4, 5])
batch2 = Batch(a=[1, 2, 3], b=Batch(a=[4, 5]))
assert batch1 != batch2
@staticmethod
def test_array_types_different() -> None:
batch1 = Batch(a=[1, 2, 3], b=np.array([4, 5]))
batch2 = Batch(a=[1, 2, 3], b=torch.Tensor([4, 5]))
assert batch1 != batch2
@staticmethod
def test_nested_values_different() -> None:
batch1 = Batch(a=Batch(a=[1, 2, 3]), b=[4, 5])
batch2 = Batch(a=Batch(a=[1, 2, 4]), b=[4, 5])
assert batch1 != batch2
@staticmethod
def test_nested_shapes_different() -> None:
batch1 = Batch(a=Batch(a=[1, 2, 3]), b=[4, 5])
batch2 = Batch(a=Batch(a=[1, 4]), b=[4, 5])
assert batch1 != batch2
@staticmethod
def test_slice_equal() -> None:
batch1 = Batch(a=[1, 2, 3])
assert batch1[:2] == batch1[:2]
@staticmethod
def test_slice_ellipsis_equal() -> None:
batch1 = Batch(a=Batch(a=[1, 2, 3]), b=[4, 5], c=[100, 1001, 2000])
assert batch1[..., 1:] == batch1[..., 1:]
@staticmethod
def test_empty_batches() -> None:
assert Batch() == Batch()
@staticmethod
def test_different_order_keys() -> None:
assert Batch(a=1, b=2) == Batch(b=2, a=1)
@staticmethod
def test_tuple_and_list_types() -> None:
assert Batch(a=(1, 2)) == Batch(a=[1, 2])
@staticmethod
def test_subbatch_dict_and_batch_types() -> None:
assert Batch(a={"x": 1}) == Batch(a=Batch(x=1))
class TestBatchToDict:
@staticmethod
def test_to_dict_empty_batch_no_recurse() -> None:
batch = Batch()
expected: dict[Any, Any] = {}
assert batch.to_dict() == expected
@staticmethod
def test_to_dict_with_simple_values_recurse() -> None:
batch = Batch(a=1, b="two", c=np.array([3, 4]))
expected = {"a": np.asanyarray(1), "b": "two", "c": np.array([3, 4])}
assert not DeepDiff(batch.to_dict(recursive=True), expected)
@staticmethod
def test_to_dict_simple() -> None:
batch = Batch(a=1, b="two")
expected = {"a": np.asanyarray(1), "b": "two"}
assert batch.to_dict() == expected
@staticmethod
def test_to_dict_nested_batch_no_recurse() -> None:
nested_batch = Batch(c=3)
batch = Batch(a=1, b=nested_batch)
expected = {"a": np.asanyarray(1), "b": nested_batch}
assert not DeepDiff(batch.to_dict(recursive=False), expected)
@staticmethod
def test_to_dict_nested_batch_recurse() -> None:
nested_batch = Batch(c=3)
batch = Batch(a=1, b=nested_batch)
expected = {"a": np.asanyarray(1), "b": {"c": np.asanyarray(3)}}
assert not DeepDiff(batch.to_dict(recursive=True), expected)
@staticmethod
def test_to_dict_multiple_nested_batch_recurse() -> None:
nested_batch = Batch(c=Batch(e=3), d=[100, 200, 300])
batch = Batch(a=1, b=nested_batch)
expected = {
"a": np.asanyarray(1),
"b": {"c": {"e": np.asanyarray(3)}, "d": np.array([100, 200, 300])},
}
assert not DeepDiff(batch.to_dict(recursive=True), expected)
@staticmethod
def test_to_dict_array() -> None:
batch = Batch(a=np.array([1, 2, 3]))
expected = {"a": np.array([1, 2, 3])}
assert not DeepDiff(batch.to_dict(), expected)
@staticmethod
def test_to_dict_nested_batch_with_array() -> None:
nested_batch = Batch(c=np.array([4, 5]))
batch = Batch(a=1, b=nested_batch)
expected = {"a": np.asanyarray(1), "b": {"c": np.array([4, 5])}}
assert not DeepDiff(batch.to_dict(recursive=True), expected)
@staticmethod
def test_to_dict_torch_tensor() -> None:
t1 = torch.tensor([1.0, 2.0]).detach().cpu().numpy()
batch = Batch(a=t1)
t2 = torch.tensor([1.0, 2.0]).detach().cpu().numpy()
expected = {"a": t2}
assert not DeepDiff(batch.to_dict(), expected)
@staticmethod
def test_to_dict_nested_batch_with_torch_tensor() -> None:
nested_batch = Batch(c=torch.tensor([4, 5]).detach().cpu().numpy())
batch = Batch(a=1, b=nested_batch)
expected = {"a": np.asanyarray(1), "b": {"c": torch.tensor([4, 5]).detach().cpu().numpy()}}
assert not DeepDiff(batch.to_dict(recursive=True), expected)
class TestToNumpy:
"""Tests for `Batch.to_numpy()` and its in-place counterpart `Batch.to_numpy_()` ."""
@staticmethod
def test_to_numpy() -> None:
batch = Batch(a=1, b=torch.arange(5), c={"d": torch.tensor([1, 2, 3])})
new_batch: Batch = Batch.to_numpy(batch)
assert id(batch) != id(new_batch)
assert isinstance(batch.b, torch.Tensor)
assert isinstance(batch.c.d, torch.Tensor)
assert isinstance(new_batch.b, np.ndarray)
assert isinstance(new_batch.c.d, np.ndarray)
@staticmethod
def test_to_numpy_() -> None:
batch = Batch(a=1, b=torch.arange(5), c={"d": torch.tensor([1, 2, 3])})
id_batch = id(batch)
batch.to_numpy_()
assert id_batch == id(batch)
assert isinstance(batch.b, np.ndarray)
assert isinstance(batch.c.d, np.ndarray)
if __name__ == "__main__":
test_batch()
test_batch_over_batch()

View File

@ -17,6 +17,7 @@ from typing import (
import numpy as np
import torch
from deepdiff import DeepDiff
_SingleIndexType = slice | int | EllipsisType
IndexType = np.ndarray | _SingleIndexType | list[_SingleIndexType] | tuple[_SingleIndexType, ...]
@ -268,7 +269,15 @@ class BatchProtocol(Protocol):
def __iter__(self) -> Iterator[Self]:
...
def to_numpy(self) -> None:
def __eq__(self, other: Any) -> bool:
...
@staticmethod
def to_numpy(batch: TBatch) -> TBatch:
"""Change all torch.Tensor to numpy.ndarray and return a new Batch."""
...
def to_numpy_(self) -> None:
"""Change all torch.Tensor to numpy.ndarray in-place."""
...
@ -396,7 +405,7 @@ class BatchProtocol(Protocol):
"""
...
def to_dict(self) -> dict[str, Any]:
def to_dict(self, recurse: bool = True) -> dict[str, Any]:
...
def to_list_of_dicts(self) -> list[dict[str, Any]]:
@ -433,11 +442,11 @@ class Batch(BatchProtocol):
# Feels like kwargs could be just merged into batch_dict in the beginning
self.__init__(kwargs, copy=copy) # type: ignore
def to_dict(self) -> dict[str, Any]:
def to_dict(self, recursive: bool = True) -> dict[str, Any]:
result = {}
for k, v in self.__dict__.items():
if isinstance(v, Batch):
v = v.to_dict()
if recursive and isinstance(v, Batch):
v = v.to_dict(recursive=recursive)
result[k] = v
return result
@ -503,6 +512,17 @@ class Batch(BatchProtocol):
return new_batch
raise IndexError("Cannot access item from empty Batch object.")
def __eq__(self, other: Any) -> bool:
if not isinstance(other, self.__class__):
return False
this_batch_no_torch_tensor: Batch = Batch.to_numpy(self)
other_batch_no_torch_tensor: Batch = Batch.to_numpy(other)
this_dict = this_batch_no_torch_tensor.to_dict(recursive=True)
other_dict = other_batch_no_torch_tensor.to_dict(recursive=True)
return not DeepDiff(this_dict, other_dict)
def __iter__(self) -> Iterator[Self]:
# TODO: empty batch raises an error on len and needs separate treatment, that's probably not a good idea
if len(self.__dict__) == 0:
@ -602,12 +622,24 @@ class Batch(BatchProtocol):
self_str = self.__class__.__name__ + "()"
return self_str
def to_numpy(self) -> None:
@staticmethod
def to_numpy(batch: TBatch) -> TBatch:
batch_dict = deepcopy(batch)
for batch_key, obj in batch_dict.items():
if isinstance(obj, torch.Tensor):
batch_dict.__dict__[batch_key] = obj.detach().cpu().numpy()
elif isinstance(obj, Batch):
obj = Batch.to_numpy(obj)
batch_dict.__dict__[batch_key] = obj
return batch_dict
def to_numpy_(self) -> None:
for batch_key, obj in self.items():
if isinstance(obj, torch.Tensor):
self.__dict__[batch_key] = obj.detach().cpu().numpy()
elif isinstance(obj, Batch):
obj.to_numpy()
obj.to_numpy_()
def to_torch(
self,

View File

@ -26,7 +26,7 @@ def to_numpy(x: Any) -> Batch | np.ndarray:
return np.array(None, dtype=object)
if isinstance(x, dict | Batch):
x = Batch(x) if isinstance(x, dict) else deepcopy(x)
x.to_numpy()
x.to_numpy_()
return x
if isinstance(x, list | tuple):
return to_numpy(_parse_value(x))