diff --git a/examples/atari/atari_wrapper.py b/examples/atari/atari_wrapper.py index 96fc6f7..a9a3f82 100644 --- a/examples/atari/atari_wrapper.py +++ b/examples/atari/atari_wrapper.py @@ -279,7 +279,7 @@ def make_atari_env(task, seed, training_num, test_num, **kwargs): ) test_envs = envpool.make_gym( task.replace("NoFrameskip-v4", "-v5"), - num_envs=training_num, + num_envs=test_num, seed=seed, episodic_life=False, reward_clip=False, diff --git a/examples/mujoco/mujoco_env.py b/examples/mujoco/mujoco_env.py index e0036bc..0c56757 100644 --- a/examples/mujoco/mujoco_env.py +++ b/examples/mujoco/mujoco_env.py @@ -19,7 +19,7 @@ def make_mujoco_env(task, seed, training_num, test_num, obs_norm): """ if envpool is not None: train_envs = env = envpool.make_gym(task, num_envs=training_num, seed=seed) - test_envs = envpool.make_gym(task, num_envs=training_num, seed=seed) + test_envs = envpool.make_gym(task, num_envs=test_num, seed=seed) else: warnings.warn( "Recommend using envpool (pip install envpool) " diff --git a/tianshou/data/buffer/base.py b/tianshou/data/buffer/base.py index 9ef99cd..e650cc5 100644 --- a/tianshou/data/buffer/base.py +++ b/tianshou/data/buffer/base.py @@ -347,7 +347,7 @@ class ReplayBuffer: indices = self.sample_indices(0) if index == slice(None) \ else self._indices[:len(self)][index] else: - indices = index + indices = index # type: ignore # raise KeyError first instead of AttributeError, # to support np.array([ReplayBuffer()]) obs = self.get(indices, "obs") diff --git a/tianshou/data/buffer/prio.py b/tianshou/data/buffer/prio.py index fa3c49b..dc30a5d 100644 --- a/tianshou/data/buffer/prio.py +++ b/tianshou/data/buffer/prio.py @@ -94,7 +94,7 @@ class PrioritizedReplayBuffer(ReplayBuffer): indices = self.sample_indices(0) if index == slice(None) \ else self._indices[:len(self)][index] else: - indices = index + indices = index # type: ignore batch = super().__getitem__(indices) weight = self.get_weight(indices) # ref: https://github.com/Kaixhin/Rainbow/blob/master/memory.py L154