* Enable to convert Batch data back to torch. * Add torch converter to collector. * Fix * Move to_numpy/to_torch convert in dedicated utils.py. * Use to_numpy/to_torch to convert arrays. * fix lint * fix * Add unit test to check Batch from/to numpy. * Fix Batch over Batch. Co-authored-by: Alexis Duburcq <alexis.duburcq@wandercraft.eu>
50 lines
1.4 KiB
Python
50 lines
1.4 KiB
Python
import pytest
|
|
import torch
|
|
import numpy as np
|
|
|
|
from tianshou.data import Batch
|
|
|
|
|
|
def test_batch():
|
|
batch = Batch(obs=[0], np=np.zeros([3, 4]))
|
|
batch.obs = [1]
|
|
assert batch.obs == [1]
|
|
batch.append(batch)
|
|
assert batch.obs == [1, 1]
|
|
assert batch.np.shape == (6, 4)
|
|
assert batch[0].obs == batch[1].obs
|
|
with pytest.raises(IndexError):
|
|
batch[2]
|
|
batch.obs = np.arange(5)
|
|
for i, b in enumerate(batch.split(1, shuffle=False)):
|
|
assert b.obs == batch[i].obs
|
|
print(batch)
|
|
|
|
|
|
def test_batch_over_batch():
|
|
batch = Batch(a=[3, 4, 5], b=[4, 5, 6])
|
|
batch2 = Batch(c=[6, 7, 8], b=batch)
|
|
batch2.b.b[-1] = 0
|
|
print(batch2)
|
|
assert batch2.values()[-1] == batch2.c
|
|
assert batch2[-1].b.b == 0
|
|
|
|
|
|
def test_batch_from_to_numpy_without_copy():
|
|
batch = Batch(a=np.ones((1,)), b=Batch(c=np.ones((1,))))
|
|
a_mem_addr_orig = batch["a"].__array_interface__['data'][0]
|
|
c_mem_addr_orig = batch["b"]["c"].__array_interface__['data'][0]
|
|
batch.to_torch()
|
|
assert isinstance(batch["a"], torch.Tensor)
|
|
assert isinstance(batch["b"]["c"], torch.Tensor)
|
|
batch.to_numpy()
|
|
a_mem_addr_new = batch["a"].__array_interface__['data'][0]
|
|
c_mem_addr_new = batch["b"]["c"].__array_interface__['data'][0]
|
|
assert a_mem_addr_new == a_mem_addr_orig
|
|
assert c_mem_addr_new == c_mem_addr_orig
|
|
|
|
|
|
if __name__ == '__main__':
|
|
test_batch()
|
|
test_batch_over_batch()
|