Tianshou/test/test_batch.py

17 lines
311 B
Python
Raw Normal View History

2020-03-13 17:49:22 +08:00
import numpy as np
from tianshou.data import Batch
def test_batch():
batch = Batch(obs=[0], np=np.zeros([3, 4]))
batch.update(obs=[1])
assert batch.obs == [1]
batch.append(batch)
assert batch.obs == [1, 1]
assert batch.np.shape == (6, 4)
if __name__ == '__main__':
test_batch()