Hindsight Experience Replay as a replay buffer (#753)
## implementation I implemented HER solely as a replay buffer. It is done by temporarily directly re-writing transitions storage (`self._meta`) during the `sample_indices()` call. The original transitions are cached and will be restored at the beginning of the next sampling or when other methods is called. This will make sure that. for example, n-step return calculation can be done without altering the policy. There is also a problem with the original indices sampling. The sampled indices are not guaranteed to be from different episodes. So I decided to perform re-writing based on the episode. This guarantees that the sampled transitions from the same episode will have the same re-written goal. This also make the re-writing ratio calculation slightly differ from the paper, but it won't be too different if there are many episodes in the buffer. In the current commit, HER replay buffer only support 'future' strategy and online sampling. This is the best of HER in term of performance and memory efficiency. I also add a few more convenient replay buffers (`HERVectorReplayBuffer`, `HERReplayBufferManager`), test env (`MyGoalEnv`), gym wrapper (`TruncatedAsTerminated`), unit tests, and a simple example (examples/offline/fetch_her_ddpg.py). ## verification I have added unit tests for almost everything I have implemented. HER replay buffer was also tested using DDPG on [`FetchReach-v3` env](https://github.com/Farama-Foundation/Gymnasium-Robotics). I used default DDPG parameters from mujoco example and didn't tune anything further to get this good result! (train script: examples/offline/fetch_her_ddpg.py). 
This commit is contained in:
parent
41ae3461f6
commit
d42a5fb354
@ -39,6 +39,7 @@
|
||||
- [Generalized Advantage Estimator (GAE)](https://arxiv.org/pdf/1506.02438.pdf)
|
||||
- [Posterior Sampling Reinforcement Learning (PSRL)](https://www.ece.uvic.ca/~bctill/papers/learning/Strens_2000.pdf)
|
||||
- [Intrinsic Curiosity Module (ICM)](https://arxiv.org/pdf/1705.05363.pdf)
|
||||
- [Hindsight Experience Replay (HER)](https://arxiv.org/pdf/1707.01495.pdf)
|
||||
|
||||
Here are Tianshou's other features:
|
||||
|
||||
|
||||
@ -30,6 +30,14 @@ PrioritizedReplayBuffer
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
HERReplayBuffer
|
||||
~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: tianshou.data.HERReplayBuffer
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
ReplayBufferManager
|
||||
~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
@ -46,6 +54,15 @@ PrioritizedReplayBufferManager
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
|
||||
HERReplayBufferManager
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: tianshou.data.HERReplayBufferManager
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
VectorReplayBuffer
|
||||
~~~~~~~~~~~~~~~~~~
|
||||
|
||||
@ -62,6 +79,14 @@ PrioritizedVectorReplayBuffer
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
HERVectorReplayBuffer
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: tianshou.data.HERVectorReplayBuffer
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
CachedReplayBuffer
|
||||
~~~~~~~~~~~~~~~~~~
|
||||
|
||||
|
||||
@ -40,6 +40,7 @@ Welcome to Tianshou!
|
||||
* :class:`~tianshou.policy.ICMPolicy` `Intrinsic Curiosity Module <https://arxiv.org/pdf/1705.05363.pdf>`_
|
||||
* :class:`~tianshou.data.PrioritizedReplayBuffer` `Prioritized Experience Replay <https://arxiv.org/pdf/1511.05952.pdf>`_
|
||||
* :meth:`~tianshou.policy.BasePolicy.compute_episodic_return` `Generalized Advantage Estimator <https://arxiv.org/pdf/1506.02438.pdf>`_
|
||||
* :class:`~tianshou.data.HERReplayBuffer` `Hindsight Experience Replay <https://arxiv.org/pdf/1707.01495.pdf>`_
|
||||
|
||||
Here is Tianshou's other features:
|
||||
|
||||
|
||||
@ -52,6 +52,7 @@ mujoco
|
||||
jit
|
||||
nstep
|
||||
preprocess
|
||||
preprocessing
|
||||
repo
|
||||
ReLU
|
||||
namespace
|
||||
|
||||
@ -20,6 +20,7 @@ Supported algorithms are listed below:
|
||||
- [Advantage Actor-Critic (A2C)](https://openai.com/blog/baselines-acktr-a2c/), [commit id](https://github.com/thu-ml/tianshou/tree/1730a9008ad6bb67cac3b21347bed33b532b17bc)
|
||||
- [Proximal Policy Optimization (PPO)](https://arxiv.org/pdf/1707.06347.pdf), [commit id](https://github.com/thu-ml/tianshou/tree/6426a39796db052bafb7cabe85c764db20a722b0)
|
||||
- [Trust Region Policy Optimization (TRPO)](https://arxiv.org/pdf/1502.05477.pdf), [commit id](https://github.com/thu-ml/tianshou/tree/5057b5c89e6168220272c9c28a15b758a72efc32)
|
||||
- [Hindsight Experience Replay (HER)](https://arxiv.org/abs/1707.01495)
|
||||
|
||||
## EnvPool
|
||||
|
||||
@ -304,6 +305,18 @@ For pretrained agents, detailed graphs (single agent, single game) and log detai
|
||||
1. All shared hyperparameters are exactly the same as TRPO, regarding how similar these two algorithms are.
|
||||
2. We found different games in Mujoco may require quite different `actor-step-size`: Reacher/Swimmer are insensitive to step-size in range (0.1~1.0), while InvertedDoublePendulum / InvertedPendulum / Humanoid are quite sensitive to step size, and even 0.1 is too large. Other games may require `actor-step-size` in range (0.1~0.4), but aren't that sensitive in general.
|
||||
|
||||
## Others
|
||||
|
||||
### HER
|
||||
| Environment | DDPG without HER | DDPG with HER |
|
||||
| :--------------------: | :--------------: | :--------------: |
|
||||
| FetchReach | -49.9±0.2. | **-17.6±21.7** |
|
||||
|
||||
#### Hints for HER
|
||||
1. The HER technique is proposed for solving task-based environments, so it cannot be compared with non-task-based mujoco benchmarks. The environment used in this evaluation is ``FetchReach-v3`` which requires an extra [installation](https://github.com/Farama-Foundation/Gymnasium-Robotics).
|
||||
2. Simple hyperparameters optimizations are done for both settings, DDPG with and without HER. However, since *DDPG without HER* failed in every experiment, the best hyperparameters for *DDPG with HER* are used in the evaluation of both settings.
|
||||
3. The scores are the mean reward ± 1 standard deviation of 16 seeds. The minimum reward for ``FetchReach-v3`` is -50 which we can imply that *DDPG without HER* performs as good as a random policy. *DDPG with HER* although has a better mean reward, the standard deviation is quite high. This is because in this setting, the agent will either fail completely (-50 reward) or successfully learn the task (close to 0 reward). This means that the agent successfully learned in about 70% of the 16 seeds.
|
||||
|
||||
## Note
|
||||
|
||||
<a name="footnote1">[1]</a> Supported environments include HalfCheetah-v3, Hopper-v3, Swimmer-v3, Walker2d-v3, Ant-v3, Humanoid-v3, Reacher-v2, InvertedPendulum-v2 and InvertedDoublePendulum-v2. Pusher, Thrower, Striker and HumanoidStandup are not supported because they are not commonly seen in literatures.
|
||||
|
||||
228
examples/mujoco/fetch_her_ddpg.py
Normal file
228
examples/mujoco/fetch_her_ddpg.py
Normal file
@ -0,0 +1,228 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import argparse
|
||||
import datetime
|
||||
import os
|
||||
import pprint
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
import torch
|
||||
import wandb
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from tianshou.data import (
|
||||
Collector,
|
||||
HERReplayBuffer,
|
||||
HERVectorReplayBuffer,
|
||||
ReplayBuffer,
|
||||
VectorReplayBuffer,
|
||||
)
|
||||
from tianshou.env import ShmemVectorEnv, TruncatedAsTerminated
|
||||
from tianshou.exploration import GaussianNoise
|
||||
from tianshou.policy import DDPGPolicy
|
||||
from tianshou.trainer import offpolicy_trainer
|
||||
from tianshou.utils import TensorboardLogger, WandbLogger
|
||||
from tianshou.utils.net.common import Net, get_dict_state_decorator
|
||||
from tianshou.utils.net.continuous import Actor, Critic
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--task", type=str, default="FetchReach-v3")
|
||||
parser.add_argument("--seed", type=int, default=0)
|
||||
parser.add_argument("--buffer-size", type=int, default=100000)
|
||||
parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[256, 256])
|
||||
parser.add_argument("--actor-lr", type=float, default=1e-3)
|
||||
parser.add_argument("--critic-lr", type=float, default=3e-3)
|
||||
parser.add_argument("--gamma", type=float, default=0.99)
|
||||
parser.add_argument("--tau", type=float, default=0.005)
|
||||
parser.add_argument("--exploration-noise", type=float, default=0.1)
|
||||
parser.add_argument("--start-timesteps", type=int, default=25000)
|
||||
parser.add_argument("--epoch", type=int, default=10)
|
||||
parser.add_argument("--step-per-epoch", type=int, default=5000)
|
||||
parser.add_argument("--step-per-collect", type=int, default=1)
|
||||
parser.add_argument("--update-per-step", type=int, default=1)
|
||||
parser.add_argument("--n-step", type=int, default=1)
|
||||
parser.add_argument("--batch-size", type=int, default=512)
|
||||
parser.add_argument(
|
||||
"--replay-buffer", type=str, default="her", choices=["normal", "her"]
|
||||
)
|
||||
parser.add_argument("--her-horizon", type=int, default=50)
|
||||
parser.add_argument("--her-future-k", type=int, default=8)
|
||||
parser.add_argument("--training-num", type=int, default=1)
|
||||
parser.add_argument("--test-num", type=int, default=10)
|
||||
parser.add_argument("--logdir", type=str, default="log")
|
||||
parser.add_argument("--render", type=float, default=0.)
|
||||
parser.add_argument(
|
||||
"--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu"
|
||||
)
|
||||
parser.add_argument("--resume-path", type=str, default=None)
|
||||
parser.add_argument("--resume-id", type=str, default=None)
|
||||
parser.add_argument(
|
||||
"--logger",
|
||||
type=str,
|
||||
default="tensorboard",
|
||||
choices=["tensorboard", "wandb"],
|
||||
)
|
||||
parser.add_argument("--wandb-project", type=str, default="HER-benchmark")
|
||||
parser.add_argument(
|
||||
"--watch",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="watch the play of pre-trained policy only",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def make_fetch_env(task, training_num, test_num):
|
||||
env = TruncatedAsTerminated(gym.make(task))
|
||||
train_envs = ShmemVectorEnv(
|
||||
[lambda: TruncatedAsTerminated(gym.make(task)) for _ in range(training_num)]
|
||||
)
|
||||
test_envs = ShmemVectorEnv(
|
||||
[lambda: TruncatedAsTerminated(gym.make(task)) for _ in range(test_num)]
|
||||
)
|
||||
return env, train_envs, test_envs
|
||||
|
||||
|
||||
def test_ddpg(args=get_args()):
|
||||
# log
|
||||
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
|
||||
args.algo_name = "ddpg"
|
||||
log_name = os.path.join(args.task, args.algo_name, str(args.seed), now)
|
||||
log_path = os.path.join(args.logdir, log_name)
|
||||
|
||||
# logger
|
||||
if args.logger == "wandb":
|
||||
logger = WandbLogger(
|
||||
save_interval=1,
|
||||
name=log_name.replace(os.path.sep, "__"),
|
||||
run_id=args.resume_id,
|
||||
config=args,
|
||||
project=args.wandb_project,
|
||||
)
|
||||
logger.wandb_run.config.setdefaults(vars(args))
|
||||
args = argparse.Namespace(**wandb.config)
|
||||
writer = SummaryWriter(log_path)
|
||||
writer.add_text("args", str(args))
|
||||
if args.logger == "tensorboard":
|
||||
logger = TensorboardLogger(writer)
|
||||
else: # wandb
|
||||
logger.load(writer)
|
||||
|
||||
env, train_envs, test_envs = make_fetch_env(
|
||||
args.task, args.training_num, args.test_num
|
||||
)
|
||||
args.state_shape = {
|
||||
'observation': env.observation_space['observation'].shape,
|
||||
'achieved_goal': env.observation_space['achieved_goal'].shape,
|
||||
'desired_goal': env.observation_space['desired_goal'].shape,
|
||||
}
|
||||
args.action_shape = env.action_space.shape or env.action_space.n
|
||||
args.max_action = env.action_space.high[0]
|
||||
args.exploration_noise = args.exploration_noise * args.max_action
|
||||
print("Observations shape:", args.state_shape)
|
||||
print("Actions shape:", args.action_shape)
|
||||
print("Action range:", np.min(env.action_space.low), np.max(env.action_space.high))
|
||||
# seed
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
# model
|
||||
dict_state_dec, flat_state_shape = get_dict_state_decorator(
|
||||
state_shape=args.state_shape,
|
||||
keys=['observation', 'achieved_goal', 'desired_goal']
|
||||
)
|
||||
net_a = dict_state_dec(Net)(
|
||||
flat_state_shape, hidden_sizes=args.hidden_sizes, device=args.device
|
||||
)
|
||||
actor = dict_state_dec(Actor)(
|
||||
net_a, args.action_shape, max_action=args.max_action, device=args.device
|
||||
).to(args.device)
|
||||
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
|
||||
net_c = dict_state_dec(Net)(
|
||||
flat_state_shape,
|
||||
action_shape=args.action_shape,
|
||||
hidden_sizes=args.hidden_sizes,
|
||||
concat=True,
|
||||
device=args.device,
|
||||
)
|
||||
critic = dict_state_dec(Critic)(net_c, device=args.device).to(args.device)
|
||||
critic_optim = torch.optim.Adam(critic.parameters(), lr=args.critic_lr)
|
||||
policy = DDPGPolicy(
|
||||
actor,
|
||||
actor_optim,
|
||||
critic,
|
||||
critic_optim,
|
||||
tau=args.tau,
|
||||
gamma=args.gamma,
|
||||
exploration_noise=GaussianNoise(sigma=args.exploration_noise),
|
||||
estimation_step=args.n_step,
|
||||
action_space=env.action_space,
|
||||
)
|
||||
|
||||
# load a previous policy
|
||||
if args.resume_path:
|
||||
policy.load_state_dict(torch.load(args.resume_path, map_location=args.device))
|
||||
print("Loaded agent from: ", args.resume_path)
|
||||
|
||||
# collector
|
||||
def compute_reward_fn(ag: np.ndarray, g: np.ndarray):
|
||||
return env.compute_reward(ag, g, {})
|
||||
|
||||
if args.replay_buffer == "normal":
|
||||
if args.training_num > 1:
|
||||
buffer = VectorReplayBuffer(args.buffer_size, len(train_envs))
|
||||
else:
|
||||
buffer = ReplayBuffer(args.buffer_size)
|
||||
else:
|
||||
if args.training_num > 1:
|
||||
buffer = HERVectorReplayBuffer(
|
||||
args.buffer_size,
|
||||
len(train_envs),
|
||||
compute_reward_fn=compute_reward_fn,
|
||||
horizon=args.her_horizon,
|
||||
future_k=args.her_future_k,
|
||||
)
|
||||
else:
|
||||
buffer = HERReplayBuffer(
|
||||
args.buffer_size,
|
||||
compute_reward_fn=compute_reward_fn,
|
||||
horizon=args.her_horizon,
|
||||
future_k=args.her_future_k,
|
||||
)
|
||||
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
|
||||
test_collector = Collector(policy, test_envs)
|
||||
train_collector.collect(n_step=args.start_timesteps, random=True)
|
||||
|
||||
def save_best_fn(policy):
|
||||
torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth"))
|
||||
|
||||
if not args.watch:
|
||||
# trainer
|
||||
result = offpolicy_trainer(
|
||||
policy,
|
||||
train_collector,
|
||||
test_collector,
|
||||
args.epoch,
|
||||
args.step_per_epoch,
|
||||
args.step_per_collect,
|
||||
args.test_num,
|
||||
args.batch_size,
|
||||
save_best_fn=save_best_fn,
|
||||
logger=logger,
|
||||
update_per_step=args.update_per_step,
|
||||
test_in_train=False,
|
||||
)
|
||||
pprint.pprint(result)
|
||||
|
||||
# Let's watch its performance!
|
||||
policy.eval()
|
||||
test_envs.seed(args.seed)
|
||||
test_collector.reset()
|
||||
result = test_collector.collect(n_episode=args.test_num, render=args.render)
|
||||
print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}')
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_ddpg()
|
||||
@ -166,3 +166,48 @@ class NXEnv(gym.Env):
|
||||
for i in range(self.size):
|
||||
self.graph.nodes[i]["data"] = next_graph_state[i]
|
||||
return self._encode_obs(), 1.0, 0, 0, {}
|
||||
|
||||
|
||||
class MyGoalEnv(MyTestEnv):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
assert kwargs.get("dict_state", 0) + kwargs.get("recurse_state", 0) == 0, \
|
||||
"dict_state / recurse_state not supported"
|
||||
super().__init__(*args, **kwargs)
|
||||
obs, _ = super().reset(state=0)
|
||||
obs, _, _, _, _ = super().step(1)
|
||||
self._goal = obs * self.size
|
||||
super_obsv = self.observation_space
|
||||
self.observation_space = gym.spaces.Dict(
|
||||
{
|
||||
'observation': super_obsv,
|
||||
'achieved_goal': super_obsv,
|
||||
'desired_goal': super_obsv,
|
||||
}
|
||||
)
|
||||
|
||||
def reset(self, *args, **kwargs):
|
||||
obs, info = super().reset(*args, **kwargs)
|
||||
new_obs = {
|
||||
'observation': obs,
|
||||
'achieved_goal': obs,
|
||||
'desired_goal': self._goal
|
||||
}
|
||||
return new_obs, info
|
||||
|
||||
def step(self, *args, **kwargs):
|
||||
obs_next, rew, terminated, truncated, info = super().step(*args, **kwargs)
|
||||
new_obs_next = {
|
||||
'observation': obs_next,
|
||||
'achieved_goal': obs_next,
|
||||
'desired_goal': self._goal
|
||||
}
|
||||
return new_obs_next, rew, terminated, truncated, info
|
||||
|
||||
def compute_reward_fn(
|
||||
self, achieved_goal: np.ndarray, desired_goal: np.ndarray, info: dict
|
||||
) -> np.ndarray:
|
||||
axis = -1
|
||||
if self.array_state:
|
||||
axis = (-3, -2, -1)
|
||||
return (achieved_goal == desired_goal).all(axis=axis)
|
||||
|
||||
@ -11,6 +11,8 @@ import torch
|
||||
from tianshou.data import (
|
||||
Batch,
|
||||
CachedReplayBuffer,
|
||||
HERReplayBuffer,
|
||||
HERVectorReplayBuffer,
|
||||
PrioritizedReplayBuffer,
|
||||
PrioritizedVectorReplayBuffer,
|
||||
ReplayBuffer,
|
||||
@ -20,9 +22,9 @@ from tianshou.data import (
|
||||
from tianshou.data.utils.converter import to_hdf5
|
||||
|
||||
if __name__ == '__main__':
|
||||
from env import MyTestEnv
|
||||
from env import MyGoalEnv, MyTestEnv
|
||||
else: # pytest
|
||||
from test.base.env import MyTestEnv
|
||||
from test.base.env import MyGoalEnv, MyTestEnv
|
||||
|
||||
|
||||
def test_replaybuffer(size=10, bufsize=20):
|
||||
@ -300,6 +302,142 @@ def test_priortized_replaybuffer(size=32, bufsize=15):
|
||||
assert weight[~mask][0] < weight[mask][0] and weight[mask][0] <= 1
|
||||
|
||||
|
||||
def test_herreplaybuffer(size=10, bufsize=100, sample_sz=4):
|
||||
env_size = size
|
||||
env = MyGoalEnv(env_size, array_state=True)
|
||||
|
||||
def compute_reward_fn(ag, g):
|
||||
return env.compute_reward_fn(ag, g, {})
|
||||
|
||||
buf = HERReplayBuffer(
|
||||
bufsize, compute_reward_fn=compute_reward_fn, horizon=30, future_k=8
|
||||
)
|
||||
buf2 = HERVectorReplayBuffer(
|
||||
bufsize,
|
||||
buffer_num=3,
|
||||
compute_reward_fn=compute_reward_fn,
|
||||
horizon=30,
|
||||
future_k=8
|
||||
)
|
||||
# Apply her on every episodes sampled (Hacky but necessary for deterministic test)
|
||||
buf.future_p = 1
|
||||
for buf2_buf in buf2.buffers:
|
||||
buf2_buf.future_p = 1
|
||||
|
||||
obs, _ = env.reset()
|
||||
action_list = [1] * 5 + [0] * 10 + [1] * 10
|
||||
for i, act in enumerate(action_list):
|
||||
obs_next, rew, terminated, truncated, info = env.step(act)
|
||||
batch = Batch(
|
||||
obs=obs,
|
||||
act=[act],
|
||||
rew=rew,
|
||||
terminated=terminated,
|
||||
truncated=truncated,
|
||||
obs_next=obs_next,
|
||||
info=info
|
||||
)
|
||||
buf.add(batch)
|
||||
buf2.add(Batch.stack([batch, batch, batch]), buffer_ids=[0, 1, 2])
|
||||
obs = obs_next
|
||||
assert len(buf) == min(bufsize, i + 1)
|
||||
assert len(buf2) == min(bufsize, 3 * (i + 1))
|
||||
|
||||
batch, indices = buf.sample(sample_sz)
|
||||
|
||||
# Check that goals are the same for the episode (only 1 ep in buffer)
|
||||
tmp_indices = indices.copy()
|
||||
for _ in range(2 * env_size):
|
||||
obs = buf[tmp_indices].obs
|
||||
obs_next = buf[tmp_indices].obs_next
|
||||
rew = buf[tmp_indices].rew
|
||||
g = obs.desired_goal.reshape(sample_sz, -1)[:, 0]
|
||||
ag_next = obs_next.achieved_goal.reshape(sample_sz, -1)[:, 0]
|
||||
g_next = obs_next.desired_goal.reshape(sample_sz, -1)[:, 0]
|
||||
assert np.all(g == g[0])
|
||||
assert np.all(g_next == g_next[0])
|
||||
assert np.all(rew == (ag_next == g).astype(np.float32))
|
||||
tmp_indices = buf.next(tmp_indices)
|
||||
|
||||
# Check that goals are correctly restored
|
||||
buf._restore_cache()
|
||||
tmp_indices = indices.copy()
|
||||
for _ in range(2 * env_size):
|
||||
obs = buf[tmp_indices].obs
|
||||
obs_next = buf[tmp_indices].obs_next
|
||||
g = obs.desired_goal.reshape(sample_sz, -1)[:, 0]
|
||||
g_next = obs_next.desired_goal.reshape(sample_sz, -1)[:, 0]
|
||||
assert np.all(g == env_size)
|
||||
assert np.all(g_next == g_next[0])
|
||||
assert np.all(g == g[0])
|
||||
tmp_indices = buf.next(tmp_indices)
|
||||
|
||||
# Test vector buffer
|
||||
batch, indices = buf2.sample(sample_sz)
|
||||
|
||||
# Check that goals are the same for the episode (only 1 ep in buffer)
|
||||
tmp_indices = indices.copy()
|
||||
for _ in range(2 * env_size):
|
||||
obs = buf2[tmp_indices].obs
|
||||
obs_next = buf2[tmp_indices].obs_next
|
||||
rew = buf2[tmp_indices].rew
|
||||
g = obs.desired_goal.reshape(sample_sz, -1)[:, 0]
|
||||
ag_next = obs_next.achieved_goal.reshape(sample_sz, -1)[:, 0]
|
||||
g_next = obs_next.desired_goal.reshape(sample_sz, -1)[:, 0]
|
||||
assert np.all(g == g_next)
|
||||
assert np.all(rew == (ag_next == g).astype(np.float32))
|
||||
tmp_indices = buf2.next(tmp_indices)
|
||||
|
||||
# Check that goals are correctly restored
|
||||
buf2._restore_cache()
|
||||
tmp_indices = indices.copy()
|
||||
for _ in range(2 * env_size):
|
||||
obs = buf2[tmp_indices].obs
|
||||
obs_next = buf2[tmp_indices].obs_next
|
||||
g = obs.desired_goal.reshape(sample_sz, -1)[:, 0]
|
||||
g_next = obs_next.desired_goal.reshape(sample_sz, -1)[:, 0]
|
||||
assert np.all(g == env_size)
|
||||
assert np.all(g_next == g_next[0])
|
||||
assert np.all(g == g[0])
|
||||
tmp_indices = buf2.next(tmp_indices)
|
||||
|
||||
# Test handling cycled indices
|
||||
env_size = size
|
||||
bufsize = 15
|
||||
env = MyGoalEnv(env_size, array_state=False)
|
||||
|
||||
def compute_reward_fn(ag, g):
|
||||
return env.compute_reward_fn(ag, g, {})
|
||||
|
||||
buf = HERReplayBuffer(
|
||||
bufsize, compute_reward_fn=compute_reward_fn, horizon=30, future_k=8
|
||||
)
|
||||
buf._index = 5 # shifted start index
|
||||
buf.future_p = 1
|
||||
action_list = [1] * 10
|
||||
for ep_len in [5, 10]:
|
||||
obs, _ = env.reset()
|
||||
for i in range(ep_len):
|
||||
act = 1
|
||||
obs_next, rew, terminated, truncated, info = env.step(act)
|
||||
batch = Batch(
|
||||
obs=obs,
|
||||
act=[act],
|
||||
rew=rew,
|
||||
terminated=(i == ep_len - 1),
|
||||
truncated=(i == ep_len - 1),
|
||||
obs_next=obs_next,
|
||||
info=info
|
||||
)
|
||||
buf.add(batch)
|
||||
obs = obs_next
|
||||
batch, indices = buf.sample(0)
|
||||
assert np.all(buf[:5].obs.desired_goal == buf[0].obs.desired_goal)
|
||||
assert np.all(buf[5:10].obs.desired_goal == buf[5].obs.desired_goal)
|
||||
assert np.all(buf[10:].obs.desired_goal == buf[0].obs.desired_goal) # (same ep)
|
||||
assert np.all(buf[0].obs.desired_goal != buf[5].obs.desired_goal) # (diff ep)
|
||||
|
||||
|
||||
def test_update():
|
||||
buf1 = ReplayBuffer(4, stack_num=2)
|
||||
buf2 = ReplayBuffer(4, stack_num=2)
|
||||
@ -1180,3 +1318,4 @@ if __name__ == '__main__':
|
||||
test_multibuf_stack()
|
||||
test_multibuf_hdf5()
|
||||
test_from_data()
|
||||
test_herreplaybuffer()
|
||||
|
||||
@ -16,6 +16,7 @@ from tianshou.env import (
|
||||
SubprocVectorEnv,
|
||||
VectorEnvNormObs,
|
||||
)
|
||||
from tianshou.env.gym_wrappers import TruncatedAsTerminated
|
||||
from tianshou.utils import RunningMeanStd
|
||||
|
||||
if __name__ == "__main__":
|
||||
@ -347,6 +348,10 @@ def test_gym_wrappers():
|
||||
self.action_space = gym.spaces.Box(
|
||||
low=-1.0, high=2.0, shape=(4, ), dtype=np.float32
|
||||
)
|
||||
self.observation_space = gym.spaces.Discrete(2)
|
||||
|
||||
def step(self, act):
|
||||
return self.observation_space.sample(), -1, False, True, {}
|
||||
|
||||
bsz = 10
|
||||
action_per_branch = [4, 6, 10, 7]
|
||||
@ -374,6 +379,14 @@ def test_gym_wrappers():
|
||||
env_d.action(np.array([env_d.action_space.n - 1] * bsz)),
|
||||
np.array([env_m.action_space.nvec - 1] * bsz),
|
||||
)
|
||||
# check truncate is True when terminated
|
||||
try:
|
||||
env_t = TruncatedAsTerminated(env)
|
||||
except EnvironmentError:
|
||||
env_t = None
|
||||
if env_t is not None:
|
||||
_, _, truncated, _, _ = env_t.step(env_t.action_space.sample())
|
||||
assert truncated
|
||||
|
||||
|
||||
@pytest.mark.skipif(envpool is None, reason="EnvPool doesn't support this platform")
|
||||
|
||||
@ -6,13 +6,16 @@ from tianshou.data.utils.converter import to_numpy, to_torch, to_torch_as
|
||||
from tianshou.data.utils.segtree import SegmentTree
|
||||
from tianshou.data.buffer.base import ReplayBuffer
|
||||
from tianshou.data.buffer.prio import PrioritizedReplayBuffer
|
||||
from tianshou.data.buffer.her import HERReplayBuffer
|
||||
from tianshou.data.buffer.manager import (
|
||||
ReplayBufferManager,
|
||||
PrioritizedReplayBufferManager,
|
||||
HERReplayBufferManager,
|
||||
)
|
||||
from tianshou.data.buffer.vecbuf import (
|
||||
VectorReplayBuffer,
|
||||
HERVectorReplayBuffer,
|
||||
PrioritizedVectorReplayBuffer,
|
||||
VectorReplayBuffer,
|
||||
)
|
||||
from tianshou.data.buffer.cached import CachedReplayBuffer
|
||||
from tianshou.data.collector import Collector, AsyncCollector
|
||||
@ -25,10 +28,13 @@ __all__ = [
|
||||
"SegmentTree",
|
||||
"ReplayBuffer",
|
||||
"PrioritizedReplayBuffer",
|
||||
"HERReplayBuffer",
|
||||
"ReplayBufferManager",
|
||||
"PrioritizedReplayBufferManager",
|
||||
"HERReplayBufferManager",
|
||||
"VectorReplayBuffer",
|
||||
"PrioritizedVectorReplayBuffer",
|
||||
"HERVectorReplayBuffer",
|
||||
"CachedReplayBuffer",
|
||||
"Collector",
|
||||
"AsyncCollector",
|
||||
|
||||
186
tianshou/data/buffer/her.py
Normal file
186
tianshou/data/buffer/her.py
Normal file
@ -0,0 +1,186 @@
|
||||
from typing import Any, Callable, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tianshou.data import Batch, ReplayBuffer
|
||||
|
||||
|
||||
class HERReplayBuffer(ReplayBuffer):
|
||||
"""Implementation of Hindsight Experience Replay. arXiv:1707.01495.
|
||||
|
||||
HERReplayBuffer is to be used with goal-based environment where the
|
||||
observation is a dictionary with keys ``observation``, ``achieved_goal`` and
|
||||
``desired_goal``. Currently support only HER's future strategy, online sampling.
|
||||
|
||||
:param int size: the size of the replay buffer.
|
||||
:param compute_reward_fn: a function that takes 2 ``np.array`` arguments,
|
||||
``acheived_goal`` and ``desired_goal``, and returns rewards as ``np.array``.
|
||||
The two arguments are of shape (batch_size, ...original_shape) and the returned
|
||||
rewards must be of shape (batch_size,).
|
||||
:param int horizon: the maximum number of steps in an episode.
|
||||
:param int future_k: the 'k' parameter introduced in the paper. In short, there
|
||||
will be at most k episodes that are re-written for every 1 unaltered episode
|
||||
during the sampling.
|
||||
|
||||
.. seealso::
|
||||
|
||||
Please refer to :class:`~tianshou.data.ReplayBuffer` for other APIs' usage.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
size: int,
|
||||
compute_reward_fn: Callable[[np.ndarray, np.ndarray], np.ndarray],
|
||||
horizon: int,
|
||||
future_k: float = 8.0,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(size, **kwargs)
|
||||
self.horizon = horizon
|
||||
self.future_p = 1 - 1 / future_k
|
||||
self.compute_reward_fn = compute_reward_fn
|
||||
self._original_meta = Batch()
|
||||
self._altered_indices = np.array([])
|
||||
|
||||
def _restore_cache(self) -> None:
|
||||
"""Write cached original meta back to `self._meta`.
|
||||
|
||||
It's called everytime before 'writing', 'sampling' or 'saving' the buffer.
|
||||
"""
|
||||
if not hasattr(self, '_altered_indices'):
|
||||
return
|
||||
|
||||
if self._altered_indices.size == 0:
|
||||
return
|
||||
self._meta[self._altered_indices] = self._original_meta
|
||||
# Clean
|
||||
self._original_meta = Batch()
|
||||
self._altered_indices = np.array([])
|
||||
|
||||
def reset(self, keep_statistics: bool = False) -> None:
|
||||
self._restore_cache()
|
||||
return super().reset(keep_statistics)
|
||||
|
||||
def save_hdf5(self, path: str, compression: Optional[str] = None) -> None:
|
||||
self._restore_cache()
|
||||
return super().save_hdf5(path, compression)
|
||||
|
||||
def set_batch(self, batch: Batch) -> None:
|
||||
self._restore_cache()
|
||||
return super().set_batch(batch)
|
||||
|
||||
def update(self, buffer: Union["HERReplayBuffer", "ReplayBuffer"]) -> np.ndarray:
|
||||
self._restore_cache()
|
||||
return super().update(buffer)
|
||||
|
||||
def add(
|
||||
self,
|
||||
batch: Batch,
|
||||
buffer_ids: Optional[Union[np.ndarray, List[int]]] = None
|
||||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
||||
self._restore_cache()
|
||||
return super().add(batch, buffer_ids)
|
||||
|
||||
def sample_indices(self, batch_size: int) -> np.ndarray:
|
||||
"""Get a random sample of index with size = batch_size.
|
||||
|
||||
Return all available indices in the buffer if batch_size is 0; return an \
|
||||
empty numpy array if batch_size < 0 or no available index can be sampled. \
|
||||
Additionally, some episodes of the sampled transitions will be re-written \
|
||||
according to HER.
|
||||
"""
|
||||
self._restore_cache()
|
||||
indices = super().sample_indices(batch_size=batch_size)
|
||||
self.rewrite_transitions(indices.copy())
|
||||
return indices
|
||||
|
||||
def rewrite_transitions(self, indices: np.ndarray) -> None:
|
||||
"""Re-write the goal of some sampled transitions' episodes according to HER.
|
||||
|
||||
Currently applies only HER's 'future' strategy. The new goals will be written \
|
||||
directly to the internal batch data temporarily and will be restored right \
|
||||
before the next sampling or when using some of the buffer's method (e.g. \
|
||||
`add`, `save_hdf5`, etc.). This is to make sure that n-step returns \
|
||||
calculation etc., performs correctly without additional alteration.
|
||||
"""
|
||||
if indices.size == 0:
|
||||
return
|
||||
|
||||
# Sort indices keeping chronological order
|
||||
indices[indices < self._index] += self.maxsize
|
||||
indices = np.sort(indices)
|
||||
indices[indices >= self.maxsize] -= self.maxsize
|
||||
|
||||
# Construct episode trajectories
|
||||
indices = [indices]
|
||||
for _ in range(self.horizon - 1):
|
||||
indices.append(self.next(indices[-1]))
|
||||
indices = np.stack(indices)
|
||||
|
||||
# Calculate future timestep to use
|
||||
current = indices[0]
|
||||
terminal = indices[-1]
|
||||
future_offset = np.random.uniform(size=len(indices[0])) * (terminal - current)
|
||||
future_offset = future_offset.astype(int)
|
||||
future_t = (current + future_offset)
|
||||
|
||||
# Compute indices
|
||||
# open indices are used to find longest, unique trajectories among
|
||||
# presented episodes
|
||||
unique_ep_open_indices = np.sort(np.unique(terminal, return_index=True)[1])
|
||||
unique_ep_indices = indices[:, unique_ep_open_indices]
|
||||
# close indices are used to find max future_t among presented episodes
|
||||
unique_ep_close_indices = np.hstack(
|
||||
[(unique_ep_open_indices - 1)[1:],
|
||||
len(terminal) - 1]
|
||||
)
|
||||
# episode indices that will be altered
|
||||
her_ep_indices = np.random.choice(
|
||||
len(unique_ep_open_indices),
|
||||
size=int(len(unique_ep_open_indices) * self.future_p),
|
||||
replace=False
|
||||
)
|
||||
|
||||
# Cache original meta
|
||||
self._altered_indices = unique_ep_indices.copy()
|
||||
self._original_meta = self._meta[self._altered_indices].copy()
|
||||
|
||||
# Copy original obs, ep_rew (and obs_next), and obs of future time step
|
||||
ep_obs = self[unique_ep_indices].obs
|
||||
ep_rew = self[unique_ep_indices].rew
|
||||
if self._save_obs_next:
|
||||
ep_obs_next = self[unique_ep_indices].obs_next
|
||||
future_obs = self[future_t[unique_ep_close_indices]].obs_next
|
||||
else:
|
||||
future_obs = self[self.next(future_t[unique_ep_close_indices])].obs
|
||||
|
||||
# Re-assign goals and rewards via broadcast assignment
|
||||
ep_obs.desired_goal[:, her_ep_indices] = \
|
||||
future_obs.achieved_goal[None, her_ep_indices]
|
||||
if self._save_obs_next:
|
||||
ep_obs_next.desired_goal[:, her_ep_indices] = \
|
||||
future_obs.achieved_goal[None, her_ep_indices]
|
||||
ep_rew[:, her_ep_indices] = \
|
||||
self._compute_reward(ep_obs_next)[:, her_ep_indices]
|
||||
else:
|
||||
tmp_ep_obs_next = self[self.next(unique_ep_indices)].obs
|
||||
ep_rew[:, her_ep_indices] = \
|
||||
self._compute_reward(tmp_ep_obs_next)[:, her_ep_indices]
|
||||
|
||||
# Sanity check
|
||||
assert ep_obs.desired_goal.shape[:2] == unique_ep_indices.shape
|
||||
assert ep_obs.achieved_goal.shape[:2] == unique_ep_indices.shape
|
||||
assert ep_rew.shape == unique_ep_indices.shape
|
||||
|
||||
# Re-write meta
|
||||
self._meta.obs[unique_ep_indices] = ep_obs
|
||||
if self._save_obs_next:
|
||||
self._meta.obs_next[unique_ep_indices] = ep_obs_next
|
||||
self._meta.rew[unique_ep_indices] = ep_rew.astype(np.float32)
|
||||
|
||||
def _compute_reward(self, obs: Batch, lead_dims: int = 2) -> np.ndarray:
|
||||
lead_shape = obs.observation.shape[:lead_dims]
|
||||
g = obs.desired_goal.reshape(-1, *obs.desired_goal.shape[lead_dims:])
|
||||
ag = obs.achieved_goal.reshape(-1, *obs.achieved_goal.shape[lead_dims:])
|
||||
rewards = self.compute_reward_fn(ag, g)
|
||||
return rewards.reshape(*lead_shape, *rewards.shape[1:])
|
||||
@ -3,7 +3,7 @@ from typing import List, Optional, Sequence, Tuple, Union
|
||||
import numpy as np
|
||||
from numba import njit
|
||||
|
||||
from tianshou.data import Batch, PrioritizedReplayBuffer, ReplayBuffer
|
||||
from tianshou.data import Batch, HERReplayBuffer, PrioritizedReplayBuffer, ReplayBuffer
|
||||
from tianshou.data.batch import _alloc_by_keys_diff, _create_value
|
||||
|
||||
|
||||
@ -21,7 +21,9 @@ class ReplayBufferManager(ReplayBuffer):
|
||||
Please refer to :class:`~tianshou.data.ReplayBuffer` for other APIs' usage.
|
||||
"""
|
||||
|
||||
def __init__(self, buffer_list: List[ReplayBuffer]) -> None:
|
||||
def __init__(
|
||||
self, buffer_list: Union[List[ReplayBuffer], List[HERReplayBuffer]]
|
||||
) -> None:
|
||||
self.buffer_num = len(buffer_list)
|
||||
self.buffers = np.array(buffer_list, dtype=object)
|
||||
offset, size = [], 0
|
||||
@ -212,6 +214,48 @@ class PrioritizedReplayBufferManager(PrioritizedReplayBuffer, ReplayBufferManage
|
||||
PrioritizedReplayBuffer.__init__(self, self.maxsize, **kwargs)
|
||||
|
||||
|
||||
class HERReplayBufferManager(ReplayBufferManager):
|
||||
"""HERReplayBufferManager contains a list of HERReplayBuffer with \
|
||||
exactly the same configuration.
|
||||
|
||||
These replay buffers have contiguous memory layout, and the storage space each
|
||||
buffer has is a shallow copy of the topmost memory.
|
||||
|
||||
:param buffer_list: a list of HERReplayBuffer needed to be handled.
|
||||
|
||||
.. seealso::
|
||||
|
||||
Please refer to :class:`~tianshou.data.ReplayBuffer` for other APIs' usage.
|
||||
"""
|
||||
|
||||
def __init__(self, buffer_list: List[HERReplayBuffer]) -> None:
|
||||
super().__init__(buffer_list)
|
||||
|
||||
def _restore_cache(self) -> None:
|
||||
for buf in self.buffers:
|
||||
buf._restore_cache()
|
||||
|
||||
def save_hdf5(self, path: str, compression: Optional[str] = None) -> None:
|
||||
self._restore_cache()
|
||||
return super().save_hdf5(path, compression)
|
||||
|
||||
def set_batch(self, batch: Batch) -> None:
|
||||
self._restore_cache()
|
||||
return super().set_batch(batch)
|
||||
|
||||
def update(self, buffer: Union["HERReplayBuffer", "ReplayBuffer"]) -> np.ndarray:
|
||||
self._restore_cache()
|
||||
return super().update(buffer)
|
||||
|
||||
def add(
|
||||
self,
|
||||
batch: Batch,
|
||||
buffer_ids: Optional[Union[np.ndarray, List[int]]] = None
|
||||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
||||
self._restore_cache()
|
||||
return super().add(batch, buffer_ids)
|
||||
|
||||
|
||||
@njit
|
||||
def _prev_index(
|
||||
index: np.ndarray,
|
||||
|
||||
@ -3,6 +3,8 @@ from typing import Any
|
||||
import numpy as np
|
||||
|
||||
from tianshou.data import (
|
||||
HERReplayBuffer,
|
||||
HERReplayBufferManager,
|
||||
PrioritizedReplayBuffer,
|
||||
PrioritizedReplayBufferManager,
|
||||
ReplayBuffer,
|
||||
@ -64,3 +66,26 @@ class PrioritizedVectorReplayBuffer(PrioritizedReplayBufferManager):
|
||||
def set_beta(self, beta: float) -> None:
|
||||
for buffer in self.buffers:
|
||||
buffer.set_beta(beta)
|
||||
|
||||
|
||||
class HERVectorReplayBuffer(HERReplayBufferManager):
|
||||
"""HERVectorReplayBuffer contains n HERReplayBuffer with same size.
|
||||
|
||||
It is used for storing transition from different environments yet keeping the order
|
||||
of time.
|
||||
|
||||
:param int total_size: the total size of HERVectorReplayBuffer.
|
||||
:param int buffer_num: the number of HERReplayBuffer it uses, which are
|
||||
under the same configuration.
|
||||
|
||||
Other input arguments are the same as :class:`~tianshou.data.HERReplayBuffer`.
|
||||
|
||||
.. seealso::
|
||||
Please refer to :class:`~tianshou.data.ReplayBuffer` for other APIs' usage.
|
||||
"""
|
||||
|
||||
def __init__(self, total_size: int, buffer_num: int, **kwargs: Any) -> None:
|
||||
assert buffer_num > 0
|
||||
size = int(np.ceil(total_size / buffer_num))
|
||||
buffer_list = [HERReplayBuffer(size, **kwargs) for _ in range(buffer_num)]
|
||||
super().__init__(buffer_list)
|
||||
|
||||
7
tianshou/env/__init__.py
vendored
7
tianshou/env/__init__.py
vendored
@ -1,6 +1,10 @@
|
||||
"""Env package."""
|
||||
|
||||
from tianshou.env.gym_wrappers import ContinuousToDiscrete, MultiDiscreteToDiscrete
|
||||
from tianshou.env.gym_wrappers import (
|
||||
ContinuousToDiscrete,
|
||||
MultiDiscreteToDiscrete,
|
||||
TruncatedAsTerminated,
|
||||
)
|
||||
from tianshou.env.venv_wrappers import VectorEnvNormObs, VectorEnvWrapper
|
||||
from tianshou.env.venvs import (
|
||||
BaseVectorEnv,
|
||||
@ -26,4 +30,5 @@ __all__ = [
|
||||
"PettingZooEnv",
|
||||
"ContinuousToDiscrete",
|
||||
"MultiDiscreteToDiscrete",
|
||||
"TruncatedAsTerminated",
|
||||
]
|
||||
|
||||
25
tianshou/env/gym_wrappers.py
vendored
25
tianshou/env/gym_wrappers.py
vendored
@ -1,7 +1,8 @@
|
||||
from typing import List, Union
|
||||
from typing import Any, Dict, List, Tuple, Union
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
from packaging import version
|
||||
|
||||
|
||||
class ContinuousToDiscrete(gym.ActionWrapper):
|
||||
@ -55,3 +56,25 @@ class MultiDiscreteToDiscrete(gym.ActionWrapper):
|
||||
converted_act.append(act // b)
|
||||
act = act % b
|
||||
return np.array(converted_act).transpose()
|
||||
|
||||
|
||||
class TruncatedAsTerminated(gym.Wrapper):
|
||||
"""A wrapper that set ``terminated = terminated or truncated`` for ``step()``.
|
||||
|
||||
It's intended to use with ``gym.wrappers.TimeLimit``.
|
||||
|
||||
:param gym.Env env: gym environment.
|
||||
"""
|
||||
|
||||
def __init__(self, env: gym.Env):
|
||||
super().__init__(env)
|
||||
if not version.parse(gym.__version__) >= version.parse('0.26.0'):
|
||||
raise EnvironmentError(
|
||||
f"TruncatedAsTerminated is not applicable with gym version \
|
||||
{gym.__version__}"
|
||||
)
|
||||
|
||||
def step(self, act: np.ndarray) -> Tuple[Any, float, bool, bool, Dict[Any, Any]]:
|
||||
observation, reward, terminated, truncated, info = super().step(act)
|
||||
terminated = (terminated or truncated)
|
||||
return observation, reward, terminated, truncated, info
|
||||
|
||||
@ -89,6 +89,7 @@ class BaseLogger(ABC):
|
||||
self.write("update/gradient_step", step, log_data)
|
||||
self.last_log_update_step = step
|
||||
|
||||
@abstractmethod
|
||||
def save_data(
|
||||
self,
|
||||
epoch: int,
|
||||
@ -106,6 +107,7 @@ class BaseLogger(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def restore_data(self) -> Tuple[int, int, int]:
|
||||
"""Return the metadata from existing log.
|
||||
|
||||
@ -126,3 +128,15 @@ class LazyLogger(BaseLogger):
|
||||
def write(self, step_type: str, step: int, data: LOG_DATA_TYPE) -> None:
|
||||
"""The LazyLogger writes nothing."""
|
||||
pass
|
||||
|
||||
def save_data(
|
||||
self,
|
||||
epoch: int,
|
||||
env_step: int,
|
||||
gradient_step: int,
|
||||
save_checkpoint_fn: Optional[Callable[[int, int, int], str]] = None,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
def restore_data(self) -> Tuple[int, int, int]:
|
||||
pass
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
@ -14,6 +15,8 @@ import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from tianshou.data.batch import Batch
|
||||
|
||||
ModuleType = Type[nn.Module]
|
||||
|
||||
|
||||
@ -262,7 +265,7 @@ class Recurrent(nn.Module):
|
||||
"""
|
||||
obs = torch.as_tensor(
|
||||
obs,
|
||||
device=self.device, # type: ignore
|
||||
device=self.device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
# obs [bsz, len, dim] (training) or [bsz, dim] (evaluation)
|
||||
@ -453,3 +456,61 @@ class BranchingNet(nn.Module):
|
||||
action_scores = action_scores - torch.mean(action_scores, 2, keepdim=True)
|
||||
logits = value_out + action_scores
|
||||
return logits, state
|
||||
|
||||
|
||||
def get_dict_state_decorator(
|
||||
state_shape: Dict[str, Union[int, Sequence[int]]], keys: Sequence[str]
|
||||
) -> Tuple[Callable, int]:
|
||||
"""A helper function to make Net or equivalent classes (e.g. Actor, Critic) \
|
||||
applicable to dict state.
|
||||
|
||||
The first return item, ``decorator_fn``, will alter the implementation of forward
|
||||
function of the given class by preprocessing the observation. The preprocessing is
|
||||
basically flatten the observation and concatenate them based on the ``keys`` order.
|
||||
The batch dimension is preserved if presented. The result observation shape will
|
||||
be equal to ``new_state_shape``, the second return item.
|
||||
|
||||
:param state_shape: A dictionary indicating each state's shape
|
||||
:param keys: A list of state's keys. The flatten observation will be according to \
|
||||
this list order.
|
||||
:returns: a 2-items tuple ``decorator_fn`` and ``new_state_shape``
|
||||
"""
|
||||
original_shape = state_shape
|
||||
flat_state_shapes = []
|
||||
for k in keys:
|
||||
flat_state_shapes.append(int(np.prod(state_shape[k])))
|
||||
new_state_shape = sum(flat_state_shapes)
|
||||
|
||||
def preprocess_obs(
|
||||
obs: Union[Batch, dict, torch.Tensor, np.ndarray]
|
||||
) -> torch.Tensor:
|
||||
if isinstance(obs, dict) or (isinstance(obs, Batch) and keys[0] in obs):
|
||||
if original_shape[keys[0]] == obs[keys[0]].shape:
|
||||
# No batch dim
|
||||
new_obs = torch.Tensor([obs[k] for k in keys]).flatten()
|
||||
# new_obs = torch.Tensor([obs[k] for k in keys]).reshape(1, -1)
|
||||
else:
|
||||
bsz = obs[keys[0]].shape[0]
|
||||
new_obs = torch.cat(
|
||||
[torch.Tensor(obs[k].reshape(bsz, -1)) for k in keys], dim=1
|
||||
)
|
||||
else:
|
||||
new_obs = torch.Tensor(obs)
|
||||
return new_obs
|
||||
|
||||
@no_type_check
|
||||
def decorator_fn(net_class):
|
||||
|
||||
class new_net_class(net_class):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
obs: Union[np.ndarray, torch.Tensor],
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> Any:
|
||||
return super().forward(preprocess_obs(obs), *args, **kwargs)
|
||||
|
||||
return new_net_class
|
||||
|
||||
return decorator_fn, new_state_shape
|
||||
|
||||
@ -124,13 +124,13 @@ class Critic(nn.Module):
|
||||
"""Mapping: (s, a) -> logits -> Q(s, a)."""
|
||||
obs = torch.as_tensor(
|
||||
obs,
|
||||
device=self.device, # type: ignore
|
||||
device=self.device,
|
||||
dtype=torch.float32,
|
||||
).flatten(1)
|
||||
if act is not None:
|
||||
act = torch.as_tensor(
|
||||
act,
|
||||
device=self.device, # type: ignore
|
||||
device=self.device,
|
||||
dtype=torch.float32,
|
||||
).flatten(1)
|
||||
obs = torch.cat([obs, act], dim=1)
|
||||
@ -266,7 +266,7 @@ class RecurrentActorProb(nn.Module):
|
||||
"""Almost the same as :class:`~tianshou.utils.net.common.Recurrent`."""
|
||||
obs = torch.as_tensor(
|
||||
obs,
|
||||
device=self.device, # type: ignore
|
||||
device=self.device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
# obs [bsz, len, dim] (training) or [bsz, dim] (evaluation)
|
||||
@ -339,7 +339,7 @@ class RecurrentCritic(nn.Module):
|
||||
"""Almost the same as :class:`~tianshou.utils.net.common.Recurrent`."""
|
||||
obs = torch.as_tensor(
|
||||
obs,
|
||||
device=self.device, # type: ignore
|
||||
device=self.device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
# obs [bsz, len, dim] (training) or [bsz, dim] (evaluation)
|
||||
@ -352,7 +352,7 @@ class RecurrentCritic(nn.Module):
|
||||
if act is not None:
|
||||
act = torch.as_tensor(
|
||||
act,
|
||||
device=self.device, # type: ignore
|
||||
device=self.device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
obs = torch.cat([obs, act], dim=1)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user