Hindsight Experience Replay as a replay buffer (#753)

## implementation
I implemented HER solely as a replay buffer. It is done by temporarily
directly re-writing transitions storage (`self._meta`) during the
`sample_indices()` call. The original transitions are cached and will be
restored at the beginning of the next sampling or when other methods is
called. This will make sure that. for example, n-step return calculation
can be done without altering the policy.

There is also a problem with the original indices sampling. The sampled
indices are not guaranteed to be from different episodes. So I decided
to perform re-writing based on the episode. This guarantees that the
sampled transitions from the same episode will have the same re-written
goal. This also make the re-writing ratio calculation slightly differ
from the paper, but it won't be too different if there are many episodes
in the buffer.

In the current commit, HER replay buffer only support 'future' strategy
and online sampling. This is the best of HER in term of performance and
memory efficiency.

I also add a few more convenient replay buffers
(`HERVectorReplayBuffer`, `HERReplayBufferManager`), test env
(`MyGoalEnv`), gym wrapper (`TruncatedAsTerminated`), unit tests, and a
simple example (examples/offline/fetch_her_ddpg.py).

## verification
I have added unit tests for almost everything I have implemented.
HER replay buffer was also tested using DDPG on [`FetchReach-v3`
env](https://github.com/Farama-Foundation/Gymnasium-Robotics). I used
default DDPG parameters from mujoco example and didn't tune anything
further to get this good result! (train script:
examples/offline/fetch_her_ddpg.py).


![Screen Shot 2022-10-02 at 19 22
53](https://user-images.githubusercontent.com/42699114/193454066-0dd0c65c-fd5f-4587-8912-b441d39de88a.png)
This commit is contained in:
Juno T 2022-10-31 08:54:54 +09:00 committed by GitHub
parent 41ae3461f6
commit d42a5fb354
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 843 additions and 13 deletions

View File

@ -39,6 +39,7 @@
- [Generalized Advantage Estimator (GAE)](https://arxiv.org/pdf/1506.02438.pdf)
- [Posterior Sampling Reinforcement Learning (PSRL)](https://www.ece.uvic.ca/~bctill/papers/learning/Strens_2000.pdf)
- [Intrinsic Curiosity Module (ICM)](https://arxiv.org/pdf/1705.05363.pdf)
- [Hindsight Experience Replay (HER)](https://arxiv.org/pdf/1707.01495.pdf)
Here are Tianshou's other features:

View File

@ -30,6 +30,14 @@ PrioritizedReplayBuffer
:undoc-members:
:show-inheritance:
HERReplayBuffer
~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: tianshou.data.HERReplayBuffer
:members:
:undoc-members:
:show-inheritance:
ReplayBufferManager
~~~~~~~~~~~~~~~~~~~
@ -46,6 +54,15 @@ PrioritizedReplayBufferManager
:undoc-members:
:show-inheritance:
HERReplayBufferManager
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: tianshou.data.HERReplayBufferManager
:members:
:undoc-members:
:show-inheritance:
VectorReplayBuffer
~~~~~~~~~~~~~~~~~~
@ -62,6 +79,14 @@ PrioritizedVectorReplayBuffer
:undoc-members:
:show-inheritance:
HERVectorReplayBuffer
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: tianshou.data.HERVectorReplayBuffer
:members:
:undoc-members:
:show-inheritance:
CachedReplayBuffer
~~~~~~~~~~~~~~~~~~

View File

@ -40,6 +40,7 @@ Welcome to Tianshou!
* :class:`~tianshou.policy.ICMPolicy` `Intrinsic Curiosity Module <https://arxiv.org/pdf/1705.05363.pdf>`_
* :class:`~tianshou.data.PrioritizedReplayBuffer` `Prioritized Experience Replay <https://arxiv.org/pdf/1511.05952.pdf>`_
* :meth:`~tianshou.policy.BasePolicy.compute_episodic_return` `Generalized Advantage Estimator <https://arxiv.org/pdf/1506.02438.pdf>`_
* :class:`~tianshou.data.HERReplayBuffer` `Hindsight Experience Replay <https://arxiv.org/pdf/1707.01495.pdf>`_
Here is Tianshou's other features:

View File

@ -52,6 +52,7 @@ mujoco
jit
nstep
preprocess
preprocessing
repo
ReLU
namespace

View File

@ -20,6 +20,7 @@ Supported algorithms are listed below:
- [Advantage Actor-Critic (A2C)](https://openai.com/blog/baselines-acktr-a2c/), [commit id](https://github.com/thu-ml/tianshou/tree/1730a9008ad6bb67cac3b21347bed33b532b17bc)
- [Proximal Policy Optimization (PPO)](https://arxiv.org/pdf/1707.06347.pdf), [commit id](https://github.com/thu-ml/tianshou/tree/6426a39796db052bafb7cabe85c764db20a722b0)
- [Trust Region Policy Optimization (TRPO)](https://arxiv.org/pdf/1502.05477.pdf), [commit id](https://github.com/thu-ml/tianshou/tree/5057b5c89e6168220272c9c28a15b758a72efc32)
- [Hindsight Experience Replay (HER)](https://arxiv.org/abs/1707.01495)
## EnvPool
@ -304,6 +305,18 @@ For pretrained agents, detailed graphs (single agent, single game) and log detai
1. All shared hyperparameters are exactly the same as TRPO, regarding how similar these two algorithms are.
2. We found different games in Mujoco may require quite different `actor-step-size`: Reacher/Swimmer are insensitive to step-size in range (0.1~1.0), while InvertedDoublePendulum / InvertedPendulum / Humanoid are quite sensitive to step size, and even 0.1 is too large. Other games may require `actor-step-size` in range (0.1~0.4), but aren't that sensitive in general.
## Others
### HER
| Environment | DDPG without HER | DDPG with HER |
| :--------------------: | :--------------: | :--------------: |
| FetchReach | -49.9±0.2. | **-17.6±21.7** |
#### Hints for HER
1. The HER technique is proposed for solving task-based environments, so it cannot be compared with non-task-based mujoco benchmarks. The environment used in this evaluation is ``FetchReach-v3`` which requires an extra [installation](https://github.com/Farama-Foundation/Gymnasium-Robotics).
2. Simple hyperparameters optimizations are done for both settings, DDPG with and without HER. However, since *DDPG without HER* failed in every experiment, the best hyperparameters for *DDPG with HER* are used in the evaluation of both settings.
3. The scores are the mean reward ± 1 standard deviation of 16 seeds. The minimum reward for ``FetchReach-v3`` is -50 which we can imply that *DDPG without HER* performs as good as a random policy. *DDPG with HER* although has a better mean reward, the standard deviation is quite high. This is because in this setting, the agent will either fail completely (-50 reward) or successfully learn the task (close to 0 reward). This means that the agent successfully learned in about 70% of the 16 seeds.
## Note
<a name="footnote1">[1]</a> Supported environments include HalfCheetah-v3, Hopper-v3, Swimmer-v3, Walker2d-v3, Ant-v3, Humanoid-v3, Reacher-v2, InvertedPendulum-v2 and InvertedDoublePendulum-v2. Pusher, Thrower, Striker and HumanoidStandup are not supported because they are not commonly seen in literatures.

View File

@ -0,0 +1,228 @@
#!/usr/bin/env python3
import argparse
import datetime
import os
import pprint
import gym
import numpy as np
import torch
import wandb
from torch.utils.tensorboard import SummaryWriter
from tianshou.data import (
Collector,
HERReplayBuffer,
HERVectorReplayBuffer,
ReplayBuffer,
VectorReplayBuffer,
)
from tianshou.env import ShmemVectorEnv, TruncatedAsTerminated
from tianshou.exploration import GaussianNoise
from tianshou.policy import DDPGPolicy
from tianshou.trainer import offpolicy_trainer
from tianshou.utils import TensorboardLogger, WandbLogger
from tianshou.utils.net.common import Net, get_dict_state_decorator
from tianshou.utils.net.continuous import Actor, Critic
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--task", type=str, default="FetchReach-v3")
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--buffer-size", type=int, default=100000)
parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[256, 256])
parser.add_argument("--actor-lr", type=float, default=1e-3)
parser.add_argument("--critic-lr", type=float, default=3e-3)
parser.add_argument("--gamma", type=float, default=0.99)
parser.add_argument("--tau", type=float, default=0.005)
parser.add_argument("--exploration-noise", type=float, default=0.1)
parser.add_argument("--start-timesteps", type=int, default=25000)
parser.add_argument("--epoch", type=int, default=10)
parser.add_argument("--step-per-epoch", type=int, default=5000)
parser.add_argument("--step-per-collect", type=int, default=1)
parser.add_argument("--update-per-step", type=int, default=1)
parser.add_argument("--n-step", type=int, default=1)
parser.add_argument("--batch-size", type=int, default=512)
parser.add_argument(
"--replay-buffer", type=str, default="her", choices=["normal", "her"]
)
parser.add_argument("--her-horizon", type=int, default=50)
parser.add_argument("--her-future-k", type=int, default=8)
parser.add_argument("--training-num", type=int, default=1)
parser.add_argument("--test-num", type=int, default=10)
parser.add_argument("--logdir", type=str, default="log")
parser.add_argument("--render", type=float, default=0.)
parser.add_argument(
"--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu"
)
parser.add_argument("--resume-path", type=str, default=None)
parser.add_argument("--resume-id", type=str, default=None)
parser.add_argument(
"--logger",
type=str,
default="tensorboard",
choices=["tensorboard", "wandb"],
)
parser.add_argument("--wandb-project", type=str, default="HER-benchmark")
parser.add_argument(
"--watch",
default=False,
action="store_true",
help="watch the play of pre-trained policy only",
)
return parser.parse_args()
def make_fetch_env(task, training_num, test_num):
env = TruncatedAsTerminated(gym.make(task))
train_envs = ShmemVectorEnv(
[lambda: TruncatedAsTerminated(gym.make(task)) for _ in range(training_num)]
)
test_envs = ShmemVectorEnv(
[lambda: TruncatedAsTerminated(gym.make(task)) for _ in range(test_num)]
)
return env, train_envs, test_envs
def test_ddpg(args=get_args()):
# log
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
args.algo_name = "ddpg"
log_name = os.path.join(args.task, args.algo_name, str(args.seed), now)
log_path = os.path.join(args.logdir, log_name)
# logger
if args.logger == "wandb":
logger = WandbLogger(
save_interval=1,
name=log_name.replace(os.path.sep, "__"),
run_id=args.resume_id,
config=args,
project=args.wandb_project,
)
logger.wandb_run.config.setdefaults(vars(args))
args = argparse.Namespace(**wandb.config)
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
if args.logger == "tensorboard":
logger = TensorboardLogger(writer)
else: # wandb
logger.load(writer)
env, train_envs, test_envs = make_fetch_env(
args.task, args.training_num, args.test_num
)
args.state_shape = {
'observation': env.observation_space['observation'].shape,
'achieved_goal': env.observation_space['achieved_goal'].shape,
'desired_goal': env.observation_space['desired_goal'].shape,
}
args.action_shape = env.action_space.shape or env.action_space.n
args.max_action = env.action_space.high[0]
args.exploration_noise = args.exploration_noise * args.max_action
print("Observations shape:", args.state_shape)
print("Actions shape:", args.action_shape)
print("Action range:", np.min(env.action_space.low), np.max(env.action_space.high))
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
# model
dict_state_dec, flat_state_shape = get_dict_state_decorator(
state_shape=args.state_shape,
keys=['observation', 'achieved_goal', 'desired_goal']
)
net_a = dict_state_dec(Net)(
flat_state_shape, hidden_sizes=args.hidden_sizes, device=args.device
)
actor = dict_state_dec(Actor)(
net_a, args.action_shape, max_action=args.max_action, device=args.device
).to(args.device)
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
net_c = dict_state_dec(Net)(
flat_state_shape,
action_shape=args.action_shape,
hidden_sizes=args.hidden_sizes,
concat=True,
device=args.device,
)
critic = dict_state_dec(Critic)(net_c, device=args.device).to(args.device)
critic_optim = torch.optim.Adam(critic.parameters(), lr=args.critic_lr)
policy = DDPGPolicy(
actor,
actor_optim,
critic,
critic_optim,
tau=args.tau,
gamma=args.gamma,
exploration_noise=GaussianNoise(sigma=args.exploration_noise),
estimation_step=args.n_step,
action_space=env.action_space,
)
# load a previous policy
if args.resume_path:
policy.load_state_dict(torch.load(args.resume_path, map_location=args.device))
print("Loaded agent from: ", args.resume_path)
# collector
def compute_reward_fn(ag: np.ndarray, g: np.ndarray):
return env.compute_reward(ag, g, {})
if args.replay_buffer == "normal":
if args.training_num > 1:
buffer = VectorReplayBuffer(args.buffer_size, len(train_envs))
else:
buffer = ReplayBuffer(args.buffer_size)
else:
if args.training_num > 1:
buffer = HERVectorReplayBuffer(
args.buffer_size,
len(train_envs),
compute_reward_fn=compute_reward_fn,
horizon=args.her_horizon,
future_k=args.her_future_k,
)
else:
buffer = HERReplayBuffer(
args.buffer_size,
compute_reward_fn=compute_reward_fn,
horizon=args.her_horizon,
future_k=args.her_future_k,
)
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
test_collector = Collector(policy, test_envs)
train_collector.collect(n_step=args.start_timesteps, random=True)
def save_best_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth"))
if not args.watch:
# trainer
result = offpolicy_trainer(
policy,
train_collector,
test_collector,
args.epoch,
args.step_per_epoch,
args.step_per_collect,
args.test_num,
args.batch_size,
save_best_fn=save_best_fn,
logger=logger,
update_per_step=args.update_per_step,
test_in_train=False,
)
pprint.pprint(result)
# Let's watch its performance!
policy.eval()
test_envs.seed(args.seed)
test_collector.reset()
result = test_collector.collect(n_episode=args.test_num, render=args.render)
print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}')
if __name__ == "__main__":
test_ddpg()

View File

@ -166,3 +166,48 @@ class NXEnv(gym.Env):
for i in range(self.size):
self.graph.nodes[i]["data"] = next_graph_state[i]
return self._encode_obs(), 1.0, 0, 0, {}
class MyGoalEnv(MyTestEnv):
def __init__(self, *args, **kwargs):
assert kwargs.get("dict_state", 0) + kwargs.get("recurse_state", 0) == 0, \
"dict_state / recurse_state not supported"
super().__init__(*args, **kwargs)
obs, _ = super().reset(state=0)
obs, _, _, _, _ = super().step(1)
self._goal = obs * self.size
super_obsv = self.observation_space
self.observation_space = gym.spaces.Dict(
{
'observation': super_obsv,
'achieved_goal': super_obsv,
'desired_goal': super_obsv,
}
)
def reset(self, *args, **kwargs):
obs, info = super().reset(*args, **kwargs)
new_obs = {
'observation': obs,
'achieved_goal': obs,
'desired_goal': self._goal
}
return new_obs, info
def step(self, *args, **kwargs):
obs_next, rew, terminated, truncated, info = super().step(*args, **kwargs)
new_obs_next = {
'observation': obs_next,
'achieved_goal': obs_next,
'desired_goal': self._goal
}
return new_obs_next, rew, terminated, truncated, info
def compute_reward_fn(
self, achieved_goal: np.ndarray, desired_goal: np.ndarray, info: dict
) -> np.ndarray:
axis = -1
if self.array_state:
axis = (-3, -2, -1)
return (achieved_goal == desired_goal).all(axis=axis)

View File

@ -11,6 +11,8 @@ import torch
from tianshou.data import (
Batch,
CachedReplayBuffer,
HERReplayBuffer,
HERVectorReplayBuffer,
PrioritizedReplayBuffer,
PrioritizedVectorReplayBuffer,
ReplayBuffer,
@ -20,9 +22,9 @@ from tianshou.data import (
from tianshou.data.utils.converter import to_hdf5
if __name__ == '__main__':
from env import MyTestEnv
from env import MyGoalEnv, MyTestEnv
else: # pytest
from test.base.env import MyTestEnv
from test.base.env import MyGoalEnv, MyTestEnv
def test_replaybuffer(size=10, bufsize=20):
@ -300,6 +302,142 @@ def test_priortized_replaybuffer(size=32, bufsize=15):
assert weight[~mask][0] < weight[mask][0] and weight[mask][0] <= 1
def test_herreplaybuffer(size=10, bufsize=100, sample_sz=4):
env_size = size
env = MyGoalEnv(env_size, array_state=True)
def compute_reward_fn(ag, g):
return env.compute_reward_fn(ag, g, {})
buf = HERReplayBuffer(
bufsize, compute_reward_fn=compute_reward_fn, horizon=30, future_k=8
)
buf2 = HERVectorReplayBuffer(
bufsize,
buffer_num=3,
compute_reward_fn=compute_reward_fn,
horizon=30,
future_k=8
)
# Apply her on every episodes sampled (Hacky but necessary for deterministic test)
buf.future_p = 1
for buf2_buf in buf2.buffers:
buf2_buf.future_p = 1
obs, _ = env.reset()
action_list = [1] * 5 + [0] * 10 + [1] * 10
for i, act in enumerate(action_list):
obs_next, rew, terminated, truncated, info = env.step(act)
batch = Batch(
obs=obs,
act=[act],
rew=rew,
terminated=terminated,
truncated=truncated,
obs_next=obs_next,
info=info
)
buf.add(batch)
buf2.add(Batch.stack([batch, batch, batch]), buffer_ids=[0, 1, 2])
obs = obs_next
assert len(buf) == min(bufsize, i + 1)
assert len(buf2) == min(bufsize, 3 * (i + 1))
batch, indices = buf.sample(sample_sz)
# Check that goals are the same for the episode (only 1 ep in buffer)
tmp_indices = indices.copy()
for _ in range(2 * env_size):
obs = buf[tmp_indices].obs
obs_next = buf[tmp_indices].obs_next
rew = buf[tmp_indices].rew
g = obs.desired_goal.reshape(sample_sz, -1)[:, 0]
ag_next = obs_next.achieved_goal.reshape(sample_sz, -1)[:, 0]
g_next = obs_next.desired_goal.reshape(sample_sz, -1)[:, 0]
assert np.all(g == g[0])
assert np.all(g_next == g_next[0])
assert np.all(rew == (ag_next == g).astype(np.float32))
tmp_indices = buf.next(tmp_indices)
# Check that goals are correctly restored
buf._restore_cache()
tmp_indices = indices.copy()
for _ in range(2 * env_size):
obs = buf[tmp_indices].obs
obs_next = buf[tmp_indices].obs_next
g = obs.desired_goal.reshape(sample_sz, -1)[:, 0]
g_next = obs_next.desired_goal.reshape(sample_sz, -1)[:, 0]
assert np.all(g == env_size)
assert np.all(g_next == g_next[0])
assert np.all(g == g[0])
tmp_indices = buf.next(tmp_indices)
# Test vector buffer
batch, indices = buf2.sample(sample_sz)
# Check that goals are the same for the episode (only 1 ep in buffer)
tmp_indices = indices.copy()
for _ in range(2 * env_size):
obs = buf2[tmp_indices].obs
obs_next = buf2[tmp_indices].obs_next
rew = buf2[tmp_indices].rew
g = obs.desired_goal.reshape(sample_sz, -1)[:, 0]
ag_next = obs_next.achieved_goal.reshape(sample_sz, -1)[:, 0]
g_next = obs_next.desired_goal.reshape(sample_sz, -1)[:, 0]
assert np.all(g == g_next)
assert np.all(rew == (ag_next == g).astype(np.float32))
tmp_indices = buf2.next(tmp_indices)
# Check that goals are correctly restored
buf2._restore_cache()
tmp_indices = indices.copy()
for _ in range(2 * env_size):
obs = buf2[tmp_indices].obs
obs_next = buf2[tmp_indices].obs_next
g = obs.desired_goal.reshape(sample_sz, -1)[:, 0]
g_next = obs_next.desired_goal.reshape(sample_sz, -1)[:, 0]
assert np.all(g == env_size)
assert np.all(g_next == g_next[0])
assert np.all(g == g[0])
tmp_indices = buf2.next(tmp_indices)
# Test handling cycled indices
env_size = size
bufsize = 15
env = MyGoalEnv(env_size, array_state=False)
def compute_reward_fn(ag, g):
return env.compute_reward_fn(ag, g, {})
buf = HERReplayBuffer(
bufsize, compute_reward_fn=compute_reward_fn, horizon=30, future_k=8
)
buf._index = 5 # shifted start index
buf.future_p = 1
action_list = [1] * 10
for ep_len in [5, 10]:
obs, _ = env.reset()
for i in range(ep_len):
act = 1
obs_next, rew, terminated, truncated, info = env.step(act)
batch = Batch(
obs=obs,
act=[act],
rew=rew,
terminated=(i == ep_len - 1),
truncated=(i == ep_len - 1),
obs_next=obs_next,
info=info
)
buf.add(batch)
obs = obs_next
batch, indices = buf.sample(0)
assert np.all(buf[:5].obs.desired_goal == buf[0].obs.desired_goal)
assert np.all(buf[5:10].obs.desired_goal == buf[5].obs.desired_goal)
assert np.all(buf[10:].obs.desired_goal == buf[0].obs.desired_goal) # (same ep)
assert np.all(buf[0].obs.desired_goal != buf[5].obs.desired_goal) # (diff ep)
def test_update():
buf1 = ReplayBuffer(4, stack_num=2)
buf2 = ReplayBuffer(4, stack_num=2)
@ -1180,3 +1318,4 @@ if __name__ == '__main__':
test_multibuf_stack()
test_multibuf_hdf5()
test_from_data()
test_herreplaybuffer()

View File

@ -16,6 +16,7 @@ from tianshou.env import (
SubprocVectorEnv,
VectorEnvNormObs,
)
from tianshou.env.gym_wrappers import TruncatedAsTerminated
from tianshou.utils import RunningMeanStd
if __name__ == "__main__":
@ -347,6 +348,10 @@ def test_gym_wrappers():
self.action_space = gym.spaces.Box(
low=-1.0, high=2.0, shape=(4, ), dtype=np.float32
)
self.observation_space = gym.spaces.Discrete(2)
def step(self, act):
return self.observation_space.sample(), -1, False, True, {}
bsz = 10
action_per_branch = [4, 6, 10, 7]
@ -374,6 +379,14 @@ def test_gym_wrappers():
env_d.action(np.array([env_d.action_space.n - 1] * bsz)),
np.array([env_m.action_space.nvec - 1] * bsz),
)
# check truncate is True when terminated
try:
env_t = TruncatedAsTerminated(env)
except EnvironmentError:
env_t = None
if env_t is not None:
_, _, truncated, _, _ = env_t.step(env_t.action_space.sample())
assert truncated
@pytest.mark.skipif(envpool is None, reason="EnvPool doesn't support this platform")

View File

@ -6,13 +6,16 @@ from tianshou.data.utils.converter import to_numpy, to_torch, to_torch_as
from tianshou.data.utils.segtree import SegmentTree
from tianshou.data.buffer.base import ReplayBuffer
from tianshou.data.buffer.prio import PrioritizedReplayBuffer
from tianshou.data.buffer.her import HERReplayBuffer
from tianshou.data.buffer.manager import (
ReplayBufferManager,
PrioritizedReplayBufferManager,
HERReplayBufferManager,
)
from tianshou.data.buffer.vecbuf import (
VectorReplayBuffer,
HERVectorReplayBuffer,
PrioritizedVectorReplayBuffer,
VectorReplayBuffer,
)
from tianshou.data.buffer.cached import CachedReplayBuffer
from tianshou.data.collector import Collector, AsyncCollector
@ -25,10 +28,13 @@ __all__ = [
"SegmentTree",
"ReplayBuffer",
"PrioritizedReplayBuffer",
"HERReplayBuffer",
"ReplayBufferManager",
"PrioritizedReplayBufferManager",
"HERReplayBufferManager",
"VectorReplayBuffer",
"PrioritizedVectorReplayBuffer",
"HERVectorReplayBuffer",
"CachedReplayBuffer",
"Collector",
"AsyncCollector",

186
tianshou/data/buffer/her.py Normal file
View File

@ -0,0 +1,186 @@
from typing import Any, Callable, List, Optional, Tuple, Union
import numpy as np
from tianshou.data import Batch, ReplayBuffer
class HERReplayBuffer(ReplayBuffer):
"""Implementation of Hindsight Experience Replay. arXiv:1707.01495.
HERReplayBuffer is to be used with goal-based environment where the
observation is a dictionary with keys ``observation``, ``achieved_goal`` and
``desired_goal``. Currently support only HER's future strategy, online sampling.
:param int size: the size of the replay buffer.
:param compute_reward_fn: a function that takes 2 ``np.array`` arguments,
``acheived_goal`` and ``desired_goal``, and returns rewards as ``np.array``.
The two arguments are of shape (batch_size, ...original_shape) and the returned
rewards must be of shape (batch_size,).
:param int horizon: the maximum number of steps in an episode.
:param int future_k: the 'k' parameter introduced in the paper. In short, there
will be at most k episodes that are re-written for every 1 unaltered episode
during the sampling.
.. seealso::
Please refer to :class:`~tianshou.data.ReplayBuffer` for other APIs' usage.
"""
def __init__(
self,
size: int,
compute_reward_fn: Callable[[np.ndarray, np.ndarray], np.ndarray],
horizon: int,
future_k: float = 8.0,
**kwargs: Any,
) -> None:
super().__init__(size, **kwargs)
self.horizon = horizon
self.future_p = 1 - 1 / future_k
self.compute_reward_fn = compute_reward_fn
self._original_meta = Batch()
self._altered_indices = np.array([])
def _restore_cache(self) -> None:
"""Write cached original meta back to `self._meta`.
It's called everytime before 'writing', 'sampling' or 'saving' the buffer.
"""
if not hasattr(self, '_altered_indices'):
return
if self._altered_indices.size == 0:
return
self._meta[self._altered_indices] = self._original_meta
# Clean
self._original_meta = Batch()
self._altered_indices = np.array([])
def reset(self, keep_statistics: bool = False) -> None:
self._restore_cache()
return super().reset(keep_statistics)
def save_hdf5(self, path: str, compression: Optional[str] = None) -> None:
self._restore_cache()
return super().save_hdf5(path, compression)
def set_batch(self, batch: Batch) -> None:
self._restore_cache()
return super().set_batch(batch)
def update(self, buffer: Union["HERReplayBuffer", "ReplayBuffer"]) -> np.ndarray:
self._restore_cache()
return super().update(buffer)
def add(
self,
batch: Batch,
buffer_ids: Optional[Union[np.ndarray, List[int]]] = None
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
self._restore_cache()
return super().add(batch, buffer_ids)
def sample_indices(self, batch_size: int) -> np.ndarray:
"""Get a random sample of index with size = batch_size.
Return all available indices in the buffer if batch_size is 0; return an \
empty numpy array if batch_size < 0 or no available index can be sampled. \
Additionally, some episodes of the sampled transitions will be re-written \
according to HER.
"""
self._restore_cache()
indices = super().sample_indices(batch_size=batch_size)
self.rewrite_transitions(indices.copy())
return indices
def rewrite_transitions(self, indices: np.ndarray) -> None:
"""Re-write the goal of some sampled transitions' episodes according to HER.
Currently applies only HER's 'future' strategy. The new goals will be written \
directly to the internal batch data temporarily and will be restored right \
before the next sampling or when using some of the buffer's method (e.g. \
`add`, `save_hdf5`, etc.). This is to make sure that n-step returns \
calculation etc., performs correctly without additional alteration.
"""
if indices.size == 0:
return
# Sort indices keeping chronological order
indices[indices < self._index] += self.maxsize
indices = np.sort(indices)
indices[indices >= self.maxsize] -= self.maxsize
# Construct episode trajectories
indices = [indices]
for _ in range(self.horizon - 1):
indices.append(self.next(indices[-1]))
indices = np.stack(indices)
# Calculate future timestep to use
current = indices[0]
terminal = indices[-1]
future_offset = np.random.uniform(size=len(indices[0])) * (terminal - current)
future_offset = future_offset.astype(int)
future_t = (current + future_offset)
# Compute indices
# open indices are used to find longest, unique trajectories among
# presented episodes
unique_ep_open_indices = np.sort(np.unique(terminal, return_index=True)[1])
unique_ep_indices = indices[:, unique_ep_open_indices]
# close indices are used to find max future_t among presented episodes
unique_ep_close_indices = np.hstack(
[(unique_ep_open_indices - 1)[1:],
len(terminal) - 1]
)
# episode indices that will be altered
her_ep_indices = np.random.choice(
len(unique_ep_open_indices),
size=int(len(unique_ep_open_indices) * self.future_p),
replace=False
)
# Cache original meta
self._altered_indices = unique_ep_indices.copy()
self._original_meta = self._meta[self._altered_indices].copy()
# Copy original obs, ep_rew (and obs_next), and obs of future time step
ep_obs = self[unique_ep_indices].obs
ep_rew = self[unique_ep_indices].rew
if self._save_obs_next:
ep_obs_next = self[unique_ep_indices].obs_next
future_obs = self[future_t[unique_ep_close_indices]].obs_next
else:
future_obs = self[self.next(future_t[unique_ep_close_indices])].obs
# Re-assign goals and rewards via broadcast assignment
ep_obs.desired_goal[:, her_ep_indices] = \
future_obs.achieved_goal[None, her_ep_indices]
if self._save_obs_next:
ep_obs_next.desired_goal[:, her_ep_indices] = \
future_obs.achieved_goal[None, her_ep_indices]
ep_rew[:, her_ep_indices] = \
self._compute_reward(ep_obs_next)[:, her_ep_indices]
else:
tmp_ep_obs_next = self[self.next(unique_ep_indices)].obs
ep_rew[:, her_ep_indices] = \
self._compute_reward(tmp_ep_obs_next)[:, her_ep_indices]
# Sanity check
assert ep_obs.desired_goal.shape[:2] == unique_ep_indices.shape
assert ep_obs.achieved_goal.shape[:2] == unique_ep_indices.shape
assert ep_rew.shape == unique_ep_indices.shape
# Re-write meta
self._meta.obs[unique_ep_indices] = ep_obs
if self._save_obs_next:
self._meta.obs_next[unique_ep_indices] = ep_obs_next
self._meta.rew[unique_ep_indices] = ep_rew.astype(np.float32)
def _compute_reward(self, obs: Batch, lead_dims: int = 2) -> np.ndarray:
lead_shape = obs.observation.shape[:lead_dims]
g = obs.desired_goal.reshape(-1, *obs.desired_goal.shape[lead_dims:])
ag = obs.achieved_goal.reshape(-1, *obs.achieved_goal.shape[lead_dims:])
rewards = self.compute_reward_fn(ag, g)
return rewards.reshape(*lead_shape, *rewards.shape[1:])

View File

@ -3,7 +3,7 @@ from typing import List, Optional, Sequence, Tuple, Union
import numpy as np
from numba import njit
from tianshou.data import Batch, PrioritizedReplayBuffer, ReplayBuffer
from tianshou.data import Batch, HERReplayBuffer, PrioritizedReplayBuffer, ReplayBuffer
from tianshou.data.batch import _alloc_by_keys_diff, _create_value
@ -21,7 +21,9 @@ class ReplayBufferManager(ReplayBuffer):
Please refer to :class:`~tianshou.data.ReplayBuffer` for other APIs' usage.
"""
def __init__(self, buffer_list: List[ReplayBuffer]) -> None:
def __init__(
self, buffer_list: Union[List[ReplayBuffer], List[HERReplayBuffer]]
) -> None:
self.buffer_num = len(buffer_list)
self.buffers = np.array(buffer_list, dtype=object)
offset, size = [], 0
@ -212,6 +214,48 @@ class PrioritizedReplayBufferManager(PrioritizedReplayBuffer, ReplayBufferManage
PrioritizedReplayBuffer.__init__(self, self.maxsize, **kwargs)
class HERReplayBufferManager(ReplayBufferManager):
"""HERReplayBufferManager contains a list of HERReplayBuffer with \
exactly the same configuration.
These replay buffers have contiguous memory layout, and the storage space each
buffer has is a shallow copy of the topmost memory.
:param buffer_list: a list of HERReplayBuffer needed to be handled.
.. seealso::
Please refer to :class:`~tianshou.data.ReplayBuffer` for other APIs' usage.
"""
def __init__(self, buffer_list: List[HERReplayBuffer]) -> None:
super().__init__(buffer_list)
def _restore_cache(self) -> None:
for buf in self.buffers:
buf._restore_cache()
def save_hdf5(self, path: str, compression: Optional[str] = None) -> None:
self._restore_cache()
return super().save_hdf5(path, compression)
def set_batch(self, batch: Batch) -> None:
self._restore_cache()
return super().set_batch(batch)
def update(self, buffer: Union["HERReplayBuffer", "ReplayBuffer"]) -> np.ndarray:
self._restore_cache()
return super().update(buffer)
def add(
self,
batch: Batch,
buffer_ids: Optional[Union[np.ndarray, List[int]]] = None
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
self._restore_cache()
return super().add(batch, buffer_ids)
@njit
def _prev_index(
index: np.ndarray,

View File

@ -3,6 +3,8 @@ from typing import Any
import numpy as np
from tianshou.data import (
HERReplayBuffer,
HERReplayBufferManager,
PrioritizedReplayBuffer,
PrioritizedReplayBufferManager,
ReplayBuffer,
@ -64,3 +66,26 @@ class PrioritizedVectorReplayBuffer(PrioritizedReplayBufferManager):
def set_beta(self, beta: float) -> None:
for buffer in self.buffers:
buffer.set_beta(beta)
class HERVectorReplayBuffer(HERReplayBufferManager):
"""HERVectorReplayBuffer contains n HERReplayBuffer with same size.
It is used for storing transition from different environments yet keeping the order
of time.
:param int total_size: the total size of HERVectorReplayBuffer.
:param int buffer_num: the number of HERReplayBuffer it uses, which are
under the same configuration.
Other input arguments are the same as :class:`~tianshou.data.HERReplayBuffer`.
.. seealso::
Please refer to :class:`~tianshou.data.ReplayBuffer` for other APIs' usage.
"""
def __init__(self, total_size: int, buffer_num: int, **kwargs: Any) -> None:
assert buffer_num > 0
size = int(np.ceil(total_size / buffer_num))
buffer_list = [HERReplayBuffer(size, **kwargs) for _ in range(buffer_num)]
super().__init__(buffer_list)

View File

@ -1,6 +1,10 @@
"""Env package."""
from tianshou.env.gym_wrappers import ContinuousToDiscrete, MultiDiscreteToDiscrete
from tianshou.env.gym_wrappers import (
ContinuousToDiscrete,
MultiDiscreteToDiscrete,
TruncatedAsTerminated,
)
from tianshou.env.venv_wrappers import VectorEnvNormObs, VectorEnvWrapper
from tianshou.env.venvs import (
BaseVectorEnv,
@ -26,4 +30,5 @@ __all__ = [
"PettingZooEnv",
"ContinuousToDiscrete",
"MultiDiscreteToDiscrete",
"TruncatedAsTerminated",
]

View File

@ -1,7 +1,8 @@
from typing import List, Union
from typing import Any, Dict, List, Tuple, Union
import gym
import numpy as np
from packaging import version
class ContinuousToDiscrete(gym.ActionWrapper):
@ -55,3 +56,25 @@ class MultiDiscreteToDiscrete(gym.ActionWrapper):
converted_act.append(act // b)
act = act % b
return np.array(converted_act).transpose()
class TruncatedAsTerminated(gym.Wrapper):
"""A wrapper that set ``terminated = terminated or truncated`` for ``step()``.
It's intended to use with ``gym.wrappers.TimeLimit``.
:param gym.Env env: gym environment.
"""
def __init__(self, env: gym.Env):
super().__init__(env)
if not version.parse(gym.__version__) >= version.parse('0.26.0'):
raise EnvironmentError(
f"TruncatedAsTerminated is not applicable with gym version \
{gym.__version__}"
)
def step(self, act: np.ndarray) -> Tuple[Any, float, bool, bool, Dict[Any, Any]]:
observation, reward, terminated, truncated, info = super().step(act)
terminated = (terminated or truncated)
return observation, reward, terminated, truncated, info

View File

@ -89,6 +89,7 @@ class BaseLogger(ABC):
self.write("update/gradient_step", step, log_data)
self.last_log_update_step = step
@abstractmethod
def save_data(
self,
epoch: int,
@ -106,6 +107,7 @@ class BaseLogger(ABC):
"""
pass
@abstractmethod
def restore_data(self) -> Tuple[int, int, int]:
"""Return the metadata from existing log.
@ -126,3 +128,15 @@ class LazyLogger(BaseLogger):
def write(self, step_type: str, step: int, data: LOG_DATA_TYPE) -> None:
"""The LazyLogger writes nothing."""
pass
def save_data(
self,
epoch: int,
env_step: int,
gradient_step: int,
save_checkpoint_fn: Optional[Callable[[int, int, int], str]] = None,
) -> None:
pass
def restore_data(self) -> Tuple[int, int, int]:
pass

View File

@ -1,5 +1,6 @@
from typing import (
Any,
Callable,
Dict,
List,
Optional,
@ -14,6 +15,8 @@ import numpy as np
import torch
from torch import nn
from tianshou.data.batch import Batch
ModuleType = Type[nn.Module]
@ -262,7 +265,7 @@ class Recurrent(nn.Module):
"""
obs = torch.as_tensor(
obs,
device=self.device, # type: ignore
device=self.device,
dtype=torch.float32,
)
# obs [bsz, len, dim] (training) or [bsz, dim] (evaluation)
@ -453,3 +456,61 @@ class BranchingNet(nn.Module):
action_scores = action_scores - torch.mean(action_scores, 2, keepdim=True)
logits = value_out + action_scores
return logits, state
def get_dict_state_decorator(
state_shape: Dict[str, Union[int, Sequence[int]]], keys: Sequence[str]
) -> Tuple[Callable, int]:
"""A helper function to make Net or equivalent classes (e.g. Actor, Critic) \
applicable to dict state.
The first return item, ``decorator_fn``, will alter the implementation of forward
function of the given class by preprocessing the observation. The preprocessing is
basically flatten the observation and concatenate them based on the ``keys`` order.
The batch dimension is preserved if presented. The result observation shape will
be equal to ``new_state_shape``, the second return item.
:param state_shape: A dictionary indicating each state's shape
:param keys: A list of state's keys. The flatten observation will be according to \
this list order.
:returns: a 2-items tuple ``decorator_fn`` and ``new_state_shape``
"""
original_shape = state_shape
flat_state_shapes = []
for k in keys:
flat_state_shapes.append(int(np.prod(state_shape[k])))
new_state_shape = sum(flat_state_shapes)
def preprocess_obs(
obs: Union[Batch, dict, torch.Tensor, np.ndarray]
) -> torch.Tensor:
if isinstance(obs, dict) or (isinstance(obs, Batch) and keys[0] in obs):
if original_shape[keys[0]] == obs[keys[0]].shape:
# No batch dim
new_obs = torch.Tensor([obs[k] for k in keys]).flatten()
# new_obs = torch.Tensor([obs[k] for k in keys]).reshape(1, -1)
else:
bsz = obs[keys[0]].shape[0]
new_obs = torch.cat(
[torch.Tensor(obs[k].reshape(bsz, -1)) for k in keys], dim=1
)
else:
new_obs = torch.Tensor(obs)
return new_obs
@no_type_check
def decorator_fn(net_class):
class new_net_class(net_class):
def forward(
self,
obs: Union[np.ndarray, torch.Tensor],
*args,
**kwargs,
) -> Any:
return super().forward(preprocess_obs(obs), *args, **kwargs)
return new_net_class
return decorator_fn, new_state_shape

View File

@ -124,13 +124,13 @@ class Critic(nn.Module):
"""Mapping: (s, a) -> logits -> Q(s, a)."""
obs = torch.as_tensor(
obs,
device=self.device, # type: ignore
device=self.device,
dtype=torch.float32,
).flatten(1)
if act is not None:
act = torch.as_tensor(
act,
device=self.device, # type: ignore
device=self.device,
dtype=torch.float32,
).flatten(1)
obs = torch.cat([obs, act], dim=1)
@ -266,7 +266,7 @@ class RecurrentActorProb(nn.Module):
"""Almost the same as :class:`~tianshou.utils.net.common.Recurrent`."""
obs = torch.as_tensor(
obs,
device=self.device, # type: ignore
device=self.device,
dtype=torch.float32,
)
# obs [bsz, len, dim] (training) or [bsz, dim] (evaluation)
@ -339,7 +339,7 @@ class RecurrentCritic(nn.Module):
"""Almost the same as :class:`~tianshou.utils.net.common.Recurrent`."""
obs = torch.as_tensor(
obs,
device=self.device, # type: ignore
device=self.device,
dtype=torch.float32,
)
# obs [bsz, len, dim] (training) or [bsz, dim] (evaluation)
@ -352,7 +352,7 @@ class RecurrentCritic(nn.Module):
if act is not None:
act = torch.as_tensor(
act,
device=self.device, # type: ignore
device=self.device,
dtype=torch.float32,
)
obs = torch.cat([obs, act], dim=1)