Closes: https://github.com/thu-ml/tianshou/issues/1086 ### Api Extensions - Batch received new method: `to_numpy_`. #1098 - `to_dict` in Batch supports also non-recursive conversion. #1098 - Batch `__eq__` now implemented, semantic equality check of batches is now possible. #1098 ### Breaking Changes - The method `to_numpy` in `data.utils.batch.Batch` is not in-place anymore. Instead, a new method `to_numpy_` does the conversion in-place. #1098
741 lines
28 KiB
Python
741 lines
28 KiB
Python
import copy
|
|
import pickle
|
|
import sys
|
|
from itertools import starmap
|
|
from typing import Any, cast
|
|
|
|
import networkx as nx
|
|
import numpy as np
|
|
import pytest
|
|
import torch
|
|
from deepdiff import DeepDiff
|
|
|
|
from tianshou.data import Batch, to_numpy, to_torch
|
|
|
|
|
|
def test_batch() -> None:
|
|
assert list(Batch()) == []
|
|
assert Batch().is_empty()
|
|
assert not Batch(b={"c": {}}).is_empty()
|
|
assert Batch(b={"c": {}}).is_empty(recurse=True)
|
|
assert not Batch(a=Batch(), b=Batch(c=Batch())).is_empty()
|
|
assert Batch(a=Batch(), b=Batch(c=Batch())).is_empty(recurse=True)
|
|
assert not Batch(d=1).is_empty()
|
|
assert not Batch(a=np.float64(1.0)).is_empty()
|
|
assert len(Batch(a=[1, 2, 3], b={"c": {}})) == 3
|
|
assert not Batch(a=[1, 2, 3]).is_empty()
|
|
b = Batch({"a": [4, 4], "b": [5, 5]}, c=[None, None])
|
|
assert b.c.dtype == object
|
|
b = Batch(d=[None], e=[starmap], f=Batch)
|
|
assert b.d.dtype == b.e.dtype == object
|
|
assert b.f == Batch
|
|
b = Batch()
|
|
b.update()
|
|
assert b.is_empty()
|
|
b.update(c=[3, 5])
|
|
assert np.allclose(b.c, [3, 5])
|
|
# mimic the behavior of dict.update, where kwargs can overwrite keys
|
|
b.update({"a": 2}, a=3)
|
|
assert "a" in b
|
|
assert b.a == 3
|
|
assert b.pop("a") == 3
|
|
assert "a" not in b
|
|
with pytest.raises(AssertionError):
|
|
Batch({1: 2})
|
|
batch = Batch(a=[torch.ones(3), torch.ones(3)])
|
|
assert Batch(a=[np.zeros((2, 3)), np.zeros((3, 3))]).a.dtype == object
|
|
with pytest.raises(TypeError):
|
|
Batch(a=[np.zeros((3, 2)), np.zeros((3, 3))])
|
|
with pytest.raises(TypeError):
|
|
Batch(a=[torch.zeros((2, 3)), torch.zeros((3, 3))])
|
|
with pytest.raises(TypeError):
|
|
Batch(a=[torch.zeros((3, 3)), np.zeros((3, 3))])
|
|
with pytest.raises(TypeError):
|
|
Batch(a=[1, np.zeros((3, 3)), torch.zeros((3, 3))])
|
|
assert torch.allclose(batch.a, torch.ones(2, 3))
|
|
batch.cat_(batch)
|
|
assert torch.allclose(batch.a, torch.ones(4, 3))
|
|
Batch(a=[])
|
|
batch = Batch(obs=[0], np=np.zeros([3, 4]))
|
|
assert batch.obs == batch["obs"]
|
|
batch.obs = [1]
|
|
assert batch.obs == [1]
|
|
batch.cat_(batch)
|
|
assert np.allclose(batch.obs, [1, 1])
|
|
assert batch.np.shape == (6, 4)
|
|
assert np.allclose(batch[0].obs, batch[1].obs)
|
|
batch.obs = np.arange(5)
|
|
for i, b in enumerate(batch.split(1, shuffle=False)):
|
|
if i != 5:
|
|
assert b.obs == batch[i].obs
|
|
else:
|
|
with pytest.raises(AttributeError):
|
|
batch[i].obs # noqa: B018
|
|
with pytest.raises(AttributeError):
|
|
b.obs # noqa: B018
|
|
print(batch)
|
|
batch = Batch(a=np.arange(10))
|
|
with pytest.raises(AssertionError):
|
|
list(batch.split(0))
|
|
data = [
|
|
(1, False, [[0], [1], [2], [3], [4], [5], [6], [7], [8], [9]]),
|
|
(1, True, [[0], [1], [2], [3], [4], [5], [6], [7], [8], [9]]),
|
|
(3, False, [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]),
|
|
(3, True, [[0, 1, 2], [3, 4, 5], [6, 7, 8, 9]]),
|
|
(5, False, [[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]),
|
|
(5, True, [[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]),
|
|
(7, False, [[0, 1, 2, 3, 4, 5, 6], [7, 8, 9]]),
|
|
(7, True, [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]),
|
|
(10, False, [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]),
|
|
(10, True, [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]),
|
|
(15, False, [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]),
|
|
(15, True, [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]),
|
|
(100, False, [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]),
|
|
(100, True, [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]),
|
|
]
|
|
for size, merge_last, result in data:
|
|
bs = list(batch.split(size, shuffle=False, merge_last=merge_last))
|
|
assert [bs[i].a.tolist() for i in range(len(bs))] == result
|
|
batch_dict = {"b": np.array([1.0]), "c": 2.0, "d": torch.Tensor([3.0])}
|
|
batch_item = Batch({"a": [batch_dict]})[0]
|
|
assert isinstance(batch_item.a.b, np.ndarray)
|
|
assert batch_item.a.b == batch_dict["b"]
|
|
assert isinstance(batch_item.a.c, float)
|
|
assert batch_item.a.c == batch_dict["c"]
|
|
assert isinstance(batch_item.a.d, torch.Tensor)
|
|
assert batch_item.a.d == batch_dict["d"]
|
|
batch2 = Batch(a=[{"b": np.float64(1.0), "c": np.zeros(1), "d": Batch(e=np.array(3.0))}])
|
|
assert len(batch2) == 1
|
|
assert Batch().shape == []
|
|
assert Batch(a=1).shape == []
|
|
assert Batch(a={1, 2}).shape == []
|
|
assert batch2.shape[0] == 1
|
|
assert "a" in batch2
|
|
assert all(i in batch2.a for i in "bcd")
|
|
with pytest.raises(IndexError):
|
|
batch2[-2]
|
|
with pytest.raises(IndexError):
|
|
batch2[1]
|
|
assert batch2[0].shape == []
|
|
with pytest.raises(IndexError):
|
|
batch2[0][0]
|
|
with pytest.raises(TypeError):
|
|
len(batch2[0])
|
|
assert isinstance(batch2[0].a.c, np.ndarray)
|
|
assert isinstance(batch2[0].a.b, np.float64)
|
|
assert isinstance(batch2[0].a.d.e, np.float64)
|
|
batch2_from_list = Batch(list(batch2))
|
|
batch2_from_comp = Batch(list(batch2))
|
|
assert batch2_from_list.a.b == batch2.a.b
|
|
assert batch2_from_list.a.c == batch2.a.c
|
|
assert batch2_from_list.a.d.e == batch2.a.d.e
|
|
assert batch2_from_comp.a.b == batch2.a.b
|
|
assert batch2_from_comp.a.c == batch2.a.c
|
|
assert batch2_from_comp.a.d.e == batch2.a.d.e
|
|
for batch_slice in [batch2[slice(0, 1)], batch2[:1], batch2[0:]]:
|
|
assert batch_slice.a.b == batch2.a.b
|
|
assert batch_slice.a.c == batch2.a.c
|
|
assert batch_slice.a.d.e == batch2.a.d.e
|
|
batch2.a.d.f = {}
|
|
batch2_sum = (batch2 + 1.0) * 2 # type: ignore # __add__ supports Number as input type
|
|
assert batch2_sum.a.b == (batch2.a.b + 1.0) * 2
|
|
assert batch2_sum.a.c == (batch2.a.c + 1.0) * 2
|
|
assert batch2_sum.a.d.e == (batch2.a.d.e + 1.0) * 2
|
|
assert batch2_sum.a.d.f.is_empty()
|
|
with pytest.raises(TypeError):
|
|
batch2 += [1] # type: ignore # error is raised explicitly
|
|
batch3 = Batch(a={"c": np.zeros(1), "d": Batch(e=np.array([0.0]), f=np.array([3.0]))})
|
|
batch3.a.d[0] = {"e": 4.0}
|
|
assert batch3.a.d.e[0] == 4.0
|
|
batch3.a.d[0] = Batch(f=5.0)
|
|
assert batch3.a.d.f[0] == 5.0
|
|
with pytest.raises(ValueError):
|
|
batch3.a.d[0] = Batch(f=5.0, g=0.0)
|
|
with pytest.raises(ValueError):
|
|
batch3[0] = Batch(a={"c": 2, "e": 1})
|
|
# auto convert
|
|
batch4 = Batch(a=np.array(["a", "b"]))
|
|
assert batch4.a.dtype == object # auto convert to object
|
|
batch4.update(a=np.array(["c", "d"]))
|
|
assert list(batch4.a) == ["c", "d"]
|
|
assert batch4.a.dtype == object # auto convert to object
|
|
batch5 = Batch(a=np.array([{"index": 0}]))
|
|
assert isinstance(batch5.a, Batch)
|
|
assert np.allclose(batch5.a.index, [0])
|
|
# We use setattr b/c the setattr of Batch will actually change the type of the field that is being set!
|
|
# However, mypy would not understand this, and rightly expect that batch.b = some_array would lead to
|
|
# batch.b being an array (which it is not, it's turned into a Batch instead)
|
|
batch5.b = np.array([{"index": 1}])
|
|
batch5.b = cast(Batch, batch5.b)
|
|
assert isinstance(batch5.b, Batch)
|
|
assert np.allclose(batch5.b.index, [1])
|
|
|
|
# None is a valid object and can be stored in Batch
|
|
a = Batch.stack([Batch(a=None), Batch(b=None)])
|
|
assert a.a[0] is None
|
|
assert a.a[1] is None
|
|
assert a.b[0] is None
|
|
assert a.b[1] is None
|
|
|
|
# nx.Graph corner case
|
|
assert Batch(a=np.array([nx.Graph(), nx.Graph()], dtype=object)).a.dtype == object
|
|
g1 = nx.Graph()
|
|
g1.add_nodes_from(list(range(10)))
|
|
g2 = nx.Graph()
|
|
g2.add_nodes_from(list(range(20)))
|
|
assert Batch(a=np.array([g1, g2], dtype=object)).a.dtype == object
|
|
|
|
|
|
def test_batch_over_batch() -> None:
|
|
batch = Batch(a=[3, 4, 5], b=[4, 5, 6])
|
|
batch2 = Batch({"c": [6, 7, 8], "b": batch})
|
|
batch2.b.b[-1] = 0
|
|
print(batch2)
|
|
for k, v in batch2.items():
|
|
assert np.all(batch2[k] == v)
|
|
assert batch2[-1].b.b == 0
|
|
batch2.cat_(Batch(c=[6, 7, 8], b=batch))
|
|
assert np.allclose(batch2.c, [6, 7, 8, 6, 7, 8])
|
|
assert np.allclose(batch2.b.a, [3, 4, 5, 3, 4, 5])
|
|
assert np.allclose(batch2.b.b, [4, 5, 0, 4, 5, 0])
|
|
batch2.update(batch2.b, six=[6, 6, 6])
|
|
assert np.allclose(batch2.c, [6, 7, 8, 6, 7, 8])
|
|
assert np.allclose(batch2.a, [3, 4, 5, 3, 4, 5])
|
|
assert np.allclose(batch2.b, [4, 5, 0, 4, 5, 0])
|
|
assert np.allclose(batch2.six, [6, 6, 6])
|
|
d = {"a": [3, 4, 5], "b": [4, 5, 6]}
|
|
batch3 = Batch(c=[6, 7, 8], b=d)
|
|
batch3.cat_(Batch(c=[6, 7, 8], b=d))
|
|
assert np.allclose(batch3.c, [6, 7, 8, 6, 7, 8])
|
|
assert np.allclose(batch3.b.a, [3, 4, 5, 3, 4, 5])
|
|
assert np.allclose(batch3.b.b, [4, 5, 6, 4, 5, 6])
|
|
batch4 = Batch(({"a": {"b": np.array([1.0])}},))
|
|
assert batch4.a.b.ndim == 2
|
|
assert batch4.a.b[0, 0] == 1.0
|
|
# advanced slicing
|
|
batch5 = Batch(a=[[1, 2]], b={"c": np.zeros([3, 2, 1])})
|
|
assert batch5.shape == [1, 2]
|
|
with pytest.raises(IndexError):
|
|
batch5[2]
|
|
with pytest.raises(IndexError):
|
|
batch5[:, 3]
|
|
with pytest.raises(IndexError):
|
|
batch5[:, :, -1]
|
|
batch5[:, -1] += np.int_(1)
|
|
assert np.allclose(batch5.a, [1, 3])
|
|
assert np.allclose(batch5.b.c.squeeze(), [[0, 1]] * 3)
|
|
with pytest.raises(ValueError):
|
|
batch5[:, -1] = 1
|
|
batch5[:, 0] = {"a": -1}
|
|
assert np.allclose(batch5.a, [-1, 3])
|
|
assert np.allclose(batch5.b.c.squeeze(), [[0, 1]] * 3)
|
|
|
|
|
|
def test_batch_cat_and_stack() -> None:
|
|
# test cat with compatible keys
|
|
b1 = Batch(a=[{"b": np.float64(1.0), "d": Batch(e=np.array(3.0))}])
|
|
b2 = Batch(a=[{"b": np.float64(4.0), "d": {"e": np.array(6.0)}}])
|
|
b12_cat_out = Batch.cat([b1, b2])
|
|
b12_cat_in = copy.deepcopy(b1)
|
|
b12_cat_in.cat_(b2)
|
|
assert np.all(b12_cat_in.a.d.e == b12_cat_out.a.d.e)
|
|
assert np.all(b12_cat_in.a.d.e == b12_cat_out.a.d.e)
|
|
assert isinstance(b12_cat_in.a.d.e, np.ndarray)
|
|
assert b12_cat_in.a.d.e.ndim == 1
|
|
|
|
a = Batch(a=Batch(a=np.random.randn(3, 4)))
|
|
assert np.allclose(
|
|
np.concatenate([a.a.a, a.a.a]),
|
|
Batch.cat([a, Batch(a=Batch(a=Batch())), a]).a.a,
|
|
)
|
|
|
|
# test cat with lens infer
|
|
a = Batch(a=Batch(a=np.random.randn(3, 4)), b=np.random.randn(3, 4))
|
|
b = Batch(a=Batch(a=Batch(), t=Batch()), b=np.random.randn(3, 4))
|
|
ans = Batch.cat([a, b, a])
|
|
assert np.allclose(ans.a.a, np.concatenate([a.a.a, np.zeros((3, 4)), a.a.a]))
|
|
assert np.allclose(ans.b, np.concatenate([a.b, b.b, a.b]))
|
|
assert ans.a.t.is_empty()
|
|
|
|
b1.stack_([b2])
|
|
assert isinstance(b1.a.d.e, np.ndarray)
|
|
assert b1.a.d.e.ndim == 2
|
|
|
|
# test cat with incompatible keys
|
|
b1 = Batch(a=np.random.rand(3, 4), common=Batch(c=np.random.rand(3, 5)))
|
|
b2 = Batch(b=torch.rand(4, 3), common=Batch(c=np.random.rand(4, 5)))
|
|
test = Batch.cat([b1, b2])
|
|
ans = Batch(
|
|
a=np.concatenate([b1.a, np.zeros((4, 4))]),
|
|
b=torch.cat([torch.zeros(3, 3), b2.b]),
|
|
common=Batch(c=np.concatenate([b1.common.c, b2.common.c])),
|
|
)
|
|
assert np.allclose(test.a, ans.a)
|
|
assert torch.allclose(test.b, ans.b)
|
|
assert np.allclose(test.common.c, ans.common.c)
|
|
|
|
# test cat with reserved keys (values are Batch())
|
|
b1 = Batch(a=np.random.rand(3, 4), common=Batch(c=np.random.rand(3, 5)))
|
|
b2 = Batch(a=Batch(), b=torch.rand(4, 3), common=Batch(c=np.random.rand(4, 5)))
|
|
test = Batch.cat([b1, b2])
|
|
ans = Batch(
|
|
a=np.concatenate([b1.a, np.zeros((4, 4))]),
|
|
b=torch.cat([torch.zeros(3, 3), b2.b]),
|
|
common=Batch(c=np.concatenate([b1.common.c, b2.common.c])),
|
|
)
|
|
assert np.allclose(test.a, ans.a)
|
|
assert torch.allclose(test.b, ans.b)
|
|
assert np.allclose(test.common.c, ans.common.c)
|
|
|
|
# test cat with all reserved keys (values are Batch())
|
|
b1 = Batch(a=Batch(), common=Batch(c=np.random.rand(3, 5)))
|
|
b2 = Batch(a=Batch(), b=torch.rand(4, 3), common=Batch(c=np.random.rand(4, 5)))
|
|
test = Batch.cat([b1, b2])
|
|
ans = Batch(
|
|
a=Batch(),
|
|
b=torch.cat([torch.zeros(3, 3), b2.b]),
|
|
common=Batch(c=np.concatenate([b1.common.c, b2.common.c])),
|
|
)
|
|
assert ans.a.is_empty()
|
|
assert torch.allclose(test.b, ans.b)
|
|
assert np.allclose(test.common.c, ans.common.c)
|
|
|
|
# test stack with compatible keys
|
|
b3 = Batch(a=np.zeros((3, 4)), b=torch.ones((2, 5)), c=Batch(d=[[1], [2]]))
|
|
b4 = Batch(a=np.ones((3, 4)), b=torch.ones((2, 5)), c=Batch(d=[[0], [3]]))
|
|
b34_stack = Batch.stack((b3, b4), axis=1)
|
|
assert np.all(b34_stack.a == np.stack((b3.a, b4.a), axis=1))
|
|
assert np.all(b34_stack.c.d == list(map(list, zip(b3.c.d, b4.c.d, strict=True))))
|
|
b5_dict = np.array([{"a": False, "b": {"c": 2.0, "d": 1.0}}, {"a": True, "b": {"c": 3.0}}])
|
|
b5 = Batch(b5_dict)
|
|
assert b5.a[0] == np.array(False)
|
|
assert b5.a[1] == np.array(True)
|
|
assert np.all(b5.b.c == np.stack([e["b"]["c"] for e in b5_dict], axis=0))
|
|
assert b5.b.d[0] == b5_dict[0]["b"]["d"]
|
|
assert b5.b.d[1] == 0.0
|
|
|
|
# test stack with incompatible keys
|
|
a = Batch(a=1, b=2, c=3)
|
|
b = Batch(a=4, b=5, d=6)
|
|
c = Batch(c=7, b=6, d=9)
|
|
d = Batch.stack([a, b, c])
|
|
assert np.allclose(d.a, [1, 4, 0])
|
|
assert np.allclose(d.b, [2, 5, 6])
|
|
assert np.allclose(d.c, [3, 0, 7])
|
|
assert np.allclose(d.d, [0, 6, 9])
|
|
|
|
# test stack with empty Batch()
|
|
assert Batch.stack([Batch(), Batch(), Batch()]).is_empty()
|
|
a = Batch(a=1, b=2, c=3, d=Batch(), e=Batch())
|
|
b = Batch(a=4, b=5, d=6, e=Batch())
|
|
c = Batch(c=7, b=6, d=9, e=Batch())
|
|
d = Batch.stack([a, b, c])
|
|
assert np.allclose(d.a, [1, 4, 0])
|
|
assert np.allclose(d.b, [2, 5, 6])
|
|
assert np.allclose(d.c, [3, 0, 7])
|
|
assert np.allclose(d.d, [0, 6, 9])
|
|
assert d.e.is_empty()
|
|
b1 = Batch(a=Batch(), common=Batch(c=np.random.rand(4, 5)))
|
|
b2 = Batch(b=Batch(), common=Batch(c=np.random.rand(4, 5)))
|
|
test = Batch.stack([b1, b2], axis=-1)
|
|
assert test.a.is_empty()
|
|
assert test.b.is_empty()
|
|
assert np.allclose(test.common.c, np.stack([b1.common.c, b2.common.c], axis=-1))
|
|
|
|
b1 = Batch(a=np.random.rand(4, 4), common=Batch(c=np.random.rand(4, 5)))
|
|
b2 = Batch(b=torch.rand(4, 6), common=Batch(c=np.random.rand(4, 5)))
|
|
test = Batch.stack([b1, b2])
|
|
ans = Batch(
|
|
a=np.stack([b1.a, np.zeros((4, 4))]),
|
|
b=torch.stack([torch.zeros(4, 6), b2.b]),
|
|
common=Batch(c=np.stack([b1.common.c, b2.common.c])),
|
|
)
|
|
assert np.allclose(test.a, ans.a)
|
|
assert torch.allclose(test.b, ans.b)
|
|
assert np.allclose(test.common.c, ans.common.c)
|
|
|
|
# test with illegal input format
|
|
with pytest.raises(ValueError):
|
|
Batch.cat([[Batch(a=1)], [Batch(a=1)]]) # type: ignore # cat() tested with invalid inp
|
|
with pytest.raises(ValueError):
|
|
Batch.stack([[Batch(a=1)], [Batch(a=1)]]) # type: ignore # stack() tested with invalid inp
|
|
|
|
# exceptions
|
|
batch_cat: Batch = Batch.cat([])
|
|
assert batch_cat.is_empty()
|
|
batch_stack: Batch = Batch.stack([])
|
|
assert batch_stack.is_empty()
|
|
b1 = Batch(e=[4, 5], d=6)
|
|
b2 = Batch(e=[4, 6])
|
|
with pytest.raises(ValueError):
|
|
Batch.cat([b1, b2])
|
|
with pytest.raises(ValueError):
|
|
Batch.stack([b1, b2], axis=1)
|
|
|
|
|
|
def test_batch_over_batch_to_torch() -> None:
|
|
batch = Batch(
|
|
a=np.float64(1.0),
|
|
b=Batch(c=np.ones((1,), dtype=np.float32), d=torch.ones((1,), dtype=torch.float64)),
|
|
)
|
|
batch.b.__dict__["e"] = 1 # bypass the check
|
|
batch.to_torch()
|
|
assert isinstance(batch.a, torch.Tensor)
|
|
assert isinstance(batch.b.c, torch.Tensor)
|
|
assert isinstance(batch.b.d, torch.Tensor)
|
|
assert isinstance(batch.b.e, torch.Tensor)
|
|
assert batch.a.dtype == torch.float64
|
|
assert batch.b.c.dtype == torch.float32
|
|
assert batch.b.d.dtype == torch.float64
|
|
if sys.platform in ["win32", "cygwin"]: # windows
|
|
assert batch.b.e.dtype == torch.int32
|
|
else:
|
|
assert batch.b.e.dtype == torch.int64
|
|
batch.to_torch(dtype=torch.float32)
|
|
assert batch.a.dtype == torch.float32
|
|
assert batch.b.c.dtype == torch.float32
|
|
assert batch.b.d.dtype == torch.float32
|
|
assert batch.b.e.dtype == torch.float32
|
|
|
|
|
|
def test_utils_to_torch_numpy() -> None:
|
|
batch = Batch(
|
|
a=np.float64(1.0),
|
|
b=Batch(c=np.ones((1,), dtype=np.float32), d=torch.ones((1,), dtype=torch.float64)),
|
|
)
|
|
a_torch_float = to_torch(batch.a, dtype=torch.float32)
|
|
assert a_torch_float.dtype == torch.float32
|
|
a_torch_double = to_torch(batch.a, dtype=torch.float64)
|
|
assert a_torch_double.dtype == torch.float64
|
|
batch_torch_float = to_torch(batch, dtype=torch.float32)
|
|
assert batch_torch_float.a.dtype == torch.float32
|
|
assert batch_torch_float.b.c.dtype == torch.float32
|
|
assert batch_torch_float.b.d.dtype == torch.float32
|
|
data_list = [float("nan"), 1]
|
|
data_list_torch = to_torch(data_list)
|
|
assert data_list_torch.dtype == torch.float64
|
|
data_list_2 = [np.random.rand(3, 3), np.random.rand(3, 3)]
|
|
data_list_2_torch = to_torch(data_list_2)
|
|
assert data_list_2_torch.shape == (2, 3, 3)
|
|
assert np.allclose(to_numpy(to_torch(data_list_2)), data_list_2)
|
|
data_list_3 = [np.zeros((3, 2)), np.zeros((3, 3))]
|
|
data_list_3_torch = [torch.zeros((3, 2)), torch.zeros((3, 3))]
|
|
with pytest.raises(TypeError):
|
|
to_torch(data_list_3)
|
|
with pytest.raises(TypeError):
|
|
to_numpy(data_list_3_torch)
|
|
data_list_4 = [np.zeros((2, 3)), np.zeros((3, 3))]
|
|
data_list_4_torch = [torch.zeros((2, 3)), torch.zeros((3, 3))]
|
|
with pytest.raises(TypeError):
|
|
to_torch(data_list_4)
|
|
with pytest.raises(TypeError):
|
|
to_numpy(data_list_4_torch)
|
|
data_list_5 = [np.zeros(2), np.zeros((3, 3))]
|
|
data_list_5_torch = [torch.zeros(2), torch.zeros((3, 3))]
|
|
with pytest.raises(TypeError):
|
|
to_torch(data_list_5)
|
|
with pytest.raises(TypeError):
|
|
to_numpy(data_list_5_torch)
|
|
data_array = np.random.rand(3, 2, 2)
|
|
data_empty_tensor = to_torch(data_array[[]])
|
|
assert isinstance(data_empty_tensor, torch.Tensor)
|
|
assert data_empty_tensor.shape == (0, 2, 2)
|
|
data_empty_array = to_numpy(data_empty_tensor)
|
|
assert isinstance(data_empty_array, np.ndarray)
|
|
assert data_empty_array.shape == (0, 2, 2)
|
|
assert np.allclose(to_numpy(to_torch(data_array)), data_array)
|
|
# additional test for to_numpy, for code-coverage
|
|
assert isinstance(to_numpy(1), np.ndarray)
|
|
assert isinstance(to_numpy(1.0), np.ndarray)
|
|
assert isinstance(to_numpy({"a": torch.tensor(1)})["a"], np.ndarray)
|
|
assert isinstance(to_numpy(Batch(a=torch.tensor(1))).a, np.ndarray)
|
|
assert to_numpy(None).item() is None
|
|
assert to_numpy(to_numpy).item() == to_numpy
|
|
# additional test for to_torch, for code-coverage
|
|
assert isinstance(to_torch(1), torch.Tensor)
|
|
if sys.platform in ["win32", "cygwin"]: # windows
|
|
assert to_torch(1).dtype == torch.int32
|
|
else:
|
|
assert to_torch(1).dtype == torch.int64
|
|
assert to_torch(1.0).dtype == torch.float64
|
|
assert isinstance(to_torch({"a": [1]})["a"], torch.Tensor)
|
|
with pytest.raises(TypeError):
|
|
to_torch(None)
|
|
with pytest.raises(TypeError):
|
|
to_torch(np.array([{}, "2"]))
|
|
|
|
|
|
def test_batch_pickle() -> None:
|
|
batch = Batch(obs=Batch(a=0.0, c=torch.Tensor([1.0, 2.0])), np=np.zeros([3, 4]))
|
|
batch_pk = pickle.loads(pickle.dumps(batch))
|
|
assert batch.obs.a == batch_pk.obs.a
|
|
assert torch.all(batch.obs.c == batch_pk.obs.c)
|
|
assert np.all(batch.np == batch_pk.np)
|
|
|
|
|
|
def test_batch_from_to_numpy_without_copy() -> None:
|
|
batch = Batch(a=np.ones((1,)), b=Batch(c=np.ones((1,))))
|
|
a_mem_addr_orig = batch.a.__array_interface__["data"][0]
|
|
c_mem_addr_orig = batch.b.c.__array_interface__["data"][0]
|
|
batch.to_torch()
|
|
batch.to_numpy_()
|
|
a_mem_addr_new = batch.a.__array_interface__["data"][0]
|
|
c_mem_addr_new = batch.b.c.__array_interface__["data"][0]
|
|
assert a_mem_addr_new == a_mem_addr_orig
|
|
assert c_mem_addr_new == c_mem_addr_orig
|
|
|
|
|
|
def test_batch_copy() -> None:
|
|
batch = Batch(a=np.array([3, 4, 5]), b=np.array([4, 5, 6]))
|
|
batch2 = Batch({"c": np.array([6, 7, 8]), "b": batch})
|
|
orig_c_addr = batch2.c.__array_interface__["data"][0]
|
|
orig_b_a_addr = batch2.b.a.__array_interface__["data"][0]
|
|
orig_b_b_addr = batch2.b.b.__array_interface__["data"][0]
|
|
# test with copy=False
|
|
batch3 = Batch(copy=False, **batch2)
|
|
curr_c_addr = batch3.c.__array_interface__["data"][0]
|
|
curr_b_a_addr = batch3.b.a.__array_interface__["data"][0]
|
|
curr_b_b_addr = batch3.b.b.__array_interface__["data"][0]
|
|
assert batch2.c is batch3.c
|
|
assert batch2.b is batch3.b
|
|
assert batch2.b.a is batch3.b.a
|
|
assert batch2.b.b is batch3.b.b
|
|
assert orig_c_addr == curr_c_addr
|
|
assert orig_b_a_addr == curr_b_a_addr
|
|
assert orig_b_b_addr == curr_b_b_addr
|
|
# test with copy=True
|
|
batch3 = Batch(copy=True, **batch2)
|
|
curr_c_addr = batch3.c.__array_interface__["data"][0]
|
|
curr_b_a_addr = batch3.b.a.__array_interface__["data"][0]
|
|
curr_b_b_addr = batch3.b.b.__array_interface__["data"][0]
|
|
assert batch2.c is not batch3.c
|
|
assert batch2.b is not batch3.b
|
|
assert batch2.b.a is not batch3.b.a
|
|
assert batch2.b.b is not batch3.b.b
|
|
assert orig_c_addr != curr_c_addr
|
|
assert orig_b_a_addr != curr_b_a_addr
|
|
assert orig_b_b_addr != curr_b_b_addr
|
|
|
|
|
|
def test_batch_empty() -> None:
|
|
b5_dict = np.array([{"a": False, "b": {"c": 2.0, "d": 1.0}}, {"a": True, "b": {"c": 3.0}}])
|
|
b5 = Batch(b5_dict)
|
|
b5[1] = Batch.empty(b5[0])
|
|
assert np.allclose(b5.a, [False, False])
|
|
assert np.allclose(b5.b.c, [2, 0])
|
|
assert np.allclose(b5.b.d, [1, 0])
|
|
data = Batch(
|
|
a=[False, True],
|
|
b={
|
|
"c": np.array([2.0, "st"], dtype=object),
|
|
"d": [1, None],
|
|
"e": [2.0, float("nan")],
|
|
},
|
|
c=np.array([1, 3, 4], dtype=int),
|
|
t=torch.tensor([4, 5, 6, 7.0]),
|
|
)
|
|
data[-1] = Batch.empty(data[1])
|
|
assert np.allclose(data.c, [1, 3, 0])
|
|
assert np.allclose(data.a, [False, False])
|
|
assert list(data.b.c) == [2.0, None]
|
|
assert list(data.b.d) == [1, None]
|
|
assert np.allclose(data.b.e, [2, 0])
|
|
assert torch.allclose(data.t, torch.tensor([4, 5, 6, 0.0]))
|
|
data[0].empty_() # which will fail in a, b.c, b.d, b.e, c
|
|
assert torch.allclose(data.t, torch.tensor([0.0, 5, 6, 0]))
|
|
data.empty_(index=0)
|
|
assert np.allclose(data.c, [0, 3, 0])
|
|
assert list(data.b.c) == [None, None]
|
|
assert list(data.b.d) == [None, None]
|
|
assert list(data.b.e) == [0, 0]
|
|
b0 = Batch()
|
|
b0.empty_()
|
|
assert b0.shape == []
|
|
|
|
|
|
def test_batch_standard_compatibility() -> None:
|
|
batch = Batch(a=np.array([[1.0, 2.0], [3.0, 4.0]]), b=Batch(), c=np.array([5.0, 6.0]))
|
|
batch_mean = np.mean(batch)
|
|
assert isinstance(batch_mean, Batch) # type: ignore # mypy doesn't know but it works, cf. `batch.rst`
|
|
assert sorted(batch_mean.get_keys()) == ["a", "b", "c"] # type: ignore
|
|
with pytest.raises(TypeError):
|
|
len(batch_mean)
|
|
assert np.all(batch_mean.a == np.mean(batch.a, axis=0))
|
|
assert batch_mean.c == np.mean(batch.c, axis=0)
|
|
with pytest.raises(IndexError):
|
|
Batch()[0]
|
|
|
|
|
|
class TestBatchEquality:
|
|
@staticmethod
|
|
def test_keys_different() -> None:
|
|
batch1 = Batch(a=[1, 2], b=[100, 50])
|
|
batch2 = Batch(b=[1, 2], c=[100, 50])
|
|
assert batch1 != batch2
|
|
|
|
@staticmethod
|
|
def test_keys_missing() -> None:
|
|
batch1 = Batch(a=[1, 2], b=[2, 3, 4])
|
|
batch2 = Batch(a=[1, 2], b=[2, 3, 4])
|
|
batch2.pop("b")
|
|
assert batch1 != batch2
|
|
|
|
@staticmethod
|
|
def test_types_keys_different() -> None:
|
|
batch1 = Batch(a=[1, 2, 3], b=[4, 5])
|
|
batch2 = Batch(a=[1, 2, 3], b=Batch(a=[4, 5]))
|
|
assert batch1 != batch2
|
|
|
|
@staticmethod
|
|
def test_array_types_different() -> None:
|
|
batch1 = Batch(a=[1, 2, 3], b=np.array([4, 5]))
|
|
batch2 = Batch(a=[1, 2, 3], b=torch.Tensor([4, 5]))
|
|
assert batch1 != batch2
|
|
|
|
@staticmethod
|
|
def test_nested_values_different() -> None:
|
|
batch1 = Batch(a=Batch(a=[1, 2, 3]), b=[4, 5])
|
|
batch2 = Batch(a=Batch(a=[1, 2, 4]), b=[4, 5])
|
|
assert batch1 != batch2
|
|
|
|
@staticmethod
|
|
def test_nested_shapes_different() -> None:
|
|
batch1 = Batch(a=Batch(a=[1, 2, 3]), b=[4, 5])
|
|
batch2 = Batch(a=Batch(a=[1, 4]), b=[4, 5])
|
|
assert batch1 != batch2
|
|
|
|
@staticmethod
|
|
def test_slice_equal() -> None:
|
|
batch1 = Batch(a=[1, 2, 3])
|
|
assert batch1[:2] == batch1[:2]
|
|
|
|
@staticmethod
|
|
def test_slice_ellipsis_equal() -> None:
|
|
batch1 = Batch(a=Batch(a=[1, 2, 3]), b=[4, 5], c=[100, 1001, 2000])
|
|
assert batch1[..., 1:] == batch1[..., 1:]
|
|
|
|
@staticmethod
|
|
def test_empty_batches() -> None:
|
|
assert Batch() == Batch()
|
|
|
|
@staticmethod
|
|
def test_different_order_keys() -> None:
|
|
assert Batch(a=1, b=2) == Batch(b=2, a=1)
|
|
|
|
@staticmethod
|
|
def test_tuple_and_list_types() -> None:
|
|
assert Batch(a=(1, 2)) == Batch(a=[1, 2])
|
|
|
|
@staticmethod
|
|
def test_subbatch_dict_and_batch_types() -> None:
|
|
assert Batch(a={"x": 1}) == Batch(a=Batch(x=1))
|
|
|
|
|
|
class TestBatchToDict:
|
|
@staticmethod
|
|
def test_to_dict_empty_batch_no_recurse() -> None:
|
|
batch = Batch()
|
|
expected: dict[Any, Any] = {}
|
|
assert batch.to_dict() == expected
|
|
|
|
@staticmethod
|
|
def test_to_dict_with_simple_values_recurse() -> None:
|
|
batch = Batch(a=1, b="two", c=np.array([3, 4]))
|
|
expected = {"a": np.asanyarray(1), "b": "two", "c": np.array([3, 4])}
|
|
assert not DeepDiff(batch.to_dict(recursive=True), expected)
|
|
|
|
@staticmethod
|
|
def test_to_dict_simple() -> None:
|
|
batch = Batch(a=1, b="two")
|
|
expected = {"a": np.asanyarray(1), "b": "two"}
|
|
assert batch.to_dict() == expected
|
|
|
|
@staticmethod
|
|
def test_to_dict_nested_batch_no_recurse() -> None:
|
|
nested_batch = Batch(c=3)
|
|
batch = Batch(a=1, b=nested_batch)
|
|
expected = {"a": np.asanyarray(1), "b": nested_batch}
|
|
assert not DeepDiff(batch.to_dict(recursive=False), expected)
|
|
|
|
@staticmethod
|
|
def test_to_dict_nested_batch_recurse() -> None:
|
|
nested_batch = Batch(c=3)
|
|
batch = Batch(a=1, b=nested_batch)
|
|
expected = {"a": np.asanyarray(1), "b": {"c": np.asanyarray(3)}}
|
|
assert not DeepDiff(batch.to_dict(recursive=True), expected)
|
|
|
|
@staticmethod
|
|
def test_to_dict_multiple_nested_batch_recurse() -> None:
|
|
nested_batch = Batch(c=Batch(e=3), d=[100, 200, 300])
|
|
batch = Batch(a=1, b=nested_batch)
|
|
expected = {
|
|
"a": np.asanyarray(1),
|
|
"b": {"c": {"e": np.asanyarray(3)}, "d": np.array([100, 200, 300])},
|
|
}
|
|
assert not DeepDiff(batch.to_dict(recursive=True), expected)
|
|
|
|
@staticmethod
|
|
def test_to_dict_array() -> None:
|
|
batch = Batch(a=np.array([1, 2, 3]))
|
|
expected = {"a": np.array([1, 2, 3])}
|
|
assert not DeepDiff(batch.to_dict(), expected)
|
|
|
|
@staticmethod
|
|
def test_to_dict_nested_batch_with_array() -> None:
|
|
nested_batch = Batch(c=np.array([4, 5]))
|
|
batch = Batch(a=1, b=nested_batch)
|
|
expected = {"a": np.asanyarray(1), "b": {"c": np.array([4, 5])}}
|
|
assert not DeepDiff(batch.to_dict(recursive=True), expected)
|
|
|
|
@staticmethod
|
|
def test_to_dict_torch_tensor() -> None:
|
|
t1 = torch.tensor([1.0, 2.0]).detach().cpu().numpy()
|
|
batch = Batch(a=t1)
|
|
t2 = torch.tensor([1.0, 2.0]).detach().cpu().numpy()
|
|
expected = {"a": t2}
|
|
assert not DeepDiff(batch.to_dict(), expected)
|
|
|
|
@staticmethod
|
|
def test_to_dict_nested_batch_with_torch_tensor() -> None:
|
|
nested_batch = Batch(c=torch.tensor([4, 5]).detach().cpu().numpy())
|
|
batch = Batch(a=1, b=nested_batch)
|
|
expected = {"a": np.asanyarray(1), "b": {"c": torch.tensor([4, 5]).detach().cpu().numpy()}}
|
|
assert not DeepDiff(batch.to_dict(recursive=True), expected)
|
|
|
|
|
|
class TestToNumpy:
|
|
"""Tests for `Batch.to_numpy()` and its in-place counterpart `Batch.to_numpy_()` ."""
|
|
|
|
@staticmethod
|
|
def test_to_numpy() -> None:
|
|
batch = Batch(a=1, b=torch.arange(5), c={"d": torch.tensor([1, 2, 3])})
|
|
new_batch: Batch = Batch.to_numpy(batch)
|
|
assert id(batch) != id(new_batch)
|
|
assert isinstance(batch.b, torch.Tensor)
|
|
assert isinstance(batch.c.d, torch.Tensor)
|
|
|
|
assert isinstance(new_batch.b, np.ndarray)
|
|
assert isinstance(new_batch.c.d, np.ndarray)
|
|
|
|
@staticmethod
|
|
def test_to_numpy_() -> None:
|
|
batch = Batch(a=1, b=torch.arange(5), c={"d": torch.tensor([1, 2, 3])})
|
|
id_batch = id(batch)
|
|
batch.to_numpy_()
|
|
assert id_batch == id(batch)
|
|
assert isinstance(batch.b, np.ndarray)
|
|
assert isinstance(batch.c.d, np.ndarray)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
test_batch()
|
|
test_batch_over_batch()
|
|
test_batch_over_batch_to_torch()
|
|
test_utils_to_torch_numpy()
|
|
test_batch_pickle()
|
|
test_batch_from_to_numpy_without_copy()
|
|
test_batch_standard_compatibility()
|
|
test_batch_cat_and_stack()
|
|
test_batch_copy()
|
|
test_batch_empty()
|