parent
277138ca5b
commit
109875d43d
@ -279,7 +279,7 @@ def make_atari_env(task, seed, training_num, test_num, **kwargs):
|
|||||||
)
|
)
|
||||||
test_envs = envpool.make_gym(
|
test_envs = envpool.make_gym(
|
||||||
task.replace("NoFrameskip-v4", "-v5"),
|
task.replace("NoFrameskip-v4", "-v5"),
|
||||||
num_envs=training_num,
|
num_envs=test_num,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
episodic_life=False,
|
episodic_life=False,
|
||||||
reward_clip=False,
|
reward_clip=False,
|
||||||
|
@ -19,7 +19,7 @@ def make_mujoco_env(task, seed, training_num, test_num, obs_norm):
|
|||||||
"""
|
"""
|
||||||
if envpool is not None:
|
if envpool is not None:
|
||||||
train_envs = env = envpool.make_gym(task, num_envs=training_num, seed=seed)
|
train_envs = env = envpool.make_gym(task, num_envs=training_num, seed=seed)
|
||||||
test_envs = envpool.make_gym(task, num_envs=training_num, seed=seed)
|
test_envs = envpool.make_gym(task, num_envs=test_num, seed=seed)
|
||||||
else:
|
else:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"Recommend using envpool (pip install envpool) "
|
"Recommend using envpool (pip install envpool) "
|
||||||
|
@ -347,7 +347,7 @@ class ReplayBuffer:
|
|||||||
indices = self.sample_indices(0) if index == slice(None) \
|
indices = self.sample_indices(0) if index == slice(None) \
|
||||||
else self._indices[:len(self)][index]
|
else self._indices[:len(self)][index]
|
||||||
else:
|
else:
|
||||||
indices = index
|
indices = index # type: ignore
|
||||||
# raise KeyError first instead of AttributeError,
|
# raise KeyError first instead of AttributeError,
|
||||||
# to support np.array([ReplayBuffer()])
|
# to support np.array([ReplayBuffer()])
|
||||||
obs = self.get(indices, "obs")
|
obs = self.get(indices, "obs")
|
||||||
|
@ -94,7 +94,7 @@ class PrioritizedReplayBuffer(ReplayBuffer):
|
|||||||
indices = self.sample_indices(0) if index == slice(None) \
|
indices = self.sample_indices(0) if index == slice(None) \
|
||||||
else self._indices[:len(self)][index]
|
else self._indices[:len(self)][index]
|
||||||
else:
|
else:
|
||||||
indices = index
|
indices = index # type: ignore
|
||||||
batch = super().__getitem__(indices)
|
batch = super().__getitem__(indices)
|
||||||
weight = self.get_weight(indices)
|
weight = self.get_weight(indices)
|
||||||
# ref: https://github.com/Kaixhin/Rainbow/blob/master/memory.py L154
|
# ref: https://github.com/Kaixhin/Rainbow/blob/master/memory.py L154
|
||||||
|
Loading…
x
Reference in New Issue
Block a user