Fix num_envs=test_num (#653)

* fix num_envs=test_num

* fix mypy
This commit is contained in:
Jiayi Weng 2022-05-29 23:38:47 -05:00 committed by GitHub
parent 277138ca5b
commit 109875d43d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 4 additions and 4 deletions

View File

@ -279,7 +279,7 @@ def make_atari_env(task, seed, training_num, test_num, **kwargs):
)
test_envs = envpool.make_gym(
task.replace("NoFrameskip-v4", "-v5"),
num_envs=training_num,
num_envs=test_num,
seed=seed,
episodic_life=False,
reward_clip=False,

View File

@ -19,7 +19,7 @@ def make_mujoco_env(task, seed, training_num, test_num, obs_norm):
"""
if envpool is not None:
train_envs = env = envpool.make_gym(task, num_envs=training_num, seed=seed)
test_envs = envpool.make_gym(task, num_envs=training_num, seed=seed)
test_envs = envpool.make_gym(task, num_envs=test_num, seed=seed)
else:
warnings.warn(
"Recommend using envpool (pip install envpool) "

View File

@ -347,7 +347,7 @@ class ReplayBuffer:
indices = self.sample_indices(0) if index == slice(None) \
else self._indices[:len(self)][index]
else:
indices = index
indices = index # type: ignore
# raise KeyError first instead of AttributeError,
# to support np.array([ReplayBuffer()])
obs = self.get(indices, "obs")

View File

@ -94,7 +94,7 @@ class PrioritizedReplayBuffer(ReplayBuffer):
indices = self.sample_indices(0) if index == slice(None) \
else self._indices[:len(self)][index]
else:
indices = index
indices = index # type: ignore
batch = super().__getitem__(indices)
weight = self.get_weight(indices)
# ref: https://github.com/Kaixhin/Rainbow/blob/master/memory.py L154