From bc222e87a636f13575d65b5603ad43af5e39090e Mon Sep 17 00:00:00 2001 From: sunkafei <82038764+sunkafei@users.noreply.github.com> Date: Sat, 4 Mar 2023 08:57:04 +0800 Subject: [PATCH] Fix #811 (#817) --- setup.cfg | 2 +- setup.py | 5 ++--- test/base/test_buffer.py | 32 ++++++++++++++++++++++++++++++++ tianshou/data/buffer/her.py | 7 ++++--- 4 files changed, 39 insertions(+), 7 deletions(-) diff --git a/setup.cfg b/setup.cfg index 2960359..0e11d54 100644 --- a/setup.cfg +++ b/setup.cfg @@ -8,7 +8,7 @@ exclude = dist *.egg-info max-line-length = 87 -ignore = B305,W504,B006,B008,B024,W503 +ignore = B305,W504,B006,B008,B024,W503,B028 [yapf] based_on_style = pep8 diff --git a/setup.py b/setup.py index 5912cd1..253be20 100644 --- a/setup.py +++ b/setup.py @@ -22,7 +22,6 @@ def get_install_requires() -> str: "torch>=1.4.0", "numba>=0.51.0", "h5py>=2.10.0", # to match tensorflow's minimal requirements - "protobuf~=3.19.0", # breaking change, sphinx fail "packaging", ] @@ -30,9 +29,9 @@ def get_install_requires() -> str: def get_extras_require() -> str: req = { "dev": [ - "sphinx<4", + "sphinx", "sphinx_rtd_theme", - "jinja2<3.1", # temporary fix + "jinja2", "sphinxcontrib-bibtex", "flake8", "flake8-bugbear", diff --git a/test/base/test_buffer.py b/test/base/test_buffer.py index cf011b7..c4d3c6b 100644 --- a/test/base/test_buffer.py +++ b/test/base/test_buffer.py @@ -437,6 +437,38 @@ def test_herreplaybuffer(size=10, bufsize=100, sample_sz=4): assert np.all(buf[10:].obs.desired_goal == buf[0].obs.desired_goal) # (same ep) assert np.all(buf[0].obs.desired_goal != buf[5].obs.desired_goal) # (diff ep) + # Another test case for cycled indices + env_size = 99 + bufsize = 15 + env = MyGoalEnv(env_size, array_state=False) + buf = HERReplayBuffer( + bufsize, compute_reward_fn=compute_reward_fn, horizon=30, future_k=8 + ) + buf.future_p = 1 + for x, ep_len in enumerate([10, 20]): + obs, _ = env.reset() + for i in range(ep_len): + act = 1 + obs_next, rew, terminated, truncated, info = env.step(act) + batch = Batch( + obs=obs, + act=[act], + rew=rew, + terminated=(i == ep_len - 1), + truncated=(i == ep_len - 1), + obs_next=obs_next, + info=info + ) + if x == 1 and obs["observation"] < 10: + obs = obs_next + continue + buf.add(batch) + obs = obs_next + buf._restore_cache() + sample_indices = np.array([10]) # Suppose the sampled indices is [10] + buf.rewrite_transitions(sample_indices) + assert int(buf.obs.desired_goal[10][0]) in [11, 12, 13, 14, 15, 16, 17, 18, 19, 20] + def test_update(): buf1 = ReplayBuffer(4, stack_num=2) diff --git a/tianshou/data/buffer/her.py b/tianshou/data/buffer/her.py index 8c5c371..fc18243 100644 --- a/tianshou/data/buffer/her.py +++ b/tianshou/data/buffer/her.py @@ -120,9 +120,10 @@ class HERReplayBuffer(ReplayBuffer): # Calculate future timestep to use current = indices[0] terminal = indices[-1] - future_offset = np.random.uniform(size=len(indices[0])) * (terminal - current) - future_offset = future_offset.astype(int) - future_t = (current + future_offset) + episodes_len = (terminal - current + self.maxsize) % self.maxsize + future_offset = np.random.uniform(size=len(indices[0])) * episodes_len + future_offset = np.round(future_offset).astype(int) + future_t = (current + future_offset) % self.maxsize # Compute indices # open indices are used to find longest, unique trajectories among