187 lines
7.5 KiB
Python
Raw Normal View History

Hindsight Experience Replay as a replay buffer (#753) ## implementation I implemented HER solely as a replay buffer. It is done by temporarily directly re-writing transitions storage (`self._meta`) during the `sample_indices()` call. The original transitions are cached and will be restored at the beginning of the next sampling or when other methods is called. This will make sure that. for example, n-step return calculation can be done without altering the policy. There is also a problem with the original indices sampling. The sampled indices are not guaranteed to be from different episodes. So I decided to perform re-writing based on the episode. This guarantees that the sampled transitions from the same episode will have the same re-written goal. This also make the re-writing ratio calculation slightly differ from the paper, but it won't be too different if there are many episodes in the buffer. In the current commit, HER replay buffer only support 'future' strategy and online sampling. This is the best of HER in term of performance and memory efficiency. I also add a few more convenient replay buffers (`HERVectorReplayBuffer`, `HERReplayBufferManager`), test env (`MyGoalEnv`), gym wrapper (`TruncatedAsTerminated`), unit tests, and a simple example (examples/offline/fetch_her_ddpg.py). ## verification I have added unit tests for almost everything I have implemented. HER replay buffer was also tested using DDPG on [`FetchReach-v3` env](https://github.com/Farama-Foundation/Gymnasium-Robotics). I used default DDPG parameters from mujoco example and didn't tune anything further to get this good result! (train script: examples/offline/fetch_her_ddpg.py). ![Screen Shot 2022-10-02 at 19 22 53](https://user-images.githubusercontent.com/42699114/193454066-0dd0c65c-fd5f-4587-8912-b441d39de88a.png)
2022-10-31 08:54:54 +09:00
from typing import Any, Callable, List, Optional, Tuple, Union
import numpy as np
from tianshou.data import Batch, ReplayBuffer
class HERReplayBuffer(ReplayBuffer):
"""Implementation of Hindsight Experience Replay. arXiv:1707.01495.
HERReplayBuffer is to be used with goal-based environment where the
observation is a dictionary with keys ``observation``, ``achieved_goal`` and
``desired_goal``. Currently support only HER's future strategy, online sampling.
:param int size: the size of the replay buffer.
:param compute_reward_fn: a function that takes 2 ``np.array`` arguments,
``acheived_goal`` and ``desired_goal``, and returns rewards as ``np.array``.
The two arguments are of shape (batch_size, ...original_shape) and the returned
rewards must be of shape (batch_size,).
:param int horizon: the maximum number of steps in an episode.
:param int future_k: the 'k' parameter introduced in the paper. In short, there
will be at most k episodes that are re-written for every 1 unaltered episode
during the sampling.
.. seealso::
Please refer to :class:`~tianshou.data.ReplayBuffer` for other APIs' usage.
"""
def __init__(
self,
size: int,
compute_reward_fn: Callable[[np.ndarray, np.ndarray], np.ndarray],
horizon: int,
future_k: float = 8.0,
**kwargs: Any,
) -> None:
super().__init__(size, **kwargs)
self.horizon = horizon
self.future_p = 1 - 1 / future_k
self.compute_reward_fn = compute_reward_fn
self._original_meta = Batch()
self._altered_indices = np.array([])
def _restore_cache(self) -> None:
"""Write cached original meta back to `self._meta`.
It's called everytime before 'writing', 'sampling' or 'saving' the buffer.
"""
if not hasattr(self, '_altered_indices'):
return
if self._altered_indices.size == 0:
return
self._meta[self._altered_indices] = self._original_meta
# Clean
self._original_meta = Batch()
self._altered_indices = np.array([])
def reset(self, keep_statistics: bool = False) -> None:
self._restore_cache()
return super().reset(keep_statistics)
def save_hdf5(self, path: str, compression: Optional[str] = None) -> None:
self._restore_cache()
return super().save_hdf5(path, compression)
def set_batch(self, batch: Batch) -> None:
self._restore_cache()
return super().set_batch(batch)
def update(self, buffer: Union["HERReplayBuffer", "ReplayBuffer"]) -> np.ndarray:
self._restore_cache()
return super().update(buffer)
def add(
self,
batch: Batch,
buffer_ids: Optional[Union[np.ndarray, List[int]]] = None
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
self._restore_cache()
return super().add(batch, buffer_ids)
def sample_indices(self, batch_size: int) -> np.ndarray:
"""Get a random sample of index with size = batch_size.
Return all available indices in the buffer if batch_size is 0; return an \
empty numpy array if batch_size < 0 or no available index can be sampled. \
Additionally, some episodes of the sampled transitions will be re-written \
according to HER.
"""
self._restore_cache()
indices = super().sample_indices(batch_size=batch_size)
self.rewrite_transitions(indices.copy())
return indices
def rewrite_transitions(self, indices: np.ndarray) -> None:
"""Re-write the goal of some sampled transitions' episodes according to HER.
Currently applies only HER's 'future' strategy. The new goals will be written \
directly to the internal batch data temporarily and will be restored right \
before the next sampling or when using some of the buffer's method (e.g. \
`add`, `save_hdf5`, etc.). This is to make sure that n-step returns \
calculation etc., performs correctly without additional alteration.
"""
if indices.size == 0:
return
# Sort indices keeping chronological order
indices[indices < self._index] += self.maxsize
indices = np.sort(indices)
indices[indices >= self.maxsize] -= self.maxsize
# Construct episode trajectories
indices = [indices]
for _ in range(self.horizon - 1):
indices.append(self.next(indices[-1]))
indices = np.stack(indices)
# Calculate future timestep to use
current = indices[0]
terminal = indices[-1]
future_offset = np.random.uniform(size=len(indices[0])) * (terminal - current)
future_offset = future_offset.astype(int)
future_t = (current + future_offset)
# Compute indices
# open indices are used to find longest, unique trajectories among
# presented episodes
unique_ep_open_indices = np.sort(np.unique(terminal, return_index=True)[1])
unique_ep_indices = indices[:, unique_ep_open_indices]
# close indices are used to find max future_t among presented episodes
unique_ep_close_indices = np.hstack(
[(unique_ep_open_indices - 1)[1:],
len(terminal) - 1]
)
# episode indices that will be altered
her_ep_indices = np.random.choice(
len(unique_ep_open_indices),
size=int(len(unique_ep_open_indices) * self.future_p),
replace=False
)
# Cache original meta
self._altered_indices = unique_ep_indices.copy()
self._original_meta = self._meta[self._altered_indices].copy()
# Copy original obs, ep_rew (and obs_next), and obs of future time step
ep_obs = self[unique_ep_indices].obs
ep_rew = self[unique_ep_indices].rew
if self._save_obs_next:
ep_obs_next = self[unique_ep_indices].obs_next
future_obs = self[future_t[unique_ep_close_indices]].obs_next
else:
future_obs = self[self.next(future_t[unique_ep_close_indices])].obs
# Re-assign goals and rewards via broadcast assignment
ep_obs.desired_goal[:, her_ep_indices] = \
future_obs.achieved_goal[None, her_ep_indices]
if self._save_obs_next:
ep_obs_next.desired_goal[:, her_ep_indices] = \
future_obs.achieved_goal[None, her_ep_indices]
ep_rew[:, her_ep_indices] = \
self._compute_reward(ep_obs_next)[:, her_ep_indices]
else:
tmp_ep_obs_next = self[self.next(unique_ep_indices)].obs
ep_rew[:, her_ep_indices] = \
self._compute_reward(tmp_ep_obs_next)[:, her_ep_indices]
# Sanity check
assert ep_obs.desired_goal.shape[:2] == unique_ep_indices.shape
assert ep_obs.achieved_goal.shape[:2] == unique_ep_indices.shape
assert ep_rew.shape == unique_ep_indices.shape
# Re-write meta
self._meta.obs[unique_ep_indices] = ep_obs
if self._save_obs_next:
self._meta.obs_next[unique_ep_indices] = ep_obs_next
self._meta.rew[unique_ep_indices] = ep_rew.astype(np.float32)
def _compute_reward(self, obs: Batch, lead_dims: int = 2) -> np.ndarray:
lead_shape = obs.observation.shape[:lead_dims]
g = obs.desired_goal.reshape(-1, *obs.desired_goal.shape[lead_dims:])
ag = obs.achieved_goal.reshape(-1, *obs.achieved_goal.shape[lead_dims:])
rewards = self.compute_reward_fn(ag, g)
return rewards.reshape(*lead_shape, *rewards.shape[1:])