Allow two (same/different) Batch objs to be tested for equality (#1098)
Closes: https://github.com/thu-ml/tianshou/issues/1086 ### Api Extensions - Batch received new method: `to_numpy_`. #1098 - `to_dict` in Batch supports also non-recursive conversion. #1098 - Batch `__eq__` now implemented, semantic equality check of batches is now possible. #1098 ### Breaking Changes - The method `to_numpy` in `data.utils.batch.Batch` is not in-place anymore. Instead, a new method `to_numpy_` does the conversion in-place. #1098
This commit is contained in:
parent
049907d9ab
commit
ca4f74f40e
@ -8,6 +8,9 @@
|
||||
- Trainers can control whether collectors should be reset prior to training. #1063
|
||||
- Convenience constructor for `CollectStats` called `with_autogenerated_stats`. #1063
|
||||
- `SamplingConfig` supports `batch_size=None`. #1077
|
||||
- Batch received new method: `to_numpy_`. #1098
|
||||
- `to_dict` in Batch supports also non-recursive conversion. #1098
|
||||
- Batch __eq__ now implemented, semantic equality check of batches is now possible. #1098
|
||||
|
||||
### Internal Improvements
|
||||
- `Collector`s rely less on state, the few stateful things are stored explicitly instead of through a `.data` attribute. #1063
|
||||
@ -34,6 +37,7 @@ expicitly or pass `reset_before_collect=True` . #1063
|
||||
- Changed interface of `dist_fn` in `PGPolicy` and all subclasses to take a single argument in both
|
||||
continuous and discrete cases. #1032
|
||||
- `utils.net.common.Recurrent` now receives and returns a `RecurrentStateBatch` instead of a dict. #1077
|
||||
- The method `to_numpy` in `data.utils.batch.Batch` is not in-place anymore. Instead, a new method `to_numpy_` does the conversion in-place. #1098
|
||||
|
||||
### Tests
|
||||
- Fixed env seeding it test_sac_with_il.py so that the test doesn't fail randomly. #1081
|
||||
|
@ -485,8 +485,8 @@ Miscellaneous Notes
|
||||
tensor([[0., 0., 0., 0.],
|
||||
[0., 0., 0., 0.],
|
||||
[0., 0., 0., 0.]])
|
||||
>>> # data.to_numpy is also available
|
||||
>>> data.to_numpy()
|
||||
>>> # data.to_numpy_ is also available
|
||||
>>> data.to_numpy_()
|
||||
|
||||
.. raw:: html
|
||||
|
||||
|
@ -331,7 +331,7 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"batch_cat.to_numpy()\n",
|
||||
"batch_cat.to_numpy_()\n",
|
||||
"print(batch_cat)\n",
|
||||
"batch_cat.to_torch()\n",
|
||||
"print(batch_cat)"
|
||||
|
35
poetry.lock
generated
35
poetry.lock
generated
@ -903,6 +903,24 @@ files = [
|
||||
{file = "decorator-5.1.1.tar.gz", hash = "sha256:637996211036b6385ef91435e4fae22989472f9d571faba8927ba8253acbc330"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "deepdiff"
|
||||
version = "7.0.1"
|
||||
description = "Deep Difference and Search of any Python object/data. Recreate objects by adding adding deltas to each other."
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "deepdiff-7.0.1-py3-none-any.whl", hash = "sha256:447760081918216aa4fd4ca78a4b6a848b81307b2ea94c810255334b759e1dc3"},
|
||||
{file = "deepdiff-7.0.1.tar.gz", hash = "sha256:260c16f052d4badbf60351b4f77e8390bee03a0b516246f6839bc813fb429ddf"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
ordered-set = ">=4.1.0,<4.2.0"
|
||||
|
||||
[package.extras]
|
||||
cli = ["click (==8.1.7)", "pyyaml (==6.0.1)"]
|
||||
optimize = ["orjson"]
|
||||
|
||||
[[package]]
|
||||
name = "defusedxml"
|
||||
version = "0.7.1"
|
||||
@ -3239,7 +3257,6 @@ optional = false
|
||||
python-versions = ">=3"
|
||||
files = [
|
||||
{file = "nvidia_nvjitlink_cu12-12.3.101-py3-none-manylinux1_x86_64.whl", hash = "sha256:64335a8088e2b9d196ae8665430bc6a2b7e6ef2eb877a9c735c804bd4ff6467c"},
|
||||
{file = "nvidia_nvjitlink_cu12-12.3.101-py3-none-manylinux2014_aarch64.whl", hash = "sha256:211a63e7b30a9d62f1a853e19928fbb1a750e3f17a13a3d1f98ff0ced19478dd"},
|
||||
{file = "nvidia_nvjitlink_cu12-12.3.101-py3-none-win_amd64.whl", hash = "sha256:1b2e317e437433753530792f13eece58f0aec21a2b05903be7bffe58a606cbd1"},
|
||||
]
|
||||
|
||||
@ -3349,6 +3366,20 @@ numpy = ["numpy"]
|
||||
test = ["pytest", "pytest-cov", "pytest-xdist"]
|
||||
torch = ["torch"]
|
||||
|
||||
[[package]]
|
||||
name = "ordered-set"
|
||||
version = "4.1.0"
|
||||
description = "An OrderedSet is a custom MutableSet that remembers its order, so that every"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "ordered-set-4.1.0.tar.gz", hash = "sha256:694a8e44c87657c59292ede72891eb91d34131f6531463aab3009191c77364a8"},
|
||||
{file = "ordered_set-4.1.0-py3-none-any.whl", hash = "sha256:046e1132c71fcf3330438a539928932caf51ddbc582496833e23de611de14562"},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
dev = ["black", "mypy", "pytest"]
|
||||
|
||||
[[package]]
|
||||
name = "overrides"
|
||||
version = "7.4.0"
|
||||
@ -6223,4 +6254,4 @@ vizdoom = ["vizdoom"]
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "^3.11"
|
||||
content-hash = "06b9166b2e752fbab564cbc0dbce226844c26dd2b59f9f7e95104570e377c43b"
|
||||
content-hash = "a7aa80de549e7af1147d14f9bdd48659b7018732af34022cc734565af1f742e9"
|
||||
|
@ -25,6 +25,7 @@ exclude = ["test/*", "examples/*", "docs/*"]
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = "^3.11"
|
||||
deepdiff = "^7.0.1"
|
||||
gymnasium = "^0.28.0"
|
||||
h5py = "^3.9.0"
|
||||
numba = "^0.57.1"
|
||||
|
@ -2,12 +2,13 @@ import copy
|
||||
import pickle
|
||||
import sys
|
||||
from itertools import starmap
|
||||
from typing import cast
|
||||
from typing import Any, cast
|
||||
|
||||
import networkx as nx
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from deepdiff import DeepDiff
|
||||
|
||||
from tianshou.data import Batch, to_numpy, to_torch
|
||||
|
||||
@ -477,7 +478,7 @@ def test_batch_from_to_numpy_without_copy() -> None:
|
||||
a_mem_addr_orig = batch.a.__array_interface__["data"][0]
|
||||
c_mem_addr_orig = batch.b.c.__array_interface__["data"][0]
|
||||
batch.to_torch()
|
||||
batch.to_numpy()
|
||||
batch.to_numpy_()
|
||||
a_mem_addr_new = batch.a.__array_interface__["data"][0]
|
||||
c_mem_addr_new = batch.b.c.__array_interface__["data"][0]
|
||||
assert a_mem_addr_new == a_mem_addr_orig
|
||||
@ -565,6 +566,167 @@ def test_batch_standard_compatibility() -> None:
|
||||
Batch()[0]
|
||||
|
||||
|
||||
class TestBatchEquality:
|
||||
@staticmethod
|
||||
def test_keys_different() -> None:
|
||||
batch1 = Batch(a=[1, 2], b=[100, 50])
|
||||
batch2 = Batch(b=[1, 2], c=[100, 50])
|
||||
assert batch1 != batch2
|
||||
|
||||
@staticmethod
|
||||
def test_keys_missing() -> None:
|
||||
batch1 = Batch(a=[1, 2], b=[2, 3, 4])
|
||||
batch2 = Batch(a=[1, 2], b=[2, 3, 4])
|
||||
batch2.pop("b")
|
||||
assert batch1 != batch2
|
||||
|
||||
@staticmethod
|
||||
def test_types_keys_different() -> None:
|
||||
batch1 = Batch(a=[1, 2, 3], b=[4, 5])
|
||||
batch2 = Batch(a=[1, 2, 3], b=Batch(a=[4, 5]))
|
||||
assert batch1 != batch2
|
||||
|
||||
@staticmethod
|
||||
def test_array_types_different() -> None:
|
||||
batch1 = Batch(a=[1, 2, 3], b=np.array([4, 5]))
|
||||
batch2 = Batch(a=[1, 2, 3], b=torch.Tensor([4, 5]))
|
||||
assert batch1 != batch2
|
||||
|
||||
@staticmethod
|
||||
def test_nested_values_different() -> None:
|
||||
batch1 = Batch(a=Batch(a=[1, 2, 3]), b=[4, 5])
|
||||
batch2 = Batch(a=Batch(a=[1, 2, 4]), b=[4, 5])
|
||||
assert batch1 != batch2
|
||||
|
||||
@staticmethod
|
||||
def test_nested_shapes_different() -> None:
|
||||
batch1 = Batch(a=Batch(a=[1, 2, 3]), b=[4, 5])
|
||||
batch2 = Batch(a=Batch(a=[1, 4]), b=[4, 5])
|
||||
assert batch1 != batch2
|
||||
|
||||
@staticmethod
|
||||
def test_slice_equal() -> None:
|
||||
batch1 = Batch(a=[1, 2, 3])
|
||||
assert batch1[:2] == batch1[:2]
|
||||
|
||||
@staticmethod
|
||||
def test_slice_ellipsis_equal() -> None:
|
||||
batch1 = Batch(a=Batch(a=[1, 2, 3]), b=[4, 5], c=[100, 1001, 2000])
|
||||
assert batch1[..., 1:] == batch1[..., 1:]
|
||||
|
||||
@staticmethod
|
||||
def test_empty_batches() -> None:
|
||||
assert Batch() == Batch()
|
||||
|
||||
@staticmethod
|
||||
def test_different_order_keys() -> None:
|
||||
assert Batch(a=1, b=2) == Batch(b=2, a=1)
|
||||
|
||||
@staticmethod
|
||||
def test_tuple_and_list_types() -> None:
|
||||
assert Batch(a=(1, 2)) == Batch(a=[1, 2])
|
||||
|
||||
@staticmethod
|
||||
def test_subbatch_dict_and_batch_types() -> None:
|
||||
assert Batch(a={"x": 1}) == Batch(a=Batch(x=1))
|
||||
|
||||
|
||||
class TestBatchToDict:
|
||||
@staticmethod
|
||||
def test_to_dict_empty_batch_no_recurse() -> None:
|
||||
batch = Batch()
|
||||
expected: dict[Any, Any] = {}
|
||||
assert batch.to_dict() == expected
|
||||
|
||||
@staticmethod
|
||||
def test_to_dict_with_simple_values_recurse() -> None:
|
||||
batch = Batch(a=1, b="two", c=np.array([3, 4]))
|
||||
expected = {"a": np.asanyarray(1), "b": "two", "c": np.array([3, 4])}
|
||||
assert not DeepDiff(batch.to_dict(recursive=True), expected)
|
||||
|
||||
@staticmethod
|
||||
def test_to_dict_simple() -> None:
|
||||
batch = Batch(a=1, b="two")
|
||||
expected = {"a": np.asanyarray(1), "b": "two"}
|
||||
assert batch.to_dict() == expected
|
||||
|
||||
@staticmethod
|
||||
def test_to_dict_nested_batch_no_recurse() -> None:
|
||||
nested_batch = Batch(c=3)
|
||||
batch = Batch(a=1, b=nested_batch)
|
||||
expected = {"a": np.asanyarray(1), "b": nested_batch}
|
||||
assert not DeepDiff(batch.to_dict(recursive=False), expected)
|
||||
|
||||
@staticmethod
|
||||
def test_to_dict_nested_batch_recurse() -> None:
|
||||
nested_batch = Batch(c=3)
|
||||
batch = Batch(a=1, b=nested_batch)
|
||||
expected = {"a": np.asanyarray(1), "b": {"c": np.asanyarray(3)}}
|
||||
assert not DeepDiff(batch.to_dict(recursive=True), expected)
|
||||
|
||||
@staticmethod
|
||||
def test_to_dict_multiple_nested_batch_recurse() -> None:
|
||||
nested_batch = Batch(c=Batch(e=3), d=[100, 200, 300])
|
||||
batch = Batch(a=1, b=nested_batch)
|
||||
expected = {
|
||||
"a": np.asanyarray(1),
|
||||
"b": {"c": {"e": np.asanyarray(3)}, "d": np.array([100, 200, 300])},
|
||||
}
|
||||
assert not DeepDiff(batch.to_dict(recursive=True), expected)
|
||||
|
||||
@staticmethod
|
||||
def test_to_dict_array() -> None:
|
||||
batch = Batch(a=np.array([1, 2, 3]))
|
||||
expected = {"a": np.array([1, 2, 3])}
|
||||
assert not DeepDiff(batch.to_dict(), expected)
|
||||
|
||||
@staticmethod
|
||||
def test_to_dict_nested_batch_with_array() -> None:
|
||||
nested_batch = Batch(c=np.array([4, 5]))
|
||||
batch = Batch(a=1, b=nested_batch)
|
||||
expected = {"a": np.asanyarray(1), "b": {"c": np.array([4, 5])}}
|
||||
assert not DeepDiff(batch.to_dict(recursive=True), expected)
|
||||
|
||||
@staticmethod
|
||||
def test_to_dict_torch_tensor() -> None:
|
||||
t1 = torch.tensor([1.0, 2.0]).detach().cpu().numpy()
|
||||
batch = Batch(a=t1)
|
||||
t2 = torch.tensor([1.0, 2.0]).detach().cpu().numpy()
|
||||
expected = {"a": t2}
|
||||
assert not DeepDiff(batch.to_dict(), expected)
|
||||
|
||||
@staticmethod
|
||||
def test_to_dict_nested_batch_with_torch_tensor() -> None:
|
||||
nested_batch = Batch(c=torch.tensor([4, 5]).detach().cpu().numpy())
|
||||
batch = Batch(a=1, b=nested_batch)
|
||||
expected = {"a": np.asanyarray(1), "b": {"c": torch.tensor([4, 5]).detach().cpu().numpy()}}
|
||||
assert not DeepDiff(batch.to_dict(recursive=True), expected)
|
||||
|
||||
|
||||
class TestToNumpy:
|
||||
"""Tests for `Batch.to_numpy()` and its in-place counterpart `Batch.to_numpy_()` ."""
|
||||
|
||||
@staticmethod
|
||||
def test_to_numpy() -> None:
|
||||
batch = Batch(a=1, b=torch.arange(5), c={"d": torch.tensor([1, 2, 3])})
|
||||
new_batch: Batch = Batch.to_numpy(batch)
|
||||
assert id(batch) != id(new_batch)
|
||||
assert isinstance(batch.b, torch.Tensor)
|
||||
assert isinstance(batch.c.d, torch.Tensor)
|
||||
|
||||
assert isinstance(new_batch.b, np.ndarray)
|
||||
assert isinstance(new_batch.c.d, np.ndarray)
|
||||
|
||||
@staticmethod
|
||||
def test_to_numpy_() -> None:
|
||||
batch = Batch(a=1, b=torch.arange(5), c={"d": torch.tensor([1, 2, 3])})
|
||||
id_batch = id(batch)
|
||||
batch.to_numpy_()
|
||||
assert id_batch == id(batch)
|
||||
assert isinstance(batch.b, np.ndarray)
|
||||
assert isinstance(batch.c.d, np.ndarray)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_batch()
|
||||
test_batch_over_batch()
|
||||
|
@ -17,6 +17,7 @@ from typing import (
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from deepdiff import DeepDiff
|
||||
|
||||
_SingleIndexType = slice | int | EllipsisType
|
||||
IndexType = np.ndarray | _SingleIndexType | list[_SingleIndexType] | tuple[_SingleIndexType, ...]
|
||||
@ -268,7 +269,15 @@ class BatchProtocol(Protocol):
|
||||
def __iter__(self) -> Iterator[Self]:
|
||||
...
|
||||
|
||||
def to_numpy(self) -> None:
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
...
|
||||
|
||||
@staticmethod
|
||||
def to_numpy(batch: TBatch) -> TBatch:
|
||||
"""Change all torch.Tensor to numpy.ndarray and return a new Batch."""
|
||||
...
|
||||
|
||||
def to_numpy_(self) -> None:
|
||||
"""Change all torch.Tensor to numpy.ndarray in-place."""
|
||||
...
|
||||
|
||||
@ -396,7 +405,7 @@ class BatchProtocol(Protocol):
|
||||
"""
|
||||
...
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
def to_dict(self, recurse: bool = True) -> dict[str, Any]:
|
||||
...
|
||||
|
||||
def to_list_of_dicts(self) -> list[dict[str, Any]]:
|
||||
@ -433,11 +442,11 @@ class Batch(BatchProtocol):
|
||||
# Feels like kwargs could be just merged into batch_dict in the beginning
|
||||
self.__init__(kwargs, copy=copy) # type: ignore
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
def to_dict(self, recursive: bool = True) -> dict[str, Any]:
|
||||
result = {}
|
||||
for k, v in self.__dict__.items():
|
||||
if isinstance(v, Batch):
|
||||
v = v.to_dict()
|
||||
if recursive and isinstance(v, Batch):
|
||||
v = v.to_dict(recursive=recursive)
|
||||
result[k] = v
|
||||
return result
|
||||
|
||||
@ -503,6 +512,17 @@ class Batch(BatchProtocol):
|
||||
return new_batch
|
||||
raise IndexError("Cannot access item from empty Batch object.")
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
if not isinstance(other, self.__class__):
|
||||
return False
|
||||
|
||||
this_batch_no_torch_tensor: Batch = Batch.to_numpy(self)
|
||||
other_batch_no_torch_tensor: Batch = Batch.to_numpy(other)
|
||||
this_dict = this_batch_no_torch_tensor.to_dict(recursive=True)
|
||||
other_dict = other_batch_no_torch_tensor.to_dict(recursive=True)
|
||||
|
||||
return not DeepDiff(this_dict, other_dict)
|
||||
|
||||
def __iter__(self) -> Iterator[Self]:
|
||||
# TODO: empty batch raises an error on len and needs separate treatment, that's probably not a good idea
|
||||
if len(self.__dict__) == 0:
|
||||
@ -602,12 +622,24 @@ class Batch(BatchProtocol):
|
||||
self_str = self.__class__.__name__ + "()"
|
||||
return self_str
|
||||
|
||||
def to_numpy(self) -> None:
|
||||
@staticmethod
|
||||
def to_numpy(batch: TBatch) -> TBatch:
|
||||
batch_dict = deepcopy(batch)
|
||||
for batch_key, obj in batch_dict.items():
|
||||
if isinstance(obj, torch.Tensor):
|
||||
batch_dict.__dict__[batch_key] = obj.detach().cpu().numpy()
|
||||
elif isinstance(obj, Batch):
|
||||
obj = Batch.to_numpy(obj)
|
||||
batch_dict.__dict__[batch_key] = obj
|
||||
|
||||
return batch_dict
|
||||
|
||||
def to_numpy_(self) -> None:
|
||||
for batch_key, obj in self.items():
|
||||
if isinstance(obj, torch.Tensor):
|
||||
self.__dict__[batch_key] = obj.detach().cpu().numpy()
|
||||
elif isinstance(obj, Batch):
|
||||
obj.to_numpy()
|
||||
obj.to_numpy_()
|
||||
|
||||
def to_torch(
|
||||
self,
|
||||
|
@ -26,7 +26,7 @@ def to_numpy(x: Any) -> Batch | np.ndarray:
|
||||
return np.array(None, dtype=object)
|
||||
if isinstance(x, dict | Batch):
|
||||
x = Batch(x) if isinstance(x, dict) else deepcopy(x)
|
||||
x.to_numpy()
|
||||
x.to_numpy_()
|
||||
return x
|
||||
if isinstance(x, list | tuple):
|
||||
return to_numpy(_parse_value(x))
|
||||
|
Loading…
x
Reference in New Issue
Block a user