fix append batch over batch

This commit is contained in:
Trinkle23897 2020-06-20 22:03:22 +08:00
parent 268f9d0533
commit aff0f9aee0
4 changed files with 18 additions and 0 deletions

BIN
docs/_static/images/concepts_arch2.png vendored Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 30 KiB

View File

@ -8,6 +8,12 @@ Tianshou splits a Reinforcement Learning agent training procedure into these par
:height: 300
Here is a more detailed description, where ``Env`` is the environment and ``Model`` is the neural network:
.. image:: /_static/images/concepts_arch2.png
:align: center
:height: 300
Data Batch
----------

View File

@ -34,6 +34,16 @@ def test_batch_over_batch():
print(batch2)
assert batch2.values()[-1] == batch2.c
assert batch2[-1].b.b == 0
batch2.append(Batch(c=[6, 7, 8], b=batch))
assert batch2.c == [6, 7, 8, 6, 7, 8]
assert batch2.b.a == [3, 4, 5, 3, 4, 5]
assert batch2.b.b == [4, 5, 0, 4, 5, 0]
d = {'a': [3, 4, 5], 'b': [4, 5, 6]}
batch3 = Batch(c=[6, 7, 8], b=d)
batch3.append(Batch(c=[6, 7, 8], b=d))
assert batch3.c == [6, 7, 8, 6, 7, 8]
assert batch3.b.a == [3, 4, 5, 3, 4, 5]
assert batch3.b.b == [4, 5, 6, 4, 5, 6]
def test_batch_over_batch_to_torch():

View File

@ -234,6 +234,8 @@ class Batch:
self.__dict__[k] = torch.cat([self.__dict__[k], v])
elif isinstance(v, list):
self.__dict__[k] += v
elif isinstance(v, Batch):
self.__dict__[k].append(v)
else:
s = f'No support for append with type \
{type(v)} in class Batch.'