This commit is contained in:
sunkafei 2023-03-04 08:57:04 +08:00 committed by GitHub
parent c8be85b240
commit bc222e87a6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 39 additions and 7 deletions

View File

@ -8,7 +8,7 @@ exclude =
dist
*.egg-info
max-line-length = 87
ignore = B305,W504,B006,B008,B024,W503
ignore = B305,W504,B006,B008,B024,W503,B028
[yapf]
based_on_style = pep8

View File

@ -22,7 +22,6 @@ def get_install_requires() -> str:
"torch>=1.4.0",
"numba>=0.51.0",
"h5py>=2.10.0", # to match tensorflow's minimal requirements
"protobuf~=3.19.0", # breaking change, sphinx fail
"packaging",
]
@ -30,9 +29,9 @@ def get_install_requires() -> str:
def get_extras_require() -> str:
req = {
"dev": [
"sphinx<4",
"sphinx",
"sphinx_rtd_theme",
"jinja2<3.1", # temporary fix
"jinja2",
"sphinxcontrib-bibtex",
"flake8",
"flake8-bugbear",

View File

@ -437,6 +437,38 @@ def test_herreplaybuffer(size=10, bufsize=100, sample_sz=4):
assert np.all(buf[10:].obs.desired_goal == buf[0].obs.desired_goal) # (same ep)
assert np.all(buf[0].obs.desired_goal != buf[5].obs.desired_goal) # (diff ep)
# Another test case for cycled indices
env_size = 99
bufsize = 15
env = MyGoalEnv(env_size, array_state=False)
buf = HERReplayBuffer(
bufsize, compute_reward_fn=compute_reward_fn, horizon=30, future_k=8
)
buf.future_p = 1
for x, ep_len in enumerate([10, 20]):
obs, _ = env.reset()
for i in range(ep_len):
act = 1
obs_next, rew, terminated, truncated, info = env.step(act)
batch = Batch(
obs=obs,
act=[act],
rew=rew,
terminated=(i == ep_len - 1),
truncated=(i == ep_len - 1),
obs_next=obs_next,
info=info
)
if x == 1 and obs["observation"] < 10:
obs = obs_next
continue
buf.add(batch)
obs = obs_next
buf._restore_cache()
sample_indices = np.array([10]) # Suppose the sampled indices is [10]
buf.rewrite_transitions(sample_indices)
assert int(buf.obs.desired_goal[10][0]) in [11, 12, 13, 14, 15, 16, 17, 18, 19, 20]
def test_update():
buf1 = ReplayBuffer(4, stack_num=2)

View File

@ -120,9 +120,10 @@ class HERReplayBuffer(ReplayBuffer):
# Calculate future timestep to use
current = indices[0]
terminal = indices[-1]
future_offset = np.random.uniform(size=len(indices[0])) * (terminal - current)
future_offset = future_offset.astype(int)
future_t = (current + future_offset)
episodes_len = (terminal - current + self.maxsize) % self.maxsize
future_offset = np.random.uniform(size=len(indices[0])) * episodes_len
future_offset = np.round(future_offset).astype(int)
future_t = (current + future_offset) % self.maxsize
# Compute indices
# open indices are used to find longest, unique trajectories among