implement REDQ based on original contribution by @Jimenius (#623)
Co-authored-by: Minhui Li <limh@lamda.nju.edu.cn>
This commit is contained in:
parent
41afc2584a
commit
dd16818ce4
@ -34,6 +34,7 @@
|
|||||||
- [Deep Deterministic Policy Gradient (DDPG)](https://arxiv.org/pdf/1509.02971.pdf)
|
- [Deep Deterministic Policy Gradient (DDPG)](https://arxiv.org/pdf/1509.02971.pdf)
|
||||||
- [Twin Delayed DDPG (TD3)](https://arxiv.org/pdf/1802.09477.pdf)
|
- [Twin Delayed DDPG (TD3)](https://arxiv.org/pdf/1802.09477.pdf)
|
||||||
- [Soft Actor-Critic (SAC)](https://arxiv.org/pdf/1812.05905.pdf)
|
- [Soft Actor-Critic (SAC)](https://arxiv.org/pdf/1812.05905.pdf)
|
||||||
|
- [Randomized Ensembled Double Q-Learning (REDQ)](https://arxiv.org/pdf/2101.05982.pdf)
|
||||||
- [Discrete Soft Actor-Critic (SAC-Discrete)](https://arxiv.org/pdf/1910.07207.pdf)
|
- [Discrete Soft Actor-Critic (SAC-Discrete)](https://arxiv.org/pdf/1910.07207.pdf)
|
||||||
- Vanilla Imitation Learning
|
- Vanilla Imitation Learning
|
||||||
- [Batch-Constrained deep Q-Learning (BCQ)](https://arxiv.org/pdf/1812.02900.pdf)
|
- [Batch-Constrained deep Q-Learning (BCQ)](https://arxiv.org/pdf/1812.02900.pdf)
|
||||||
|
@ -96,6 +96,11 @@ Off-policy
|
|||||||
:undoc-members:
|
:undoc-members:
|
||||||
:show-inheritance:
|
:show-inheritance:
|
||||||
|
|
||||||
|
.. autoclass:: tianshou.policy.REDQPolicy
|
||||||
|
:members:
|
||||||
|
:undoc-members:
|
||||||
|
:show-inheritance:
|
||||||
|
|
||||||
.. autoclass:: tianshou.policy.DiscreteSACPolicy
|
.. autoclass:: tianshou.policy.DiscreteSACPolicy
|
||||||
:members:
|
:members:
|
||||||
:undoc-members:
|
:undoc-members:
|
||||||
|
@ -25,6 +25,7 @@ Welcome to Tianshou!
|
|||||||
* :class:`~tianshou.policy.DDPGPolicy` `Deep Deterministic Policy Gradient <https://arxiv.org/pdf/1509.02971.pdf>`_
|
* :class:`~tianshou.policy.DDPGPolicy` `Deep Deterministic Policy Gradient <https://arxiv.org/pdf/1509.02971.pdf>`_
|
||||||
* :class:`~tianshou.policy.TD3Policy` `Twin Delayed DDPG <https://arxiv.org/pdf/1802.09477.pdf>`_
|
* :class:`~tianshou.policy.TD3Policy` `Twin Delayed DDPG <https://arxiv.org/pdf/1802.09477.pdf>`_
|
||||||
* :class:`~tianshou.policy.SACPolicy` `Soft Actor-Critic <https://arxiv.org/pdf/1812.05905.pdf>`_
|
* :class:`~tianshou.policy.SACPolicy` `Soft Actor-Critic <https://arxiv.org/pdf/1812.05905.pdf>`_
|
||||||
|
* :class:`~tianshou.policy.REDQPolicy` `Randomized Ensembled Double Q-Learning <https://arxiv.org/pdf/2101.05982.pdf>`_
|
||||||
* :class:`~tianshou.policy.DiscreteSACPolicy` `Discrete Soft Actor-Critic <https://arxiv.org/pdf/1910.07207.pdf>`_
|
* :class:`~tianshou.policy.DiscreteSACPolicy` `Discrete Soft Actor-Critic <https://arxiv.org/pdf/1910.07207.pdf>`_
|
||||||
* :class:`~tianshou.policy.ImitationPolicy` Imitation Learning
|
* :class:`~tianshou.policy.ImitationPolicy` Imitation Learning
|
||||||
* :class:`~tianshou.policy.BCQPolicy` `Batch-Constrained deep Q-Learning <https://arxiv.org/pdf/1812.02900.pdf>`_
|
* :class:`~tianshou.policy.BCQPolicy` `Batch-Constrained deep Q-Learning <https://arxiv.org/pdf/1812.02900.pdf>`_
|
||||||
|
@ -157,3 +157,4 @@ Nvidia
|
|||||||
Enduro
|
Enduro
|
||||||
Qbert
|
Qbert
|
||||||
Seaquest
|
Seaquest
|
||||||
|
subnets
|
||||||
|
192
examples/mujoco/mujoco_redq.py
Executable file
192
examples/mujoco/mujoco_redq.py
Executable file
@ -0,0 +1,192 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import datetime
|
||||||
|
import os
|
||||||
|
import pprint
|
||||||
|
|
||||||
|
import gym
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
|
from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer
|
||||||
|
from tianshou.env import SubprocVectorEnv
|
||||||
|
from tianshou.policy import REDQPolicy
|
||||||
|
from tianshou.trainer import offpolicy_trainer
|
||||||
|
from tianshou.utils import TensorboardLogger
|
||||||
|
from tianshou.utils.net.common import EnsembleLinear, Net
|
||||||
|
from tianshou.utils.net.continuous import ActorProb, Critic
|
||||||
|
|
||||||
|
|
||||||
|
def get_args():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--task', type=str, default='Ant-v3')
|
||||||
|
parser.add_argument('--seed', type=int, default=0)
|
||||||
|
parser.add_argument('--buffer-size', type=int, default=1000000)
|
||||||
|
parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[256, 256])
|
||||||
|
parser.add_argument('--ensemble-size', type=int, default=10)
|
||||||
|
parser.add_argument('--subset-size', type=int, default=2)
|
||||||
|
parser.add_argument('--actor-lr', type=float, default=1e-3)
|
||||||
|
parser.add_argument('--critic-lr', type=float, default=1e-3)
|
||||||
|
parser.add_argument('--gamma', type=float, default=0.99)
|
||||||
|
parser.add_argument('--tau', type=float, default=0.005)
|
||||||
|
parser.add_argument('--alpha', type=float, default=0.2)
|
||||||
|
parser.add_argument('--auto-alpha', default=False, action='store_true')
|
||||||
|
parser.add_argument('--alpha-lr', type=float, default=3e-4)
|
||||||
|
parser.add_argument("--start-timesteps", type=int, default=10000)
|
||||||
|
parser.add_argument('--epoch', type=int, default=200)
|
||||||
|
parser.add_argument('--step-per-epoch', type=int, default=5000)
|
||||||
|
parser.add_argument('--step-per-collect', type=int, default=1)
|
||||||
|
parser.add_argument('--update-per-step', type=int, default=20)
|
||||||
|
parser.add_argument('--n-step', type=int, default=1)
|
||||||
|
parser.add_argument('--batch-size', type=int, default=256)
|
||||||
|
parser.add_argument(
|
||||||
|
'--target-mode', type=str, choices=('min', 'mean'), default='min'
|
||||||
|
)
|
||||||
|
parser.add_argument('--training-num', type=int, default=1)
|
||||||
|
parser.add_argument('--test-num', type=int, default=10)
|
||||||
|
parser.add_argument('--logdir', type=str, default='log')
|
||||||
|
parser.add_argument('--render', type=float, default=0.)
|
||||||
|
parser.add_argument(
|
||||||
|
'--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu'
|
||||||
|
)
|
||||||
|
parser.add_argument('--resume-path', type=str, default=None)
|
||||||
|
parser.add_argument(
|
||||||
|
'--watch',
|
||||||
|
default=False,
|
||||||
|
action='store_true',
|
||||||
|
help='watch the play of pre-trained policy only'
|
||||||
|
)
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def test_redq(args=get_args()):
|
||||||
|
env = gym.make(args.task)
|
||||||
|
args.state_shape = env.observation_space.shape or env.observation_space.n
|
||||||
|
args.action_shape = env.action_space.shape or env.action_space.n
|
||||||
|
args.max_action = env.action_space.high[0]
|
||||||
|
print("Observations shape:", args.state_shape)
|
||||||
|
print("Actions shape:", args.action_shape)
|
||||||
|
print("Action range:", np.min(env.action_space.low), np.max(env.action_space.high))
|
||||||
|
# train_envs = gym.make(args.task)
|
||||||
|
if args.training_num > 1:
|
||||||
|
train_envs = SubprocVectorEnv(
|
||||||
|
[lambda: gym.make(args.task) for _ in range(args.training_num)]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
train_envs = gym.make(args.task)
|
||||||
|
# test_envs = gym.make(args.task)
|
||||||
|
test_envs = SubprocVectorEnv(
|
||||||
|
[lambda: gym.make(args.task) for _ in range(args.test_num)]
|
||||||
|
)
|
||||||
|
# seed
|
||||||
|
np.random.seed(args.seed)
|
||||||
|
torch.manual_seed(args.seed)
|
||||||
|
train_envs.seed(args.seed)
|
||||||
|
test_envs.seed(args.seed)
|
||||||
|
# model
|
||||||
|
net_a = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device)
|
||||||
|
actor = ActorProb(
|
||||||
|
net_a,
|
||||||
|
args.action_shape,
|
||||||
|
max_action=args.max_action,
|
||||||
|
device=args.device,
|
||||||
|
unbounded=True,
|
||||||
|
conditioned_sigma=True
|
||||||
|
).to(args.device)
|
||||||
|
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
|
||||||
|
|
||||||
|
def linear(x, y):
|
||||||
|
return EnsembleLinear(args.ensemble_size, x, y)
|
||||||
|
|
||||||
|
net_c = Net(
|
||||||
|
args.state_shape,
|
||||||
|
args.action_shape,
|
||||||
|
hidden_sizes=args.hidden_sizes,
|
||||||
|
concat=True,
|
||||||
|
device=args.device,
|
||||||
|
linear_layer=linear,
|
||||||
|
)
|
||||||
|
critics = Critic(
|
||||||
|
net_c,
|
||||||
|
device=args.device,
|
||||||
|
linear_layer=linear,
|
||||||
|
flatten_input=False,
|
||||||
|
).to(args.device)
|
||||||
|
critics_optim = torch.optim.Adam(critics.parameters(), lr=args.critic_lr)
|
||||||
|
|
||||||
|
if args.auto_alpha:
|
||||||
|
target_entropy = -np.prod(env.action_space.shape)
|
||||||
|
log_alpha = torch.zeros(1, requires_grad=True, device=args.device)
|
||||||
|
alpha_optim = torch.optim.Adam([log_alpha], lr=args.alpha_lr)
|
||||||
|
args.alpha = (target_entropy, log_alpha, alpha_optim)
|
||||||
|
|
||||||
|
policy = REDQPolicy(
|
||||||
|
actor,
|
||||||
|
actor_optim,
|
||||||
|
critics,
|
||||||
|
critics_optim,
|
||||||
|
args.ensemble_size,
|
||||||
|
args.subset_size,
|
||||||
|
tau=args.tau,
|
||||||
|
gamma=args.gamma,
|
||||||
|
alpha=args.alpha,
|
||||||
|
estimation_step=args.n_step,
|
||||||
|
actor_delay=args.update_per_step,
|
||||||
|
target_mode=args.target_mode,
|
||||||
|
action_space=env.action_space,
|
||||||
|
)
|
||||||
|
|
||||||
|
# load a previous policy
|
||||||
|
if args.resume_path:
|
||||||
|
policy.load_state_dict(torch.load(args.resume_path, map_location=args.device))
|
||||||
|
print("Loaded agent from: ", args.resume_path)
|
||||||
|
|
||||||
|
# collector
|
||||||
|
if args.training_num > 1:
|
||||||
|
buffer = VectorReplayBuffer(args.buffer_size, len(train_envs))
|
||||||
|
else:
|
||||||
|
buffer = ReplayBuffer(args.buffer_size)
|
||||||
|
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
|
||||||
|
test_collector = Collector(policy, test_envs)
|
||||||
|
train_collector.collect(n_step=args.start_timesteps, random=True)
|
||||||
|
# log
|
||||||
|
t0 = datetime.datetime.now().strftime("%m%d_%H%M%S")
|
||||||
|
log_file = f'seed_{args.seed}_{t0}-{args.task.replace("-", "_")}_redq'
|
||||||
|
log_path = os.path.join(args.logdir, args.task, 'redq', log_file)
|
||||||
|
writer = SummaryWriter(log_path)
|
||||||
|
writer.add_text("args", str(args))
|
||||||
|
logger = TensorboardLogger(writer)
|
||||||
|
|
||||||
|
def save_best_fn(policy):
|
||||||
|
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||||
|
|
||||||
|
if not args.watch:
|
||||||
|
# trainer
|
||||||
|
result = offpolicy_trainer(
|
||||||
|
policy,
|
||||||
|
train_collector,
|
||||||
|
test_collector,
|
||||||
|
args.epoch,
|
||||||
|
args.step_per_epoch,
|
||||||
|
args.step_per_collect,
|
||||||
|
args.test_num,
|
||||||
|
args.batch_size,
|
||||||
|
save_best_fn=save_best_fn,
|
||||||
|
logger=logger,
|
||||||
|
update_per_step=args.update_per_step,
|
||||||
|
test_in_train=False
|
||||||
|
)
|
||||||
|
pprint.pprint(result)
|
||||||
|
|
||||||
|
# Let's watch its performance!
|
||||||
|
policy.eval()
|
||||||
|
test_envs.seed(args.seed)
|
||||||
|
test_collector.reset()
|
||||||
|
result = test_collector.collect(n_episode=args.test_num, render=args.render)
|
||||||
|
print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}')
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_redq()
|
178
test/continuous/test_redq.py
Normal file
178
test/continuous/test_redq.py
Normal file
@ -0,0 +1,178 @@
|
|||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import pprint
|
||||||
|
|
||||||
|
import gym
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
|
from tianshou.data import Collector, VectorReplayBuffer
|
||||||
|
from tianshou.env import DummyVectorEnv
|
||||||
|
from tianshou.policy import REDQPolicy
|
||||||
|
from tianshou.trainer import offpolicy_trainer
|
||||||
|
from tianshou.utils import TensorboardLogger
|
||||||
|
from tianshou.utils.net.common import EnsembleLinear, Net
|
||||||
|
from tianshou.utils.net.continuous import ActorProb, Critic
|
||||||
|
|
||||||
|
|
||||||
|
def get_args():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--task', type=str, default='Pendulum-v1')
|
||||||
|
parser.add_argument('--reward-threshold', type=float, default=None)
|
||||||
|
parser.add_argument('--seed', type=int, default=0)
|
||||||
|
parser.add_argument('--buffer-size', type=int, default=20000)
|
||||||
|
parser.add_argument('--ensemble-size', type=int, default=4)
|
||||||
|
parser.add_argument('--subset-size', type=int, default=2)
|
||||||
|
parser.add_argument('--actor-lr', type=float, default=1e-4)
|
||||||
|
parser.add_argument('--critic-lr', type=float, default=1e-3)
|
||||||
|
parser.add_argument('--gamma', type=float, default=0.99)
|
||||||
|
parser.add_argument('--tau', type=float, default=0.005)
|
||||||
|
parser.add_argument('--alpha', type=float, default=0.2)
|
||||||
|
parser.add_argument('--auto-alpha', action='store_true', default=False)
|
||||||
|
parser.add_argument('--alpha-lr', type=float, default=3e-4)
|
||||||
|
parser.add_argument("--start-timesteps", type=int, default=1000)
|
||||||
|
parser.add_argument('--epoch', type=int, default=5)
|
||||||
|
parser.add_argument('--step-per-epoch', type=int, default=5000)
|
||||||
|
parser.add_argument('--step-per-collect', type=int, default=1)
|
||||||
|
parser.add_argument('--update-per-step', type=int, default=3)
|
||||||
|
parser.add_argument('--n-step', type=int, default=1)
|
||||||
|
parser.add_argument('--batch-size', type=int, default=64)
|
||||||
|
parser.add_argument(
|
||||||
|
'--target-mode', type=str, choices=('min', 'mean'), default='min'
|
||||||
|
)
|
||||||
|
parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[64, 64])
|
||||||
|
parser.add_argument('--training-num', type=int, default=8)
|
||||||
|
parser.add_argument('--test-num', type=int, default=100)
|
||||||
|
parser.add_argument('--logdir', type=str, default='log')
|
||||||
|
parser.add_argument('--render', type=float, default=0.)
|
||||||
|
parser.add_argument(
|
||||||
|
'--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu'
|
||||||
|
)
|
||||||
|
args = parser.parse_known_args()[0]
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
def test_redq(args=get_args()):
|
||||||
|
env = gym.make(args.task)
|
||||||
|
args.state_shape = env.observation_space.shape or env.observation_space.n
|
||||||
|
args.action_shape = env.action_space.shape or env.action_space.n
|
||||||
|
args.max_action = env.action_space.high[0]
|
||||||
|
if args.reward_threshold is None:
|
||||||
|
default_reward_threshold = {"Pendulum-v0": -250, "Pendulum-v1": -250}
|
||||||
|
args.reward_threshold = default_reward_threshold.get(
|
||||||
|
args.task, env.spec.reward_threshold
|
||||||
|
)
|
||||||
|
# you can also use tianshou.env.SubprocVectorEnv
|
||||||
|
# train_envs = gym.make(args.task)
|
||||||
|
train_envs = DummyVectorEnv(
|
||||||
|
[lambda: gym.make(args.task) for _ in range(args.training_num)]
|
||||||
|
)
|
||||||
|
# test_envs = gym.make(args.task)
|
||||||
|
test_envs = DummyVectorEnv(
|
||||||
|
[lambda: gym.make(args.task) for _ in range(args.test_num)]
|
||||||
|
)
|
||||||
|
# seed
|
||||||
|
np.random.seed(args.seed)
|
||||||
|
torch.manual_seed(args.seed)
|
||||||
|
train_envs.seed(args.seed)
|
||||||
|
test_envs.seed(args.seed)
|
||||||
|
# model
|
||||||
|
net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device)
|
||||||
|
actor = ActorProb(
|
||||||
|
net,
|
||||||
|
args.action_shape,
|
||||||
|
max_action=args.max_action,
|
||||||
|
device=args.device,
|
||||||
|
unbounded=True,
|
||||||
|
conditioned_sigma=True
|
||||||
|
).to(args.device)
|
||||||
|
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
|
||||||
|
|
||||||
|
def linear(x, y):
|
||||||
|
return EnsembleLinear(args.ensemble_size, x, y)
|
||||||
|
|
||||||
|
net_c = Net(
|
||||||
|
args.state_shape,
|
||||||
|
args.action_shape,
|
||||||
|
hidden_sizes=args.hidden_sizes,
|
||||||
|
concat=True,
|
||||||
|
device=args.device,
|
||||||
|
linear_layer=linear,
|
||||||
|
)
|
||||||
|
critic = Critic(
|
||||||
|
net_c, device=args.device, linear_layer=linear, flatten_input=False
|
||||||
|
).to(args.device)
|
||||||
|
critic_optim = torch.optim.Adam(critic.parameters(), lr=args.critic_lr)
|
||||||
|
|
||||||
|
if args.auto_alpha:
|
||||||
|
target_entropy = -np.prod(env.action_space.shape)
|
||||||
|
log_alpha = torch.zeros(1, requires_grad=True, device=args.device)
|
||||||
|
alpha_optim = torch.optim.Adam([log_alpha], lr=args.alpha_lr)
|
||||||
|
args.alpha = (target_entropy, log_alpha, alpha_optim)
|
||||||
|
|
||||||
|
policy = REDQPolicy(
|
||||||
|
actor,
|
||||||
|
actor_optim,
|
||||||
|
critic,
|
||||||
|
critic_optim,
|
||||||
|
args.ensemble_size,
|
||||||
|
args.subset_size,
|
||||||
|
tau=args.tau,
|
||||||
|
gamma=args.gamma,
|
||||||
|
alpha=args.alpha,
|
||||||
|
estimation_step=args.n_step,
|
||||||
|
actor_delay=args.update_per_step,
|
||||||
|
target_mode=args.target_mode,
|
||||||
|
action_space=env.action_space,
|
||||||
|
)
|
||||||
|
# collector
|
||||||
|
train_collector = Collector(
|
||||||
|
policy,
|
||||||
|
train_envs,
|
||||||
|
VectorReplayBuffer(args.buffer_size, len(train_envs)),
|
||||||
|
exploration_noise=True
|
||||||
|
)
|
||||||
|
test_collector = Collector(policy, test_envs)
|
||||||
|
train_collector.collect(n_step=args.start_timesteps, random=True)
|
||||||
|
# log
|
||||||
|
log_path = os.path.join(args.logdir, args.task, 'redq')
|
||||||
|
writer = SummaryWriter(log_path)
|
||||||
|
logger = TensorboardLogger(writer)
|
||||||
|
|
||||||
|
def save_best_fn(policy):
|
||||||
|
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||||
|
|
||||||
|
def stop_fn(mean_rewards):
|
||||||
|
return mean_rewards >= args.reward_threshold
|
||||||
|
|
||||||
|
# trainer
|
||||||
|
result = offpolicy_trainer(
|
||||||
|
policy,
|
||||||
|
train_collector,
|
||||||
|
test_collector,
|
||||||
|
args.epoch,
|
||||||
|
args.step_per_epoch,
|
||||||
|
args.step_per_collect,
|
||||||
|
args.test_num,
|
||||||
|
args.batch_size,
|
||||||
|
update_per_step=args.update_per_step,
|
||||||
|
stop_fn=stop_fn,
|
||||||
|
save_best_fn=save_best_fn,
|
||||||
|
logger=logger
|
||||||
|
)
|
||||||
|
assert stop_fn(result['best_reward'])
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
pprint.pprint(result)
|
||||||
|
# Let's watch its performance!
|
||||||
|
env = gym.make(args.task)
|
||||||
|
policy.eval()
|
||||||
|
collector = Collector(policy, env)
|
||||||
|
result = collector.collect(n_episode=1, render=args.render)
|
||||||
|
rews, lens = result["rews"], result["lens"]
|
||||||
|
print(f"Final reward: {rews.mean()}, length: {lens.mean()}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_redq()
|
9
tianshou/env/pettingzoo_env.py
vendored
9
tianshou/env/pettingzoo_env.py
vendored
@ -55,8 +55,8 @@ class PettingZooEnv(AECEnv, ABC):
|
|||||||
|
|
||||||
self.reset()
|
self.reset()
|
||||||
|
|
||||||
def reset(self) -> dict:
|
def reset(self, *args: Any, **kwargs: Any) -> dict:
|
||||||
self.env.reset()
|
self.env.reset(*args, **kwargs)
|
||||||
observation = self.env.observe(self.env.agent_selection)
|
observation = self.env.observe(self.env.agent_selection)
|
||||||
if isinstance(observation, dict) and 'action_mask' in observation:
|
if isinstance(observation, dict) and 'action_mask' in observation:
|
||||||
return {
|
return {
|
||||||
@ -103,7 +103,10 @@ class PettingZooEnv(AECEnv, ABC):
|
|||||||
self.env.close()
|
self.env.close()
|
||||||
|
|
||||||
def seed(self, seed: Any = None) -> None:
|
def seed(self, seed: Any = None) -> None:
|
||||||
self.env.seed(seed)
|
try:
|
||||||
|
self.env.seed(seed)
|
||||||
|
except NotImplementedError:
|
||||||
|
self.env.reset(seed=seed)
|
||||||
|
|
||||||
def render(self, mode: str = "human") -> Any:
|
def render(self, mode: str = "human") -> Any:
|
||||||
return self.env.render(mode)
|
return self.env.render(mode)
|
||||||
|
@ -17,6 +17,7 @@ from tianshou.policy.modelfree.ppo import PPOPolicy
|
|||||||
from tianshou.policy.modelfree.trpo import TRPOPolicy
|
from tianshou.policy.modelfree.trpo import TRPOPolicy
|
||||||
from tianshou.policy.modelfree.td3 import TD3Policy
|
from tianshou.policy.modelfree.td3 import TD3Policy
|
||||||
from tianshou.policy.modelfree.sac import SACPolicy
|
from tianshou.policy.modelfree.sac import SACPolicy
|
||||||
|
from tianshou.policy.modelfree.redq import REDQPolicy
|
||||||
from tianshou.policy.modelfree.discrete_sac import DiscreteSACPolicy
|
from tianshou.policy.modelfree.discrete_sac import DiscreteSACPolicy
|
||||||
from tianshou.policy.imitation.base import ImitationPolicy
|
from tianshou.policy.imitation.base import ImitationPolicy
|
||||||
from tianshou.policy.imitation.bcq import BCQPolicy
|
from tianshou.policy.imitation.bcq import BCQPolicy
|
||||||
@ -46,6 +47,7 @@ __all__ = [
|
|||||||
"TRPOPolicy",
|
"TRPOPolicy",
|
||||||
"TD3Policy",
|
"TD3Policy",
|
||||||
"SACPolicy",
|
"SACPolicy",
|
||||||
|
"REDQPolicy",
|
||||||
"DiscreteSACPolicy",
|
"DiscreteSACPolicy",
|
||||||
"ImitationPolicy",
|
"ImitationPolicy",
|
||||||
"BCQPolicy",
|
"BCQPolicy",
|
||||||
|
200
tianshou/policy/modelfree/redq.py
Normal file
200
tianshou/policy/modelfree/redq.py
Normal file
@ -0,0 +1,200 @@
|
|||||||
|
from copy import deepcopy
|
||||||
|
from typing import Any, Dict, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from torch.distributions import Independent, Normal
|
||||||
|
|
||||||
|
from tianshou.data import Batch, ReplayBuffer
|
||||||
|
from tianshou.exploration import BaseNoise
|
||||||
|
from tianshou.policy import DDPGPolicy
|
||||||
|
|
||||||
|
|
||||||
|
class REDQPolicy(DDPGPolicy):
|
||||||
|
"""Implementation of REDQ. arXiv:2101.05982.
|
||||||
|
|
||||||
|
:param torch.nn.Module actor: the actor network following the rules in
|
||||||
|
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
|
||||||
|
:param torch.optim.Optimizer actor_optim: the optimizer for actor network.
|
||||||
|
:param torch.nn.Module critics: critic ensemble networks.
|
||||||
|
:param torch.optim.Optimizer critics_optim: the optimizer for the critic networks.
|
||||||
|
:param int ensemble_size: Number of sub-networks in the critic ensemble.
|
||||||
|
Default to 10.
|
||||||
|
:param int subset_size: Number of networks in the subset. Default to 2.
|
||||||
|
:param float tau: param for soft update of the target network. Default to 0.005.
|
||||||
|
:param float gamma: discount factor, in [0, 1]. Default to 0.99.
|
||||||
|
:param (float, torch.Tensor, torch.optim.Optimizer) or float alpha: entropy
|
||||||
|
regularization coefficient. Default to 0.2.
|
||||||
|
If a tuple (target_entropy, log_alpha, alpha_optim) is provided, then
|
||||||
|
alpha is automatically tuned.
|
||||||
|
:param bool reward_normalization: normalize the reward to Normal(0, 1).
|
||||||
|
Default to False.
|
||||||
|
:param int actor_delay: Number of critic updates before an actor update.
|
||||||
|
Default to 20.
|
||||||
|
:param BaseNoise exploration_noise: add a noise to action for exploration.
|
||||||
|
Default to None. This is useful when solving hard-exploration problem.
|
||||||
|
:param bool deterministic_eval: whether to use deterministic action (mean
|
||||||
|
of Gaussian policy) instead of stochastic action sampled by the policy.
|
||||||
|
Default to True.
|
||||||
|
:param str target_mode: methods to integrate critic values in the subset,
|
||||||
|
currently support minimum and average. Default to min.
|
||||||
|
:param bool action_scaling: whether to map actions from range [-1, 1] to range
|
||||||
|
[action_spaces.low, action_spaces.high]. Default to True.
|
||||||
|
:param str action_bound_method: method to bound action to range [-1, 1], can be
|
||||||
|
either "clip" (for simply clipping the action) or empty string for no bounding.
|
||||||
|
Default to "clip".
|
||||||
|
:param Optional[gym.Space] action_space: env's action space, mandatory if you want
|
||||||
|
to use option "action_scaling" or "action_bound_method". Default to None.
|
||||||
|
|
||||||
|
.. seealso::
|
||||||
|
|
||||||
|
Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed
|
||||||
|
explanation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
actor: torch.nn.Module,
|
||||||
|
actor_optim: torch.optim.Optimizer,
|
||||||
|
critics: torch.nn.Module,
|
||||||
|
critics_optim: torch.optim.Optimizer,
|
||||||
|
ensemble_size: int = 10,
|
||||||
|
subset_size: int = 2,
|
||||||
|
tau: float = 0.005,
|
||||||
|
gamma: float = 0.99,
|
||||||
|
alpha: Union[float, Tuple[float, torch.Tensor, torch.optim.Optimizer]] = 0.2,
|
||||||
|
reward_normalization: bool = False,
|
||||||
|
estimation_step: int = 1,
|
||||||
|
actor_delay: int = 20,
|
||||||
|
exploration_noise: Optional[BaseNoise] = None,
|
||||||
|
deterministic_eval: bool = True,
|
||||||
|
target_mode: str = "min",
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(
|
||||||
|
None, None, None, None, tau, gamma, exploration_noise,
|
||||||
|
reward_normalization, estimation_step, **kwargs
|
||||||
|
)
|
||||||
|
self.actor, self.actor_optim = actor, actor_optim
|
||||||
|
self.critics, self.critics_old = critics, deepcopy(critics)
|
||||||
|
self.critics_old.eval()
|
||||||
|
self.critics_optim = critics_optim
|
||||||
|
assert 0 < subset_size <= ensemble_size, \
|
||||||
|
"Invalid choice of ensemble size or subset size."
|
||||||
|
self.ensemble_size = ensemble_size
|
||||||
|
self.subset_size = subset_size
|
||||||
|
|
||||||
|
self._is_auto_alpha = False
|
||||||
|
self._alpha: Union[float, torch.Tensor]
|
||||||
|
if isinstance(alpha, tuple):
|
||||||
|
self._is_auto_alpha = True
|
||||||
|
self._target_entropy, self._log_alpha, self._alpha_optim = alpha
|
||||||
|
assert alpha[1].shape == torch.Size([1]) and alpha[1].requires_grad
|
||||||
|
self._alpha = self._log_alpha.detach().exp()
|
||||||
|
else:
|
||||||
|
self._alpha = alpha
|
||||||
|
|
||||||
|
if target_mode in ("min", "mean"):
|
||||||
|
self.target_mode = target_mode
|
||||||
|
else:
|
||||||
|
raise ValueError("Unsupported mode of Q target computing.")
|
||||||
|
|
||||||
|
self.critic_gradient_step = 0
|
||||||
|
self.actor_delay = actor_delay
|
||||||
|
self._deterministic_eval = deterministic_eval
|
||||||
|
self.__eps = np.finfo(np.float32).eps.item()
|
||||||
|
|
||||||
|
def train(self, mode: bool = True) -> "REDQPolicy":
|
||||||
|
self.training = mode
|
||||||
|
self.actor.train(mode)
|
||||||
|
self.critics.train(mode)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def sync_weight(self) -> None:
|
||||||
|
for o, n in zip(self.critics_old.parameters(), self.critics.parameters()):
|
||||||
|
o.data.copy_(o.data * (1.0 - self.tau) + n.data * self.tau)
|
||||||
|
|
||||||
|
def forward( # type: ignore
|
||||||
|
self,
|
||||||
|
batch: Batch,
|
||||||
|
state: Optional[Union[dict, Batch, np.ndarray]] = None,
|
||||||
|
input: str = "obs",
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Batch:
|
||||||
|
obs = batch[input]
|
||||||
|
logits, h = self.actor(obs, state=state, info=batch.info)
|
||||||
|
assert isinstance(logits, tuple)
|
||||||
|
dist = Independent(Normal(*logits), 1)
|
||||||
|
if self._deterministic_eval and not self.training:
|
||||||
|
act = logits[0]
|
||||||
|
else:
|
||||||
|
act = dist.rsample()
|
||||||
|
log_prob = dist.log_prob(act).unsqueeze(-1)
|
||||||
|
# apply correction for Tanh squashing when computing logprob from Gaussian
|
||||||
|
# You can check out the original SAC paper (arXiv 1801.01290): Eq 21.
|
||||||
|
# in appendix C to get some understanding of this equation.
|
||||||
|
squashed_action = torch.tanh(act)
|
||||||
|
log_prob = log_prob - torch.log((1 - squashed_action.pow(2)) +
|
||||||
|
self.__eps).sum(-1, keepdim=True)
|
||||||
|
return Batch(
|
||||||
|
logits=logits, act=squashed_action, state=h, dist=dist, log_prob=log_prob
|
||||||
|
)
|
||||||
|
|
||||||
|
def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor:
|
||||||
|
batch = buffer[indices] # batch.obs: s_{t+n}
|
||||||
|
obs_next_result = self(batch, input="obs_next")
|
||||||
|
a_ = obs_next_result.act
|
||||||
|
sample_ensemble_idx = np.random.choice(
|
||||||
|
self.ensemble_size, self.subset_size, replace=False
|
||||||
|
)
|
||||||
|
qs = self.critics_old(batch.obs_next, a_)[sample_ensemble_idx, ...]
|
||||||
|
if self.target_mode == "min":
|
||||||
|
target_q, _ = torch.min(qs, dim=0)
|
||||||
|
elif self.target_mode == "mean":
|
||||||
|
target_q = torch.mean(qs, dim=0)
|
||||||
|
target_q -= self._alpha * obs_next_result.log_prob
|
||||||
|
|
||||||
|
return target_q
|
||||||
|
|
||||||
|
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
|
||||||
|
# critic ensemble
|
||||||
|
weight = getattr(batch, "weight", 1.0)
|
||||||
|
current_qs = self.critics(batch.obs, batch.act).flatten(1)
|
||||||
|
target_q = batch.returns.flatten()
|
||||||
|
td = current_qs - target_q
|
||||||
|
critic_loss = (td.pow(2) * weight).mean()
|
||||||
|
self.critics_optim.zero_grad()
|
||||||
|
critic_loss.backward()
|
||||||
|
self.critics_optim.step()
|
||||||
|
batch.weight = torch.mean(td, dim=0) # prio-buffer
|
||||||
|
self.critic_gradient_step += 1
|
||||||
|
|
||||||
|
# actor
|
||||||
|
if self.critic_gradient_step % self.actor_delay == 0:
|
||||||
|
obs_result = self(batch)
|
||||||
|
a = obs_result.act
|
||||||
|
current_qa = self.critics(batch.obs, a).mean(dim=0).flatten()
|
||||||
|
actor_loss = (self._alpha * obs_result.log_prob.flatten() -
|
||||||
|
current_qa).mean()
|
||||||
|
self.actor_optim.zero_grad()
|
||||||
|
actor_loss.backward()
|
||||||
|
self.actor_optim.step()
|
||||||
|
|
||||||
|
if self._is_auto_alpha:
|
||||||
|
log_prob = obs_result.log_prob.detach() + self._target_entropy
|
||||||
|
alpha_loss = -(self._log_alpha * log_prob).mean()
|
||||||
|
self._alpha_optim.zero_grad()
|
||||||
|
alpha_loss.backward()
|
||||||
|
self._alpha_optim.step()
|
||||||
|
self._alpha = self._log_alpha.detach().exp()
|
||||||
|
|
||||||
|
self.sync_weight()
|
||||||
|
|
||||||
|
result = {"loss/critics": critic_loss.item()}
|
||||||
|
if self.critic_gradient_step % self.actor_delay == 0:
|
||||||
|
result["loss/actor"] = actor_loss.item(),
|
||||||
|
if self._is_auto_alpha:
|
||||||
|
result["loss/alpha"] = alpha_loss.item()
|
||||||
|
result["alpha"] = self._alpha.item() # type: ignore
|
||||||
|
|
||||||
|
return result
|
@ -1,4 +1,14 @@
|
|||||||
from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, Union
|
from typing import (
|
||||||
|
Any,
|
||||||
|
Dict,
|
||||||
|
List,
|
||||||
|
Optional,
|
||||||
|
Sequence,
|
||||||
|
Tuple,
|
||||||
|
Type,
|
||||||
|
Union,
|
||||||
|
no_type_check,
|
||||||
|
)
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -46,6 +56,7 @@ class MLP(nn.Module):
|
|||||||
nn.ReLU.
|
nn.ReLU.
|
||||||
:param device: which device to create this model on. Default to None.
|
:param device: which device to create this model on. Default to None.
|
||||||
:param linear_layer: use this module as linear layer. Default to nn.Linear.
|
:param linear_layer: use this module as linear layer. Default to nn.Linear.
|
||||||
|
:param bool flatten_input: whether to flatten input data. Default to True.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -57,6 +68,7 @@ class MLP(nn.Module):
|
|||||||
activation: Optional[Union[ModuleType, Sequence[ModuleType]]] = nn.ReLU,
|
activation: Optional[Union[ModuleType, Sequence[ModuleType]]] = nn.ReLU,
|
||||||
device: Optional[Union[str, int, torch.device]] = None,
|
device: Optional[Union[str, int, torch.device]] = None,
|
||||||
linear_layer: Type[nn.Linear] = nn.Linear,
|
linear_layer: Type[nn.Linear] = nn.Linear,
|
||||||
|
flatten_input: bool = True,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.device = device
|
self.device = device
|
||||||
@ -86,15 +98,15 @@ class MLP(nn.Module):
|
|||||||
model += [linear_layer(hidden_sizes[-1], output_dim)]
|
model += [linear_layer(hidden_sizes[-1], output_dim)]
|
||||||
self.output_dim = output_dim or hidden_sizes[-1]
|
self.output_dim = output_dim or hidden_sizes[-1]
|
||||||
self.model = nn.Sequential(*model)
|
self.model = nn.Sequential(*model)
|
||||||
|
self.flatten_input = flatten_input
|
||||||
|
|
||||||
|
@no_type_check
|
||||||
def forward(self, obs: Union[np.ndarray, torch.Tensor]) -> torch.Tensor:
|
def forward(self, obs: Union[np.ndarray, torch.Tensor]) -> torch.Tensor:
|
||||||
if self.device is not None:
|
if self.device is not None:
|
||||||
obs = torch.as_tensor(
|
obs = torch.as_tensor(obs, device=self.device, dtype=torch.float32)
|
||||||
obs,
|
if self.flatten_input:
|
||||||
device=self.device, # type: ignore
|
obs = obs.flatten(1)
|
||||||
dtype=torch.float32,
|
return self.model(obs)
|
||||||
)
|
|
||||||
return self.model(obs.flatten(1)) # type: ignore
|
|
||||||
|
|
||||||
|
|
||||||
class Net(nn.Module):
|
class Net(nn.Module):
|
||||||
@ -129,6 +141,7 @@ class Net(nn.Module):
|
|||||||
pass a tuple of two dict (first for Q and second for V) stating
|
pass a tuple of two dict (first for Q and second for V) stating
|
||||||
self-defined arguments as stated in
|
self-defined arguments as stated in
|
||||||
class:`~tianshou.utils.net.common.MLP`. Default to None.
|
class:`~tianshou.utils.net.common.MLP`. Default to None.
|
||||||
|
:param linear_layer: use this module as linear layer. Default to nn.Linear.
|
||||||
|
|
||||||
.. seealso::
|
.. seealso::
|
||||||
|
|
||||||
@ -152,6 +165,7 @@ class Net(nn.Module):
|
|||||||
concat: bool = False,
|
concat: bool = False,
|
||||||
num_atoms: int = 1,
|
num_atoms: int = 1,
|
||||||
dueling_param: Optional[Tuple[Dict[str, Any], Dict[str, Any]]] = None,
|
dueling_param: Optional[Tuple[Dict[str, Any], Dict[str, Any]]] = None,
|
||||||
|
linear_layer: Type[nn.Linear] = nn.Linear,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.device = device
|
self.device = device
|
||||||
@ -164,7 +178,8 @@ class Net(nn.Module):
|
|||||||
self.use_dueling = dueling_param is not None
|
self.use_dueling = dueling_param is not None
|
||||||
output_dim = action_dim if not self.use_dueling and not concat else 0
|
output_dim = action_dim if not self.use_dueling and not concat else 0
|
||||||
self.model = MLP(
|
self.model = MLP(
|
||||||
input_dim, output_dim, hidden_sizes, norm_layer, activation, device
|
input_dim, output_dim, hidden_sizes, norm_layer, activation, device,
|
||||||
|
linear_layer
|
||||||
)
|
)
|
||||||
self.output_dim = self.model.output_dim
|
self.output_dim = self.model.output_dim
|
||||||
if self.use_dueling: # dueling DQN
|
if self.use_dueling: # dueling DQN
|
||||||
@ -311,3 +326,40 @@ class DataParallelNet(nn.Module):
|
|||||||
if not isinstance(obs, torch.Tensor):
|
if not isinstance(obs, torch.Tensor):
|
||||||
obs = torch.as_tensor(obs, dtype=torch.float32)
|
obs = torch.as_tensor(obs, dtype=torch.float32)
|
||||||
return self.net(obs=obs.cuda(), *args, **kwargs)
|
return self.net(obs=obs.cuda(), *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class EnsembleLinear(nn.Module):
|
||||||
|
"""Linear Layer of Ensemble network.
|
||||||
|
|
||||||
|
:param int ensemble_size: Number of subnets in the ensemble.
|
||||||
|
:param int inp_feature: dimension of the input vector.
|
||||||
|
:param int out_feature: dimension of the output vector.
|
||||||
|
:param bool bias: whether to include an additive bias, default to be True.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
ensemble_size: int,
|
||||||
|
in_feature: int,
|
||||||
|
out_feature: int,
|
||||||
|
bias: bool = True,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# To be consistent with PyTorch default initializer
|
||||||
|
k = np.sqrt(1. / in_feature)
|
||||||
|
weight_data = torch.rand((ensemble_size, in_feature, out_feature)) * 2 * k - k
|
||||||
|
self.weight = nn.Parameter(weight_data, requires_grad=True)
|
||||||
|
|
||||||
|
self.bias: Union[nn.Parameter, None]
|
||||||
|
if bias:
|
||||||
|
bias_data = torch.rand((ensemble_size, 1, out_feature)) * 2 * k - k
|
||||||
|
self.bias = nn.Parameter(bias_data, requires_grad=True)
|
||||||
|
else:
|
||||||
|
self.bias = None
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
x = torch.matmul(x, self.weight)
|
||||||
|
if self.bias is not None:
|
||||||
|
x = x + self.bias
|
||||||
|
return x
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from typing import Any, Dict, Optional, Sequence, Tuple, Union
|
from typing import Any, Dict, Optional, Sequence, Tuple, Type, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -79,6 +79,9 @@ class Critic(nn.Module):
|
|||||||
only a single linear layer).
|
only a single linear layer).
|
||||||
:param int preprocess_net_output_dim: the output dimension of
|
:param int preprocess_net_output_dim: the output dimension of
|
||||||
preprocess_net.
|
preprocess_net.
|
||||||
|
:param linear_layer: use this module as linear layer. Default to nn.Linear.
|
||||||
|
:param bool flatten_input: whether to flatten input data for the last layer.
|
||||||
|
Default to True.
|
||||||
|
|
||||||
For advanced usage (how to customize the network), please refer to
|
For advanced usage (how to customize the network), please refer to
|
||||||
:ref:`build_the_network`.
|
:ref:`build_the_network`.
|
||||||
@ -95,6 +98,8 @@ class Critic(nn.Module):
|
|||||||
hidden_sizes: Sequence[int] = (),
|
hidden_sizes: Sequence[int] = (),
|
||||||
device: Union[str, int, torch.device] = "cpu",
|
device: Union[str, int, torch.device] = "cpu",
|
||||||
preprocess_net_output_dim: Optional[int] = None,
|
preprocess_net_output_dim: Optional[int] = None,
|
||||||
|
linear_layer: Type[nn.Linear] = nn.Linear,
|
||||||
|
flatten_input: bool = True,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.device = device
|
self.device = device
|
||||||
@ -105,7 +110,9 @@ class Critic(nn.Module):
|
|||||||
input_dim, # type: ignore
|
input_dim, # type: ignore
|
||||||
1,
|
1,
|
||||||
hidden_sizes,
|
hidden_sizes,
|
||||||
device=self.device
|
device=self.device,
|
||||||
|
linear_layer=linear_layer,
|
||||||
|
flatten_input=flatten_input,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user