import copy
import pickle
import sys
from itertools import starmap
from typing import Any, cast

import networkx as nx
import numpy as np
import pytest
import torch
from deepdiff import DeepDiff

from tianshou.data import Batch, to_numpy, to_torch


def test_batch() -> None:
    assert list(Batch()) == []
    assert Batch().is_empty()
    assert not Batch(b={"c": {}}).is_empty()
    assert Batch(b={"c": {}}).is_empty(recurse=True)
    assert not Batch(a=Batch(), b=Batch(c=Batch())).is_empty()
    assert Batch(a=Batch(), b=Batch(c=Batch())).is_empty(recurse=True)
    assert not Batch(d=1).is_empty()
    assert not Batch(a=np.float64(1.0)).is_empty()
    assert len(Batch(a=[1, 2, 3], b={"c": {}})) == 3
    assert not Batch(a=[1, 2, 3]).is_empty()
    b = Batch({"a": [4, 4], "b": [5, 5]}, c=[None, None])
    assert b.c.dtype == object
    b = Batch(d=[None], e=[starmap], f=Batch)
    assert b.d.dtype == b.e.dtype == object
    assert b.f == Batch
    b = Batch()
    b.update()
    assert b.is_empty()
    b.update(c=[3, 5])
    assert np.allclose(b.c, [3, 5])
    # mimic the behavior of dict.update, where kwargs can overwrite keys
    b.update({"a": 2}, a=3)
    assert "a" in b
    assert b.a == 3
    assert b.pop("a") == 3
    assert "a" not in b
    with pytest.raises(AssertionError):
        Batch({1: 2})
    batch = Batch(a=[torch.ones(3), torch.ones(3)])
    assert Batch(a=[np.zeros((2, 3)), np.zeros((3, 3))]).a.dtype == object
    with pytest.raises(TypeError):
        Batch(a=[np.zeros((3, 2)), np.zeros((3, 3))])
    with pytest.raises(TypeError):
        Batch(a=[torch.zeros((2, 3)), torch.zeros((3, 3))])
    with pytest.raises(TypeError):
        Batch(a=[torch.zeros((3, 3)), np.zeros((3, 3))])
    with pytest.raises(TypeError):
        Batch(a=[1, np.zeros((3, 3)), torch.zeros((3, 3))])
    assert torch.allclose(batch.a, torch.ones(2, 3))
    batch.cat_(batch)
    assert torch.allclose(batch.a, torch.ones(4, 3))
    Batch(a=[])
    batch = Batch(obs=[0], np=np.zeros([3, 4]))
    assert batch.obs == batch["obs"]
    batch.obs = [1]
    assert batch.obs == [1]
    batch.cat_(batch)
    assert np.allclose(batch.obs, [1, 1])
    assert batch.np.shape == (6, 4)
    assert np.allclose(batch[0].obs, batch[1].obs)
    batch.obs = np.arange(5)
    for i, b in enumerate(batch.split(1, shuffle=False)):
        if i != 5:
            assert b.obs == batch[i].obs
        else:
            with pytest.raises(AttributeError):
                batch[i].obs  # noqa: B018
            with pytest.raises(AttributeError):
                b.obs  # noqa: B018
    print(batch)
    batch = Batch(a=np.arange(10))
    with pytest.raises(AssertionError):
        list(batch.split(0))
    data = [
        (1, False, [[0], [1], [2], [3], [4], [5], [6], [7], [8], [9]]),
        (1, True, [[0], [1], [2], [3], [4], [5], [6], [7], [8], [9]]),
        (3, False, [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]),
        (3, True, [[0, 1, 2], [3, 4, 5], [6, 7, 8, 9]]),
        (5, False, [[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]),
        (5, True, [[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]),
        (7, False, [[0, 1, 2, 3, 4, 5, 6], [7, 8, 9]]),
        (7, True, [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]),
        (10, False, [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]),
        (10, True, [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]),
        (15, False, [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]),
        (15, True, [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]),
        (100, False, [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]),
        (100, True, [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]),
    ]
    for size, merge_last, result in data:
        bs = list(batch.split(size, shuffle=False, merge_last=merge_last))
        assert [bs[i].a.tolist() for i in range(len(bs))] == result
    batch_dict = {"b": np.array([1.0]), "c": 2.0, "d": torch.Tensor([3.0])}
    batch_item = Batch({"a": [batch_dict]})[0]
    assert isinstance(batch_item.a.b, np.ndarray)
    assert batch_item.a.b == batch_dict["b"]
    assert isinstance(batch_item.a.c, float)
    assert batch_item.a.c == batch_dict["c"]
    assert isinstance(batch_item.a.d, torch.Tensor)
    assert batch_item.a.d == batch_dict["d"]
    batch2 = Batch(a=[{"b": np.float64(1.0), "c": np.zeros(1), "d": Batch(e=np.array(3.0))}])
    assert len(batch2) == 1
    assert Batch().shape == []
    assert Batch(a=1).shape == []
    assert Batch(a={1, 2}).shape == []
    assert batch2.shape[0] == 1
    assert "a" in batch2
    assert all(i in batch2.a for i in "bcd")
    with pytest.raises(IndexError):
        batch2[-2]
    with pytest.raises(IndexError):
        batch2[1]
    assert batch2[0].shape == []
    with pytest.raises(IndexError):
        batch2[0][0]
    with pytest.raises(TypeError):
        len(batch2[0])
    assert isinstance(batch2[0].a.c, np.ndarray)
    assert isinstance(batch2[0].a.b, np.float64)
    assert isinstance(batch2[0].a.d.e, np.float64)
    batch2_from_list = Batch(list(batch2))
    batch2_from_comp = Batch(list(batch2))
    assert batch2_from_list.a.b == batch2.a.b
    assert batch2_from_list.a.c == batch2.a.c
    assert batch2_from_list.a.d.e == batch2.a.d.e
    assert batch2_from_comp.a.b == batch2.a.b
    assert batch2_from_comp.a.c == batch2.a.c
    assert batch2_from_comp.a.d.e == batch2.a.d.e
    for batch_slice in [batch2[slice(0, 1)], batch2[:1], batch2[0:]]:
        assert batch_slice.a.b == batch2.a.b
        assert batch_slice.a.c == batch2.a.c
        assert batch_slice.a.d.e == batch2.a.d.e
    batch2.a.d.f = {}
    batch2_sum = (batch2 + 1.0) * 2  # type: ignore  # __add__ supports Number as input type
    assert batch2_sum.a.b == (batch2.a.b + 1.0) * 2
    assert batch2_sum.a.c == (batch2.a.c + 1.0) * 2
    assert batch2_sum.a.d.e == (batch2.a.d.e + 1.0) * 2
    assert batch2_sum.a.d.f.is_empty()
    with pytest.raises(TypeError):
        batch2 += [1]  # type: ignore  # error is raised explicitly
    batch3 = Batch(a={"c": np.zeros(1), "d": Batch(e=np.array([0.0]), f=np.array([3.0]))})
    batch3.a.d[0] = {"e": 4.0}
    assert batch3.a.d.e[0] == 4.0
    batch3.a.d[0] = Batch(f=5.0)
    assert batch3.a.d.f[0] == 5.0
    with pytest.raises(ValueError):
        batch3.a.d[0] = Batch(f=5.0, g=0.0)
    with pytest.raises(ValueError):
        batch3[0] = Batch(a={"c": 2, "e": 1})
    # auto convert
    batch4 = Batch(a=np.array(["a", "b"]))
    assert batch4.a.dtype == object  # auto convert to object
    batch4.update(a=np.array(["c", "d"]))
    assert list(batch4.a) == ["c", "d"]
    assert batch4.a.dtype == object  # auto convert to object
    batch5 = Batch(a=np.array([{"index": 0}]))
    assert isinstance(batch5.a, Batch)
    assert np.allclose(batch5.a.index, [0])
    # We use setattr b/c the setattr of Batch will actually change the type of the field that is being set!
    # However, mypy would not understand this, and rightly expect that batch.b = some_array would lead to
    # batch.b being an array (which it is not, it's turned into a Batch instead)
    batch5.b = np.array([{"index": 1}])
    batch5.b = cast(Batch, batch5.b)
    assert isinstance(batch5.b, Batch)
    assert np.allclose(batch5.b.index, [1])

    # None is a valid object and can be stored in Batch
    a = Batch.stack([Batch(a=None), Batch(b=None)])
    assert a.a[0] is None
    assert a.a[1] is None
    assert a.b[0] is None
    assert a.b[1] is None

    # nx.Graph corner case
    assert Batch(a=np.array([nx.Graph(), nx.Graph()], dtype=object)).a.dtype == object
    g1 = nx.Graph()
    g1.add_nodes_from(list(range(10)))
    g2 = nx.Graph()
    g2.add_nodes_from(list(range(20)))
    assert Batch(a=np.array([g1, g2], dtype=object)).a.dtype == object


def test_batch_over_batch() -> None:
    batch = Batch(a=[3, 4, 5], b=[4, 5, 6])
    batch2 = Batch({"c": [6, 7, 8], "b": batch})
    batch2.b.b[-1] = 0
    print(batch2)
    for k, v in batch2.items():
        assert np.all(batch2[k] == v)
    assert batch2[-1].b.b == 0
    batch2.cat_(Batch(c=[6, 7, 8], b=batch))
    assert np.allclose(batch2.c, [6, 7, 8, 6, 7, 8])
    assert np.allclose(batch2.b.a, [3, 4, 5, 3, 4, 5])
    assert np.allclose(batch2.b.b, [4, 5, 0, 4, 5, 0])
    batch2.update(batch2.b, six=[6, 6, 6])
    assert np.allclose(batch2.c, [6, 7, 8, 6, 7, 8])
    assert np.allclose(batch2.a, [3, 4, 5, 3, 4, 5])
    assert np.allclose(batch2.b, [4, 5, 0, 4, 5, 0])
    assert np.allclose(batch2.six, [6, 6, 6])
    d = {"a": [3, 4, 5], "b": [4, 5, 6]}
    batch3 = Batch(c=[6, 7, 8], b=d)
    batch3.cat_(Batch(c=[6, 7, 8], b=d))
    assert np.allclose(batch3.c, [6, 7, 8, 6, 7, 8])
    assert np.allclose(batch3.b.a, [3, 4, 5, 3, 4, 5])
    assert np.allclose(batch3.b.b, [4, 5, 6, 4, 5, 6])
    batch4 = Batch(({"a": {"b": np.array([1.0])}},))
    assert batch4.a.b.ndim == 2
    assert batch4.a.b[0, 0] == 1.0
    # advanced slicing
    batch5 = Batch(a=[[1, 2]], b={"c": np.zeros([3, 2, 1])})
    assert batch5.shape == [1, 2]
    with pytest.raises(IndexError):
        batch5[2]
    with pytest.raises(IndexError):
        batch5[:, 3]
    with pytest.raises(IndexError):
        batch5[:, :, -1]
    batch5[:, -1] += np.int_(1)
    assert np.allclose(batch5.a, [1, 3])
    assert np.allclose(batch5.b.c.squeeze(), [[0, 1]] * 3)
    with pytest.raises(ValueError):
        batch5[:, -1] = 1
    batch5[:, 0] = {"a": -1}
    assert np.allclose(batch5.a, [-1, 3])
    assert np.allclose(batch5.b.c.squeeze(), [[0, 1]] * 3)


def test_batch_cat_and_stack() -> None:
    # test cat with compatible keys
    b1 = Batch(a=[{"b": np.float64(1.0), "d": Batch(e=np.array(3.0))}])
    b2 = Batch(a=[{"b": np.float64(4.0), "d": {"e": np.array(6.0)}}])
    b12_cat_out = Batch.cat([b1, b2])
    b12_cat_in = copy.deepcopy(b1)
    b12_cat_in.cat_(b2)
    assert np.all(b12_cat_in.a.d.e == b12_cat_out.a.d.e)
    assert np.all(b12_cat_in.a.d.e == b12_cat_out.a.d.e)
    assert isinstance(b12_cat_in.a.d.e, np.ndarray)
    assert b12_cat_in.a.d.e.ndim == 1

    a = Batch(a=Batch(a=np.random.randn(3, 4)))
    assert np.allclose(
        np.concatenate([a.a.a, a.a.a]),
        Batch.cat([a, Batch(a=Batch(a=Batch())), a]).a.a,
    )

    # test cat with lens infer
    a = Batch(a=Batch(a=np.random.randn(3, 4)), b=np.random.randn(3, 4))
    b = Batch(a=Batch(a=Batch(), t=Batch()), b=np.random.randn(3, 4))
    ans = Batch.cat([a, b, a])
    assert np.allclose(ans.a.a, np.concatenate([a.a.a, np.zeros((3, 4)), a.a.a]))
    assert np.allclose(ans.b, np.concatenate([a.b, b.b, a.b]))
    assert ans.a.t.is_empty()

    b1.stack_([b2])
    assert isinstance(b1.a.d.e, np.ndarray)
    assert b1.a.d.e.ndim == 2

    # test cat with incompatible keys
    b1 = Batch(a=np.random.rand(3, 4), common=Batch(c=np.random.rand(3, 5)))
    b2 = Batch(b=torch.rand(4, 3), common=Batch(c=np.random.rand(4, 5)))
    test = Batch.cat([b1, b2])
    ans = Batch(
        a=np.concatenate([b1.a, np.zeros((4, 4))]),
        b=torch.cat([torch.zeros(3, 3), b2.b]),
        common=Batch(c=np.concatenate([b1.common.c, b2.common.c])),
    )
    assert np.allclose(test.a, ans.a)
    assert torch.allclose(test.b, ans.b)
    assert np.allclose(test.common.c, ans.common.c)

    # test cat with reserved keys (values are Batch())
    b1 = Batch(a=np.random.rand(3, 4), common=Batch(c=np.random.rand(3, 5)))
    b2 = Batch(a=Batch(), b=torch.rand(4, 3), common=Batch(c=np.random.rand(4, 5)))
    test = Batch.cat([b1, b2])
    ans = Batch(
        a=np.concatenate([b1.a, np.zeros((4, 4))]),
        b=torch.cat([torch.zeros(3, 3), b2.b]),
        common=Batch(c=np.concatenate([b1.common.c, b2.common.c])),
    )
    assert np.allclose(test.a, ans.a)
    assert torch.allclose(test.b, ans.b)
    assert np.allclose(test.common.c, ans.common.c)

    # test cat with all reserved keys (values are Batch())
    b1 = Batch(a=Batch(), common=Batch(c=np.random.rand(3, 5)))
    b2 = Batch(a=Batch(), b=torch.rand(4, 3), common=Batch(c=np.random.rand(4, 5)))
    test = Batch.cat([b1, b2])
    ans = Batch(
        a=Batch(),
        b=torch.cat([torch.zeros(3, 3), b2.b]),
        common=Batch(c=np.concatenate([b1.common.c, b2.common.c])),
    )
    assert ans.a.is_empty()
    assert torch.allclose(test.b, ans.b)
    assert np.allclose(test.common.c, ans.common.c)

    # test stack with compatible keys
    b3 = Batch(a=np.zeros((3, 4)), b=torch.ones((2, 5)), c=Batch(d=[[1], [2]]))
    b4 = Batch(a=np.ones((3, 4)), b=torch.ones((2, 5)), c=Batch(d=[[0], [3]]))
    b34_stack = Batch.stack((b3, b4), axis=1)
    assert np.all(b34_stack.a == np.stack((b3.a, b4.a), axis=1))
    assert np.all(b34_stack.c.d == list(map(list, zip(b3.c.d, b4.c.d, strict=True))))
    b5_dict = np.array([{"a": False, "b": {"c": 2.0, "d": 1.0}}, {"a": True, "b": {"c": 3.0}}])
    b5 = Batch(b5_dict)
    assert b5.a[0] == np.array(False)
    assert b5.a[1] == np.array(True)
    assert np.all(b5.b.c == np.stack([e["b"]["c"] for e in b5_dict], axis=0))
    assert b5.b.d[0] == b5_dict[0]["b"]["d"]
    assert b5.b.d[1] == 0.0

    # test stack with incompatible keys
    a = Batch(a=1, b=2, c=3)
    b = Batch(a=4, b=5, d=6)
    c = Batch(c=7, b=6, d=9)
    d = Batch.stack([a, b, c])
    assert np.allclose(d.a, [1, 4, 0])
    assert np.allclose(d.b, [2, 5, 6])
    assert np.allclose(d.c, [3, 0, 7])
    assert np.allclose(d.d, [0, 6, 9])

    # test stack with empty Batch()
    assert Batch.stack([Batch(), Batch(), Batch()]).is_empty()
    a = Batch(a=1, b=2, c=3, d=Batch(), e=Batch())
    b = Batch(a=4, b=5, d=6, e=Batch())
    c = Batch(c=7, b=6, d=9, e=Batch())
    d = Batch.stack([a, b, c])
    assert np.allclose(d.a, [1, 4, 0])
    assert np.allclose(d.b, [2, 5, 6])
    assert np.allclose(d.c, [3, 0, 7])
    assert np.allclose(d.d, [0, 6, 9])
    assert d.e.is_empty()
    b1 = Batch(a=Batch(), common=Batch(c=np.random.rand(4, 5)))
    b2 = Batch(b=Batch(), common=Batch(c=np.random.rand(4, 5)))
    test = Batch.stack([b1, b2], axis=-1)
    assert test.a.is_empty()
    assert test.b.is_empty()
    assert np.allclose(test.common.c, np.stack([b1.common.c, b2.common.c], axis=-1))

    b1 = Batch(a=np.random.rand(4, 4), common=Batch(c=np.random.rand(4, 5)))
    b2 = Batch(b=torch.rand(4, 6), common=Batch(c=np.random.rand(4, 5)))
    test = Batch.stack([b1, b2])
    ans = Batch(
        a=np.stack([b1.a, np.zeros((4, 4))]),
        b=torch.stack([torch.zeros(4, 6), b2.b]),
        common=Batch(c=np.stack([b1.common.c, b2.common.c])),
    )
    assert np.allclose(test.a, ans.a)
    assert torch.allclose(test.b, ans.b)
    assert np.allclose(test.common.c, ans.common.c)

    # test with illegal input format
    with pytest.raises(ValueError):
        Batch.cat([[Batch(a=1)], [Batch(a=1)]])  # type: ignore  # cat() tested with invalid inp
    with pytest.raises(ValueError):
        Batch.stack([[Batch(a=1)], [Batch(a=1)]])  # type: ignore # stack() tested with invalid inp

    # exceptions
    batch_cat: Batch = Batch.cat([])
    assert batch_cat.is_empty()
    batch_stack: Batch = Batch.stack([])
    assert batch_stack.is_empty()
    b1 = Batch(e=[4, 5], d=6)
    b2 = Batch(e=[4, 6])
    with pytest.raises(ValueError):
        Batch.cat([b1, b2])
    with pytest.raises(ValueError):
        Batch.stack([b1, b2], axis=1)


def test_batch_over_batch_to_torch() -> None:
    batch = Batch(
        a=np.float64(1.0),
        b=Batch(c=np.ones((1,), dtype=np.float32), d=torch.ones((1,), dtype=torch.float64)),
    )
    batch.b.__dict__["e"] = 1  # bypass the check
    batch.to_torch()
    assert isinstance(batch.a, torch.Tensor)
    assert isinstance(batch.b.c, torch.Tensor)
    assert isinstance(batch.b.d, torch.Tensor)
    assert isinstance(batch.b.e, torch.Tensor)
    assert batch.a.dtype == torch.float64
    assert batch.b.c.dtype == torch.float32
    assert batch.b.d.dtype == torch.float64
    if sys.platform in ["win32", "cygwin"]:  # windows
        assert batch.b.e.dtype == torch.int32
    else:
        assert batch.b.e.dtype == torch.int64
    batch.to_torch(dtype=torch.float32)
    assert batch.a.dtype == torch.float32
    assert batch.b.c.dtype == torch.float32
    assert batch.b.d.dtype == torch.float32
    assert batch.b.e.dtype == torch.float32


def test_utils_to_torch_numpy() -> None:
    batch = Batch(
        a=np.float64(1.0),
        b=Batch(c=np.ones((1,), dtype=np.float32), d=torch.ones((1,), dtype=torch.float64)),
    )
    a_torch_float = to_torch(batch.a, dtype=torch.float32)
    assert a_torch_float.dtype == torch.float32
    a_torch_double = to_torch(batch.a, dtype=torch.float64)
    assert a_torch_double.dtype == torch.float64
    batch_torch_float = to_torch(batch, dtype=torch.float32)
    assert batch_torch_float.a.dtype == torch.float32
    assert batch_torch_float.b.c.dtype == torch.float32
    assert batch_torch_float.b.d.dtype == torch.float32
    data_list = [float("nan"), 1]
    data_list_torch = to_torch(data_list)
    assert data_list_torch.dtype == torch.float64
    data_list_2 = [np.random.rand(3, 3), np.random.rand(3, 3)]
    data_list_2_torch = to_torch(data_list_2)
    assert data_list_2_torch.shape == (2, 3, 3)
    assert np.allclose(to_numpy(to_torch(data_list_2)), data_list_2)
    data_list_3 = [np.zeros((3, 2)), np.zeros((3, 3))]
    data_list_3_torch = [torch.zeros((3, 2)), torch.zeros((3, 3))]
    with pytest.raises(TypeError):
        to_torch(data_list_3)
    with pytest.raises(TypeError):
        to_numpy(data_list_3_torch)
    data_list_4 = [np.zeros((2, 3)), np.zeros((3, 3))]
    data_list_4_torch = [torch.zeros((2, 3)), torch.zeros((3, 3))]
    with pytest.raises(TypeError):
        to_torch(data_list_4)
    with pytest.raises(TypeError):
        to_numpy(data_list_4_torch)
    data_list_5 = [np.zeros(2), np.zeros((3, 3))]
    data_list_5_torch = [torch.zeros(2), torch.zeros((3, 3))]
    with pytest.raises(TypeError):
        to_torch(data_list_5)
    with pytest.raises(TypeError):
        to_numpy(data_list_5_torch)
    data_array = np.random.rand(3, 2, 2)
    data_empty_tensor = to_torch(data_array[[]])
    assert isinstance(data_empty_tensor, torch.Tensor)
    assert data_empty_tensor.shape == (0, 2, 2)
    data_empty_array = to_numpy(data_empty_tensor)
    assert isinstance(data_empty_array, np.ndarray)
    assert data_empty_array.shape == (0, 2, 2)
    assert np.allclose(to_numpy(to_torch(data_array)), data_array)
    # additional test for to_numpy, for code-coverage
    assert isinstance(to_numpy(1), np.ndarray)
    assert isinstance(to_numpy(1.0), np.ndarray)
    assert isinstance(to_numpy({"a": torch.tensor(1)})["a"], np.ndarray)
    assert isinstance(to_numpy(Batch(a=torch.tensor(1))).a, np.ndarray)
    assert to_numpy(None).item() is None
    assert to_numpy(to_numpy).item() == to_numpy
    # additional test for to_torch, for code-coverage
    assert isinstance(to_torch(1), torch.Tensor)
    if sys.platform in ["win32", "cygwin"]:  # windows
        assert to_torch(1).dtype == torch.int32
    else:
        assert to_torch(1).dtype == torch.int64
    assert to_torch(1.0).dtype == torch.float64
    assert isinstance(to_torch({"a": [1]})["a"], torch.Tensor)
    with pytest.raises(TypeError):
        to_torch(None)
    with pytest.raises(TypeError):
        to_torch(np.array([{}, "2"]))


def test_batch_pickle() -> None:
    batch = Batch(obs=Batch(a=0.0, c=torch.Tensor([1.0, 2.0])), np=np.zeros([3, 4]))
    batch_pk = pickle.loads(pickle.dumps(batch))
    assert batch.obs.a == batch_pk.obs.a
    assert torch.all(batch.obs.c == batch_pk.obs.c)
    assert np.all(batch.np == batch_pk.np)


def test_batch_from_to_numpy_without_copy() -> None:
    batch = Batch(a=np.ones((1,)), b=Batch(c=np.ones((1,))))
    a_mem_addr_orig = batch.a.__array_interface__["data"][0]
    c_mem_addr_orig = batch.b.c.__array_interface__["data"][0]
    batch.to_torch()
    batch.to_numpy_()
    a_mem_addr_new = batch.a.__array_interface__["data"][0]
    c_mem_addr_new = batch.b.c.__array_interface__["data"][0]
    assert a_mem_addr_new == a_mem_addr_orig
    assert c_mem_addr_new == c_mem_addr_orig


def test_batch_copy() -> None:
    batch = Batch(a=np.array([3, 4, 5]), b=np.array([4, 5, 6]))
    batch2 = Batch({"c": np.array([6, 7, 8]), "b": batch})
    orig_c_addr = batch2.c.__array_interface__["data"][0]
    orig_b_a_addr = batch2.b.a.__array_interface__["data"][0]
    orig_b_b_addr = batch2.b.b.__array_interface__["data"][0]
    # test with copy=False
    batch3 = Batch(copy=False, **batch2)
    curr_c_addr = batch3.c.__array_interface__["data"][0]
    curr_b_a_addr = batch3.b.a.__array_interface__["data"][0]
    curr_b_b_addr = batch3.b.b.__array_interface__["data"][0]
    assert batch2.c is batch3.c
    assert batch2.b is batch3.b
    assert batch2.b.a is batch3.b.a
    assert batch2.b.b is batch3.b.b
    assert orig_c_addr == curr_c_addr
    assert orig_b_a_addr == curr_b_a_addr
    assert orig_b_b_addr == curr_b_b_addr
    # test with copy=True
    batch3 = Batch(copy=True, **batch2)
    curr_c_addr = batch3.c.__array_interface__["data"][0]
    curr_b_a_addr = batch3.b.a.__array_interface__["data"][0]
    curr_b_b_addr = batch3.b.b.__array_interface__["data"][0]
    assert batch2.c is not batch3.c
    assert batch2.b is not batch3.b
    assert batch2.b.a is not batch3.b.a
    assert batch2.b.b is not batch3.b.b
    assert orig_c_addr != curr_c_addr
    assert orig_b_a_addr != curr_b_a_addr
    assert orig_b_b_addr != curr_b_b_addr


def test_batch_empty() -> None:
    b5_dict = np.array([{"a": False, "b": {"c": 2.0, "d": 1.0}}, {"a": True, "b": {"c": 3.0}}])
    b5 = Batch(b5_dict)
    b5[1] = Batch.empty(b5[0])
    assert np.allclose(b5.a, [False, False])
    assert np.allclose(b5.b.c, [2, 0])
    assert np.allclose(b5.b.d, [1, 0])
    data = Batch(
        a=[False, True],
        b={
            "c": np.array([2.0, "st"], dtype=object),
            "d": [1, None],
            "e": [2.0, float("nan")],
        },
        c=np.array([1, 3, 4], dtype=int),
        t=torch.tensor([4, 5, 6, 7.0]),
    )
    data[-1] = Batch.empty(data[1])
    assert np.allclose(data.c, [1, 3, 0])
    assert np.allclose(data.a, [False, False])
    assert list(data.b.c) == [2.0, None]
    assert list(data.b.d) == [1, None]
    assert np.allclose(data.b.e, [2, 0])
    assert torch.allclose(data.t, torch.tensor([4, 5, 6, 0.0]))
    data[0].empty_()  # which will fail in a, b.c, b.d, b.e, c
    assert torch.allclose(data.t, torch.tensor([0.0, 5, 6, 0]))
    data.empty_(index=0)
    assert np.allclose(data.c, [0, 3, 0])
    assert list(data.b.c) == [None, None]
    assert list(data.b.d) == [None, None]
    assert list(data.b.e) == [0, 0]
    b0 = Batch()
    b0.empty_()
    assert b0.shape == []


def test_batch_standard_compatibility() -> None:
    batch = Batch(a=np.array([[1.0, 2.0], [3.0, 4.0]]), b=Batch(), c=np.array([5.0, 6.0]))
    batch_mean = np.mean(batch)
    assert isinstance(batch_mean, Batch)  # type: ignore  # mypy doesn't know but it works, cf. `batch.rst`
    assert sorted(batch_mean.get_keys()) == ["a", "b", "c"]  # type: ignore
    with pytest.raises(TypeError):
        len(batch_mean)
    assert np.all(batch_mean.a == np.mean(batch.a, axis=0))
    assert batch_mean.c == np.mean(batch.c, axis=0)
    with pytest.raises(IndexError):
        Batch()[0]


class TestBatchEquality:
    @staticmethod
    def test_keys_different() -> None:
        batch1 = Batch(a=[1, 2], b=[100, 50])
        batch2 = Batch(b=[1, 2], c=[100, 50])
        assert batch1 != batch2

    @staticmethod
    def test_keys_missing() -> None:
        batch1 = Batch(a=[1, 2], b=[2, 3, 4])
        batch2 = Batch(a=[1, 2], b=[2, 3, 4])
        batch2.pop("b")
        assert batch1 != batch2

    @staticmethod
    def test_types_keys_different() -> None:
        batch1 = Batch(a=[1, 2, 3], b=[4, 5])
        batch2 = Batch(a=[1, 2, 3], b=Batch(a=[4, 5]))
        assert batch1 != batch2

    @staticmethod
    def test_array_types_different() -> None:
        batch1 = Batch(a=[1, 2, 3], b=np.array([4, 5]))
        batch2 = Batch(a=[1, 2, 3], b=torch.Tensor([4, 5]))
        assert batch1 != batch2

    @staticmethod
    def test_nested_values_different() -> None:
        batch1 = Batch(a=Batch(a=[1, 2, 3]), b=[4, 5])
        batch2 = Batch(a=Batch(a=[1, 2, 4]), b=[4, 5])
        assert batch1 != batch2

    @staticmethod
    def test_nested_shapes_different() -> None:
        batch1 = Batch(a=Batch(a=[1, 2, 3]), b=[4, 5])
        batch2 = Batch(a=Batch(a=[1, 4]), b=[4, 5])
        assert batch1 != batch2

    @staticmethod
    def test_slice_equal() -> None:
        batch1 = Batch(a=[1, 2, 3])
        assert batch1[:2] == batch1[:2]

    @staticmethod
    def test_slice_ellipsis_equal() -> None:
        batch1 = Batch(a=Batch(a=[1, 2, 3]), b=[4, 5], c=[100, 1001, 2000])
        assert batch1[..., 1:] == batch1[..., 1:]

    @staticmethod
    def test_empty_batches() -> None:
        assert Batch() == Batch()

    @staticmethod
    def test_different_order_keys() -> None:
        assert Batch(a=1, b=2) == Batch(b=2, a=1)

    @staticmethod
    def test_tuple_and_list_types() -> None:
        assert Batch(a=(1, 2)) == Batch(a=[1, 2])

    @staticmethod
    def test_subbatch_dict_and_batch_types() -> None:
        assert Batch(a={"x": 1}) == Batch(a=Batch(x=1))


class TestBatchToDict:
    @staticmethod
    def test_to_dict_empty_batch_no_recurse() -> None:
        batch = Batch()
        expected: dict[Any, Any] = {}
        assert batch.to_dict() == expected

    @staticmethod
    def test_to_dict_with_simple_values_recurse() -> None:
        batch = Batch(a=1, b="two", c=np.array([3, 4]))
        expected = {"a": np.asanyarray(1), "b": "two", "c": np.array([3, 4])}
        assert not DeepDiff(batch.to_dict(recursive=True), expected)

    @staticmethod
    def test_to_dict_simple() -> None:
        batch = Batch(a=1, b="two")
        expected = {"a": np.asanyarray(1), "b": "two"}
        assert batch.to_dict() == expected

    @staticmethod
    def test_to_dict_nested_batch_no_recurse() -> None:
        nested_batch = Batch(c=3)
        batch = Batch(a=1, b=nested_batch)
        expected = {"a": np.asanyarray(1), "b": nested_batch}
        assert not DeepDiff(batch.to_dict(recursive=False), expected)

    @staticmethod
    def test_to_dict_nested_batch_recurse() -> None:
        nested_batch = Batch(c=3)
        batch = Batch(a=1, b=nested_batch)
        expected = {"a": np.asanyarray(1), "b": {"c": np.asanyarray(3)}}
        assert not DeepDiff(batch.to_dict(recursive=True), expected)

    @staticmethod
    def test_to_dict_multiple_nested_batch_recurse() -> None:
        nested_batch = Batch(c=Batch(e=3), d=[100, 200, 300])
        batch = Batch(a=1, b=nested_batch)
        expected = {
            "a": np.asanyarray(1),
            "b": {"c": {"e": np.asanyarray(3)}, "d": np.array([100, 200, 300])},
        }
        assert not DeepDiff(batch.to_dict(recursive=True), expected)

    @staticmethod
    def test_to_dict_array() -> None:
        batch = Batch(a=np.array([1, 2, 3]))
        expected = {"a": np.array([1, 2, 3])}
        assert not DeepDiff(batch.to_dict(), expected)

    @staticmethod
    def test_to_dict_nested_batch_with_array() -> None:
        nested_batch = Batch(c=np.array([4, 5]))
        batch = Batch(a=1, b=nested_batch)
        expected = {"a": np.asanyarray(1), "b": {"c": np.array([4, 5])}}
        assert not DeepDiff(batch.to_dict(recursive=True), expected)

    @staticmethod
    def test_to_dict_torch_tensor() -> None:
        t1 = torch.tensor([1.0, 2.0]).detach().cpu().numpy()
        batch = Batch(a=t1)
        t2 = torch.tensor([1.0, 2.0]).detach().cpu().numpy()
        expected = {"a": t2}
        assert not DeepDiff(batch.to_dict(), expected)

    @staticmethod
    def test_to_dict_nested_batch_with_torch_tensor() -> None:
        nested_batch = Batch(c=torch.tensor([4, 5]).detach().cpu().numpy())
        batch = Batch(a=1, b=nested_batch)
        expected = {"a": np.asanyarray(1), "b": {"c": torch.tensor([4, 5]).detach().cpu().numpy()}}
        assert not DeepDiff(batch.to_dict(recursive=True), expected)


class TestToNumpy:
    """Tests for `Batch.to_numpy()` and its in-place counterpart `Batch.to_numpy_()` ."""

    @staticmethod
    def test_to_numpy() -> None:
        batch = Batch(a=1, b=torch.arange(5), c={"d": torch.tensor([1, 2, 3])})
        new_batch: Batch = Batch.to_numpy(batch)
        assert id(batch) != id(new_batch)
        assert isinstance(batch.b, torch.Tensor)
        assert isinstance(batch.c.d, torch.Tensor)

        assert isinstance(new_batch.b, np.ndarray)
        assert isinstance(new_batch.c.d, np.ndarray)

    @staticmethod
    def test_to_numpy_() -> None:
        batch = Batch(a=1, b=torch.arange(5), c={"d": torch.tensor([1, 2, 3])})
        id_batch = id(batch)
        batch.to_numpy_()
        assert id_batch == id(batch)
        assert isinstance(batch.b, np.ndarray)
        assert isinstance(batch.c.d, np.ndarray)


if __name__ == "__main__":
    test_batch()
    test_batch_over_batch()
    test_batch_over_batch_to_torch()
    test_utils_to_torch_numpy()
    test_batch_pickle()
    test_batch_from_to_numpy_without_copy()
    test_batch_standard_compatibility()
    test_batch_cat_and_stack()
    test_batch_copy()
    test_batch_empty()