parent
ec23c7efe9
commit
0c7117dd55
@ -77,10 +77,10 @@ The following code snippet illustrates its usage, including:
|
||||
::
|
||||
|
||||
>>> import pickle, numpy as np
|
||||
>>> from tianshou.data import ReplayBuffer
|
||||
>>> from tianshou.data import Batch, ReplayBuffer
|
||||
>>> buf = ReplayBuffer(size=20)
|
||||
>>> for i in range(3):
|
||||
... buf.add(obs=i, act=i, rew=i, done=0, obs_next=i + 1, info={})
|
||||
... buf.add(Batch(obs=i, act=i, rew=i, done=0, obs_next=i + 1, info={}))
|
||||
|
||||
>>> buf.obs
|
||||
# since we set size = 20, len(buf.obs) == 20.
|
||||
@ -96,7 +96,7 @@ The following code snippet illustrates its usage, including:
|
||||
>>> buf2 = ReplayBuffer(size=10)
|
||||
>>> for i in range(15):
|
||||
... done = i % 4 == 0
|
||||
... buf2.add(obs=i, act=i, rew=i, done=done, obs_next=i + 1, info={})
|
||||
... buf2.add(Batch(obs=i, act=i, rew=i, done=done, obs_next=i + 1, info={}))
|
||||
>>> len(buf2)
|
||||
10
|
||||
>>> buf2.obs
|
||||
@ -147,25 +147,26 @@ The following code snippet illustrates its usage, including:
|
||||
>>> buf = ReplayBuffer(size=9, stack_num=4, ignore_obs_next=True)
|
||||
>>> for i in range(16):
|
||||
... done = i % 5 == 0
|
||||
... ep_len, ep_rew = buf.add(obs={'id': i}, act=i, rew=i,
|
||||
... done=done, obs_next={'id': i + 1})
|
||||
... ptr, ep_rew, ep_len, ep_idx = buf.add(
|
||||
... Batch(obs={'id': i}, act=i, rew=i,
|
||||
... done=done, obs_next={'id': i + 1}))
|
||||
... print(i, ep_len, ep_rew)
|
||||
0 1 0.0
|
||||
1 0 0.0
|
||||
2 0 0.0
|
||||
3 0 0.0
|
||||
4 0 0.0
|
||||
5 5 15.0
|
||||
6 0 0.0
|
||||
7 0 0.0
|
||||
8 0 0.0
|
||||
9 0 0.0
|
||||
10 5 40.0
|
||||
11 0 0.0
|
||||
12 0 0.0
|
||||
13 0 0.0
|
||||
14 0 0.0
|
||||
15 5 65.0
|
||||
0 [1] [0.]
|
||||
1 [0] [0.]
|
||||
2 [0] [0.]
|
||||
3 [0] [0.]
|
||||
4 [0] [0.]
|
||||
5 [5] [15.]
|
||||
6 [0] [0.]
|
||||
7 [0] [0.]
|
||||
8 [0] [0.]
|
||||
9 [0] [0.]
|
||||
10 [5] [40.]
|
||||
11 [0] [0.]
|
||||
12 [0] [0.]
|
||||
13 [0] [0.]
|
||||
14 [0] [0.]
|
||||
15 [5] [65.]
|
||||
>>> print(buf) # you can see obs_next is not saved in buf
|
||||
ReplayBuffer(
|
||||
obs: Batch(
|
||||
@ -175,8 +176,6 @@ The following code snippet illustrates its usage, including:
|
||||
rew: array([ 9., 10., 11., 12., 13., 14., 15., 7., 8.]),
|
||||
done: array([False, True, False, False, False, False, True, False,
|
||||
False]),
|
||||
info: Batch(),
|
||||
policy: Batch(),
|
||||
)
|
||||
>>> index = np.arange(len(buf))
|
||||
>>> print(buf.get(index, 'obs').id)
|
||||
@ -194,16 +193,21 @@ The following code snippet illustrates its usage, including:
|
||||
>>> abs(buf.get(index, 'obs')['id'] - buf[index].obs.id).sum().sum()
|
||||
0
|
||||
>>> # we can get obs_next through __getitem__, even if it doesn't exist
|
||||
>>> # however, [:] will select the item according to timestamp,
|
||||
>>> # that equals to index == [7, 8, 0, 1, 2, 3, 4, 5, 6]
|
||||
>>> print(buf[:].obs_next.id)
|
||||
[[ 7 8 9 10]
|
||||
[[ 7 7 7 8]
|
||||
[ 7 7 8 9]
|
||||
[ 7 8 9 10]
|
||||
[ 7 8 9 10]
|
||||
[11 11 11 12]
|
||||
[11 11 12 13]
|
||||
[11 12 13 14]
|
||||
[12 13 14 15]
|
||||
[12 13 14 15]
|
||||
[ 7 7 7 8]
|
||||
[ 7 7 8 9]]
|
||||
[12 13 14 15]]
|
||||
>>> full_index = np.array([7, 8, 0, 1, 2, 3, 4, 5, 6])
|
||||
>>> np.allclose(buf[:].obs_next.id, buf[full_index].obs_next.id)
|
||||
True
|
||||
|
||||
.. raw:: html
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user