Tianshou/test/base/test_batch.py
Michael Panchenko 12d4262f80 Tests: removed all instances of if __name__ == ... in tests
A test is not a script and should not be used as such

Also marked pistonball test as skipped since it doesn't actually test anything
2024-04-26 17:39:30 +02:00

752 lines
29 KiB
Python

import copy
import pickle
import sys
from itertools import starmap
from typing import Any, cast
import networkx as nx
import numpy as np
import pytest
import torch
from deepdiff import DeepDiff
from tianshou.data import Batch, to_numpy, to_torch
def test_batch() -> None:
assert list(Batch()) == []
assert Batch().is_empty()
assert not Batch(b={"c": {}}).is_empty()
assert Batch(b={"c": {}}).is_empty(recurse=True)
assert not Batch(a=Batch(), b=Batch(c=Batch())).is_empty()
assert Batch(a=Batch(), b=Batch(c=Batch())).is_empty(recurse=True)
assert not Batch(d=1).is_empty()
assert not Batch(a=np.float64(1.0)).is_empty()
assert len(Batch(a=[1, 2, 3], b={"c": {}})) == 3
assert not Batch(a=[1, 2, 3]).is_empty()
b = Batch({"a": [4, 4], "b": [5, 5]}, c=[None, None])
assert b.c.dtype == object
b = Batch(d=[None], e=[starmap], f=Batch)
assert b.d.dtype == b.e.dtype == object
assert b.f == Batch
b = Batch()
b.update()
assert b.is_empty()
b.update(c=[3, 5])
assert np.allclose(b.c, [3, 5])
# mimic the behavior of dict.update, where kwargs can overwrite keys
b.update({"a": 2}, a=3)
assert "a" in b
assert b.a == 3
assert b.pop("a") == 3
assert "a" not in b
with pytest.raises(AssertionError):
Batch({1: 2})
batch = Batch(a=[torch.ones(3), torch.ones(3)])
assert Batch(a=[np.zeros((2, 3)), np.zeros((3, 3))]).a.dtype == object
with pytest.raises(TypeError):
Batch(a=[np.zeros((3, 2)), np.zeros((3, 3))])
with pytest.raises(TypeError):
Batch(a=[torch.zeros((2, 3)), torch.zeros((3, 3))])
with pytest.raises(TypeError):
Batch(a=[torch.zeros((3, 3)), np.zeros((3, 3))])
with pytest.raises(TypeError):
Batch(a=[1, np.zeros((3, 3)), torch.zeros((3, 3))])
assert torch.allclose(batch.a, torch.ones(2, 3))
batch.cat_(batch)
assert torch.allclose(batch.a, torch.ones(4, 3))
Batch(a=[])
batch = Batch(obs=[0], np=np.zeros([3, 4]))
assert batch.obs == batch["obs"]
batch.obs = [1]
assert batch.obs == [1]
batch.cat_(batch)
assert np.allclose(batch.obs, [1, 1])
assert batch.np.shape == (6, 4)
assert np.allclose(batch[0].obs, batch[1].obs)
batch.obs = np.arange(5)
for i, b in enumerate(batch.split(1, shuffle=False)):
if i != 5:
assert b.obs == batch[i].obs
else:
with pytest.raises(AttributeError):
batch[i].obs # noqa: B018
with pytest.raises(AttributeError):
b.obs # noqa: B018
print(batch)
batch = Batch(a=np.arange(10))
with pytest.raises(AssertionError):
list(batch.split(0))
data = [
(1, False, [[0], [1], [2], [3], [4], [5], [6], [7], [8], [9]]),
(1, True, [[0], [1], [2], [3], [4], [5], [6], [7], [8], [9]]),
(3, False, [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]),
(3, True, [[0, 1, 2], [3, 4, 5], [6, 7, 8, 9]]),
(5, False, [[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]),
(5, True, [[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]),
(7, False, [[0, 1, 2, 3, 4, 5, 6], [7, 8, 9]]),
(7, True, [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]),
(10, False, [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]),
(10, True, [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]),
(15, False, [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]),
(15, True, [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]),
(100, False, [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]),
(100, True, [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]),
]
for size, merge_last, result in data:
bs = list(batch.split(size, shuffle=False, merge_last=merge_last))
assert [bs[i].a.tolist() for i in range(len(bs))] == result
batch_dict = {"b": np.array([1.0]), "c": 2.0, "d": torch.Tensor([3.0])}
batch_item = Batch({"a": [batch_dict]})[0]
assert isinstance(batch_item.a.b, np.ndarray)
assert batch_item.a.b == batch_dict["b"]
assert isinstance(batch_item.a.c, float)
assert batch_item.a.c == batch_dict["c"]
assert isinstance(batch_item.a.d, torch.Tensor)
assert batch_item.a.d == batch_dict["d"]
batch2 = Batch(a=[{"b": np.float64(1.0), "c": np.zeros(1), "d": Batch(e=np.array(3.0))}])
assert len(batch2) == 1
assert Batch().shape == []
assert Batch(a=1).shape == []
assert Batch(a={1, 2}).shape == []
assert batch2.shape[0] == 1
assert "a" in batch2
assert all(i in batch2.a for i in "bcd")
with pytest.raises(IndexError):
batch2[-2]
with pytest.raises(IndexError):
batch2[1]
assert batch2[0].shape == []
with pytest.raises(IndexError):
batch2[0][0]
with pytest.raises(TypeError):
len(batch2[0])
assert isinstance(batch2[0].a.c, np.ndarray)
assert isinstance(batch2[0].a.b, np.float64)
assert isinstance(batch2[0].a.d.e, np.float64)
batch2_from_list = Batch(list(batch2))
batch2_from_comp = Batch(list(batch2))
assert batch2_from_list.a.b == batch2.a.b
assert batch2_from_list.a.c == batch2.a.c
assert batch2_from_list.a.d.e == batch2.a.d.e
assert batch2_from_comp.a.b == batch2.a.b
assert batch2_from_comp.a.c == batch2.a.c
assert batch2_from_comp.a.d.e == batch2.a.d.e
for batch_slice in [batch2[slice(0, 1)], batch2[:1], batch2[0:]]:
assert batch_slice.a.b == batch2.a.b
assert batch_slice.a.c == batch2.a.c
assert batch_slice.a.d.e == batch2.a.d.e
batch2.a.d.f = {}
batch2_sum = (batch2 + 1.0) * 2 # type: ignore # __add__ supports Number as input type
assert batch2_sum.a.b == (batch2.a.b + 1.0) * 2
assert batch2_sum.a.c == (batch2.a.c + 1.0) * 2
assert batch2_sum.a.d.e == (batch2.a.d.e + 1.0) * 2
assert batch2_sum.a.d.f.is_empty()
with pytest.raises(TypeError):
batch2 += [1] # type: ignore # error is raised explicitly
batch3 = Batch(a={"c": np.zeros(1), "d": Batch(e=np.array([0.0]), f=np.array([3.0]))})
batch3.a.d[0] = {"e": 4.0}
assert batch3.a.d.e[0] == 4.0
batch3.a.d[0] = Batch(f=5.0)
assert batch3.a.d.f[0] == 5.0
with pytest.raises(ValueError):
batch3.a.d[0] = Batch(f=5.0, g=0.0)
with pytest.raises(ValueError):
batch3[0] = Batch(a={"c": 2, "e": 1})
# auto convert
batch4 = Batch(a=np.array(["a", "b"]))
assert batch4.a.dtype == object # auto convert to object
batch4.update(a=np.array(["c", "d"]))
assert list(batch4.a) == ["c", "d"]
assert batch4.a.dtype == object # auto convert to object
batch5 = Batch(a=np.array([{"index": 0}]))
assert isinstance(batch5.a, Batch)
assert np.allclose(batch5.a.index, [0])
# We use setattr b/c the setattr of Batch will actually change the type of the field that is being set!
# However, mypy would not understand this, and rightly expect that batch.b = some_array would lead to
# batch.b being an array (which it is not, it's turned into a Batch instead)
batch5.b = np.array([{"index": 1}])
batch5.b = cast(Batch, batch5.b)
assert isinstance(batch5.b, Batch)
assert np.allclose(batch5.b.index, [1])
# None is a valid object and can be stored in Batch
a = Batch.stack([Batch(a=None), Batch(b=None)])
assert a.a[0] is None
assert a.a[1] is None
assert a.b[0] is None
assert a.b[1] is None
# nx.Graph corner case
assert Batch(a=np.array([nx.Graph(), nx.Graph()], dtype=object)).a.dtype == object
g1 = nx.Graph()
g1.add_nodes_from(list(range(10)))
g2 = nx.Graph()
g2.add_nodes_from(list(range(20)))
assert Batch(a=np.array([g1, g2], dtype=object)).a.dtype == object
def test_batch_over_batch() -> None:
batch = Batch(a=[3, 4, 5], b=[4, 5, 6])
batch2 = Batch({"c": [6, 7, 8], "b": batch})
batch2.b.b[-1] = 0
print(batch2)
for k, v in batch2.items():
assert np.all(batch2[k] == v)
assert batch2[-1].b.b == 0
batch2.cat_(Batch(c=[6, 7, 8], b=batch))
assert np.allclose(batch2.c, [6, 7, 8, 6, 7, 8])
assert np.allclose(batch2.b.a, [3, 4, 5, 3, 4, 5])
assert np.allclose(batch2.b.b, [4, 5, 0, 4, 5, 0])
batch2.update(batch2.b, six=[6, 6, 6])
assert np.allclose(batch2.c, [6, 7, 8, 6, 7, 8])
assert np.allclose(batch2.a, [3, 4, 5, 3, 4, 5])
assert np.allclose(batch2.b, [4, 5, 0, 4, 5, 0])
assert np.allclose(batch2.six, [6, 6, 6])
d = {"a": [3, 4, 5], "b": [4, 5, 6]}
batch3 = Batch(c=[6, 7, 8], b=d)
batch3.cat_(Batch(c=[6, 7, 8], b=d))
assert np.allclose(batch3.c, [6, 7, 8, 6, 7, 8])
assert np.allclose(batch3.b.a, [3, 4, 5, 3, 4, 5])
assert np.allclose(batch3.b.b, [4, 5, 6, 4, 5, 6])
batch4 = Batch(({"a": {"b": np.array([1.0])}},))
assert batch4.a.b.ndim == 2
assert batch4.a.b[0, 0] == 1.0
# advanced slicing
batch5 = Batch(a=[[1, 2]], b={"c": np.zeros([3, 2, 1])})
assert batch5.shape == [1, 2]
with pytest.raises(IndexError):
batch5[2]
with pytest.raises(IndexError):
batch5[:, 3]
with pytest.raises(IndexError):
batch5[:, :, -1]
batch5[:, -1] += np.int_(1)
assert np.allclose(batch5.a, [1, 3])
assert np.allclose(batch5.b.c.squeeze(), [[0, 1]] * 3)
with pytest.raises(ValueError):
batch5[:, -1] = 1
batch5[:, 0] = {"a": -1}
assert np.allclose(batch5.a, [-1, 3])
assert np.allclose(batch5.b.c.squeeze(), [[0, 1]] * 3)
def test_batch_cat_and_stack() -> None:
# test cat with compatible keys
b1 = Batch(a=[{"b": np.float64(1.0), "d": Batch(e=np.array(3.0))}])
b2 = Batch(a=[{"b": np.float64(4.0), "d": {"e": np.array(6.0)}}])
b12_cat_out = Batch.cat([b1, b2])
b12_cat_in = copy.deepcopy(b1)
b12_cat_in.cat_(b2)
assert np.all(b12_cat_in.a.d.e == b12_cat_out.a.d.e)
assert np.all(b12_cat_in.a.d.e == b12_cat_out.a.d.e)
assert isinstance(b12_cat_in.a.d.e, np.ndarray)
assert b12_cat_in.a.d.e.ndim == 1
a = Batch(a=Batch(a=np.random.randn(3, 4)))
assert np.allclose(
np.concatenate([a.a.a, a.a.a]),
Batch.cat([a, Batch(a=Batch(a=Batch())), a]).a.a,
)
# test cat with lens infer
a = Batch(a=Batch(a=np.random.randn(3, 4)), b=np.random.randn(3, 4))
b = Batch(a=Batch(a=Batch(), t=Batch()), b=np.random.randn(3, 4))
ans = Batch.cat([a, b, a])
assert np.allclose(ans.a.a, np.concatenate([a.a.a, np.zeros((3, 4)), a.a.a]))
assert np.allclose(ans.b, np.concatenate([a.b, b.b, a.b]))
assert ans.a.t.is_empty()
b1.stack_([b2])
assert isinstance(b1.a.d.e, np.ndarray)
assert b1.a.d.e.ndim == 2
# test cat with incompatible keys
b1 = Batch(a=np.random.rand(3, 4), common=Batch(c=np.random.rand(3, 5)))
b2 = Batch(b=torch.rand(4, 3), common=Batch(c=np.random.rand(4, 5)))
test = Batch.cat([b1, b2])
ans = Batch(
a=np.concatenate([b1.a, np.zeros((4, 4))]),
b=torch.cat([torch.zeros(3, 3), b2.b]),
common=Batch(c=np.concatenate([b1.common.c, b2.common.c])),
)
assert np.allclose(test.a, ans.a)
assert torch.allclose(test.b, ans.b)
assert np.allclose(test.common.c, ans.common.c)
# test cat with reserved keys (values are Batch())
b1 = Batch(a=np.random.rand(3, 4), common=Batch(c=np.random.rand(3, 5)))
b2 = Batch(a=Batch(), b=torch.rand(4, 3), common=Batch(c=np.random.rand(4, 5)))
test = Batch.cat([b1, b2])
ans = Batch(
a=np.concatenate([b1.a, np.zeros((4, 4))]),
b=torch.cat([torch.zeros(3, 3), b2.b]),
common=Batch(c=np.concatenate([b1.common.c, b2.common.c])),
)
assert np.allclose(test.a, ans.a)
assert torch.allclose(test.b, ans.b)
assert np.allclose(test.common.c, ans.common.c)
# test cat with all reserved keys (values are Batch())
b1 = Batch(a=Batch(), common=Batch(c=np.random.rand(3, 5)))
b2 = Batch(a=Batch(), b=torch.rand(4, 3), common=Batch(c=np.random.rand(4, 5)))
test = Batch.cat([b1, b2])
ans = Batch(
a=Batch(),
b=torch.cat([torch.zeros(3, 3), b2.b]),
common=Batch(c=np.concatenate([b1.common.c, b2.common.c])),
)
assert ans.a.is_empty()
assert torch.allclose(test.b, ans.b)
assert np.allclose(test.common.c, ans.common.c)
# test stack with compatible keys
b3 = Batch(a=np.zeros((3, 4)), b=torch.ones((2, 5)), c=Batch(d=[[1], [2]]))
b4 = Batch(a=np.ones((3, 4)), b=torch.ones((2, 5)), c=Batch(d=[[0], [3]]))
b34_stack = Batch.stack((b3, b4), axis=1)
assert np.all(b34_stack.a == np.stack((b3.a, b4.a), axis=1))
assert np.all(b34_stack.c.d == list(map(list, zip(b3.c.d, b4.c.d, strict=True))))
b5_dict = np.array([{"a": False, "b": {"c": 2.0, "d": 1.0}}, {"a": True, "b": {"c": 3.0}}])
b5 = Batch(b5_dict)
assert b5.a[0] == np.array(False)
assert b5.a[1] == np.array(True)
assert np.all(b5.b.c == np.stack([e["b"]["c"] for e in b5_dict], axis=0))
assert b5.b.d[0] == b5_dict[0]["b"]["d"]
assert b5.b.d[1] == 0.0
# test stack with incompatible keys
a = Batch(a=1, b=2, c=3)
b = Batch(a=4, b=5, d=6)
c = Batch(c=7, b=6, d=9)
d = Batch.stack([a, b, c])
assert np.allclose(d.a, [1, 4, 0])
assert np.allclose(d.b, [2, 5, 6])
assert np.allclose(d.c, [3, 0, 7])
assert np.allclose(d.d, [0, 6, 9])
# test stack with empty Batch()
assert Batch.stack([Batch(), Batch(), Batch()]).is_empty()
a = Batch(a=1, b=2, c=3, d=Batch(), e=Batch())
b = Batch(a=4, b=5, d=6, e=Batch())
c = Batch(c=7, b=6, d=9, e=Batch())
d = Batch.stack([a, b, c])
assert np.allclose(d.a, [1, 4, 0])
assert np.allclose(d.b, [2, 5, 6])
assert np.allclose(d.c, [3, 0, 7])
assert np.allclose(d.d, [0, 6, 9])
assert d.e.is_empty()
b1 = Batch(a=Batch(), common=Batch(c=np.random.rand(4, 5)))
b2 = Batch(b=Batch(), common=Batch(c=np.random.rand(4, 5)))
test = Batch.stack([b1, b2], axis=-1)
assert test.a.is_empty()
assert test.b.is_empty()
assert np.allclose(test.common.c, np.stack([b1.common.c, b2.common.c], axis=-1))
b1 = Batch(a=np.random.rand(4, 4), common=Batch(c=np.random.rand(4, 5)))
b2 = Batch(b=torch.rand(4, 6), common=Batch(c=np.random.rand(4, 5)))
test = Batch.stack([b1, b2])
ans = Batch(
a=np.stack([b1.a, np.zeros((4, 4))]),
b=torch.stack([torch.zeros(4, 6), b2.b]),
common=Batch(c=np.stack([b1.common.c, b2.common.c])),
)
assert np.allclose(test.a, ans.a)
assert torch.allclose(test.b, ans.b)
assert np.allclose(test.common.c, ans.common.c)
# test with illegal input format
with pytest.raises(ValueError):
Batch.cat([[Batch(a=1)], [Batch(a=1)]]) # type: ignore # cat() tested with invalid inp
with pytest.raises(ValueError):
Batch.stack([[Batch(a=1)], [Batch(a=1)]]) # type: ignore # stack() tested with invalid inp
# exceptions
batch_cat: Batch = Batch.cat([])
assert batch_cat.is_empty()
batch_stack: Batch = Batch.stack([])
assert batch_stack.is_empty()
b1 = Batch(e=[4, 5], d=6)
b2 = Batch(e=[4, 6])
with pytest.raises(ValueError):
Batch.cat([b1, b2])
with pytest.raises(ValueError):
Batch.stack([b1, b2], axis=1)
def test_batch_over_batch_to_torch() -> None:
batch = Batch(
a=np.float64(1.0),
b=Batch(c=np.ones((1,), dtype=np.float32), d=torch.ones((1,), dtype=torch.float64)),
)
batch.b.__dict__["e"] = 1 # bypass the check
batch.to_torch_()
assert isinstance(batch.a, torch.Tensor)
assert isinstance(batch.b.c, torch.Tensor)
assert isinstance(batch.b.d, torch.Tensor)
assert isinstance(batch.b.e, torch.Tensor)
assert batch.a.dtype == torch.float64
assert batch.b.c.dtype == torch.float32
assert batch.b.d.dtype == torch.float64
if sys.platform in ["win32", "cygwin"]: # windows
assert batch.b.e.dtype == torch.int32
else:
assert batch.b.e.dtype == torch.int64
batch.to_torch_(dtype=torch.float32)
assert batch.a.dtype == torch.float32
assert batch.b.c.dtype == torch.float32
assert batch.b.d.dtype == torch.float32
assert batch.b.e.dtype == torch.float32
def test_utils_to_torch_numpy() -> None:
batch = Batch(
a=np.float64(1.0),
b=Batch(c=np.ones((1,), dtype=np.float32), d=torch.ones((1,), dtype=torch.float64)),
)
a_torch_float = to_torch(batch.a, dtype=torch.float32)
assert a_torch_float.dtype == torch.float32
a_torch_double = to_torch(batch.a, dtype=torch.float64)
assert a_torch_double.dtype == torch.float64
batch_torch_float = to_torch(batch, dtype=torch.float32)
assert batch_torch_float.a.dtype == torch.float32
assert batch_torch_float.b.c.dtype == torch.float32
assert batch_torch_float.b.d.dtype == torch.float32
data_list = [float("nan"), 1]
data_list_torch = to_torch(data_list)
assert data_list_torch.dtype == torch.float64
data_list_2 = [np.random.rand(3, 3), np.random.rand(3, 3)]
data_list_2_torch = to_torch(data_list_2)
assert data_list_2_torch.shape == (2, 3, 3)
assert np.allclose(to_numpy(to_torch(data_list_2)), data_list_2)
data_list_3 = [np.zeros((3, 2)), np.zeros((3, 3))]
data_list_3_torch = [torch.zeros((3, 2)), torch.zeros((3, 3))]
with pytest.raises(TypeError):
to_torch(data_list_3)
with pytest.raises(TypeError):
to_numpy(data_list_3_torch)
data_list_4 = [np.zeros((2, 3)), np.zeros((3, 3))]
data_list_4_torch = [torch.zeros((2, 3)), torch.zeros((3, 3))]
with pytest.raises(TypeError):
to_torch(data_list_4)
with pytest.raises(TypeError):
to_numpy(data_list_4_torch)
data_list_5 = [np.zeros(2), np.zeros((3, 3))]
data_list_5_torch = [torch.zeros(2), torch.zeros((3, 3))]
with pytest.raises(TypeError):
to_torch(data_list_5)
with pytest.raises(TypeError):
to_numpy(data_list_5_torch)
data_array = np.random.rand(3, 2, 2)
data_empty_tensor = to_torch(data_array[[]])
assert isinstance(data_empty_tensor, torch.Tensor)
assert data_empty_tensor.shape == (0, 2, 2)
data_empty_array = to_numpy(data_empty_tensor)
assert isinstance(data_empty_array, np.ndarray)
assert data_empty_array.shape == (0, 2, 2)
assert np.allclose(to_numpy(to_torch(data_array)), data_array)
# additional test for to_numpy, for code-coverage
assert isinstance(to_numpy(1), np.ndarray)
assert isinstance(to_numpy(1.0), np.ndarray)
assert isinstance(to_numpy({"a": torch.tensor(1)})["a"], np.ndarray)
assert isinstance(to_numpy(Batch(a=torch.tensor(1))).a, np.ndarray)
assert to_numpy(None).item() is None
assert to_numpy(to_numpy).item() == to_numpy
# additional test for to_torch, for code-coverage
assert isinstance(to_torch(1), torch.Tensor)
if sys.platform in ["win32", "cygwin"]: # windows
assert to_torch(1).dtype == torch.int32
else:
assert to_torch(1).dtype == torch.int64
assert to_torch(1.0).dtype == torch.float64
assert isinstance(to_torch({"a": [1]})["a"], torch.Tensor)
with pytest.raises(TypeError):
to_torch(None)
with pytest.raises(TypeError):
to_torch(np.array([{}, "2"]))
def test_batch_pickle() -> None:
batch = Batch(obs=Batch(a=0.0, c=torch.Tensor([1.0, 2.0])), np=np.zeros([3, 4]))
batch_pk = pickle.loads(pickle.dumps(batch))
assert batch.obs.a == batch_pk.obs.a
assert torch.all(batch.obs.c == batch_pk.obs.c)
assert np.all(batch.np == batch_pk.np)
def test_batch_from_to_numpy_without_copy() -> None:
batch = Batch(a=np.ones((1,)), b=Batch(c=np.ones((1,))))
a_mem_addr_orig = batch.a.__array_interface__["data"][0]
c_mem_addr_orig = batch.b.c.__array_interface__["data"][0]
batch.to_torch_()
batch.to_numpy_()
a_mem_addr_new = batch.a.__array_interface__["data"][0]
c_mem_addr_new = batch.b.c.__array_interface__["data"][0]
assert a_mem_addr_new == a_mem_addr_orig
assert c_mem_addr_new == c_mem_addr_orig
def test_batch_copy() -> None:
batch = Batch(a=np.array([3, 4, 5]), b=np.array([4, 5, 6]))
batch2 = Batch({"c": np.array([6, 7, 8]), "b": batch})
orig_c_addr = batch2.c.__array_interface__["data"][0]
orig_b_a_addr = batch2.b.a.__array_interface__["data"][0]
orig_b_b_addr = batch2.b.b.__array_interface__["data"][0]
# test with copy=False
batch3 = Batch(copy=False, **batch2)
curr_c_addr = batch3.c.__array_interface__["data"][0]
curr_b_a_addr = batch3.b.a.__array_interface__["data"][0]
curr_b_b_addr = batch3.b.b.__array_interface__["data"][0]
assert batch2.c is batch3.c
assert batch2.b is batch3.b
assert batch2.b.a is batch3.b.a
assert batch2.b.b is batch3.b.b
assert orig_c_addr == curr_c_addr
assert orig_b_a_addr == curr_b_a_addr
assert orig_b_b_addr == curr_b_b_addr
# test with copy=True
batch3 = Batch(copy=True, **batch2)
curr_c_addr = batch3.c.__array_interface__["data"][0]
curr_b_a_addr = batch3.b.a.__array_interface__["data"][0]
curr_b_b_addr = batch3.b.b.__array_interface__["data"][0]
assert batch2.c is not batch3.c
assert batch2.b is not batch3.b
assert batch2.b.a is not batch3.b.a
assert batch2.b.b is not batch3.b.b
assert orig_c_addr != curr_c_addr
assert orig_b_a_addr != curr_b_a_addr
assert orig_b_b_addr != curr_b_b_addr
def test_batch_empty() -> None:
b5_dict = np.array([{"a": False, "b": {"c": 2.0, "d": 1.0}}, {"a": True, "b": {"c": 3.0}}])
b5 = Batch(b5_dict)
b5[1] = Batch.empty(b5[0])
assert np.allclose(b5.a, [False, False])
assert np.allclose(b5.b.c, [2, 0])
assert np.allclose(b5.b.d, [1, 0])
data = Batch(
a=[False, True],
b={
"c": np.array([2.0, "st"], dtype=object),
"d": [1, None],
"e": [2.0, float("nan")],
},
c=np.array([1, 3, 4], dtype=int),
t=torch.tensor([4, 5, 6, 7.0]),
)
data[-1] = Batch.empty(data[1])
assert np.allclose(data.c, [1, 3, 0])
assert np.allclose(data.a, [False, False])
assert list(data.b.c) == [2.0, None]
assert list(data.b.d) == [1, None]
assert np.allclose(data.b.e, [2, 0])
assert torch.allclose(data.t, torch.tensor([4, 5, 6, 0.0]))
data[0].empty_() # which will fail in a, b.c, b.d, b.e, c
assert torch.allclose(data.t, torch.tensor([0.0, 5, 6, 0]))
data.empty_(index=0)
assert np.allclose(data.c, [0, 3, 0])
assert list(data.b.c) == [None, None]
assert list(data.b.d) == [None, None]
assert list(data.b.e) == [0, 0]
b0 = Batch()
b0.empty_()
assert b0.shape == []
def test_batch_standard_compatibility() -> None:
batch = Batch(a=np.array([[1.0, 2.0], [3.0, 4.0]]), b=Batch(), c=np.array([5.0, 6.0]))
batch_mean = np.mean(batch)
assert isinstance(batch_mean, Batch) # type: ignore # mypy doesn't know but it works, cf. `batch.rst`
assert sorted(batch_mean.get_keys()) == ["a", "b", "c"] # type: ignore
with pytest.raises(TypeError):
len(batch_mean)
assert np.all(batch_mean.a == np.mean(batch.a, axis=0))
assert batch_mean.c == np.mean(batch.c, axis=0)
with pytest.raises(IndexError):
Batch()[0]
class TestBatchEquality:
@staticmethod
def test_keys_different() -> None:
batch1 = Batch(a=[1, 2], b=[100, 50])
batch2 = Batch(b=[1, 2], c=[100, 50])
assert batch1 != batch2
@staticmethod
def test_keys_missing() -> None:
batch1 = Batch(a=[1, 2], b=[2, 3, 4])
batch2 = Batch(a=[1, 2], b=[2, 3, 4])
batch2.pop("b")
assert batch1 != batch2
@staticmethod
def test_types_keys_different() -> None:
batch1 = Batch(a=[1, 2, 3], b=[4, 5])
batch2 = Batch(a=[1, 2, 3], b=Batch(a=[4, 5]))
assert batch1 != batch2
@staticmethod
def test_array_types_different() -> None:
batch1 = Batch(a=[1, 2, 3], b=np.array([4, 5]))
batch2 = Batch(a=[1, 2, 3], b=torch.Tensor([4, 5]))
assert batch1 != batch2
@staticmethod
def test_nested_values_different() -> None:
batch1 = Batch(a=Batch(a=[1, 2, 3]), b=[4, 5])
batch2 = Batch(a=Batch(a=[1, 2, 4]), b=[4, 5])
assert batch1 != batch2
@staticmethod
def test_nested_shapes_different() -> None:
batch1 = Batch(a=Batch(a=[1, 2, 3]), b=[4, 5])
batch2 = Batch(a=Batch(a=[1, 4]), b=[4, 5])
assert batch1 != batch2
@staticmethod
def test_slice_equal() -> None:
batch1 = Batch(a=[1, 2, 3])
assert batch1[:2] == batch1[:2]
@staticmethod
def test_slice_ellipsis_equal() -> None:
batch1 = Batch(a=Batch(a=[1, 2, 3]), b=[4, 5], c=[100, 1001, 2000])
assert batch1[..., 1:] == batch1[..., 1:]
@staticmethod
def test_empty_batches() -> None:
assert Batch() == Batch()
@staticmethod
def test_different_order_keys() -> None:
assert Batch(a=1, b=2) == Batch(b=2, a=1)
@staticmethod
def test_tuple_and_list_types() -> None:
assert Batch(a=(1, 2)) == Batch(a=[1, 2])
@staticmethod
def test_subbatch_dict_and_batch_types() -> None:
assert Batch(a={"x": 1}) == Batch(a=Batch(x=1))
class TestBatchToDict:
@staticmethod
def test_to_dict_empty_batch_no_recurse() -> None:
batch = Batch()
expected: dict[Any, Any] = {}
assert batch.to_dict() == expected
@staticmethod
def test_to_dict_with_simple_values_recurse() -> None:
batch = Batch(a=1, b="two", c=np.array([3, 4]))
expected = {"a": np.asanyarray(1), "b": "two", "c": np.array([3, 4])}
assert not DeepDiff(batch.to_dict(recursive=True), expected)
@staticmethod
def test_to_dict_simple() -> None:
batch = Batch(a=1, b="two")
expected = {"a": np.asanyarray(1), "b": "two"}
assert batch.to_dict() == expected
@staticmethod
def test_to_dict_nested_batch_no_recurse() -> None:
nested_batch = Batch(c=3)
batch = Batch(a=1, b=nested_batch)
expected = {"a": np.asanyarray(1), "b": nested_batch}
assert not DeepDiff(batch.to_dict(recursive=False), expected)
@staticmethod
def test_to_dict_nested_batch_recurse() -> None:
nested_batch = Batch(c=3)
batch = Batch(a=1, b=nested_batch)
expected = {"a": np.asanyarray(1), "b": {"c": np.asanyarray(3)}}
assert not DeepDiff(batch.to_dict(recursive=True), expected)
@staticmethod
def test_to_dict_multiple_nested_batch_recurse() -> None:
nested_batch = Batch(c=Batch(e=3), d=[100, 200, 300])
batch = Batch(a=1, b=nested_batch)
expected = {
"a": np.asanyarray(1),
"b": {"c": {"e": np.asanyarray(3)}, "d": np.array([100, 200, 300])},
}
assert not DeepDiff(batch.to_dict(recursive=True), expected)
@staticmethod
def test_to_dict_array() -> None:
batch = Batch(a=np.array([1, 2, 3]))
expected = {"a": np.array([1, 2, 3])}
assert not DeepDiff(batch.to_dict(), expected)
@staticmethod
def test_to_dict_nested_batch_with_array() -> None:
nested_batch = Batch(c=np.array([4, 5]))
batch = Batch(a=1, b=nested_batch)
expected = {"a": np.asanyarray(1), "b": {"c": np.array([4, 5])}}
assert not DeepDiff(batch.to_dict(recursive=True), expected)
@staticmethod
def test_to_dict_torch_tensor() -> None:
t1 = torch.tensor([1.0, 2.0]).detach().cpu().numpy()
batch = Batch(a=t1)
t2 = torch.tensor([1.0, 2.0]).detach().cpu().numpy()
expected = {"a": t2}
assert not DeepDiff(batch.to_dict(), expected)
@staticmethod
def test_to_dict_nested_batch_with_torch_tensor() -> None:
nested_batch = Batch(c=torch.tensor([4, 5]).detach().cpu().numpy())
batch = Batch(a=1, b=nested_batch)
expected = {"a": np.asanyarray(1), "b": {"c": torch.tensor([4, 5]).detach().cpu().numpy()}}
assert not DeepDiff(batch.to_dict(recursive=True), expected)
class TestToNumpy:
"""Tests for `Batch.to_numpy()` and its in-place counterpart `Batch.to_numpy_()` ."""
@staticmethod
def test_to_numpy() -> None:
batch = Batch(a=1, b=torch.arange(5), c={"d": torch.tensor([1, 2, 3])})
new_batch: Batch = Batch.to_numpy(batch)
assert id(batch) != id(new_batch)
assert isinstance(batch.b, torch.Tensor)
assert isinstance(batch.c.d, torch.Tensor)
assert isinstance(new_batch.b, np.ndarray)
assert isinstance(new_batch.c.d, np.ndarray)
@staticmethod
def test_to_numpy_() -> None:
batch = Batch(a=1, b=torch.arange(5), c={"d": torch.tensor([1, 2, 3])})
id_batch = id(batch)
batch.to_numpy_()
assert id_batch == id(batch)
assert isinstance(batch.b, np.ndarray)
assert isinstance(batch.c.d, np.ndarray)
class TestToTorch:
"""Tests for `Batch.to_torch()` and its in-place counterpart `Batch.to_torch_()` ."""
@staticmethod
def test_to_torch() -> None:
batch = Batch(a=1, b=np.arange(5), c={"d": np.array([1, 2, 3])})
new_batch: Batch = Batch.to_torch(batch)
assert id(batch) != id(new_batch)
assert isinstance(batch.b, np.ndarray)
assert isinstance(batch.c.d, np.ndarray)
assert isinstance(new_batch.b, torch.Tensor)
assert isinstance(new_batch.c.d, torch.Tensor)
@staticmethod
def test_to_torch_() -> None:
batch = Batch(a=1, b=np.arange(5), c={"d": np.array([1, 2, 3])})
id_batch = id(batch)
batch.to_torch_()
assert id_batch == id(batch)
assert isinstance(batch.b, torch.Tensor)
assert isinstance(batch.c.d, torch.Tensor)