parent
277138ca5b
commit
109875d43d
@ -279,7 +279,7 @@ def make_atari_env(task, seed, training_num, test_num, **kwargs):
|
||||
)
|
||||
test_envs = envpool.make_gym(
|
||||
task.replace("NoFrameskip-v4", "-v5"),
|
||||
num_envs=training_num,
|
||||
num_envs=test_num,
|
||||
seed=seed,
|
||||
episodic_life=False,
|
||||
reward_clip=False,
|
||||
|
@ -19,7 +19,7 @@ def make_mujoco_env(task, seed, training_num, test_num, obs_norm):
|
||||
"""
|
||||
if envpool is not None:
|
||||
train_envs = env = envpool.make_gym(task, num_envs=training_num, seed=seed)
|
||||
test_envs = envpool.make_gym(task, num_envs=training_num, seed=seed)
|
||||
test_envs = envpool.make_gym(task, num_envs=test_num, seed=seed)
|
||||
else:
|
||||
warnings.warn(
|
||||
"Recommend using envpool (pip install envpool) "
|
||||
|
@ -347,7 +347,7 @@ class ReplayBuffer:
|
||||
indices = self.sample_indices(0) if index == slice(None) \
|
||||
else self._indices[:len(self)][index]
|
||||
else:
|
||||
indices = index
|
||||
indices = index # type: ignore
|
||||
# raise KeyError first instead of AttributeError,
|
||||
# to support np.array([ReplayBuffer()])
|
||||
obs = self.get(indices, "obs")
|
||||
|
@ -94,7 +94,7 @@ class PrioritizedReplayBuffer(ReplayBuffer):
|
||||
indices = self.sample_indices(0) if index == slice(None) \
|
||||
else self._indices[:len(self)][index]
|
||||
else:
|
||||
indices = index
|
||||
indices = index # type: ignore
|
||||
batch = super().__getitem__(indices)
|
||||
weight = self.get_weight(indices)
|
||||
# ref: https://github.com/Kaixhin/Rainbow/blob/master/memory.py L154
|
||||
|
Loading…
x
Reference in New Issue
Block a user