From 0c7117dd557e6e4b0ad3f61bb42accaaa08b1660 Mon Sep 17 00:00:00 2001 From: n+e Date: Sat, 20 Mar 2021 21:46:36 +0800 Subject: [PATCH] fix concepts.rst with regard to new buffer behavior (#316) fix #315 --- docs/tutorials/concepts.rst | 58 ++++++++++++++++++++----------------- 1 file changed, 31 insertions(+), 27 deletions(-) diff --git a/docs/tutorials/concepts.rst b/docs/tutorials/concepts.rst index 888e7bf..65b5075 100644 --- a/docs/tutorials/concepts.rst +++ b/docs/tutorials/concepts.rst @@ -77,10 +77,10 @@ The following code snippet illustrates its usage, including: :: >>> import pickle, numpy as np - >>> from tianshou.data import ReplayBuffer + >>> from tianshou.data import Batch, ReplayBuffer >>> buf = ReplayBuffer(size=20) >>> for i in range(3): - ... buf.add(obs=i, act=i, rew=i, done=0, obs_next=i + 1, info={}) + ... buf.add(Batch(obs=i, act=i, rew=i, done=0, obs_next=i + 1, info={})) >>> buf.obs # since we set size = 20, len(buf.obs) == 20. @@ -96,7 +96,7 @@ The following code snippet illustrates its usage, including: >>> buf2 = ReplayBuffer(size=10) >>> for i in range(15): ... done = i % 4 == 0 - ... buf2.add(obs=i, act=i, rew=i, done=done, obs_next=i + 1, info={}) + ... buf2.add(Batch(obs=i, act=i, rew=i, done=done, obs_next=i + 1, info={})) >>> len(buf2) 10 >>> buf2.obs @@ -147,25 +147,26 @@ The following code snippet illustrates its usage, including: >>> buf = ReplayBuffer(size=9, stack_num=4, ignore_obs_next=True) >>> for i in range(16): ... done = i % 5 == 0 - ... ep_len, ep_rew = buf.add(obs={'id': i}, act=i, rew=i, - ... done=done, obs_next={'id': i + 1}) + ... ptr, ep_rew, ep_len, ep_idx = buf.add( + ... Batch(obs={'id': i}, act=i, rew=i, + ... done=done, obs_next={'id': i + 1})) ... print(i, ep_len, ep_rew) - 0 1 0.0 - 1 0 0.0 - 2 0 0.0 - 3 0 0.0 - 4 0 0.0 - 5 5 15.0 - 6 0 0.0 - 7 0 0.0 - 8 0 0.0 - 9 0 0.0 - 10 5 40.0 - 11 0 0.0 - 12 0 0.0 - 13 0 0.0 - 14 0 0.0 - 15 5 65.0 + 0 [1] [0.] + 1 [0] [0.] + 2 [0] [0.] + 3 [0] [0.] + 4 [0] [0.] + 5 [5] [15.] + 6 [0] [0.] + 7 [0] [0.] + 8 [0] [0.] + 9 [0] [0.] + 10 [5] [40.] + 11 [0] [0.] + 12 [0] [0.] + 13 [0] [0.] + 14 [0] [0.] + 15 [5] [65.] >>> print(buf) # you can see obs_next is not saved in buf ReplayBuffer( obs: Batch( @@ -175,8 +176,6 @@ The following code snippet illustrates its usage, including: rew: array([ 9., 10., 11., 12., 13., 14., 15., 7., 8.]), done: array([False, True, False, False, False, False, True, False, False]), - info: Batch(), - policy: Batch(), ) >>> index = np.arange(len(buf)) >>> print(buf.get(index, 'obs').id) @@ -194,16 +193,21 @@ The following code snippet illustrates its usage, including: >>> abs(buf.get(index, 'obs')['id'] - buf[index].obs.id).sum().sum() 0 >>> # we can get obs_next through __getitem__, even if it doesn't exist + >>> # however, [:] will select the item according to timestamp, + >>> # that equals to index == [7, 8, 0, 1, 2, 3, 4, 5, 6] >>> print(buf[:].obs_next.id) - [[ 7 8 9 10] + [[ 7 7 7 8] + [ 7 7 8 9] + [ 7 8 9 10] [ 7 8 9 10] [11 11 11 12] [11 11 12 13] [11 12 13 14] [12 13 14 15] - [12 13 14 15] - [ 7 7 7 8] - [ 7 7 8 9]] + [12 13 14 15]] + >>> full_index = np.array([7, 8, 0, 1, 2, 3, 4, 5, 6]) + >>> np.allclose(buf[:].obs_next.id, buf[full_index].obs_next.id) + True .. raw:: html