2020-03-16 11:11:29 +08:00
|
|
|
import pytest
|
2020-03-13 17:49:22 +08:00
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
from tianshou.data import Batch
|
|
|
|
|
|
|
|
|
|
|
|
def test_batch():
|
|
|
|
batch = Batch(obs=[0], np=np.zeros([3, 4]))
|
|
|
|
batch.update(obs=[1])
|
|
|
|
assert batch.obs == [1]
|
|
|
|
batch.append(batch)
|
|
|
|
assert batch.obs == [1, 1]
|
|
|
|
assert batch.np.shape == (6, 4)
|
2020-03-16 11:11:29 +08:00
|
|
|
assert batch[0].obs == batch[1].obs
|
|
|
|
with pytest.raises(IndexError):
|
|
|
|
batch[2]
|
2020-03-17 11:37:31 +08:00
|
|
|
batch.obs = np.arange(5)
|
2020-03-20 19:52:29 +08:00
|
|
|
for i, b in enumerate(batch.split(1, permute=False)):
|
2020-03-17 11:37:31 +08:00
|
|
|
assert b.obs == batch[i].obs
|
2020-03-13 17:49:22 +08:00
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
test_batch()
|