34 lines
769 B
Python
Raw Normal View History

2020-03-16 11:11:29 +08:00
import pytest
2020-03-13 17:49:22 +08:00
import numpy as np
from tianshou.data import Batch
def test_batch():
batch = Batch(obs=[0], np=np.zeros([3, 4]))
2020-04-03 21:28:12 +08:00
batch.obs = [1]
2020-03-13 17:49:22 +08:00
assert batch.obs == [1]
batch.append(batch)
assert batch.obs == [1, 1]
assert batch.np.shape == (6, 4)
2020-03-16 11:11:29 +08:00
assert batch[0].obs == batch[1].obs
with pytest.raises(IndexError):
batch[2]
2020-03-17 11:37:31 +08:00
batch.obs = np.arange(5)
2020-04-28 20:56:02 +08:00
for i, b in enumerate(batch.split(1, shuffle=False)):
2020-03-17 11:37:31 +08:00
assert b.obs == batch[i].obs
print(batch)
2020-03-13 17:49:22 +08:00
2020-05-27 11:02:23 +08:00
def test_batch_over_batch():
batch = Batch(a=[3, 4, 5], b=[4, 5, 6])
batch2 = Batch(b=batch, c=[6, 7, 8])
batch2.b.b[-1] = 0
print(batch2)
assert batch2[-1].b.b == 0
2020-03-13 17:49:22 +08:00
if __name__ == '__main__':
test_batch()
2020-05-27 11:02:23 +08:00
test_batch_over_batch()