Allow two (same/different) Batch objs to be tested for equality (#1098)
Closes: https://github.com/thu-ml/tianshou/issues/1086 ### Api Extensions - Batch received new method: `to_numpy_`. #1098 - `to_dict` in Batch supports also non-recursive conversion. #1098 - Batch `__eq__` now implemented, semantic equality check of batches is now possible. #1098 ### Breaking Changes - The method `to_numpy` in `data.utils.batch.Batch` is not in-place anymore. Instead, a new method `to_numpy_` does the conversion in-place. #1098
This commit is contained in:
		
							parent
							
								
									049907d9ab
								
							
						
					
					
						commit
						ca4f74f40e
					
				@ -8,6 +8,9 @@
 | 
			
		||||
- Trainers can control whether collectors should be reset prior to training. #1063
 | 
			
		||||
- Convenience constructor for `CollectStats` called `with_autogenerated_stats`. #1063
 | 
			
		||||
- `SamplingConfig` supports `batch_size=None`. #1077
 | 
			
		||||
- Batch received new method: `to_numpy_`. #1098
 | 
			
		||||
- `to_dict` in Batch supports also non-recursive conversion. #1098
 | 
			
		||||
- Batch __eq__ now implemented, semantic equality check of batches is now possible. #1098
 | 
			
		||||
 | 
			
		||||
### Internal Improvements
 | 
			
		||||
- `Collector`s rely less on state, the few stateful things are stored explicitly instead of through a `.data` attribute. #1063
 | 
			
		||||
@ -34,6 +37,7 @@ expicitly or pass `reset_before_collect=True` . #1063
 | 
			
		||||
- Changed interface of `dist_fn` in `PGPolicy` and all subclasses to take a single argument in both
 | 
			
		||||
continuous and discrete cases. #1032
 | 
			
		||||
- `utils.net.common.Recurrent` now receives and returns a `RecurrentStateBatch` instead of a dict. #1077
 | 
			
		||||
- The method `to_numpy` in `data.utils.batch.Batch` is not in-place anymore. Instead, a new method `to_numpy_` does the conversion in-place. #1098
 | 
			
		||||
 | 
			
		||||
### Tests
 | 
			
		||||
- Fixed env seeding it test_sac_with_il.py so that the test doesn't fail randomly. #1081
 | 
			
		||||
 | 
			
		||||
@ -485,8 +485,8 @@ Miscellaneous Notes
 | 
			
		||||
    tensor([[0., 0., 0., 0.],
 | 
			
		||||
            [0., 0., 0., 0.],
 | 
			
		||||
            [0., 0., 0., 0.]])
 | 
			
		||||
    >>> # data.to_numpy is also available
 | 
			
		||||
    >>> data.to_numpy()
 | 
			
		||||
    >>> # data.to_numpy_ is also available
 | 
			
		||||
    >>> data.to_numpy_()
 | 
			
		||||
 | 
			
		||||
.. raw:: html
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -331,7 +331,7 @@
 | 
			
		||||
   },
 | 
			
		||||
   "outputs": [],
 | 
			
		||||
   "source": [
 | 
			
		||||
    "batch_cat.to_numpy()\n",
 | 
			
		||||
    "batch_cat.to_numpy_()\n",
 | 
			
		||||
    "print(batch_cat)\n",
 | 
			
		||||
    "batch_cat.to_torch()\n",
 | 
			
		||||
    "print(batch_cat)"
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										35
									
								
								poetry.lock
									
									
									
										generated
									
									
									
								
							
							
						
						
									
										35
									
								
								poetry.lock
									
									
									
										generated
									
									
									
								
							@ -903,6 +903,24 @@ files = [
 | 
			
		||||
    {file = "decorator-5.1.1.tar.gz", hash = "sha256:637996211036b6385ef91435e4fae22989472f9d571faba8927ba8253acbc330"},
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
[[package]]
 | 
			
		||||
name = "deepdiff"
 | 
			
		||||
version = "7.0.1"
 | 
			
		||||
description = "Deep Difference and Search of any Python object/data. Recreate objects by adding adding deltas to each other."
 | 
			
		||||
optional = false
 | 
			
		||||
python-versions = ">=3.8"
 | 
			
		||||
files = [
 | 
			
		||||
    {file = "deepdiff-7.0.1-py3-none-any.whl", hash = "sha256:447760081918216aa4fd4ca78a4b6a848b81307b2ea94c810255334b759e1dc3"},
 | 
			
		||||
    {file = "deepdiff-7.0.1.tar.gz", hash = "sha256:260c16f052d4badbf60351b4f77e8390bee03a0b516246f6839bc813fb429ddf"},
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
[package.dependencies]
 | 
			
		||||
ordered-set = ">=4.1.0,<4.2.0"
 | 
			
		||||
 | 
			
		||||
[package.extras]
 | 
			
		||||
cli = ["click (==8.1.7)", "pyyaml (==6.0.1)"]
 | 
			
		||||
optimize = ["orjson"]
 | 
			
		||||
 | 
			
		||||
[[package]]
 | 
			
		||||
name = "defusedxml"
 | 
			
		||||
version = "0.7.1"
 | 
			
		||||
@ -3239,7 +3257,6 @@ optional = false
 | 
			
		||||
python-versions = ">=3"
 | 
			
		||||
files = [
 | 
			
		||||
    {file = "nvidia_nvjitlink_cu12-12.3.101-py3-none-manylinux1_x86_64.whl", hash = "sha256:64335a8088e2b9d196ae8665430bc6a2b7e6ef2eb877a9c735c804bd4ff6467c"},
 | 
			
		||||
    {file = "nvidia_nvjitlink_cu12-12.3.101-py3-none-manylinux2014_aarch64.whl", hash = "sha256:211a63e7b30a9d62f1a853e19928fbb1a750e3f17a13a3d1f98ff0ced19478dd"},
 | 
			
		||||
    {file = "nvidia_nvjitlink_cu12-12.3.101-py3-none-win_amd64.whl", hash = "sha256:1b2e317e437433753530792f13eece58f0aec21a2b05903be7bffe58a606cbd1"},
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
@ -3349,6 +3366,20 @@ numpy = ["numpy"]
 | 
			
		||||
test = ["pytest", "pytest-cov", "pytest-xdist"]
 | 
			
		||||
torch = ["torch"]
 | 
			
		||||
 | 
			
		||||
[[package]]
 | 
			
		||||
name = "ordered-set"
 | 
			
		||||
version = "4.1.0"
 | 
			
		||||
description = "An OrderedSet is a custom MutableSet that remembers its order, so that every"
 | 
			
		||||
optional = false
 | 
			
		||||
python-versions = ">=3.7"
 | 
			
		||||
files = [
 | 
			
		||||
    {file = "ordered-set-4.1.0.tar.gz", hash = "sha256:694a8e44c87657c59292ede72891eb91d34131f6531463aab3009191c77364a8"},
 | 
			
		||||
    {file = "ordered_set-4.1.0-py3-none-any.whl", hash = "sha256:046e1132c71fcf3330438a539928932caf51ddbc582496833e23de611de14562"},
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
[package.extras]
 | 
			
		||||
dev = ["black", "mypy", "pytest"]
 | 
			
		||||
 | 
			
		||||
[[package]]
 | 
			
		||||
name = "overrides"
 | 
			
		||||
version = "7.4.0"
 | 
			
		||||
@ -6223,4 +6254,4 @@ vizdoom = ["vizdoom"]
 | 
			
		||||
[metadata]
 | 
			
		||||
lock-version = "2.0"
 | 
			
		||||
python-versions = "^3.11"
 | 
			
		||||
content-hash = "06b9166b2e752fbab564cbc0dbce226844c26dd2b59f9f7e95104570e377c43b"
 | 
			
		||||
content-hash = "a7aa80de549e7af1147d14f9bdd48659b7018732af34022cc734565af1f742e9"
 | 
			
		||||
 | 
			
		||||
@ -25,6 +25,7 @@ exclude = ["test/*", "examples/*", "docs/*"]
 | 
			
		||||
 | 
			
		||||
[tool.poetry.dependencies]
 | 
			
		||||
python = "^3.11"
 | 
			
		||||
deepdiff = "^7.0.1"
 | 
			
		||||
gymnasium = "^0.28.0"
 | 
			
		||||
h5py = "^3.9.0"
 | 
			
		||||
numba = "^0.57.1"
 | 
			
		||||
 | 
			
		||||
@ -2,12 +2,13 @@ import copy
 | 
			
		||||
import pickle
 | 
			
		||||
import sys
 | 
			
		||||
from itertools import starmap
 | 
			
		||||
from typing import cast
 | 
			
		||||
from typing import Any, cast
 | 
			
		||||
 | 
			
		||||
import networkx as nx
 | 
			
		||||
import numpy as np
 | 
			
		||||
import pytest
 | 
			
		||||
import torch
 | 
			
		||||
from deepdiff import DeepDiff
 | 
			
		||||
 | 
			
		||||
from tianshou.data import Batch, to_numpy, to_torch
 | 
			
		||||
 | 
			
		||||
@ -477,7 +478,7 @@ def test_batch_from_to_numpy_without_copy() -> None:
 | 
			
		||||
    a_mem_addr_orig = batch.a.__array_interface__["data"][0]
 | 
			
		||||
    c_mem_addr_orig = batch.b.c.__array_interface__["data"][0]
 | 
			
		||||
    batch.to_torch()
 | 
			
		||||
    batch.to_numpy()
 | 
			
		||||
    batch.to_numpy_()
 | 
			
		||||
    a_mem_addr_new = batch.a.__array_interface__["data"][0]
 | 
			
		||||
    c_mem_addr_new = batch.b.c.__array_interface__["data"][0]
 | 
			
		||||
    assert a_mem_addr_new == a_mem_addr_orig
 | 
			
		||||
@ -565,6 +566,167 @@ def test_batch_standard_compatibility() -> None:
 | 
			
		||||
        Batch()[0]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestBatchEquality:
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def test_keys_different() -> None:
 | 
			
		||||
        batch1 = Batch(a=[1, 2], b=[100, 50])
 | 
			
		||||
        batch2 = Batch(b=[1, 2], c=[100, 50])
 | 
			
		||||
        assert batch1 != batch2
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def test_keys_missing() -> None:
 | 
			
		||||
        batch1 = Batch(a=[1, 2], b=[2, 3, 4])
 | 
			
		||||
        batch2 = Batch(a=[1, 2], b=[2, 3, 4])
 | 
			
		||||
        batch2.pop("b")
 | 
			
		||||
        assert batch1 != batch2
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def test_types_keys_different() -> None:
 | 
			
		||||
        batch1 = Batch(a=[1, 2, 3], b=[4, 5])
 | 
			
		||||
        batch2 = Batch(a=[1, 2, 3], b=Batch(a=[4, 5]))
 | 
			
		||||
        assert batch1 != batch2
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def test_array_types_different() -> None:
 | 
			
		||||
        batch1 = Batch(a=[1, 2, 3], b=np.array([4, 5]))
 | 
			
		||||
        batch2 = Batch(a=[1, 2, 3], b=torch.Tensor([4, 5]))
 | 
			
		||||
        assert batch1 != batch2
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def test_nested_values_different() -> None:
 | 
			
		||||
        batch1 = Batch(a=Batch(a=[1, 2, 3]), b=[4, 5])
 | 
			
		||||
        batch2 = Batch(a=Batch(a=[1, 2, 4]), b=[4, 5])
 | 
			
		||||
        assert batch1 != batch2
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def test_nested_shapes_different() -> None:
 | 
			
		||||
        batch1 = Batch(a=Batch(a=[1, 2, 3]), b=[4, 5])
 | 
			
		||||
        batch2 = Batch(a=Batch(a=[1, 4]), b=[4, 5])
 | 
			
		||||
        assert batch1 != batch2
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def test_slice_equal() -> None:
 | 
			
		||||
        batch1 = Batch(a=[1, 2, 3])
 | 
			
		||||
        assert batch1[:2] == batch1[:2]
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def test_slice_ellipsis_equal() -> None:
 | 
			
		||||
        batch1 = Batch(a=Batch(a=[1, 2, 3]), b=[4, 5], c=[100, 1001, 2000])
 | 
			
		||||
        assert batch1[..., 1:] == batch1[..., 1:]
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def test_empty_batches() -> None:
 | 
			
		||||
        assert Batch() == Batch()
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def test_different_order_keys() -> None:
 | 
			
		||||
        assert Batch(a=1, b=2) == Batch(b=2, a=1)
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def test_tuple_and_list_types() -> None:
 | 
			
		||||
        assert Batch(a=(1, 2)) == Batch(a=[1, 2])
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def test_subbatch_dict_and_batch_types() -> None:
 | 
			
		||||
        assert Batch(a={"x": 1}) == Batch(a=Batch(x=1))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestBatchToDict:
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def test_to_dict_empty_batch_no_recurse() -> None:
 | 
			
		||||
        batch = Batch()
 | 
			
		||||
        expected: dict[Any, Any] = {}
 | 
			
		||||
        assert batch.to_dict() == expected
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def test_to_dict_with_simple_values_recurse() -> None:
 | 
			
		||||
        batch = Batch(a=1, b="two", c=np.array([3, 4]))
 | 
			
		||||
        expected = {"a": np.asanyarray(1), "b": "two", "c": np.array([3, 4])}
 | 
			
		||||
        assert not DeepDiff(batch.to_dict(recursive=True), expected)
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def test_to_dict_simple() -> None:
 | 
			
		||||
        batch = Batch(a=1, b="two")
 | 
			
		||||
        expected = {"a": np.asanyarray(1), "b": "two"}
 | 
			
		||||
        assert batch.to_dict() == expected
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def test_to_dict_nested_batch_no_recurse() -> None:
 | 
			
		||||
        nested_batch = Batch(c=3)
 | 
			
		||||
        batch = Batch(a=1, b=nested_batch)
 | 
			
		||||
        expected = {"a": np.asanyarray(1), "b": nested_batch}
 | 
			
		||||
        assert not DeepDiff(batch.to_dict(recursive=False), expected)
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def test_to_dict_nested_batch_recurse() -> None:
 | 
			
		||||
        nested_batch = Batch(c=3)
 | 
			
		||||
        batch = Batch(a=1, b=nested_batch)
 | 
			
		||||
        expected = {"a": np.asanyarray(1), "b": {"c": np.asanyarray(3)}}
 | 
			
		||||
        assert not DeepDiff(batch.to_dict(recursive=True), expected)
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def test_to_dict_multiple_nested_batch_recurse() -> None:
 | 
			
		||||
        nested_batch = Batch(c=Batch(e=3), d=[100, 200, 300])
 | 
			
		||||
        batch = Batch(a=1, b=nested_batch)
 | 
			
		||||
        expected = {
 | 
			
		||||
            "a": np.asanyarray(1),
 | 
			
		||||
            "b": {"c": {"e": np.asanyarray(3)}, "d": np.array([100, 200, 300])},
 | 
			
		||||
        }
 | 
			
		||||
        assert not DeepDiff(batch.to_dict(recursive=True), expected)
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def test_to_dict_array() -> None:
 | 
			
		||||
        batch = Batch(a=np.array([1, 2, 3]))
 | 
			
		||||
        expected = {"a": np.array([1, 2, 3])}
 | 
			
		||||
        assert not DeepDiff(batch.to_dict(), expected)
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def test_to_dict_nested_batch_with_array() -> None:
 | 
			
		||||
        nested_batch = Batch(c=np.array([4, 5]))
 | 
			
		||||
        batch = Batch(a=1, b=nested_batch)
 | 
			
		||||
        expected = {"a": np.asanyarray(1), "b": {"c": np.array([4, 5])}}
 | 
			
		||||
        assert not DeepDiff(batch.to_dict(recursive=True), expected)
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def test_to_dict_torch_tensor() -> None:
 | 
			
		||||
        t1 = torch.tensor([1.0, 2.0]).detach().cpu().numpy()
 | 
			
		||||
        batch = Batch(a=t1)
 | 
			
		||||
        t2 = torch.tensor([1.0, 2.0]).detach().cpu().numpy()
 | 
			
		||||
        expected = {"a": t2}
 | 
			
		||||
        assert not DeepDiff(batch.to_dict(), expected)
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def test_to_dict_nested_batch_with_torch_tensor() -> None:
 | 
			
		||||
        nested_batch = Batch(c=torch.tensor([4, 5]).detach().cpu().numpy())
 | 
			
		||||
        batch = Batch(a=1, b=nested_batch)
 | 
			
		||||
        expected = {"a": np.asanyarray(1), "b": {"c": torch.tensor([4, 5]).detach().cpu().numpy()}}
 | 
			
		||||
        assert not DeepDiff(batch.to_dict(recursive=True), expected)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestToNumpy:
 | 
			
		||||
    """Tests for `Batch.to_numpy()` and its in-place counterpart `Batch.to_numpy_()` ."""
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def test_to_numpy() -> None:
 | 
			
		||||
        batch = Batch(a=1, b=torch.arange(5), c={"d": torch.tensor([1, 2, 3])})
 | 
			
		||||
        new_batch: Batch = Batch.to_numpy(batch)
 | 
			
		||||
        assert id(batch) != id(new_batch)
 | 
			
		||||
        assert isinstance(batch.b, torch.Tensor)
 | 
			
		||||
        assert isinstance(batch.c.d, torch.Tensor)
 | 
			
		||||
 | 
			
		||||
        assert isinstance(new_batch.b, np.ndarray)
 | 
			
		||||
        assert isinstance(new_batch.c.d, np.ndarray)
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def test_to_numpy_() -> None:
 | 
			
		||||
        batch = Batch(a=1, b=torch.arange(5), c={"d": torch.tensor([1, 2, 3])})
 | 
			
		||||
        id_batch = id(batch)
 | 
			
		||||
        batch.to_numpy_()
 | 
			
		||||
        assert id_batch == id(batch)
 | 
			
		||||
        assert isinstance(batch.b, np.ndarray)
 | 
			
		||||
        assert isinstance(batch.c.d, np.ndarray)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    test_batch()
 | 
			
		||||
    test_batch_over_batch()
 | 
			
		||||
 | 
			
		||||
@ -17,6 +17,7 @@ from typing import (
 | 
			
		||||
 | 
			
		||||
import numpy as np
 | 
			
		||||
import torch
 | 
			
		||||
from deepdiff import DeepDiff
 | 
			
		||||
 | 
			
		||||
_SingleIndexType = slice | int | EllipsisType
 | 
			
		||||
IndexType = np.ndarray | _SingleIndexType | list[_SingleIndexType] | tuple[_SingleIndexType, ...]
 | 
			
		||||
@ -268,7 +269,15 @@ class BatchProtocol(Protocol):
 | 
			
		||||
    def __iter__(self) -> Iterator[Self]:
 | 
			
		||||
        ...
 | 
			
		||||
 | 
			
		||||
    def to_numpy(self) -> None:
 | 
			
		||||
    def __eq__(self, other: Any) -> bool:
 | 
			
		||||
        ...
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def to_numpy(batch: TBatch) -> TBatch:
 | 
			
		||||
        """Change all torch.Tensor to numpy.ndarray and return a new Batch."""
 | 
			
		||||
        ...
 | 
			
		||||
 | 
			
		||||
    def to_numpy_(self) -> None:
 | 
			
		||||
        """Change all torch.Tensor to numpy.ndarray in-place."""
 | 
			
		||||
        ...
 | 
			
		||||
 | 
			
		||||
@ -396,7 +405,7 @@ class BatchProtocol(Protocol):
 | 
			
		||||
        """
 | 
			
		||||
        ...
 | 
			
		||||
 | 
			
		||||
    def to_dict(self) -> dict[str, Any]:
 | 
			
		||||
    def to_dict(self, recurse: bool = True) -> dict[str, Any]:
 | 
			
		||||
        ...
 | 
			
		||||
 | 
			
		||||
    def to_list_of_dicts(self) -> list[dict[str, Any]]:
 | 
			
		||||
@ -433,11 +442,11 @@ class Batch(BatchProtocol):
 | 
			
		||||
            # Feels like kwargs could be just merged into batch_dict in the beginning
 | 
			
		||||
            self.__init__(kwargs, copy=copy)  # type: ignore
 | 
			
		||||
 | 
			
		||||
    def to_dict(self) -> dict[str, Any]:
 | 
			
		||||
    def to_dict(self, recursive: bool = True) -> dict[str, Any]:
 | 
			
		||||
        result = {}
 | 
			
		||||
        for k, v in self.__dict__.items():
 | 
			
		||||
            if isinstance(v, Batch):
 | 
			
		||||
                v = v.to_dict()
 | 
			
		||||
            if recursive and isinstance(v, Batch):
 | 
			
		||||
                v = v.to_dict(recursive=recursive)
 | 
			
		||||
            result[k] = v
 | 
			
		||||
        return result
 | 
			
		||||
 | 
			
		||||
@ -503,6 +512,17 @@ class Batch(BatchProtocol):
 | 
			
		||||
            return new_batch
 | 
			
		||||
        raise IndexError("Cannot access item from empty Batch object.")
 | 
			
		||||
 | 
			
		||||
    def __eq__(self, other: Any) -> bool:
 | 
			
		||||
        if not isinstance(other, self.__class__):
 | 
			
		||||
            return False
 | 
			
		||||
 | 
			
		||||
        this_batch_no_torch_tensor: Batch = Batch.to_numpy(self)
 | 
			
		||||
        other_batch_no_torch_tensor: Batch = Batch.to_numpy(other)
 | 
			
		||||
        this_dict = this_batch_no_torch_tensor.to_dict(recursive=True)
 | 
			
		||||
        other_dict = other_batch_no_torch_tensor.to_dict(recursive=True)
 | 
			
		||||
 | 
			
		||||
        return not DeepDiff(this_dict, other_dict)
 | 
			
		||||
 | 
			
		||||
    def __iter__(self) -> Iterator[Self]:
 | 
			
		||||
        # TODO: empty batch raises an error on len and needs separate treatment, that's probably not a good idea
 | 
			
		||||
        if len(self.__dict__) == 0:
 | 
			
		||||
@ -602,12 +622,24 @@ class Batch(BatchProtocol):
 | 
			
		||||
            self_str = self.__class__.__name__ + "()"
 | 
			
		||||
        return self_str
 | 
			
		||||
 | 
			
		||||
    def to_numpy(self) -> None:
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def to_numpy(batch: TBatch) -> TBatch:
 | 
			
		||||
        batch_dict = deepcopy(batch)
 | 
			
		||||
        for batch_key, obj in batch_dict.items():
 | 
			
		||||
            if isinstance(obj, torch.Tensor):
 | 
			
		||||
                batch_dict.__dict__[batch_key] = obj.detach().cpu().numpy()
 | 
			
		||||
            elif isinstance(obj, Batch):
 | 
			
		||||
                obj = Batch.to_numpy(obj)
 | 
			
		||||
                batch_dict.__dict__[batch_key] = obj
 | 
			
		||||
 | 
			
		||||
        return batch_dict
 | 
			
		||||
 | 
			
		||||
    def to_numpy_(self) -> None:
 | 
			
		||||
        for batch_key, obj in self.items():
 | 
			
		||||
            if isinstance(obj, torch.Tensor):
 | 
			
		||||
                self.__dict__[batch_key] = obj.detach().cpu().numpy()
 | 
			
		||||
            elif isinstance(obj, Batch):
 | 
			
		||||
                obj.to_numpy()
 | 
			
		||||
                obj.to_numpy_()
 | 
			
		||||
 | 
			
		||||
    def to_torch(
 | 
			
		||||
        self,
 | 
			
		||||
 | 
			
		||||
@ -26,7 +26,7 @@ def to_numpy(x: Any) -> Batch | np.ndarray:
 | 
			
		||||
        return np.array(None, dtype=object)
 | 
			
		||||
    if isinstance(x, dict | Batch):
 | 
			
		||||
        x = Batch(x) if isinstance(x, dict) else deepcopy(x)
 | 
			
		||||
        x.to_numpy()
 | 
			
		||||
        x.to_numpy_()
 | 
			
		||||
        return x
 | 
			
		||||
    if isinstance(x, list | tuple):
 | 
			
		||||
        return to_numpy(_parse_value(x))
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user