parent
c8be85b240
commit
bc222e87a6
@ -8,7 +8,7 @@ exclude =
|
|||||||
dist
|
dist
|
||||||
*.egg-info
|
*.egg-info
|
||||||
max-line-length = 87
|
max-line-length = 87
|
||||||
ignore = B305,W504,B006,B008,B024,W503
|
ignore = B305,W504,B006,B008,B024,W503,B028
|
||||||
|
|
||||||
[yapf]
|
[yapf]
|
||||||
based_on_style = pep8
|
based_on_style = pep8
|
||||||
|
5
setup.py
5
setup.py
@ -22,7 +22,6 @@ def get_install_requires() -> str:
|
|||||||
"torch>=1.4.0",
|
"torch>=1.4.0",
|
||||||
"numba>=0.51.0",
|
"numba>=0.51.0",
|
||||||
"h5py>=2.10.0", # to match tensorflow's minimal requirements
|
"h5py>=2.10.0", # to match tensorflow's minimal requirements
|
||||||
"protobuf~=3.19.0", # breaking change, sphinx fail
|
|
||||||
"packaging",
|
"packaging",
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -30,9 +29,9 @@ def get_install_requires() -> str:
|
|||||||
def get_extras_require() -> str:
|
def get_extras_require() -> str:
|
||||||
req = {
|
req = {
|
||||||
"dev": [
|
"dev": [
|
||||||
"sphinx<4",
|
"sphinx",
|
||||||
"sphinx_rtd_theme",
|
"sphinx_rtd_theme",
|
||||||
"jinja2<3.1", # temporary fix
|
"jinja2",
|
||||||
"sphinxcontrib-bibtex",
|
"sphinxcontrib-bibtex",
|
||||||
"flake8",
|
"flake8",
|
||||||
"flake8-bugbear",
|
"flake8-bugbear",
|
||||||
|
@ -437,6 +437,38 @@ def test_herreplaybuffer(size=10, bufsize=100, sample_sz=4):
|
|||||||
assert np.all(buf[10:].obs.desired_goal == buf[0].obs.desired_goal) # (same ep)
|
assert np.all(buf[10:].obs.desired_goal == buf[0].obs.desired_goal) # (same ep)
|
||||||
assert np.all(buf[0].obs.desired_goal != buf[5].obs.desired_goal) # (diff ep)
|
assert np.all(buf[0].obs.desired_goal != buf[5].obs.desired_goal) # (diff ep)
|
||||||
|
|
||||||
|
# Another test case for cycled indices
|
||||||
|
env_size = 99
|
||||||
|
bufsize = 15
|
||||||
|
env = MyGoalEnv(env_size, array_state=False)
|
||||||
|
buf = HERReplayBuffer(
|
||||||
|
bufsize, compute_reward_fn=compute_reward_fn, horizon=30, future_k=8
|
||||||
|
)
|
||||||
|
buf.future_p = 1
|
||||||
|
for x, ep_len in enumerate([10, 20]):
|
||||||
|
obs, _ = env.reset()
|
||||||
|
for i in range(ep_len):
|
||||||
|
act = 1
|
||||||
|
obs_next, rew, terminated, truncated, info = env.step(act)
|
||||||
|
batch = Batch(
|
||||||
|
obs=obs,
|
||||||
|
act=[act],
|
||||||
|
rew=rew,
|
||||||
|
terminated=(i == ep_len - 1),
|
||||||
|
truncated=(i == ep_len - 1),
|
||||||
|
obs_next=obs_next,
|
||||||
|
info=info
|
||||||
|
)
|
||||||
|
if x == 1 and obs["observation"] < 10:
|
||||||
|
obs = obs_next
|
||||||
|
continue
|
||||||
|
buf.add(batch)
|
||||||
|
obs = obs_next
|
||||||
|
buf._restore_cache()
|
||||||
|
sample_indices = np.array([10]) # Suppose the sampled indices is [10]
|
||||||
|
buf.rewrite_transitions(sample_indices)
|
||||||
|
assert int(buf.obs.desired_goal[10][0]) in [11, 12, 13, 14, 15, 16, 17, 18, 19, 20]
|
||||||
|
|
||||||
|
|
||||||
def test_update():
|
def test_update():
|
||||||
buf1 = ReplayBuffer(4, stack_num=2)
|
buf1 = ReplayBuffer(4, stack_num=2)
|
||||||
|
@ -120,9 +120,10 @@ class HERReplayBuffer(ReplayBuffer):
|
|||||||
# Calculate future timestep to use
|
# Calculate future timestep to use
|
||||||
current = indices[0]
|
current = indices[0]
|
||||||
terminal = indices[-1]
|
terminal = indices[-1]
|
||||||
future_offset = np.random.uniform(size=len(indices[0])) * (terminal - current)
|
episodes_len = (terminal - current + self.maxsize) % self.maxsize
|
||||||
future_offset = future_offset.astype(int)
|
future_offset = np.random.uniform(size=len(indices[0])) * episodes_len
|
||||||
future_t = (current + future_offset)
|
future_offset = np.round(future_offset).astype(int)
|
||||||
|
future_t = (current + future_offset) % self.maxsize
|
||||||
|
|
||||||
# Compute indices
|
# Compute indices
|
||||||
# open indices are used to find longest, unique trajectories among
|
# open indices are used to find longest, unique trajectories among
|
||||||
|
Loading…
x
Reference in New Issue
Block a user