parent
c8be85b240
commit
bc222e87a6
@ -8,7 +8,7 @@ exclude =
|
||||
dist
|
||||
*.egg-info
|
||||
max-line-length = 87
|
||||
ignore = B305,W504,B006,B008,B024,W503
|
||||
ignore = B305,W504,B006,B008,B024,W503,B028
|
||||
|
||||
[yapf]
|
||||
based_on_style = pep8
|
||||
|
5
setup.py
5
setup.py
@ -22,7 +22,6 @@ def get_install_requires() -> str:
|
||||
"torch>=1.4.0",
|
||||
"numba>=0.51.0",
|
||||
"h5py>=2.10.0", # to match tensorflow's minimal requirements
|
||||
"protobuf~=3.19.0", # breaking change, sphinx fail
|
||||
"packaging",
|
||||
]
|
||||
|
||||
@ -30,9 +29,9 @@ def get_install_requires() -> str:
|
||||
def get_extras_require() -> str:
|
||||
req = {
|
||||
"dev": [
|
||||
"sphinx<4",
|
||||
"sphinx",
|
||||
"sphinx_rtd_theme",
|
||||
"jinja2<3.1", # temporary fix
|
||||
"jinja2",
|
||||
"sphinxcontrib-bibtex",
|
||||
"flake8",
|
||||
"flake8-bugbear",
|
||||
|
@ -437,6 +437,38 @@ def test_herreplaybuffer(size=10, bufsize=100, sample_sz=4):
|
||||
assert np.all(buf[10:].obs.desired_goal == buf[0].obs.desired_goal) # (same ep)
|
||||
assert np.all(buf[0].obs.desired_goal != buf[5].obs.desired_goal) # (diff ep)
|
||||
|
||||
# Another test case for cycled indices
|
||||
env_size = 99
|
||||
bufsize = 15
|
||||
env = MyGoalEnv(env_size, array_state=False)
|
||||
buf = HERReplayBuffer(
|
||||
bufsize, compute_reward_fn=compute_reward_fn, horizon=30, future_k=8
|
||||
)
|
||||
buf.future_p = 1
|
||||
for x, ep_len in enumerate([10, 20]):
|
||||
obs, _ = env.reset()
|
||||
for i in range(ep_len):
|
||||
act = 1
|
||||
obs_next, rew, terminated, truncated, info = env.step(act)
|
||||
batch = Batch(
|
||||
obs=obs,
|
||||
act=[act],
|
||||
rew=rew,
|
||||
terminated=(i == ep_len - 1),
|
||||
truncated=(i == ep_len - 1),
|
||||
obs_next=obs_next,
|
||||
info=info
|
||||
)
|
||||
if x == 1 and obs["observation"] < 10:
|
||||
obs = obs_next
|
||||
continue
|
||||
buf.add(batch)
|
||||
obs = obs_next
|
||||
buf._restore_cache()
|
||||
sample_indices = np.array([10]) # Suppose the sampled indices is [10]
|
||||
buf.rewrite_transitions(sample_indices)
|
||||
assert int(buf.obs.desired_goal[10][0]) in [11, 12, 13, 14, 15, 16, 17, 18, 19, 20]
|
||||
|
||||
|
||||
def test_update():
|
||||
buf1 = ReplayBuffer(4, stack_num=2)
|
||||
|
@ -120,9 +120,10 @@ class HERReplayBuffer(ReplayBuffer):
|
||||
# Calculate future timestep to use
|
||||
current = indices[0]
|
||||
terminal = indices[-1]
|
||||
future_offset = np.random.uniform(size=len(indices[0])) * (terminal - current)
|
||||
future_offset = future_offset.astype(int)
|
||||
future_t = (current + future_offset)
|
||||
episodes_len = (terminal - current + self.maxsize) % self.maxsize
|
||||
future_offset = np.random.uniform(size=len(indices[0])) * episodes_len
|
||||
future_offset = np.round(future_offset).astype(int)
|
||||
future_t = (current + future_offset) % self.maxsize
|
||||
|
||||
# Compute indices
|
||||
# open indices are used to find longest, unique trajectories among
|
||||
|
Loading…
x
Reference in New Issue
Block a user