2020-05-29 14:45:21 +02:00
|
|
|
import torch
|
2020-06-01 08:30:09 +08:00
|
|
|
import pickle
|
|
|
|
import pytest
|
2020-03-13 17:49:22 +08:00
|
|
|
import numpy as np
|
|
|
|
|
2020-05-30 15:40:31 +02:00
|
|
|
from tianshou.data import Batch, to_torch
|
2020-03-13 17:49:22 +08:00
|
|
|
|
|
|
|
|
|
|
|
def test_batch():
|
|
|
|
batch = Batch(obs=[0], np=np.zeros([3, 4]))
|
2020-05-30 15:40:31 +02:00
|
|
|
assert batch.obs == batch["obs"]
|
2020-04-03 21:28:12 +08:00
|
|
|
batch.obs = [1]
|
2020-03-13 17:49:22 +08:00
|
|
|
assert batch.obs == [1]
|
|
|
|
batch.append(batch)
|
|
|
|
assert batch.obs == [1, 1]
|
|
|
|
assert batch.np.shape == (6, 4)
|
2020-03-16 11:11:29 +08:00
|
|
|
assert batch[0].obs == batch[1].obs
|
2020-03-17 11:37:31 +08:00
|
|
|
batch.obs = np.arange(5)
|
2020-04-28 20:56:02 +08:00
|
|
|
for i, b in enumerate(batch.split(1, shuffle=False)):
|
2020-06-01 08:30:09 +08:00
|
|
|
if i != 5:
|
|
|
|
assert b.obs == batch[i].obs
|
|
|
|
else:
|
|
|
|
with pytest.raises(AttributeError):
|
|
|
|
batch[i].obs
|
|
|
|
with pytest.raises(AttributeError):
|
|
|
|
b.obs
|
2020-04-09 19:53:45 +08:00
|
|
|
print(batch)
|
2020-03-13 17:49:22 +08:00
|
|
|
|
|
|
|
|
2020-05-27 11:02:23 +08:00
|
|
|
def test_batch_over_batch():
|
|
|
|
batch = Batch(a=[3, 4, 5], b=[4, 5, 6])
|
2020-05-29 08:03:37 +08:00
|
|
|
batch2 = Batch(c=[6, 7, 8], b=batch)
|
2020-05-27 11:02:23 +08:00
|
|
|
batch2.b.b[-1] = 0
|
|
|
|
print(batch2)
|
2020-05-29 08:03:37 +08:00
|
|
|
assert batch2.values()[-1] == batch2.c
|
2020-05-27 11:02:23 +08:00
|
|
|
assert batch2[-1].b.b == 0
|
|
|
|
|
|
|
|
|
2020-05-30 15:40:31 +02:00
|
|
|
def test_batch_over_batch_to_torch():
|
|
|
|
batch = Batch(
|
|
|
|
a=np.ones((1,), dtype=np.float64),
|
|
|
|
b=Batch(
|
|
|
|
c=np.ones((1,), dtype=np.float64),
|
|
|
|
d=torch.ones((1,), dtype=torch.float64)
|
|
|
|
)
|
|
|
|
)
|
|
|
|
batch.to_torch()
|
|
|
|
assert isinstance(batch.a, torch.Tensor)
|
|
|
|
assert isinstance(batch.b.c, torch.Tensor)
|
|
|
|
assert isinstance(batch.b.d, torch.Tensor)
|
|
|
|
assert batch.a.dtype == torch.float64
|
|
|
|
assert batch.b.c.dtype == torch.float64
|
|
|
|
assert batch.b.d.dtype == torch.float64
|
|
|
|
batch.to_torch(dtype=torch.float32)
|
|
|
|
assert batch.a.dtype == torch.float32
|
|
|
|
assert batch.b.c.dtype == torch.float32
|
|
|
|
assert batch.b.d.dtype == torch.float32
|
|
|
|
|
|
|
|
|
|
|
|
def test_utils_to_torch():
|
|
|
|
batch = Batch(
|
|
|
|
a=np.ones((1,), dtype=np.float64),
|
|
|
|
b=Batch(
|
|
|
|
c=np.ones((1,), dtype=np.float64),
|
|
|
|
d=torch.ones((1,), dtype=torch.float64)
|
|
|
|
)
|
|
|
|
)
|
|
|
|
a_torch_float = to_torch(batch.a, dtype=torch.float32)
|
|
|
|
assert a_torch_float.dtype == torch.float32
|
|
|
|
a_torch_double = to_torch(batch.a, dtype=torch.float64)
|
|
|
|
assert a_torch_double.dtype == torch.float64
|
|
|
|
batch_torch_float = to_torch(batch, dtype=torch.float32)
|
|
|
|
assert batch_torch_float.a.dtype == torch.float32
|
|
|
|
assert batch_torch_float.b.c.dtype == torch.float32
|
|
|
|
assert batch_torch_float.b.d.dtype == torch.float32
|
|
|
|
|
|
|
|
|
2020-05-30 15:29:33 +02:00
|
|
|
def test_batch_pickle():
|
|
|
|
batch = Batch(obs=Batch(a=0.0, c=torch.Tensor([1.0, 2.0])),
|
|
|
|
np=np.zeros([3, 4]))
|
|
|
|
batch_pk = pickle.loads(pickle.dumps(batch))
|
|
|
|
assert batch.obs.a == batch_pk.obs.a
|
|
|
|
assert torch.all(batch.obs.c == batch_pk.obs.c)
|
|
|
|
assert np.all(batch.np == batch_pk.np)
|
|
|
|
|
|
|
|
|
2020-05-29 14:45:21 +02:00
|
|
|
def test_batch_from_to_numpy_without_copy():
|
|
|
|
batch = Batch(a=np.ones((1,)), b=Batch(c=np.ones((1,))))
|
2020-05-30 15:40:31 +02:00
|
|
|
a_mem_addr_orig = batch.a.__array_interface__['data'][0]
|
|
|
|
c_mem_addr_orig = batch.b.c.__array_interface__['data'][0]
|
2020-05-29 14:45:21 +02:00
|
|
|
batch.to_torch()
|
|
|
|
batch.to_numpy()
|
2020-05-30 15:40:31 +02:00
|
|
|
a_mem_addr_new = batch.a.__array_interface__['data'][0]
|
|
|
|
c_mem_addr_new = batch.b.c.__array_interface__['data'][0]
|
2020-05-29 14:45:21 +02:00
|
|
|
assert a_mem_addr_new == a_mem_addr_orig
|
|
|
|
assert c_mem_addr_new == c_mem_addr_orig
|
|
|
|
|
|
|
|
|
2020-03-13 17:49:22 +08:00
|
|
|
if __name__ == '__main__':
|
|
|
|
test_batch()
|
2020-05-27 11:02:23 +08:00
|
|
|
test_batch_over_batch()
|
2020-06-01 08:30:09 +08:00
|
|
|
test_batch_over_batch_to_torch()
|
|
|
|
test_utils_to_torch()
|
|
|
|
test_batch_pickle()
|
|
|
|
test_batch_from_to_numpy_without_copy()
|