627 lines
		
	
	
		
			18 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			627 lines
		
	
	
		
			18 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import numpy as np
 | |
| import pytest
 | |
| import tqdm
 | |
| from torch.utils.tensorboard import SummaryWriter
 | |
| 
 | |
| from tianshou.data import (
 | |
|     AsyncCollector,
 | |
|     Batch,
 | |
|     CachedReplayBuffer,
 | |
|     Collector,
 | |
|     PrioritizedReplayBuffer,
 | |
|     ReplayBuffer,
 | |
|     VectorReplayBuffer,
 | |
| )
 | |
| from tianshou.env import DummyVectorEnv, SubprocVectorEnv
 | |
| from tianshou.policy import BasePolicy
 | |
| 
 | |
| if __name__ == '__main__':
 | |
|     from env import MyTestEnv, NXEnv
 | |
| else:  # pytest
 | |
|     from test.base.env import MyTestEnv, NXEnv
 | |
| 
 | |
| 
 | |
| class MyPolicy(BasePolicy):
 | |
| 
 | |
|     def __init__(self, dict_state=False, need_state=True):
 | |
|         """
 | |
|         :param bool dict_state: if the observation of the environment is a dict
 | |
|         :param bool need_state: if the policy needs the hidden state (for RNN)
 | |
|         """
 | |
|         super().__init__()
 | |
|         self.dict_state = dict_state
 | |
|         self.need_state = need_state
 | |
| 
 | |
|     def forward(self, batch, state=None):
 | |
|         if self.need_state:
 | |
|             if state is None:
 | |
|                 state = np.zeros((len(batch.obs), 2))
 | |
|             else:
 | |
|                 state += 1
 | |
|         if self.dict_state:
 | |
|             return Batch(act=np.ones(len(batch.obs['index'])), state=state)
 | |
|         return Batch(act=np.ones(len(batch.obs)), state=state)
 | |
| 
 | |
|     def learn(self):
 | |
|         pass
 | |
| 
 | |
| 
 | |
| class Logger:
 | |
| 
 | |
|     def __init__(self, writer):
 | |
|         self.cnt = 0
 | |
|         self.writer = writer
 | |
| 
 | |
|     def preprocess_fn(self, **kwargs):
 | |
|         # modify info before adding into the buffer, and recorded into tfb
 | |
|         # if obs && env_id exist -> reset
 | |
|         # if obs_next/rew/done/info/env_id exist -> normal step
 | |
|         if 'rew' in kwargs:
 | |
|             info = kwargs['info']
 | |
|             info.rew = kwargs['rew']
 | |
|             if 'key' in info.keys():
 | |
|                 self.writer.add_scalar('key', np.mean(info.key), global_step=self.cnt)
 | |
|             self.cnt += 1
 | |
|             return Batch(info=info)
 | |
|         else:
 | |
|             return Batch()
 | |
| 
 | |
|     @staticmethod
 | |
|     def single_preprocess_fn(**kwargs):
 | |
|         # same as above, without tfb
 | |
|         if 'rew' in kwargs:
 | |
|             info = kwargs['info']
 | |
|             info.rew = kwargs['rew']
 | |
|             return Batch(info=info)
 | |
|         else:
 | |
|             return Batch()
 | |
| 
 | |
| 
 | |
| def test_collector():
 | |
|     writer = SummaryWriter('log/collector')
 | |
|     logger = Logger(writer)
 | |
|     env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0) for i in [2, 3, 4, 5]]
 | |
| 
 | |
|     venv = SubprocVectorEnv(env_fns)
 | |
|     dum = DummyVectorEnv(env_fns)
 | |
|     policy = MyPolicy()
 | |
|     env = env_fns[0]()
 | |
|     c0 = Collector(policy, env, ReplayBuffer(size=100), logger.preprocess_fn)
 | |
|     c0.collect(n_step=3)
 | |
|     assert len(c0.buffer) == 3
 | |
|     assert np.allclose(c0.buffer.obs[:4, 0], [0, 1, 0, 0])
 | |
|     assert np.allclose(c0.buffer[:].obs_next[..., 0], [1, 2, 1])
 | |
|     c0.collect(n_episode=3)
 | |
|     assert len(c0.buffer) == 8
 | |
|     assert np.allclose(c0.buffer.obs[:10, 0], [0, 1, 0, 1, 0, 1, 0, 1, 0, 0])
 | |
|     assert np.allclose(c0.buffer[:].obs_next[..., 0], [1, 2, 1, 2, 1, 2, 1, 2])
 | |
|     c0.collect(n_step=3, random=True)
 | |
|     c1 = Collector(
 | |
|         policy, venv, VectorReplayBuffer(total_size=100, buffer_num=4),
 | |
|         logger.preprocess_fn
 | |
|     )
 | |
|     c1.collect(n_step=8)
 | |
|     obs = np.zeros(100)
 | |
|     obs[[0, 1, 25, 26, 50, 51, 75, 76]] = [0, 1, 0, 1, 0, 1, 0, 1]
 | |
| 
 | |
|     assert np.allclose(c1.buffer.obs[:, 0], obs)
 | |
|     assert np.allclose(c1.buffer[:].obs_next[..., 0], [1, 2, 1, 2, 1, 2, 1, 2])
 | |
|     c1.collect(n_episode=4)
 | |
|     assert len(c1.buffer) == 16
 | |
|     obs[[2, 3, 27, 52, 53, 77, 78, 79]] = [0, 1, 2, 2, 3, 2, 3, 4]
 | |
|     assert np.allclose(c1.buffer.obs[:, 0], obs)
 | |
|     assert np.allclose(
 | |
|         c1.buffer[:].obs_next[..., 0],
 | |
|         [1, 2, 1, 2, 1, 2, 3, 1, 2, 3, 4, 1, 2, 3, 4, 5]
 | |
|     )
 | |
|     c1.collect(n_episode=4, random=True)
 | |
|     c2 = Collector(
 | |
|         policy, dum, VectorReplayBuffer(total_size=100, buffer_num=4),
 | |
|         logger.preprocess_fn
 | |
|     )
 | |
|     c2.collect(n_episode=7)
 | |
|     obs1 = obs.copy()
 | |
|     obs1[[4, 5, 28, 29, 30]] = [0, 1, 0, 1, 2]
 | |
|     obs2 = obs.copy()
 | |
|     obs2[[28, 29, 30, 54, 55, 56, 57]] = [0, 1, 2, 0, 1, 2, 3]
 | |
|     c2obs = c2.buffer.obs[:, 0]
 | |
|     assert np.all(c2obs == obs1) or np.all(c2obs == obs2)
 | |
|     c2.reset_env()
 | |
|     c2.reset_buffer()
 | |
|     assert c2.collect(n_episode=8)['n/ep'] == 8
 | |
|     obs[[4, 5, 28, 29, 30, 54, 55, 56, 57]] = [0, 1, 0, 1, 2, 0, 1, 2, 3]
 | |
|     assert np.all(c2.buffer.obs[:, 0] == obs)
 | |
|     c2.collect(n_episode=4, random=True)
 | |
| 
 | |
|     # test corner case
 | |
|     with pytest.raises(TypeError):
 | |
|         Collector(policy, dum, ReplayBuffer(10))
 | |
|     with pytest.raises(TypeError):
 | |
|         Collector(policy, dum, PrioritizedReplayBuffer(10, 0.5, 0.5))
 | |
|     with pytest.raises(TypeError):
 | |
|         c2.collect()
 | |
| 
 | |
|     # test NXEnv
 | |
|     for obs_type in ["array", "object"]:
 | |
|         envs = SubprocVectorEnv(
 | |
|             [lambda i=x: NXEnv(i, obs_type) for x in [5, 10, 15, 20]]
 | |
|         )
 | |
|         c3 = Collector(policy, envs, VectorReplayBuffer(total_size=100, buffer_num=4))
 | |
|         c3.collect(n_step=6)
 | |
|         assert c3.buffer.obs.dtype == object
 | |
| 
 | |
| 
 | |
| def test_collector_with_async():
 | |
|     env_lens = [2, 3, 4, 5]
 | |
|     writer = SummaryWriter('log/async_collector')
 | |
|     logger = Logger(writer)
 | |
|     env_fns = [
 | |
|         lambda x=i: MyTestEnv(size=x, sleep=0.001, random_sleep=True) for i in env_lens
 | |
|     ]
 | |
| 
 | |
|     venv = SubprocVectorEnv(env_fns, wait_num=len(env_fns) - 1)
 | |
|     policy = MyPolicy()
 | |
|     bufsize = 60
 | |
|     c1 = AsyncCollector(
 | |
|         policy, venv, VectorReplayBuffer(total_size=bufsize * 4, buffer_num=4),
 | |
|         logger.preprocess_fn
 | |
|     )
 | |
|     ptr = [0, 0, 0, 0]
 | |
|     for n_episode in tqdm.trange(1, 30, desc="test async n_episode"):
 | |
|         result = c1.collect(n_episode=n_episode)
 | |
|         assert result["n/ep"] >= n_episode
 | |
|         # check buffer data, obs and obs_next, env_id
 | |
|         for i, count in enumerate(np.bincount(result["lens"], minlength=6)[2:]):
 | |
|             env_len = i + 2
 | |
|             total = env_len * count
 | |
|             indices = np.arange(ptr[i], ptr[i] + total) % bufsize
 | |
|             ptr[i] = (ptr[i] + total) % bufsize
 | |
|             seq = np.arange(env_len)
 | |
|             buf = c1.buffer.buffers[i]
 | |
|             assert np.all(buf.info.env_id[indices] == i)
 | |
|             assert np.all(buf.obs[indices].reshape(count, env_len) == seq)
 | |
|             assert np.all(buf.obs_next[indices].reshape(count, env_len) == seq + 1)
 | |
|     # test async n_step, for now the buffer should be full of data
 | |
|     for n_step in tqdm.trange(1, 15, desc="test async n_step"):
 | |
|         result = c1.collect(n_step=n_step)
 | |
|         assert result["n/st"] >= n_step
 | |
|         for i in range(4):
 | |
|             env_len = i + 2
 | |
|             seq = np.arange(env_len)
 | |
|             buf = c1.buffer.buffers[i]
 | |
|             assert np.all(buf.info.env_id == i)
 | |
|             assert np.all(buf.obs.reshape(-1, env_len) == seq)
 | |
|             assert np.all(buf.obs_next.reshape(-1, env_len) == seq + 1)
 | |
|     with pytest.raises(TypeError):
 | |
|         c1.collect()
 | |
| 
 | |
| 
 | |
| def test_collector_with_dict_state():
 | |
|     env = MyTestEnv(size=5, sleep=0, dict_state=True)
 | |
|     policy = MyPolicy(dict_state=True)
 | |
|     c0 = Collector(policy, env, ReplayBuffer(size=100), Logger.single_preprocess_fn)
 | |
|     c0.collect(n_step=3)
 | |
|     c0.collect(n_episode=2)
 | |
|     assert len(c0.buffer) == 10
 | |
|     env_fns = [
 | |
|         lambda x=i: MyTestEnv(size=x, sleep=0, dict_state=True) for i in [2, 3, 4, 5]
 | |
|     ]
 | |
|     envs = DummyVectorEnv(env_fns)
 | |
|     envs.seed(666)
 | |
|     obs = envs.reset()
 | |
|     assert not np.isclose(obs[0]['rand'], obs[1]['rand'])
 | |
|     c1 = Collector(
 | |
|         policy, envs, VectorReplayBuffer(total_size=100, buffer_num=4),
 | |
|         Logger.single_preprocess_fn
 | |
|     )
 | |
|     c1.collect(n_step=12)
 | |
|     result = c1.collect(n_episode=8)
 | |
|     assert result['n/ep'] == 8
 | |
|     lens = np.bincount(result['lens'])
 | |
|     assert result['n/st'] == 21 and np.all(lens == [0, 0, 2, 2, 2, 2]) or \
 | |
|         result['n/st'] == 20 and np.all(lens == [0, 0, 3, 1, 2, 2])
 | |
|     batch, _ = c1.buffer.sample(10)
 | |
|     c0.buffer.update(c1.buffer)
 | |
|     assert len(c0.buffer) in [42, 43]
 | |
|     if len(c0.buffer) == 42:
 | |
|         assert np.all(
 | |
|             c0.buffer[:].obs.index[..., 0] == [
 | |
|                 0,
 | |
|                 1,
 | |
|                 2,
 | |
|                 3,
 | |
|                 4,
 | |
|                 0,
 | |
|                 1,
 | |
|                 2,
 | |
|                 3,
 | |
|                 4,
 | |
|                 0,
 | |
|                 1,
 | |
|                 0,
 | |
|                 1,
 | |
|                 0,
 | |
|                 1,
 | |
|                 0,
 | |
|                 1,
 | |
|                 0,
 | |
|                 1,
 | |
|                 2,
 | |
|                 0,
 | |
|                 1,
 | |
|                 2,
 | |
|                 0,
 | |
|                 1,
 | |
|                 2,
 | |
|                 3,
 | |
|                 0,
 | |
|                 1,
 | |
|                 2,
 | |
|                 3,
 | |
|                 0,
 | |
|                 1,
 | |
|                 2,
 | |
|                 3,
 | |
|                 4,
 | |
|                 0,
 | |
|                 1,
 | |
|                 2,
 | |
|                 3,
 | |
|                 4,
 | |
|             ]
 | |
|         ), c0.buffer[:].obs.index[..., 0]
 | |
|     else:
 | |
|         assert np.all(
 | |
|             c0.buffer[:].obs.index[..., 0] == [
 | |
|                 0,
 | |
|                 1,
 | |
|                 2,
 | |
|                 3,
 | |
|                 4,
 | |
|                 0,
 | |
|                 1,
 | |
|                 2,
 | |
|                 3,
 | |
|                 4,
 | |
|                 0,
 | |
|                 1,
 | |
|                 0,
 | |
|                 1,
 | |
|                 0,
 | |
|                 1,
 | |
|                 0,
 | |
|                 1,
 | |
|                 2,
 | |
|                 0,
 | |
|                 1,
 | |
|                 2,
 | |
|                 0,
 | |
|                 1,
 | |
|                 2,
 | |
|                 0,
 | |
|                 1,
 | |
|                 2,
 | |
|                 3,
 | |
|                 0,
 | |
|                 1,
 | |
|                 2,
 | |
|                 3,
 | |
|                 0,
 | |
|                 1,
 | |
|                 2,
 | |
|                 3,
 | |
|                 4,
 | |
|                 0,
 | |
|                 1,
 | |
|                 2,
 | |
|                 3,
 | |
|                 4,
 | |
|             ]
 | |
|         ), c0.buffer[:].obs.index[..., 0]
 | |
|     c2 = Collector(
 | |
|         policy, envs, VectorReplayBuffer(total_size=100, buffer_num=4, stack_num=4),
 | |
|         Logger.single_preprocess_fn
 | |
|     )
 | |
|     c2.collect(n_episode=10)
 | |
|     batch, _ = c2.buffer.sample(10)
 | |
| 
 | |
| 
 | |
| def test_collector_with_ma():
 | |
|     env = MyTestEnv(size=5, sleep=0, ma_rew=4)
 | |
|     policy = MyPolicy()
 | |
|     c0 = Collector(policy, env, ReplayBuffer(size=100), Logger.single_preprocess_fn)
 | |
|     # n_step=3 will collect a full episode
 | |
|     rew = c0.collect(n_step=3)['rews']
 | |
|     assert len(rew) == 0
 | |
|     rew = c0.collect(n_episode=2)['rews']
 | |
|     assert rew.shape == (2, 4) and np.all(rew == 1)
 | |
|     env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0, ma_rew=4) for i in [2, 3, 4, 5]]
 | |
|     envs = DummyVectorEnv(env_fns)
 | |
|     c1 = Collector(
 | |
|         policy, envs, VectorReplayBuffer(total_size=100, buffer_num=4),
 | |
|         Logger.single_preprocess_fn
 | |
|     )
 | |
|     rew = c1.collect(n_step=12)['rews']
 | |
|     assert rew.shape == (2, 4) and np.all(rew == 1), rew
 | |
|     rew = c1.collect(n_episode=8)['rews']
 | |
|     assert rew.shape == (8, 4) and np.all(rew == 1)
 | |
|     batch, _ = c1.buffer.sample(10)
 | |
|     print(batch)
 | |
|     c0.buffer.update(c1.buffer)
 | |
|     assert len(c0.buffer) in [42, 43]
 | |
|     if len(c0.buffer) == 42:
 | |
|         rew = [
 | |
|             0,
 | |
|             0,
 | |
|             0,
 | |
|             0,
 | |
|             1,
 | |
|             0,
 | |
|             0,
 | |
|             0,
 | |
|             0,
 | |
|             1,
 | |
|             0,
 | |
|             1,
 | |
|             0,
 | |
|             1,
 | |
|             0,
 | |
|             1,
 | |
|             0,
 | |
|             1,
 | |
|             0,
 | |
|             0,
 | |
|             1,
 | |
|             0,
 | |
|             0,
 | |
|             1,
 | |
|             0,
 | |
|             0,
 | |
|             0,
 | |
|             1,
 | |
|             0,
 | |
|             0,
 | |
|             0,
 | |
|             1,
 | |
|             0,
 | |
|             0,
 | |
|             0,
 | |
|             0,
 | |
|             1,
 | |
|             0,
 | |
|             0,
 | |
|             0,
 | |
|             0,
 | |
|             1,
 | |
|         ]
 | |
|     else:
 | |
|         rew = [
 | |
|             0,
 | |
|             0,
 | |
|             0,
 | |
|             0,
 | |
|             1,
 | |
|             0,
 | |
|             0,
 | |
|             0,
 | |
|             0,
 | |
|             1,
 | |
|             0,
 | |
|             1,
 | |
|             0,
 | |
|             1,
 | |
|             0,
 | |
|             1,
 | |
|             0,
 | |
|             0,
 | |
|             1,
 | |
|             0,
 | |
|             0,
 | |
|             1,
 | |
|             0,
 | |
|             0,
 | |
|             1,
 | |
|             0,
 | |
|             0,
 | |
|             0,
 | |
|             1,
 | |
|             0,
 | |
|             0,
 | |
|             0,
 | |
|             1,
 | |
|             0,
 | |
|             0,
 | |
|             0,
 | |
|             0,
 | |
|             1,
 | |
|             0,
 | |
|             0,
 | |
|             0,
 | |
|             0,
 | |
|             1,
 | |
|         ]
 | |
|     assert np.all(c0.buffer[:].rew == [[x] * 4 for x in rew])
 | |
|     assert np.all(c0.buffer[:].done == rew)
 | |
|     c2 = Collector(
 | |
|         policy, envs, VectorReplayBuffer(total_size=100, buffer_num=4, stack_num=4),
 | |
|         Logger.single_preprocess_fn
 | |
|     )
 | |
|     rew = c2.collect(n_episode=10)['rews']
 | |
|     assert rew.shape == (10, 4) and np.all(rew == 1)
 | |
|     batch, _ = c2.buffer.sample(10)
 | |
| 
 | |
| 
 | |
| def test_collector_with_atari_setting():
 | |
|     reference_obs = np.zeros([6, 4, 84, 84])
 | |
|     for i in range(6):
 | |
|         reference_obs[i, 3, np.arange(84), np.arange(84)] = i
 | |
|         reference_obs[i, 2, np.arange(84)] = i
 | |
|         reference_obs[i, 1, :, np.arange(84)] = i
 | |
|         reference_obs[i, 0] = i
 | |
| 
 | |
|     # atari single buffer
 | |
|     env = MyTestEnv(size=5, sleep=0, array_state=True)
 | |
|     policy = MyPolicy()
 | |
|     c0 = Collector(policy, env, ReplayBuffer(size=100))
 | |
|     c0.collect(n_step=6)
 | |
|     c0.collect(n_episode=2)
 | |
|     assert c0.buffer.obs.shape == (100, 4, 84, 84)
 | |
|     assert c0.buffer.obs_next.shape == (100, 4, 84, 84)
 | |
|     assert len(c0.buffer) == 15
 | |
|     obs = np.zeros_like(c0.buffer.obs)
 | |
|     obs[np.arange(15)] = reference_obs[np.arange(15) % 5]
 | |
|     assert np.all(obs == c0.buffer.obs)
 | |
| 
 | |
|     c1 = Collector(policy, env, ReplayBuffer(size=100, ignore_obs_next=True))
 | |
|     c1.collect(n_episode=3)
 | |
|     assert np.allclose(c0.buffer.obs, c1.buffer.obs)
 | |
|     with pytest.raises(AttributeError):
 | |
|         c1.buffer.obs_next
 | |
|     assert np.all(reference_obs[[1, 2, 3, 4, 4] * 3] == c1.buffer[:].obs_next)
 | |
| 
 | |
|     c2 = Collector(
 | |
|         policy, env,
 | |
|         ReplayBuffer(size=100, ignore_obs_next=True, save_only_last_obs=True)
 | |
|     )
 | |
|     c2.collect(n_step=8)
 | |
|     assert c2.buffer.obs.shape == (100, 84, 84)
 | |
|     obs = np.zeros_like(c2.buffer.obs)
 | |
|     obs[np.arange(8)] = reference_obs[[0, 1, 2, 3, 4, 0, 1, 2], -1]
 | |
|     assert np.all(c2.buffer.obs == obs)
 | |
|     assert np.allclose(
 | |
|         c2.buffer[:].obs_next, reference_obs[[1, 2, 3, 4, 4, 1, 2, 2], -1]
 | |
|     )
 | |
| 
 | |
|     # atari multi buffer
 | |
|     env_fns = [
 | |
|         lambda x=i: MyTestEnv(size=x, sleep=0, array_state=True) for i in [2, 3, 4, 5]
 | |
|     ]
 | |
|     envs = DummyVectorEnv(env_fns)
 | |
|     c3 = Collector(policy, envs, VectorReplayBuffer(total_size=100, buffer_num=4))
 | |
|     c3.collect(n_step=12)
 | |
|     result = c3.collect(n_episode=9)
 | |
|     assert result["n/ep"] == 9 and result["n/st"] == 23
 | |
|     assert c3.buffer.obs.shape == (100, 4, 84, 84)
 | |
|     obs = np.zeros_like(c3.buffer.obs)
 | |
|     obs[np.arange(8)] = reference_obs[[0, 1, 0, 1, 0, 1, 0, 1]]
 | |
|     obs[np.arange(25, 34)] = reference_obs[[0, 1, 2, 0, 1, 2, 0, 1, 2]]
 | |
|     obs[np.arange(50, 58)] = reference_obs[[0, 1, 2, 3, 0, 1, 2, 3]]
 | |
|     obs[np.arange(75, 85)] = reference_obs[[0, 1, 2, 3, 4, 0, 1, 2, 3, 4]]
 | |
|     assert np.all(obs == c3.buffer.obs)
 | |
|     obs_next = np.zeros_like(c3.buffer.obs_next)
 | |
|     obs_next[np.arange(8)] = reference_obs[[1, 2, 1, 2, 1, 2, 1, 2]]
 | |
|     obs_next[np.arange(25, 34)] = reference_obs[[1, 2, 3, 1, 2, 3, 1, 2, 3]]
 | |
|     obs_next[np.arange(50, 58)] = reference_obs[[1, 2, 3, 4, 1, 2, 3, 4]]
 | |
|     obs_next[np.arange(75, 85)] = reference_obs[[1, 2, 3, 4, 5, 1, 2, 3, 4, 5]]
 | |
|     assert np.all(obs_next == c3.buffer.obs_next)
 | |
|     c4 = Collector(
 | |
|         policy, envs,
 | |
|         VectorReplayBuffer(
 | |
|             total_size=100,
 | |
|             buffer_num=4,
 | |
|             stack_num=4,
 | |
|             ignore_obs_next=True,
 | |
|             save_only_last_obs=True
 | |
|         )
 | |
|     )
 | |
|     c4.collect(n_step=12)
 | |
|     result = c4.collect(n_episode=9)
 | |
|     assert result["n/ep"] == 9 and result["n/st"] == 23
 | |
|     assert c4.buffer.obs.shape == (100, 84, 84)
 | |
|     obs = np.zeros_like(c4.buffer.obs)
 | |
|     slice_obs = reference_obs[:, -1]
 | |
|     obs[np.arange(8)] = slice_obs[[0, 1, 0, 1, 0, 1, 0, 1]]
 | |
|     obs[np.arange(25, 34)] = slice_obs[[0, 1, 2, 0, 1, 2, 0, 1, 2]]
 | |
|     obs[np.arange(50, 58)] = slice_obs[[0, 1, 2, 3, 0, 1, 2, 3]]
 | |
|     obs[np.arange(75, 85)] = slice_obs[[0, 1, 2, 3, 4, 0, 1, 2, 3, 4]]
 | |
|     assert np.all(c4.buffer.obs == obs)
 | |
|     obs_next = np.zeros([len(c4.buffer), 4, 84, 84])
 | |
|     ref_index = np.array(
 | |
|         [
 | |
|             1,
 | |
|             1,
 | |
|             1,
 | |
|             1,
 | |
|             1,
 | |
|             1,
 | |
|             1,
 | |
|             1,
 | |
|             1,
 | |
|             2,
 | |
|             2,
 | |
|             1,
 | |
|             2,
 | |
|             2,
 | |
|             1,
 | |
|             2,
 | |
|             2,
 | |
|             1,
 | |
|             2,
 | |
|             3,
 | |
|             3,
 | |
|             1,
 | |
|             2,
 | |
|             3,
 | |
|             3,
 | |
|             1,
 | |
|             2,
 | |
|             3,
 | |
|             4,
 | |
|             4,
 | |
|             1,
 | |
|             2,
 | |
|             3,
 | |
|             4,
 | |
|             4,
 | |
|         ]
 | |
|     )
 | |
|     obs_next[:, -1] = slice_obs[ref_index]
 | |
|     ref_index -= 1
 | |
|     ref_index[ref_index < 0] = 0
 | |
|     obs_next[:, -2] = slice_obs[ref_index]
 | |
|     ref_index -= 1
 | |
|     ref_index[ref_index < 0] = 0
 | |
|     obs_next[:, -3] = slice_obs[ref_index]
 | |
|     ref_index -= 1
 | |
|     ref_index[ref_index < 0] = 0
 | |
|     obs_next[:, -4] = slice_obs[ref_index]
 | |
|     assert np.all(obs_next == c4.buffer[:].obs_next)
 | |
| 
 | |
|     buf = ReplayBuffer(100, stack_num=4, ignore_obs_next=True, save_only_last_obs=True)
 | |
|     c5 = Collector(policy, envs, CachedReplayBuffer(buf, 4, 10))
 | |
|     result_ = c5.collect(n_step=12)
 | |
|     assert len(buf) == 5 and len(c5.buffer) == 12
 | |
|     result = c5.collect(n_episode=9)
 | |
|     assert result["n/ep"] == 9 and result["n/st"] == 23
 | |
|     assert len(buf) == 35
 | |
|     assert np.all(
 | |
|         buf.obs[:len(buf)] == slice_obs[[
 | |
|             0, 1, 0, 1, 2, 0, 1, 0, 1, 2, 3, 0, 1, 2, 3, 4, 0, 1, 0, 1, 2, 0, 1, 0, 1,
 | |
|             2, 3, 0, 1, 2, 0, 1, 2, 3, 4
 | |
|         ]]
 | |
|     )
 | |
|     assert np.all(
 | |
|         buf[:].obs_next[:, -1] == slice_obs[[
 | |
|             1, 1, 1, 2, 2, 1, 1, 1, 2, 3, 3, 1, 2, 3, 4, 4, 1, 1, 1, 2, 2, 1, 1, 1, 2,
 | |
|             3, 3, 1, 2, 2, 1, 2, 3, 4, 4
 | |
|         ]]
 | |
|     )
 | |
|     assert len(buf) == len(c5.buffer)
 | |
| 
 | |
|     # test buffer=None
 | |
|     c6 = Collector(policy, envs)
 | |
|     result1 = c6.collect(n_step=12)
 | |
|     for key in ["n/ep", "n/st", "rews", "lens"]:
 | |
|         assert np.allclose(result1[key], result_[key])
 | |
|     result2 = c6.collect(n_episode=9)
 | |
|     for key in ["n/ep", "n/st", "rews", "lens"]:
 | |
|         assert np.allclose(result2[key], result[key])
 | |
| 
 | |
| 
 | |
| if __name__ == '__main__':
 | |
|     test_collector()
 | |
|     test_collector_with_dict_state()
 | |
|     test_collector_with_ma()
 | |
|     test_collector_with_atari_setting()
 | |
|     test_collector_with_async()
 |