Tianshou/test/base/test_batch.py

211 lines
7.3 KiB
Python
Raw Normal View History

import torch
import copy
2020-06-01 08:30:09 +08:00
import pickle
import pytest
2020-03-13 17:49:22 +08:00
import numpy as np
from tianshou.data import Batch, to_torch
2020-03-13 17:49:22 +08:00
def test_batch():
batch = Batch(obs=[0], np=np.zeros([3, 4]))
assert batch.obs == batch["obs"]
2020-04-03 21:28:12 +08:00
batch.obs = [1]
2020-03-13 17:49:22 +08:00
assert batch.obs == [1]
batch.cat_(batch)
2020-03-13 17:49:22 +08:00
assert batch.obs == [1, 1]
assert batch.np.shape == (6, 4)
2020-03-16 11:11:29 +08:00
assert batch[0].obs == batch[1].obs
2020-03-17 11:37:31 +08:00
batch.obs = np.arange(5)
2020-04-28 20:56:02 +08:00
for i, b in enumerate(batch.split(1, shuffle=False)):
2020-06-01 08:30:09 +08:00
if i != 5:
assert b.obs == batch[i].obs
else:
with pytest.raises(AttributeError):
batch[i].obs
with pytest.raises(AttributeError):
b.obs
print(batch)
batch_dict = {'b': np.array([1.0]), 'c': 2.0, 'd': torch.Tensor([3.0])}
batch_item = Batch({'a': [batch_dict]})[0]
assert isinstance(batch_item.a.b, np.ndarray)
assert batch_item.a.b == batch_dict['b']
assert isinstance(batch_item.a.c, float)
assert batch_item.a.c == batch_dict['c']
assert isinstance(batch_item.a.d, torch.Tensor)
assert batch_item.a.d == batch_dict['d']
batch2 = Batch(a=[{
'b': np.float64(1.0),
'c': np.zeros(1),
'd': Batch(e=np.array(3.0))}])
assert len(batch2) == 1
assert Batch().size == 0
assert batch2.size == 1
with pytest.raises(IndexError):
batch2[-2]
with pytest.raises(IndexError):
batch2[1]
assert batch2[0].size == 1
with pytest.raises(TypeError):
batch2[0][0]
with pytest.raises(TypeError):
len(batch2[0])
assert isinstance(batch2[0].a.c, np.ndarray)
assert isinstance(batch2[0].a.b, np.float64)
assert isinstance(batch2[0].a.d.e, np.float64)
batch2_from_list = Batch(list(batch2))
batch2_from_comp = Batch([e for e in batch2])
assert batch2_from_list.a.b == batch2.a.b
assert batch2_from_list.a.c == batch2.a.c
assert batch2_from_list.a.d.e == batch2.a.d.e
assert batch2_from_comp.a.b == batch2.a.b
assert batch2_from_comp.a.c == batch2.a.c
assert batch2_from_comp.a.d.e == batch2.a.d.e
for batch_slice in [
batch2[slice(0, 1)], batch2[:1], batch2[0:]]:
assert batch_slice.a.b == batch2.a.b
assert batch_slice.a.c == batch2.a.c
assert batch_slice.a.d.e == batch2.a.d.e
batch2_sum = (batch2 + 1.0) * 2
assert batch2_sum.a.b == (batch2.a.b + 1.0) * 2
assert batch2_sum.a.c == (batch2.a.c + 1.0) * 2
assert batch2_sum.a.d.e == (batch2.a.d.e + 1.0) * 2
batch3 = Batch(a={
'c': np.zeros(1),
'd': Batch(e=np.array([0.0]), f=np.array([3.0]))})
batch3.a.d[0] = {'e': 4.0}
assert batch3.a.d.e[0] == 4.0
batch3.a.d[0] = Batch(f=5.0)
assert batch3.a.d.f[0] == 5.0
with pytest.raises(KeyError):
batch3.a.d[0] = Batch(f=5.0, g=0.0)
2020-03-13 17:49:22 +08:00
2020-05-27 11:02:23 +08:00
def test_batch_over_batch():
batch = Batch(a=[3, 4, 5], b=[4, 5, 6])
batch2 = Batch({'c': [6, 7, 8], 'b': batch})
2020-05-27 11:02:23 +08:00
batch2.b.b[-1] = 0
print(batch2)
for k, v in batch2.items():
assert batch2[k] == v
2020-05-27 11:02:23 +08:00
assert batch2[-1].b.b == 0
batch2.cat_(Batch(c=[6, 7, 8], b=batch))
2020-06-20 22:03:22 +08:00
assert batch2.c == [6, 7, 8, 6, 7, 8]
assert batch2.b.a == [3, 4, 5, 3, 4, 5]
assert batch2.b.b == [4, 5, 0, 4, 5, 0]
d = {'a': [3, 4, 5], 'b': [4, 5, 6]}
batch3 = Batch(c=[6, 7, 8], b=d)
batch3.cat_(Batch(c=[6, 7, 8], b=d))
2020-06-20 22:03:22 +08:00
assert batch3.c == [6, 7, 8, 6, 7, 8]
assert batch3.b.a == [3, 4, 5, 3, 4, 5]
assert batch3.b.b == [4, 5, 6, 4, 5, 6]
2020-06-23 17:37:26 +02:00
batch4 = Batch(({'a': {'b': np.array([1.0])}},))
assert batch4.a.b.ndim == 2
assert batch4.a.b[0, 0] == 1.0
2020-05-27 11:02:23 +08:00
def test_batch_cat_and_stack():
b1 = Batch(a=[{'b': np.float64(1.0), 'd': Batch(e=np.array(3.0))}])
b2 = Batch(a=[{'b': np.float64(4.0), 'd': {'e': np.array(6.0)}}])
b12_cat_out = Batch.cat((b1, b2))
b12_cat_in = copy.deepcopy(b1)
b12_cat_in.cat_(b2)
assert np.all(b12_cat_in.a.d.e == b12_cat_out.a.d.e)
assert np.all(b12_cat_in.a.d.e == b12_cat_out.a.d.e)
assert isinstance(b12_cat_in.a.d.e, np.ndarray)
assert b12_cat_in.a.d.e.ndim == 1
b12_stack = Batch.stack((b1, b2))
assert isinstance(b12_stack.a.d.e, np.ndarray)
assert b12_stack.a.d.e.ndim == 2
b3 = Batch(a=np.zeros((3, 4)),
b=torch.ones((2, 5)),
c=Batch(d=[[1], [2]]))
b4 = Batch(a=np.ones((3, 4)),
b=torch.ones((2, 5)),
c=Batch(d=[[0], [3]]))
b34_stack = Batch.stack((b3, b4), axis=1)
assert np.all(b34_stack.a == np.stack((b3.a, b4.a), axis=1))
assert np.all(b34_stack.c.d == list(map(list, zip(b3.c.d, b4.c.d))))
def test_batch_over_batch_to_torch():
batch = Batch(
a=np.ones((1,), dtype=np.float64),
b=Batch(
c=np.ones((1,), dtype=np.float64),
d=torch.ones((1,), dtype=torch.float64)
)
)
batch.to_torch()
assert isinstance(batch.a, torch.Tensor)
assert isinstance(batch.b.c, torch.Tensor)
assert isinstance(batch.b.d, torch.Tensor)
assert batch.a.dtype == torch.float64
assert batch.b.c.dtype == torch.float64
assert batch.b.d.dtype == torch.float64
batch.to_torch(dtype=torch.float32)
assert batch.a.dtype == torch.float32
assert batch.b.c.dtype == torch.float32
assert batch.b.d.dtype == torch.float32
def test_utils_to_torch():
batch = Batch(
a=np.ones((1,), dtype=np.float64),
b=Batch(
c=np.ones((1,), dtype=np.float64),
d=torch.ones((1,), dtype=torch.float64)
)
)
a_torch_float = to_torch(batch.a, dtype=torch.float32)
assert a_torch_float.dtype == torch.float32
a_torch_double = to_torch(batch.a, dtype=torch.float64)
assert a_torch_double.dtype == torch.float64
batch_torch_float = to_torch(batch, dtype=torch.float32)
assert batch_torch_float.a.dtype == torch.float32
assert batch_torch_float.b.c.dtype == torch.float32
assert batch_torch_float.b.d.dtype == torch.float32
def test_batch_pickle():
batch = Batch(obs=Batch(a=0.0, c=torch.Tensor([1.0, 2.0])),
np=np.zeros([3, 4]))
batch_pk = pickle.loads(pickle.dumps(batch))
assert batch.obs.a == batch_pk.obs.a
assert torch.all(batch.obs.c == batch_pk.obs.c)
assert np.all(batch.np == batch_pk.np)
def test_batch_from_to_numpy_without_copy():
batch = Batch(a=np.ones((1,)), b=Batch(c=np.ones((1,))))
a_mem_addr_orig = batch.a.__array_interface__['data'][0]
c_mem_addr_orig = batch.b.c.__array_interface__['data'][0]
batch.to_torch()
batch.to_numpy()
a_mem_addr_new = batch.a.__array_interface__['data'][0]
c_mem_addr_new = batch.b.c.__array_interface__['data'][0]
assert a_mem_addr_new == a_mem_addr_orig
assert c_mem_addr_new == c_mem_addr_orig
def test_batch_numpy_compatibility():
batch = Batch(a=np.array([[1.0, 2.0], [3.0, 4.0]]),
b=Batch(),
c=np.array([5.0, 6.0]))
batch_mean = np.mean(batch)
assert isinstance(batch_mean, Batch)
assert sorted(batch_mean.keys()) == ['a', 'b', 'c']
with pytest.raises(TypeError):
len(batch_mean)
assert np.all(batch_mean.a == np.mean(batch.a, axis=0))
assert batch_mean.c == np.mean(batch.c, axis=0)
2020-03-13 17:49:22 +08:00
if __name__ == '__main__':
test_batch()
2020-05-27 11:02:23 +08:00
test_batch_over_batch()
2020-06-01 08:30:09 +08:00
test_batch_over_batch_to_torch()
test_utils_to_torch()
test_batch_pickle()
test_batch_from_to_numpy_without_copy()