implement sac for discrete action settings (#216)
Co-authored-by: n+e <trinkle23897@cmu.edu>
This commit is contained in:
parent
b284ace102
commit
a6ee979609
1
.gitignore
vendored
1
.gitignore
vendored
@ -137,7 +137,6 @@ dmypy.json
|
||||
.pytype/
|
||||
|
||||
# customize
|
||||
flake8.sh
|
||||
log/
|
||||
MUJOCO_LOG.TXT
|
||||
*.pth
|
||||
|
@ -26,7 +26,7 @@ def get_args():
|
||||
parser.add_argument('--noise_std', type=float, default=1.2)
|
||||
parser.add_argument('--gamma', type=float, default=0.99)
|
||||
parser.add_argument('--tau', type=float, default=0.005)
|
||||
parser.add_argument('--auto_alpha', type=bool, default=True)
|
||||
parser.add_argument('--auto_alpha', type=int, default=1)
|
||||
parser.add_argument('--alpha', type=float, default=0.2)
|
||||
parser.add_argument('--epoch', type=int, default=20)
|
||||
parser.add_argument('--step-per-epoch', type=int, default=2400)
|
||||
|
117
test/discrete/test_sac.py
Normal file
117
test/discrete/test_sac.py
Normal file
@ -0,0 +1,117 @@
|
||||
import os
|
||||
import gym
|
||||
import torch
|
||||
import pprint
|
||||
import argparse
|
||||
import numpy as np
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from tianshou.env import SubprocVectorEnv
|
||||
from tianshou.utils.net.common import Net
|
||||
from tianshou.trainer import offpolicy_trainer
|
||||
from tianshou.data import Collector, ReplayBuffer
|
||||
from tianshou.policy import DiscreteSACPolicy
|
||||
from tianshou.utils.net.discrete import Actor, Critic
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--task', type=str, default='CartPole-v0')
|
||||
parser.add_argument('--seed', type=int, default=1626)
|
||||
parser.add_argument('--buffer-size', type=int, default=20000)
|
||||
parser.add_argument('--actor-lr', type=float, default=3e-4)
|
||||
parser.add_argument('--critic-lr', type=float, default=1e-3)
|
||||
parser.add_argument('--alpha-lr', type=float, default=3e-4)
|
||||
parser.add_argument('--gamma', type=float, default=0.95)
|
||||
parser.add_argument('--tau', type=float, default=0.005)
|
||||
parser.add_argument('--alpha', type=float, default=0.05)
|
||||
parser.add_argument('--auto_alpha', type=int, default=0)
|
||||
parser.add_argument('--epoch', type=int, default=5)
|
||||
parser.add_argument('--step-per-epoch', type=int, default=1000)
|
||||
parser.add_argument('--collect-per-step', type=int, default=5)
|
||||
parser.add_argument('--batch-size', type=int, default=64)
|
||||
parser.add_argument('--layer-num', type=int, default=1)
|
||||
parser.add_argument('--training-num', type=int, default=16)
|
||||
parser.add_argument('--test-num', type=int, default=100)
|
||||
parser.add_argument('--logdir', type=str, default='log')
|
||||
parser.add_argument('--render', type=float, default=0.0)
|
||||
parser.add_argument('--rew-norm', type=int, default=0)
|
||||
parser.add_argument('--ignore-done', type=int, default=0)
|
||||
parser.add_argument(
|
||||
'--device', type=str,
|
||||
default='cuda' if torch.cuda.is_available() else 'cpu')
|
||||
args = parser.parse_known_args()[0]
|
||||
return args
|
||||
|
||||
|
||||
def test_discrete_sac(args=get_args()):
|
||||
env = gym.make(args.task)
|
||||
args.state_shape = env.observation_space.shape or env.observation_space.n
|
||||
args.action_shape = env.action_space.shape or env.action_space.n
|
||||
|
||||
train_envs = SubprocVectorEnv(
|
||||
[lambda: gym.make(args.task) for _ in range(args.training_num)])
|
||||
test_envs = SubprocVectorEnv(
|
||||
[lambda: gym.make(args.task) for _ in range(args.test_num)])
|
||||
# seed
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
train_envs.seed(args.seed)
|
||||
test_envs.seed(args.seed)
|
||||
# model
|
||||
net = Net(args.layer_num, args.state_shape, device=args.device)
|
||||
actor = Actor(net, args.action_shape, softmax_output=False).to(args.device)
|
||||
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
|
||||
net = Net(args.layer_num, args.state_shape, device=args.device)
|
||||
critic1 = Critic(net, last_size=args.action_shape).to(args.device)
|
||||
critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr)
|
||||
net = Net(args.layer_num, args.state_shape, device=args.device)
|
||||
critic2 = Critic(net, last_size=args.action_shape).to(args.device)
|
||||
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)
|
||||
|
||||
# better not to use auto alpha in CartPole
|
||||
if args.auto_alpha:
|
||||
target_entropy = 0.98 * np.log(np.prod(args.action_shape))
|
||||
log_alpha = torch.zeros(1, requires_grad=True, device=args.device)
|
||||
alpha_optim = torch.optim.Adam([log_alpha], lr=args.alpha_lr)
|
||||
args.alpha = (target_entropy, log_alpha, alpha_optim)
|
||||
|
||||
policy = DiscreteSACPolicy(
|
||||
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
|
||||
args.tau, args.gamma, args.alpha,
|
||||
reward_normalization=args.rew_norm,
|
||||
ignore_done=args.ignore_done)
|
||||
# collector
|
||||
train_collector = Collector(
|
||||
policy, train_envs, ReplayBuffer(args.buffer_size))
|
||||
test_collector = Collector(policy, test_envs)
|
||||
# train_collector.collect(n_step=args.buffer_size)
|
||||
# log
|
||||
log_path = os.path.join(args.logdir, args.task, 'discrete_sac')
|
||||
writer = SummaryWriter(log_path)
|
||||
|
||||
def save_fn(policy):
|
||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||
|
||||
def stop_fn(x):
|
||||
return x >= env.spec.reward_threshold
|
||||
|
||||
# trainer
|
||||
result = offpolicy_trainer(
|
||||
policy, train_collector, test_collector, args.epoch,
|
||||
args.step_per_epoch, args.collect_per_step, args.test_num,
|
||||
args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer,
|
||||
test_in_train=False)
|
||||
assert stop_fn(result['best_reward'])
|
||||
if __name__ == '__main__':
|
||||
pprint.pprint(result)
|
||||
# Let's watch its performance!
|
||||
env = gym.make(args.task)
|
||||
policy.eval()
|
||||
collector = Collector(policy, env)
|
||||
result = collector.collect(n_episode=1, render=args.render)
|
||||
print(f'Final reward: {result["rew"]}, length: {result["len"]}')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_discrete_sac()
|
@ -8,6 +8,7 @@ from tianshou.policy.modelfree.ddpg import DDPGPolicy
|
||||
from tianshou.policy.modelfree.ppo import PPOPolicy
|
||||
from tianshou.policy.modelfree.td3 import TD3Policy
|
||||
from tianshou.policy.modelfree.sac import SACPolicy
|
||||
from tianshou.policy.modelfree.discrete_sac import DiscreteSACPolicy
|
||||
from tianshou.policy.multiagent.mapolicy import MultiAgentPolicyManager
|
||||
|
||||
|
||||
@ -22,5 +23,6 @@ __all__ = [
|
||||
"PPOPolicy",
|
||||
"TD3Policy",
|
||||
"SACPolicy",
|
||||
"DiscreteSACPolicy",
|
||||
"MultiAgentPolicyManager",
|
||||
]
|
||||
|
@ -140,7 +140,7 @@ class DDPGPolicy(BasePolicy):
|
||||
more detailed explanation.
|
||||
"""
|
||||
model = getattr(self, model)
|
||||
obs = getattr(batch, input)
|
||||
obs = batch[input]
|
||||
actions, h = model(obs, state=state, info=batch.info)
|
||||
actions += self._action_bias
|
||||
if self._noise and self.training and explorating:
|
||||
|
148
tianshou/policy/modelfree/discrete_sac.py
Normal file
148
tianshou/policy/modelfree/discrete_sac.py
Normal file
@ -0,0 +1,148 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from torch.distributions import Categorical
|
||||
from typing import Any, Dict, Tuple, Union, Optional
|
||||
|
||||
from tianshou.policy import SACPolicy
|
||||
from tianshou.data import Batch, ReplayBuffer, to_torch
|
||||
|
||||
|
||||
class DiscreteSACPolicy(SACPolicy):
|
||||
"""Implementation of SAC for Discrete Action Settings. arXiv:1910.07207.
|
||||
|
||||
:param torch.nn.Module actor: the actor network following the rules in
|
||||
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
|
||||
:param torch.optim.Optimizer actor_optim: the optimizer for actor network.
|
||||
:param torch.nn.Module critic1: the first critic network. (s -> Q(s))
|
||||
:param torch.optim.Optimizer critic1_optim: the optimizer for the first
|
||||
critic network.
|
||||
:param torch.nn.Module critic2: the second critic network. (s -> Q(s))
|
||||
:param torch.optim.Optimizer critic2_optim: the optimizer for the second
|
||||
critic network.
|
||||
:param float tau: param for soft update of the target network, defaults to
|
||||
0.005.
|
||||
:param float gamma: discount factor, in [0, 1], defaults to 0.99.
|
||||
:param (float, torch.Tensor, torch.optim.Optimizer) or float alpha: entropy
|
||||
regularization coefficient, default to 0.2.
|
||||
If a tuple (target_entropy, log_alpha, alpha_optim) is provided, then
|
||||
alpha is automatatically tuned.
|
||||
:param bool reward_normalization: normalize the reward to Normal(0, 1),
|
||||
defaults to ``False``.
|
||||
:param bool ignore_done: ignore the done flag while training the policy,
|
||||
defaults to ``False``.
|
||||
|
||||
.. seealso::
|
||||
|
||||
Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed
|
||||
explanation.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
actor: torch.nn.Module,
|
||||
actor_optim: torch.optim.Optimizer,
|
||||
critic1: torch.nn.Module,
|
||||
critic1_optim: torch.optim.Optimizer,
|
||||
critic2: torch.nn.Module,
|
||||
critic2_optim: torch.optim.Optimizer,
|
||||
tau: float = 0.005,
|
||||
gamma: float = 0.99,
|
||||
alpha: Union[
|
||||
float, Tuple[float, torch.Tensor, torch.optim.Optimizer]
|
||||
] = 0.2,
|
||||
reward_normalization: bool = False,
|
||||
ignore_done: bool = False,
|
||||
estimation_step: int = 1,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(actor, actor_optim, critic1, critic1_optim, critic2,
|
||||
critic2_optim, (-np.inf, np.inf), tau, gamma, alpha,
|
||||
reward_normalization, ignore_done, estimation_step,
|
||||
**kwargs)
|
||||
self._alpha: Union[float, torch.Tensor]
|
||||
|
||||
def forward( # type: ignore
|
||||
self,
|
||||
batch: Batch,
|
||||
state: Optional[Union[dict, Batch, np.ndarray]] = None,
|
||||
input: str = "obs",
|
||||
**kwargs: Any,
|
||||
) -> Batch:
|
||||
obs = batch[input]
|
||||
logits, h = self.actor(obs, state=state, info=batch.info)
|
||||
dist = Categorical(logits=logits)
|
||||
act = dist.sample()
|
||||
return Batch(logits=logits, act=act, state=h, dist=dist)
|
||||
|
||||
def _target_q(
|
||||
self, buffer: ReplayBuffer, indice: np.ndarray
|
||||
) -> torch.Tensor:
|
||||
batch = buffer[indice] # batch.obs: s_{t+n}
|
||||
with torch.no_grad():
|
||||
obs_next_result = self(batch, input="obs_next")
|
||||
dist = obs_next_result.dist
|
||||
target_q = dist.probs * torch.min(
|
||||
self.critic1_old(batch.obs_next),
|
||||
self.critic2_old(batch.obs_next),
|
||||
)
|
||||
target_q = target_q.sum(dim=-1) + self._alpha * dist.entropy()
|
||||
return target_q
|
||||
|
||||
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
|
||||
weight = batch.pop("weight", 1.0)
|
||||
target_q = batch.returns.flatten()
|
||||
act = to_torch(
|
||||
batch.act[:, np.newaxis], device=target_q.device, dtype=torch.long)
|
||||
|
||||
# critic 1
|
||||
current_q1 = self.critic1(batch.obs).gather(1, act).flatten()
|
||||
td1 = current_q1 - target_q
|
||||
critic1_loss = (td1.pow(2) * weight).mean()
|
||||
|
||||
self.critic1_optim.zero_grad()
|
||||
critic1_loss.backward()
|
||||
self.critic1_optim.step()
|
||||
|
||||
# critic 2
|
||||
current_q2 = self.critic2(batch.obs).gather(1, act).flatten()
|
||||
td2 = current_q2 - target_q
|
||||
critic2_loss = (td2.pow(2) * weight).mean()
|
||||
|
||||
self.critic2_optim.zero_grad()
|
||||
critic2_loss.backward()
|
||||
self.critic2_optim.step()
|
||||
batch.weight = (td1 + td2) / 2.0 # prio-buffer
|
||||
|
||||
# actor
|
||||
dist = self(batch).dist
|
||||
entropy = dist.entropy()
|
||||
with torch.no_grad():
|
||||
current_q1a = self.critic1(batch.obs)
|
||||
current_q2a = self.critic2(batch.obs)
|
||||
q = torch.min(current_q1a, current_q2a)
|
||||
actor_loss = -(self._alpha * entropy
|
||||
+ (dist.probs * q).sum(dim=-1)).mean()
|
||||
self.actor_optim.zero_grad()
|
||||
actor_loss.backward()
|
||||
self.actor_optim.step()
|
||||
|
||||
if self._is_auto_alpha:
|
||||
log_prob = entropy.detach() - self._target_entropy
|
||||
alpha_loss = (self._log_alpha * log_prob).mean()
|
||||
self._alpha_optim.zero_grad()
|
||||
alpha_loss.backward()
|
||||
self._alpha_optim.step()
|
||||
self._alpha = self._log_alpha.detach().exp()
|
||||
|
||||
self.sync_weight()
|
||||
|
||||
result = {
|
||||
"loss/actor": actor_loss.item(),
|
||||
"loss/critic1": critic1_loss.item(),
|
||||
"loss/critic2": critic2_loss.item(),
|
||||
}
|
||||
if self._is_auto_alpha:
|
||||
result["loss/alpha"] = alpha_loss.item()
|
||||
result["alpha"] = self._alpha.item() # type: ignore
|
||||
|
||||
return result
|
@ -143,7 +143,7 @@ class DQNPolicy(BasePolicy):
|
||||
more detailed explanation.
|
||||
"""
|
||||
model = getattr(self, model)
|
||||
obs = getattr(batch, input)
|
||||
obs = batch[input]
|
||||
obs_ = obs.obs if hasattr(obs, "obs") else obs
|
||||
q, h = model(obs_, state=state, info=batch.info)
|
||||
act: np.ndarray = to_numpy(q.max(dim=1)[1])
|
||||
|
@ -28,8 +28,6 @@ class SACPolicy(DDPGPolicy):
|
||||
:param float tau: param for soft update of the target network, defaults to
|
||||
0.005.
|
||||
:param float gamma: discount factor, in [0, 1], defaults to 0.99.
|
||||
:param float exploration_noise: the noise intensity, add to the action,
|
||||
defaults to 0.1.
|
||||
:param (float, torch.Tensor, torch.optim.Optimizer) or float alpha: entropy
|
||||
regularization coefficient, default to 0.2.
|
||||
If a tuple (target_entropy, log_alpha, alpha_optim) is provided, then
|
||||
@ -38,8 +36,8 @@ class SACPolicy(DDPGPolicy):
|
||||
defaults to False.
|
||||
:param bool ignore_done: ignore the done flag while training the policy,
|
||||
defaults to False.
|
||||
:param BaseNoise exploration_noise: add a noise to action for exploration.
|
||||
This is useful when solving hard-exploration problem.
|
||||
:param BaseNoise exploration_noise: add a noise to action for exploration,
|
||||
defaults to None. This is useful when solving hard-exploration problem.
|
||||
|
||||
.. seealso::
|
||||
|
||||
@ -115,7 +113,7 @@ class SACPolicy(DDPGPolicy):
|
||||
explorating: bool = True,
|
||||
**kwargs: Any,
|
||||
) -> Batch:
|
||||
obs = getattr(batch, input)
|
||||
obs = batch[input]
|
||||
logits, h = self.actor(obs, state=state, info=batch.info)
|
||||
assert isinstance(logits, tuple)
|
||||
dist = Independent(Normal(*logits), 1)
|
||||
@ -147,6 +145,7 @@ class SACPolicy(DDPGPolicy):
|
||||
|
||||
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
|
||||
weight = batch.pop("weight", 1.0)
|
||||
|
||||
# critic 1
|
||||
current_q1 = self.critic1(batch.obs, batch.act).flatten()
|
||||
target_q = batch.returns.flatten()
|
||||
@ -156,6 +155,7 @@ class SACPolicy(DDPGPolicy):
|
||||
self.critic1_optim.zero_grad()
|
||||
critic1_loss.backward()
|
||||
self.critic1_optim.step()
|
||||
|
||||
# critic 2
|
||||
current_q2 = self.critic2(batch.obs, batch.act).flatten()
|
||||
td2 = current_q2 - target_q
|
||||
@ -165,6 +165,7 @@ class SACPolicy(DDPGPolicy):
|
||||
critic2_loss.backward()
|
||||
self.critic2_optim.step()
|
||||
batch.weight = (td1 + td2) / 2.0 # prio-buffer
|
||||
|
||||
# actor
|
||||
obs_result = self(batch, explorating=False)
|
||||
a = obs_result.act
|
||||
@ -193,5 +194,6 @@ class SACPolicy(DDPGPolicy):
|
||||
}
|
||||
if self._is_auto_alpha:
|
||||
result["loss/alpha"] = alpha_loss.item()
|
||||
result["v/alpha"] = self._alpha.item() # type: ignore
|
||||
result["alpha"] = self._alpha.item() # type: ignore
|
||||
|
||||
return result
|
||||
|
@ -19,10 +19,12 @@ class Actor(nn.Module):
|
||||
preprocess_net: nn.Module,
|
||||
action_shape: Sequence[int],
|
||||
hidden_layer_size: int = 128,
|
||||
softmax_output: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.preprocess = preprocess_net
|
||||
self.last = nn.Linear(hidden_layer_size, np.prod(action_shape))
|
||||
self.softmax_output = softmax_output
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -32,7 +34,9 @@ class Actor(nn.Module):
|
||||
) -> Tuple[torch.Tensor, Any]:
|
||||
r"""Mapping: s -> Q(s, \*)."""
|
||||
logits, h = self.preprocess(s, state)
|
||||
logits = F.softmax(self.last(logits), dim=-1)
|
||||
logits = self.last(logits)
|
||||
if self.softmax_output:
|
||||
logits = F.softmax(logits, dim=-1)
|
||||
return logits, h
|
||||
|
||||
|
||||
@ -44,11 +48,14 @@ class Critic(nn.Module):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, preprocess_net: nn.Module, hidden_layer_size: int = 128
|
||||
self,
|
||||
preprocess_net: nn.Module,
|
||||
hidden_layer_size: int = 128,
|
||||
last_size: int = 1
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.preprocess = preprocess_net
|
||||
self.last = nn.Linear(hidden_layer_size, 1)
|
||||
self.last = nn.Linear(hidden_layer_size, last_size)
|
||||
|
||||
def forward(
|
||||
self, s: Union[np.ndarray, torch.Tensor], **kwargs: Any
|
||||
|
Loading…
x
Reference in New Issue
Block a user