implement REDQ based on original contribution by @Jimenius (#623)
Co-authored-by: Minhui Li <limh@lamda.nju.edu.cn>
This commit is contained in:
parent
41afc2584a
commit
dd16818ce4
@ -34,6 +34,7 @@
|
||||
- [Deep Deterministic Policy Gradient (DDPG)](https://arxiv.org/pdf/1509.02971.pdf)
|
||||
- [Twin Delayed DDPG (TD3)](https://arxiv.org/pdf/1802.09477.pdf)
|
||||
- [Soft Actor-Critic (SAC)](https://arxiv.org/pdf/1812.05905.pdf)
|
||||
- [Randomized Ensembled Double Q-Learning (REDQ)](https://arxiv.org/pdf/2101.05982.pdf)
|
||||
- [Discrete Soft Actor-Critic (SAC-Discrete)](https://arxiv.org/pdf/1910.07207.pdf)
|
||||
- Vanilla Imitation Learning
|
||||
- [Batch-Constrained deep Q-Learning (BCQ)](https://arxiv.org/pdf/1812.02900.pdf)
|
||||
|
@ -96,6 +96,11 @@ Off-policy
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
.. autoclass:: tianshou.policy.REDQPolicy
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
.. autoclass:: tianshou.policy.DiscreteSACPolicy
|
||||
:members:
|
||||
:undoc-members:
|
||||
|
@ -25,6 +25,7 @@ Welcome to Tianshou!
|
||||
* :class:`~tianshou.policy.DDPGPolicy` `Deep Deterministic Policy Gradient <https://arxiv.org/pdf/1509.02971.pdf>`_
|
||||
* :class:`~tianshou.policy.TD3Policy` `Twin Delayed DDPG <https://arxiv.org/pdf/1802.09477.pdf>`_
|
||||
* :class:`~tianshou.policy.SACPolicy` `Soft Actor-Critic <https://arxiv.org/pdf/1812.05905.pdf>`_
|
||||
* :class:`~tianshou.policy.REDQPolicy` `Randomized Ensembled Double Q-Learning <https://arxiv.org/pdf/2101.05982.pdf>`_
|
||||
* :class:`~tianshou.policy.DiscreteSACPolicy` `Discrete Soft Actor-Critic <https://arxiv.org/pdf/1910.07207.pdf>`_
|
||||
* :class:`~tianshou.policy.ImitationPolicy` Imitation Learning
|
||||
* :class:`~tianshou.policy.BCQPolicy` `Batch-Constrained deep Q-Learning <https://arxiv.org/pdf/1812.02900.pdf>`_
|
||||
|
@ -157,3 +157,4 @@ Nvidia
|
||||
Enduro
|
||||
Qbert
|
||||
Seaquest
|
||||
subnets
|
||||
|
192
examples/mujoco/mujoco_redq.py
Executable file
192
examples/mujoco/mujoco_redq.py
Executable file
@ -0,0 +1,192 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import argparse
|
||||
import datetime
|
||||
import os
|
||||
import pprint
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer
|
||||
from tianshou.env import SubprocVectorEnv
|
||||
from tianshou.policy import REDQPolicy
|
||||
from tianshou.trainer import offpolicy_trainer
|
||||
from tianshou.utils import TensorboardLogger
|
||||
from tianshou.utils.net.common import EnsembleLinear, Net
|
||||
from tianshou.utils.net.continuous import ActorProb, Critic
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--task', type=str, default='Ant-v3')
|
||||
parser.add_argument('--seed', type=int, default=0)
|
||||
parser.add_argument('--buffer-size', type=int, default=1000000)
|
||||
parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[256, 256])
|
||||
parser.add_argument('--ensemble-size', type=int, default=10)
|
||||
parser.add_argument('--subset-size', type=int, default=2)
|
||||
parser.add_argument('--actor-lr', type=float, default=1e-3)
|
||||
parser.add_argument('--critic-lr', type=float, default=1e-3)
|
||||
parser.add_argument('--gamma', type=float, default=0.99)
|
||||
parser.add_argument('--tau', type=float, default=0.005)
|
||||
parser.add_argument('--alpha', type=float, default=0.2)
|
||||
parser.add_argument('--auto-alpha', default=False, action='store_true')
|
||||
parser.add_argument('--alpha-lr', type=float, default=3e-4)
|
||||
parser.add_argument("--start-timesteps", type=int, default=10000)
|
||||
parser.add_argument('--epoch', type=int, default=200)
|
||||
parser.add_argument('--step-per-epoch', type=int, default=5000)
|
||||
parser.add_argument('--step-per-collect', type=int, default=1)
|
||||
parser.add_argument('--update-per-step', type=int, default=20)
|
||||
parser.add_argument('--n-step', type=int, default=1)
|
||||
parser.add_argument('--batch-size', type=int, default=256)
|
||||
parser.add_argument(
|
||||
'--target-mode', type=str, choices=('min', 'mean'), default='min'
|
||||
)
|
||||
parser.add_argument('--training-num', type=int, default=1)
|
||||
parser.add_argument('--test-num', type=int, default=10)
|
||||
parser.add_argument('--logdir', type=str, default='log')
|
||||
parser.add_argument('--render', type=float, default=0.)
|
||||
parser.add_argument(
|
||||
'--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu'
|
||||
)
|
||||
parser.add_argument('--resume-path', type=str, default=None)
|
||||
parser.add_argument(
|
||||
'--watch',
|
||||
default=False,
|
||||
action='store_true',
|
||||
help='watch the play of pre-trained policy only'
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def test_redq(args=get_args()):
|
||||
env = gym.make(args.task)
|
||||
args.state_shape = env.observation_space.shape or env.observation_space.n
|
||||
args.action_shape = env.action_space.shape or env.action_space.n
|
||||
args.max_action = env.action_space.high[0]
|
||||
print("Observations shape:", args.state_shape)
|
||||
print("Actions shape:", args.action_shape)
|
||||
print("Action range:", np.min(env.action_space.low), np.max(env.action_space.high))
|
||||
# train_envs = gym.make(args.task)
|
||||
if args.training_num > 1:
|
||||
train_envs = SubprocVectorEnv(
|
||||
[lambda: gym.make(args.task) for _ in range(args.training_num)]
|
||||
)
|
||||
else:
|
||||
train_envs = gym.make(args.task)
|
||||
# test_envs = gym.make(args.task)
|
||||
test_envs = SubprocVectorEnv(
|
||||
[lambda: gym.make(args.task) for _ in range(args.test_num)]
|
||||
)
|
||||
# seed
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
train_envs.seed(args.seed)
|
||||
test_envs.seed(args.seed)
|
||||
# model
|
||||
net_a = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device)
|
||||
actor = ActorProb(
|
||||
net_a,
|
||||
args.action_shape,
|
||||
max_action=args.max_action,
|
||||
device=args.device,
|
||||
unbounded=True,
|
||||
conditioned_sigma=True
|
||||
).to(args.device)
|
||||
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
|
||||
|
||||
def linear(x, y):
|
||||
return EnsembleLinear(args.ensemble_size, x, y)
|
||||
|
||||
net_c = Net(
|
||||
args.state_shape,
|
||||
args.action_shape,
|
||||
hidden_sizes=args.hidden_sizes,
|
||||
concat=True,
|
||||
device=args.device,
|
||||
linear_layer=linear,
|
||||
)
|
||||
critics = Critic(
|
||||
net_c,
|
||||
device=args.device,
|
||||
linear_layer=linear,
|
||||
flatten_input=False,
|
||||
).to(args.device)
|
||||
critics_optim = torch.optim.Adam(critics.parameters(), lr=args.critic_lr)
|
||||
|
||||
if args.auto_alpha:
|
||||
target_entropy = -np.prod(env.action_space.shape)
|
||||
log_alpha = torch.zeros(1, requires_grad=True, device=args.device)
|
||||
alpha_optim = torch.optim.Adam([log_alpha], lr=args.alpha_lr)
|
||||
args.alpha = (target_entropy, log_alpha, alpha_optim)
|
||||
|
||||
policy = REDQPolicy(
|
||||
actor,
|
||||
actor_optim,
|
||||
critics,
|
||||
critics_optim,
|
||||
args.ensemble_size,
|
||||
args.subset_size,
|
||||
tau=args.tau,
|
||||
gamma=args.gamma,
|
||||
alpha=args.alpha,
|
||||
estimation_step=args.n_step,
|
||||
actor_delay=args.update_per_step,
|
||||
target_mode=args.target_mode,
|
||||
action_space=env.action_space,
|
||||
)
|
||||
|
||||
# load a previous policy
|
||||
if args.resume_path:
|
||||
policy.load_state_dict(torch.load(args.resume_path, map_location=args.device))
|
||||
print("Loaded agent from: ", args.resume_path)
|
||||
|
||||
# collector
|
||||
if args.training_num > 1:
|
||||
buffer = VectorReplayBuffer(args.buffer_size, len(train_envs))
|
||||
else:
|
||||
buffer = ReplayBuffer(args.buffer_size)
|
||||
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
|
||||
test_collector = Collector(policy, test_envs)
|
||||
train_collector.collect(n_step=args.start_timesteps, random=True)
|
||||
# log
|
||||
t0 = datetime.datetime.now().strftime("%m%d_%H%M%S")
|
||||
log_file = f'seed_{args.seed}_{t0}-{args.task.replace("-", "_")}_redq'
|
||||
log_path = os.path.join(args.logdir, args.task, 'redq', log_file)
|
||||
writer = SummaryWriter(log_path)
|
||||
writer.add_text("args", str(args))
|
||||
logger = TensorboardLogger(writer)
|
||||
|
||||
def save_best_fn(policy):
|
||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||
|
||||
if not args.watch:
|
||||
# trainer
|
||||
result = offpolicy_trainer(
|
||||
policy,
|
||||
train_collector,
|
||||
test_collector,
|
||||
args.epoch,
|
||||
args.step_per_epoch,
|
||||
args.step_per_collect,
|
||||
args.test_num,
|
||||
args.batch_size,
|
||||
save_best_fn=save_best_fn,
|
||||
logger=logger,
|
||||
update_per_step=args.update_per_step,
|
||||
test_in_train=False
|
||||
)
|
||||
pprint.pprint(result)
|
||||
|
||||
# Let's watch its performance!
|
||||
policy.eval()
|
||||
test_envs.seed(args.seed)
|
||||
test_collector.reset()
|
||||
result = test_collector.collect(n_episode=args.test_num, render=args.render)
|
||||
print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_redq()
|
178
test/continuous/test_redq.py
Normal file
178
test/continuous/test_redq.py
Normal file
@ -0,0 +1,178 @@
|
||||
import argparse
|
||||
import os
|
||||
import pprint
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from tianshou.data import Collector, VectorReplayBuffer
|
||||
from tianshou.env import DummyVectorEnv
|
||||
from tianshou.policy import REDQPolicy
|
||||
from tianshou.trainer import offpolicy_trainer
|
||||
from tianshou.utils import TensorboardLogger
|
||||
from tianshou.utils.net.common import EnsembleLinear, Net
|
||||
from tianshou.utils.net.continuous import ActorProb, Critic
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--task', type=str, default='Pendulum-v1')
|
||||
parser.add_argument('--reward-threshold', type=float, default=None)
|
||||
parser.add_argument('--seed', type=int, default=0)
|
||||
parser.add_argument('--buffer-size', type=int, default=20000)
|
||||
parser.add_argument('--ensemble-size', type=int, default=4)
|
||||
parser.add_argument('--subset-size', type=int, default=2)
|
||||
parser.add_argument('--actor-lr', type=float, default=1e-4)
|
||||
parser.add_argument('--critic-lr', type=float, default=1e-3)
|
||||
parser.add_argument('--gamma', type=float, default=0.99)
|
||||
parser.add_argument('--tau', type=float, default=0.005)
|
||||
parser.add_argument('--alpha', type=float, default=0.2)
|
||||
parser.add_argument('--auto-alpha', action='store_true', default=False)
|
||||
parser.add_argument('--alpha-lr', type=float, default=3e-4)
|
||||
parser.add_argument("--start-timesteps", type=int, default=1000)
|
||||
parser.add_argument('--epoch', type=int, default=5)
|
||||
parser.add_argument('--step-per-epoch', type=int, default=5000)
|
||||
parser.add_argument('--step-per-collect', type=int, default=1)
|
||||
parser.add_argument('--update-per-step', type=int, default=3)
|
||||
parser.add_argument('--n-step', type=int, default=1)
|
||||
parser.add_argument('--batch-size', type=int, default=64)
|
||||
parser.add_argument(
|
||||
'--target-mode', type=str, choices=('min', 'mean'), default='min'
|
||||
)
|
||||
parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[64, 64])
|
||||
parser.add_argument('--training-num', type=int, default=8)
|
||||
parser.add_argument('--test-num', type=int, default=100)
|
||||
parser.add_argument('--logdir', type=str, default='log')
|
||||
parser.add_argument('--render', type=float, default=0.)
|
||||
parser.add_argument(
|
||||
'--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu'
|
||||
)
|
||||
args = parser.parse_known_args()[0]
|
||||
return args
|
||||
|
||||
|
||||
def test_redq(args=get_args()):
|
||||
env = gym.make(args.task)
|
||||
args.state_shape = env.observation_space.shape or env.observation_space.n
|
||||
args.action_shape = env.action_space.shape or env.action_space.n
|
||||
args.max_action = env.action_space.high[0]
|
||||
if args.reward_threshold is None:
|
||||
default_reward_threshold = {"Pendulum-v0": -250, "Pendulum-v1": -250}
|
||||
args.reward_threshold = default_reward_threshold.get(
|
||||
args.task, env.spec.reward_threshold
|
||||
)
|
||||
# you can also use tianshou.env.SubprocVectorEnv
|
||||
# train_envs = gym.make(args.task)
|
||||
train_envs = DummyVectorEnv(
|
||||
[lambda: gym.make(args.task) for _ in range(args.training_num)]
|
||||
)
|
||||
# test_envs = gym.make(args.task)
|
||||
test_envs = DummyVectorEnv(
|
||||
[lambda: gym.make(args.task) for _ in range(args.test_num)]
|
||||
)
|
||||
# seed
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
train_envs.seed(args.seed)
|
||||
test_envs.seed(args.seed)
|
||||
# model
|
||||
net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device)
|
||||
actor = ActorProb(
|
||||
net,
|
||||
args.action_shape,
|
||||
max_action=args.max_action,
|
||||
device=args.device,
|
||||
unbounded=True,
|
||||
conditioned_sigma=True
|
||||
).to(args.device)
|
||||
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
|
||||
|
||||
def linear(x, y):
|
||||
return EnsembleLinear(args.ensemble_size, x, y)
|
||||
|
||||
net_c = Net(
|
||||
args.state_shape,
|
||||
args.action_shape,
|
||||
hidden_sizes=args.hidden_sizes,
|
||||
concat=True,
|
||||
device=args.device,
|
||||
linear_layer=linear,
|
||||
)
|
||||
critic = Critic(
|
||||
net_c, device=args.device, linear_layer=linear, flatten_input=False
|
||||
).to(args.device)
|
||||
critic_optim = torch.optim.Adam(critic.parameters(), lr=args.critic_lr)
|
||||
|
||||
if args.auto_alpha:
|
||||
target_entropy = -np.prod(env.action_space.shape)
|
||||
log_alpha = torch.zeros(1, requires_grad=True, device=args.device)
|
||||
alpha_optim = torch.optim.Adam([log_alpha], lr=args.alpha_lr)
|
||||
args.alpha = (target_entropy, log_alpha, alpha_optim)
|
||||
|
||||
policy = REDQPolicy(
|
||||
actor,
|
||||
actor_optim,
|
||||
critic,
|
||||
critic_optim,
|
||||
args.ensemble_size,
|
||||
args.subset_size,
|
||||
tau=args.tau,
|
||||
gamma=args.gamma,
|
||||
alpha=args.alpha,
|
||||
estimation_step=args.n_step,
|
||||
actor_delay=args.update_per_step,
|
||||
target_mode=args.target_mode,
|
||||
action_space=env.action_space,
|
||||
)
|
||||
# collector
|
||||
train_collector = Collector(
|
||||
policy,
|
||||
train_envs,
|
||||
VectorReplayBuffer(args.buffer_size, len(train_envs)),
|
||||
exploration_noise=True
|
||||
)
|
||||
test_collector = Collector(policy, test_envs)
|
||||
train_collector.collect(n_step=args.start_timesteps, random=True)
|
||||
# log
|
||||
log_path = os.path.join(args.logdir, args.task, 'redq')
|
||||
writer = SummaryWriter(log_path)
|
||||
logger = TensorboardLogger(writer)
|
||||
|
||||
def save_best_fn(policy):
|
||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||
|
||||
def stop_fn(mean_rewards):
|
||||
return mean_rewards >= args.reward_threshold
|
||||
|
||||
# trainer
|
||||
result = offpolicy_trainer(
|
||||
policy,
|
||||
train_collector,
|
||||
test_collector,
|
||||
args.epoch,
|
||||
args.step_per_epoch,
|
||||
args.step_per_collect,
|
||||
args.test_num,
|
||||
args.batch_size,
|
||||
update_per_step=args.update_per_step,
|
||||
stop_fn=stop_fn,
|
||||
save_best_fn=save_best_fn,
|
||||
logger=logger
|
||||
)
|
||||
assert stop_fn(result['best_reward'])
|
||||
|
||||
if __name__ == '__main__':
|
||||
pprint.pprint(result)
|
||||
# Let's watch its performance!
|
||||
env = gym.make(args.task)
|
||||
policy.eval()
|
||||
collector = Collector(policy, env)
|
||||
result = collector.collect(n_episode=1, render=args.render)
|
||||
rews, lens = result["rews"], result["lens"]
|
||||
print(f"Final reward: {rews.mean()}, length: {lens.mean()}")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_redq()
|
9
tianshou/env/pettingzoo_env.py
vendored
9
tianshou/env/pettingzoo_env.py
vendored
@ -55,8 +55,8 @@ class PettingZooEnv(AECEnv, ABC):
|
||||
|
||||
self.reset()
|
||||
|
||||
def reset(self) -> dict:
|
||||
self.env.reset()
|
||||
def reset(self, *args: Any, **kwargs: Any) -> dict:
|
||||
self.env.reset(*args, **kwargs)
|
||||
observation = self.env.observe(self.env.agent_selection)
|
||||
if isinstance(observation, dict) and 'action_mask' in observation:
|
||||
return {
|
||||
@ -103,7 +103,10 @@ class PettingZooEnv(AECEnv, ABC):
|
||||
self.env.close()
|
||||
|
||||
def seed(self, seed: Any = None) -> None:
|
||||
self.env.seed(seed)
|
||||
try:
|
||||
self.env.seed(seed)
|
||||
except NotImplementedError:
|
||||
self.env.reset(seed=seed)
|
||||
|
||||
def render(self, mode: str = "human") -> Any:
|
||||
return self.env.render(mode)
|
||||
|
@ -17,6 +17,7 @@ from tianshou.policy.modelfree.ppo import PPOPolicy
|
||||
from tianshou.policy.modelfree.trpo import TRPOPolicy
|
||||
from tianshou.policy.modelfree.td3 import TD3Policy
|
||||
from tianshou.policy.modelfree.sac import SACPolicy
|
||||
from tianshou.policy.modelfree.redq import REDQPolicy
|
||||
from tianshou.policy.modelfree.discrete_sac import DiscreteSACPolicy
|
||||
from tianshou.policy.imitation.base import ImitationPolicy
|
||||
from tianshou.policy.imitation.bcq import BCQPolicy
|
||||
@ -46,6 +47,7 @@ __all__ = [
|
||||
"TRPOPolicy",
|
||||
"TD3Policy",
|
||||
"SACPolicy",
|
||||
"REDQPolicy",
|
||||
"DiscreteSACPolicy",
|
||||
"ImitationPolicy",
|
||||
"BCQPolicy",
|
||||
|
200
tianshou/policy/modelfree/redq.py
Normal file
200
tianshou/policy/modelfree/redq.py
Normal file
@ -0,0 +1,200 @@
|
||||
from copy import deepcopy
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.distributions import Independent, Normal
|
||||
|
||||
from tianshou.data import Batch, ReplayBuffer
|
||||
from tianshou.exploration import BaseNoise
|
||||
from tianshou.policy import DDPGPolicy
|
||||
|
||||
|
||||
class REDQPolicy(DDPGPolicy):
|
||||
"""Implementation of REDQ. arXiv:2101.05982.
|
||||
|
||||
:param torch.nn.Module actor: the actor network following the rules in
|
||||
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
|
||||
:param torch.optim.Optimizer actor_optim: the optimizer for actor network.
|
||||
:param torch.nn.Module critics: critic ensemble networks.
|
||||
:param torch.optim.Optimizer critics_optim: the optimizer for the critic networks.
|
||||
:param int ensemble_size: Number of sub-networks in the critic ensemble.
|
||||
Default to 10.
|
||||
:param int subset_size: Number of networks in the subset. Default to 2.
|
||||
:param float tau: param for soft update of the target network. Default to 0.005.
|
||||
:param float gamma: discount factor, in [0, 1]. Default to 0.99.
|
||||
:param (float, torch.Tensor, torch.optim.Optimizer) or float alpha: entropy
|
||||
regularization coefficient. Default to 0.2.
|
||||
If a tuple (target_entropy, log_alpha, alpha_optim) is provided, then
|
||||
alpha is automatically tuned.
|
||||
:param bool reward_normalization: normalize the reward to Normal(0, 1).
|
||||
Default to False.
|
||||
:param int actor_delay: Number of critic updates before an actor update.
|
||||
Default to 20.
|
||||
:param BaseNoise exploration_noise: add a noise to action for exploration.
|
||||
Default to None. This is useful when solving hard-exploration problem.
|
||||
:param bool deterministic_eval: whether to use deterministic action (mean
|
||||
of Gaussian policy) instead of stochastic action sampled by the policy.
|
||||
Default to True.
|
||||
:param str target_mode: methods to integrate critic values in the subset,
|
||||
currently support minimum and average. Default to min.
|
||||
:param bool action_scaling: whether to map actions from range [-1, 1] to range
|
||||
[action_spaces.low, action_spaces.high]. Default to True.
|
||||
:param str action_bound_method: method to bound action to range [-1, 1], can be
|
||||
either "clip" (for simply clipping the action) or empty string for no bounding.
|
||||
Default to "clip".
|
||||
:param Optional[gym.Space] action_space: env's action space, mandatory if you want
|
||||
to use option "action_scaling" or "action_bound_method". Default to None.
|
||||
|
||||
.. seealso::
|
||||
|
||||
Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed
|
||||
explanation.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
actor: torch.nn.Module,
|
||||
actor_optim: torch.optim.Optimizer,
|
||||
critics: torch.nn.Module,
|
||||
critics_optim: torch.optim.Optimizer,
|
||||
ensemble_size: int = 10,
|
||||
subset_size: int = 2,
|
||||
tau: float = 0.005,
|
||||
gamma: float = 0.99,
|
||||
alpha: Union[float, Tuple[float, torch.Tensor, torch.optim.Optimizer]] = 0.2,
|
||||
reward_normalization: bool = False,
|
||||
estimation_step: int = 1,
|
||||
actor_delay: int = 20,
|
||||
exploration_noise: Optional[BaseNoise] = None,
|
||||
deterministic_eval: bool = True,
|
||||
target_mode: str = "min",
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
None, None, None, None, tau, gamma, exploration_noise,
|
||||
reward_normalization, estimation_step, **kwargs
|
||||
)
|
||||
self.actor, self.actor_optim = actor, actor_optim
|
||||
self.critics, self.critics_old = critics, deepcopy(critics)
|
||||
self.critics_old.eval()
|
||||
self.critics_optim = critics_optim
|
||||
assert 0 < subset_size <= ensemble_size, \
|
||||
"Invalid choice of ensemble size or subset size."
|
||||
self.ensemble_size = ensemble_size
|
||||
self.subset_size = subset_size
|
||||
|
||||
self._is_auto_alpha = False
|
||||
self._alpha: Union[float, torch.Tensor]
|
||||
if isinstance(alpha, tuple):
|
||||
self._is_auto_alpha = True
|
||||
self._target_entropy, self._log_alpha, self._alpha_optim = alpha
|
||||
assert alpha[1].shape == torch.Size([1]) and alpha[1].requires_grad
|
||||
self._alpha = self._log_alpha.detach().exp()
|
||||
else:
|
||||
self._alpha = alpha
|
||||
|
||||
if target_mode in ("min", "mean"):
|
||||
self.target_mode = target_mode
|
||||
else:
|
||||
raise ValueError("Unsupported mode of Q target computing.")
|
||||
|
||||
self.critic_gradient_step = 0
|
||||
self.actor_delay = actor_delay
|
||||
self._deterministic_eval = deterministic_eval
|
||||
self.__eps = np.finfo(np.float32).eps.item()
|
||||
|
||||
def train(self, mode: bool = True) -> "REDQPolicy":
|
||||
self.training = mode
|
||||
self.actor.train(mode)
|
||||
self.critics.train(mode)
|
||||
return self
|
||||
|
||||
def sync_weight(self) -> None:
|
||||
for o, n in zip(self.critics_old.parameters(), self.critics.parameters()):
|
||||
o.data.copy_(o.data * (1.0 - self.tau) + n.data * self.tau)
|
||||
|
||||
def forward( # type: ignore
|
||||
self,
|
||||
batch: Batch,
|
||||
state: Optional[Union[dict, Batch, np.ndarray]] = None,
|
||||
input: str = "obs",
|
||||
**kwargs: Any,
|
||||
) -> Batch:
|
||||
obs = batch[input]
|
||||
logits, h = self.actor(obs, state=state, info=batch.info)
|
||||
assert isinstance(logits, tuple)
|
||||
dist = Independent(Normal(*logits), 1)
|
||||
if self._deterministic_eval and not self.training:
|
||||
act = logits[0]
|
||||
else:
|
||||
act = dist.rsample()
|
||||
log_prob = dist.log_prob(act).unsqueeze(-1)
|
||||
# apply correction for Tanh squashing when computing logprob from Gaussian
|
||||
# You can check out the original SAC paper (arXiv 1801.01290): Eq 21.
|
||||
# in appendix C to get some understanding of this equation.
|
||||
squashed_action = torch.tanh(act)
|
||||
log_prob = log_prob - torch.log((1 - squashed_action.pow(2)) +
|
||||
self.__eps).sum(-1, keepdim=True)
|
||||
return Batch(
|
||||
logits=logits, act=squashed_action, state=h, dist=dist, log_prob=log_prob
|
||||
)
|
||||
|
||||
def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor:
|
||||
batch = buffer[indices] # batch.obs: s_{t+n}
|
||||
obs_next_result = self(batch, input="obs_next")
|
||||
a_ = obs_next_result.act
|
||||
sample_ensemble_idx = np.random.choice(
|
||||
self.ensemble_size, self.subset_size, replace=False
|
||||
)
|
||||
qs = self.critics_old(batch.obs_next, a_)[sample_ensemble_idx, ...]
|
||||
if self.target_mode == "min":
|
||||
target_q, _ = torch.min(qs, dim=0)
|
||||
elif self.target_mode == "mean":
|
||||
target_q = torch.mean(qs, dim=0)
|
||||
target_q -= self._alpha * obs_next_result.log_prob
|
||||
|
||||
return target_q
|
||||
|
||||
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
|
||||
# critic ensemble
|
||||
weight = getattr(batch, "weight", 1.0)
|
||||
current_qs = self.critics(batch.obs, batch.act).flatten(1)
|
||||
target_q = batch.returns.flatten()
|
||||
td = current_qs - target_q
|
||||
critic_loss = (td.pow(2) * weight).mean()
|
||||
self.critics_optim.zero_grad()
|
||||
critic_loss.backward()
|
||||
self.critics_optim.step()
|
||||
batch.weight = torch.mean(td, dim=0) # prio-buffer
|
||||
self.critic_gradient_step += 1
|
||||
|
||||
# actor
|
||||
if self.critic_gradient_step % self.actor_delay == 0:
|
||||
obs_result = self(batch)
|
||||
a = obs_result.act
|
||||
current_qa = self.critics(batch.obs, a).mean(dim=0).flatten()
|
||||
actor_loss = (self._alpha * obs_result.log_prob.flatten() -
|
||||
current_qa).mean()
|
||||
self.actor_optim.zero_grad()
|
||||
actor_loss.backward()
|
||||
self.actor_optim.step()
|
||||
|
||||
if self._is_auto_alpha:
|
||||
log_prob = obs_result.log_prob.detach() + self._target_entropy
|
||||
alpha_loss = -(self._log_alpha * log_prob).mean()
|
||||
self._alpha_optim.zero_grad()
|
||||
alpha_loss.backward()
|
||||
self._alpha_optim.step()
|
||||
self._alpha = self._log_alpha.detach().exp()
|
||||
|
||||
self.sync_weight()
|
||||
|
||||
result = {"loss/critics": critic_loss.item()}
|
||||
if self.critic_gradient_step % self.actor_delay == 0:
|
||||
result["loss/actor"] = actor_loss.item(),
|
||||
if self._is_auto_alpha:
|
||||
result["loss/alpha"] = alpha_loss.item()
|
||||
result["alpha"] = self._alpha.item() # type: ignore
|
||||
|
||||
return result
|
@ -1,4 +1,14 @@
|
||||
from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, Union
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
no_type_check,
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -46,6 +56,7 @@ class MLP(nn.Module):
|
||||
nn.ReLU.
|
||||
:param device: which device to create this model on. Default to None.
|
||||
:param linear_layer: use this module as linear layer. Default to nn.Linear.
|
||||
:param bool flatten_input: whether to flatten input data. Default to True.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -57,6 +68,7 @@ class MLP(nn.Module):
|
||||
activation: Optional[Union[ModuleType, Sequence[ModuleType]]] = nn.ReLU,
|
||||
device: Optional[Union[str, int, torch.device]] = None,
|
||||
linear_layer: Type[nn.Linear] = nn.Linear,
|
||||
flatten_input: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.device = device
|
||||
@ -86,15 +98,15 @@ class MLP(nn.Module):
|
||||
model += [linear_layer(hidden_sizes[-1], output_dim)]
|
||||
self.output_dim = output_dim or hidden_sizes[-1]
|
||||
self.model = nn.Sequential(*model)
|
||||
self.flatten_input = flatten_input
|
||||
|
||||
@no_type_check
|
||||
def forward(self, obs: Union[np.ndarray, torch.Tensor]) -> torch.Tensor:
|
||||
if self.device is not None:
|
||||
obs = torch.as_tensor(
|
||||
obs,
|
||||
device=self.device, # type: ignore
|
||||
dtype=torch.float32,
|
||||
)
|
||||
return self.model(obs.flatten(1)) # type: ignore
|
||||
obs = torch.as_tensor(obs, device=self.device, dtype=torch.float32)
|
||||
if self.flatten_input:
|
||||
obs = obs.flatten(1)
|
||||
return self.model(obs)
|
||||
|
||||
|
||||
class Net(nn.Module):
|
||||
@ -129,6 +141,7 @@ class Net(nn.Module):
|
||||
pass a tuple of two dict (first for Q and second for V) stating
|
||||
self-defined arguments as stated in
|
||||
class:`~tianshou.utils.net.common.MLP`. Default to None.
|
||||
:param linear_layer: use this module as linear layer. Default to nn.Linear.
|
||||
|
||||
.. seealso::
|
||||
|
||||
@ -152,6 +165,7 @@ class Net(nn.Module):
|
||||
concat: bool = False,
|
||||
num_atoms: int = 1,
|
||||
dueling_param: Optional[Tuple[Dict[str, Any], Dict[str, Any]]] = None,
|
||||
linear_layer: Type[nn.Linear] = nn.Linear,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.device = device
|
||||
@ -164,7 +178,8 @@ class Net(nn.Module):
|
||||
self.use_dueling = dueling_param is not None
|
||||
output_dim = action_dim if not self.use_dueling and not concat else 0
|
||||
self.model = MLP(
|
||||
input_dim, output_dim, hidden_sizes, norm_layer, activation, device
|
||||
input_dim, output_dim, hidden_sizes, norm_layer, activation, device,
|
||||
linear_layer
|
||||
)
|
||||
self.output_dim = self.model.output_dim
|
||||
if self.use_dueling: # dueling DQN
|
||||
@ -311,3 +326,40 @@ class DataParallelNet(nn.Module):
|
||||
if not isinstance(obs, torch.Tensor):
|
||||
obs = torch.as_tensor(obs, dtype=torch.float32)
|
||||
return self.net(obs=obs.cuda(), *args, **kwargs)
|
||||
|
||||
|
||||
class EnsembleLinear(nn.Module):
|
||||
"""Linear Layer of Ensemble network.
|
||||
|
||||
:param int ensemble_size: Number of subnets in the ensemble.
|
||||
:param int inp_feature: dimension of the input vector.
|
||||
:param int out_feature: dimension of the output vector.
|
||||
:param bool bias: whether to include an additive bias, default to be True.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ensemble_size: int,
|
||||
in_feature: int,
|
||||
out_feature: int,
|
||||
bias: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
# To be consistent with PyTorch default initializer
|
||||
k = np.sqrt(1. / in_feature)
|
||||
weight_data = torch.rand((ensemble_size, in_feature, out_feature)) * 2 * k - k
|
||||
self.weight = nn.Parameter(weight_data, requires_grad=True)
|
||||
|
||||
self.bias: Union[nn.Parameter, None]
|
||||
if bias:
|
||||
bias_data = torch.rand((ensemble_size, 1, out_feature)) * 2 * k - k
|
||||
self.bias = nn.Parameter(bias_data, requires_grad=True)
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = torch.matmul(x, self.weight)
|
||||
if self.bias is not None:
|
||||
x = x + self.bias
|
||||
return x
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import Any, Dict, Optional, Sequence, Tuple, Union
|
||||
from typing import Any, Dict, Optional, Sequence, Tuple, Type, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -79,6 +79,9 @@ class Critic(nn.Module):
|
||||
only a single linear layer).
|
||||
:param int preprocess_net_output_dim: the output dimension of
|
||||
preprocess_net.
|
||||
:param linear_layer: use this module as linear layer. Default to nn.Linear.
|
||||
:param bool flatten_input: whether to flatten input data for the last layer.
|
||||
Default to True.
|
||||
|
||||
For advanced usage (how to customize the network), please refer to
|
||||
:ref:`build_the_network`.
|
||||
@ -95,6 +98,8 @@ class Critic(nn.Module):
|
||||
hidden_sizes: Sequence[int] = (),
|
||||
device: Union[str, int, torch.device] = "cpu",
|
||||
preprocess_net_output_dim: Optional[int] = None,
|
||||
linear_layer: Type[nn.Linear] = nn.Linear,
|
||||
flatten_input: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.device = device
|
||||
@ -105,7 +110,9 @@ class Critic(nn.Module):
|
||||
input_dim, # type: ignore
|
||||
1,
|
||||
hidden_sizes,
|
||||
device=self.device
|
||||
device=self.device,
|
||||
linear_layer=linear_layer,
|
||||
flatten_input=flatten_input,
|
||||
)
|
||||
|
||||
def forward(
|
||||
|
Loading…
x
Reference in New Issue
Block a user