parent
ec23c7efe9
commit
0c7117dd55
@ -77,10 +77,10 @@ The following code snippet illustrates its usage, including:
|
|||||||
::
|
::
|
||||||
|
|
||||||
>>> import pickle, numpy as np
|
>>> import pickle, numpy as np
|
||||||
>>> from tianshou.data import ReplayBuffer
|
>>> from tianshou.data import Batch, ReplayBuffer
|
||||||
>>> buf = ReplayBuffer(size=20)
|
>>> buf = ReplayBuffer(size=20)
|
||||||
>>> for i in range(3):
|
>>> for i in range(3):
|
||||||
... buf.add(obs=i, act=i, rew=i, done=0, obs_next=i + 1, info={})
|
... buf.add(Batch(obs=i, act=i, rew=i, done=0, obs_next=i + 1, info={}))
|
||||||
|
|
||||||
>>> buf.obs
|
>>> buf.obs
|
||||||
# since we set size = 20, len(buf.obs) == 20.
|
# since we set size = 20, len(buf.obs) == 20.
|
||||||
@ -96,7 +96,7 @@ The following code snippet illustrates its usage, including:
|
|||||||
>>> buf2 = ReplayBuffer(size=10)
|
>>> buf2 = ReplayBuffer(size=10)
|
||||||
>>> for i in range(15):
|
>>> for i in range(15):
|
||||||
... done = i % 4 == 0
|
... done = i % 4 == 0
|
||||||
... buf2.add(obs=i, act=i, rew=i, done=done, obs_next=i + 1, info={})
|
... buf2.add(Batch(obs=i, act=i, rew=i, done=done, obs_next=i + 1, info={}))
|
||||||
>>> len(buf2)
|
>>> len(buf2)
|
||||||
10
|
10
|
||||||
>>> buf2.obs
|
>>> buf2.obs
|
||||||
@ -147,25 +147,26 @@ The following code snippet illustrates its usage, including:
|
|||||||
>>> buf = ReplayBuffer(size=9, stack_num=4, ignore_obs_next=True)
|
>>> buf = ReplayBuffer(size=9, stack_num=4, ignore_obs_next=True)
|
||||||
>>> for i in range(16):
|
>>> for i in range(16):
|
||||||
... done = i % 5 == 0
|
... done = i % 5 == 0
|
||||||
... ep_len, ep_rew = buf.add(obs={'id': i}, act=i, rew=i,
|
... ptr, ep_rew, ep_len, ep_idx = buf.add(
|
||||||
... done=done, obs_next={'id': i + 1})
|
... Batch(obs={'id': i}, act=i, rew=i,
|
||||||
|
... done=done, obs_next={'id': i + 1}))
|
||||||
... print(i, ep_len, ep_rew)
|
... print(i, ep_len, ep_rew)
|
||||||
0 1 0.0
|
0 [1] [0.]
|
||||||
1 0 0.0
|
1 [0] [0.]
|
||||||
2 0 0.0
|
2 [0] [0.]
|
||||||
3 0 0.0
|
3 [0] [0.]
|
||||||
4 0 0.0
|
4 [0] [0.]
|
||||||
5 5 15.0
|
5 [5] [15.]
|
||||||
6 0 0.0
|
6 [0] [0.]
|
||||||
7 0 0.0
|
7 [0] [0.]
|
||||||
8 0 0.0
|
8 [0] [0.]
|
||||||
9 0 0.0
|
9 [0] [0.]
|
||||||
10 5 40.0
|
10 [5] [40.]
|
||||||
11 0 0.0
|
11 [0] [0.]
|
||||||
12 0 0.0
|
12 [0] [0.]
|
||||||
13 0 0.0
|
13 [0] [0.]
|
||||||
14 0 0.0
|
14 [0] [0.]
|
||||||
15 5 65.0
|
15 [5] [65.]
|
||||||
>>> print(buf) # you can see obs_next is not saved in buf
|
>>> print(buf) # you can see obs_next is not saved in buf
|
||||||
ReplayBuffer(
|
ReplayBuffer(
|
||||||
obs: Batch(
|
obs: Batch(
|
||||||
@ -175,8 +176,6 @@ The following code snippet illustrates its usage, including:
|
|||||||
rew: array([ 9., 10., 11., 12., 13., 14., 15., 7., 8.]),
|
rew: array([ 9., 10., 11., 12., 13., 14., 15., 7., 8.]),
|
||||||
done: array([False, True, False, False, False, False, True, False,
|
done: array([False, True, False, False, False, False, True, False,
|
||||||
False]),
|
False]),
|
||||||
info: Batch(),
|
|
||||||
policy: Batch(),
|
|
||||||
)
|
)
|
||||||
>>> index = np.arange(len(buf))
|
>>> index = np.arange(len(buf))
|
||||||
>>> print(buf.get(index, 'obs').id)
|
>>> print(buf.get(index, 'obs').id)
|
||||||
@ -194,16 +193,21 @@ The following code snippet illustrates its usage, including:
|
|||||||
>>> abs(buf.get(index, 'obs')['id'] - buf[index].obs.id).sum().sum()
|
>>> abs(buf.get(index, 'obs')['id'] - buf[index].obs.id).sum().sum()
|
||||||
0
|
0
|
||||||
>>> # we can get obs_next through __getitem__, even if it doesn't exist
|
>>> # we can get obs_next through __getitem__, even if it doesn't exist
|
||||||
|
>>> # however, [:] will select the item according to timestamp,
|
||||||
|
>>> # that equals to index == [7, 8, 0, 1, 2, 3, 4, 5, 6]
|
||||||
>>> print(buf[:].obs_next.id)
|
>>> print(buf[:].obs_next.id)
|
||||||
[[ 7 8 9 10]
|
[[ 7 7 7 8]
|
||||||
|
[ 7 7 8 9]
|
||||||
|
[ 7 8 9 10]
|
||||||
[ 7 8 9 10]
|
[ 7 8 9 10]
|
||||||
[11 11 11 12]
|
[11 11 11 12]
|
||||||
[11 11 12 13]
|
[11 11 12 13]
|
||||||
[11 12 13 14]
|
[11 12 13 14]
|
||||||
[12 13 14 15]
|
[12 13 14 15]
|
||||||
[12 13 14 15]
|
[12 13 14 15]]
|
||||||
[ 7 7 7 8]
|
>>> full_index = np.array([7, 8, 0, 1, 2, 3, 4, 5, 6])
|
||||||
[ 7 7 8 9]]
|
>>> np.allclose(buf[:].obs_next.id, buf[full_index].obs_next.id)
|
||||||
|
True
|
||||||
|
|
||||||
.. raw:: html
|
.. raw:: html
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user