sac mujoco result (#246)

This commit is contained in:
Trinkle23897 2020-11-09 16:43:55 +08:00
parent c97aa4065e
commit cd481423dc
25 changed files with 199 additions and 93 deletions

Binary file not shown.

Before

Width:  |  Height:  |  Size: 183 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 30 KiB

After

Width:  |  Height:  |  Size: 29 KiB

View File

@ -58,8 +58,8 @@ def test_dqn(args=get_args()):
args.state_shape = env.observation_space.shape or env.observation_space.n
args.action_shape = env.env.action_space.shape or env.env.action_space.n
# should be N_FRAMES x H x W
print("Observations shape: ", args.state_shape)
print("Actions shape: ", args.action_shape)
print("Observations shape:", args.state_shape)
print("Actions shape:", args.action_shape)
# make environments
train_envs = SubprocVectorEnv([lambda: make_atari_env(args)
for _ in range(args.training_num)])
@ -79,7 +79,9 @@ def test_dqn(args=get_args()):
target_update_freq=args.target_update_freq)
# load a previous policy
if args.resume_path:
policy.load_state_dict(torch.load(args.resume_path))
policy.load_state_dict(torch.load(
args.resume_path, map_location=args.device
))
print("Loaded agent from: ", args.resume_path)
# replay buffer: `save_last_obs` and `stack_num` can be removed together
# when you have enough RAM

View File

@ -1,6 +1,6 @@
# Bipedal-Hardcore-SAC
- Our default choice: remove the done flag penalty, will soon converge to \~270 reward within 100 epochs (10M env steps, 3~4 hours, see the image below)
- Our default choice: remove the done flag penalty, will soon converge to \~280 reward within 100 epochs (10M env steps, 3~4 hours, see the image below)
- If the done penalty is not removed, it converges much slower than before, about 200 epochs (20M env steps) to reach the same performance (\~200 reward)
![](results/sac/BipedalHardcore.png)

View File

@ -6,11 +6,11 @@ import argparse
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from tianshou.policy import SACPolicy
from tianshou.utils.net.common import Net
from tianshou.env import SubprocVectorEnv
from tianshou.trainer import offpolicy_trainer
from tianshou.data import Collector, ReplayBuffer
from tianshou.policy import SACPolicy
from tianshou.utils.net.common import Net
from tianshou.utils.net.continuous import ActorProb, Critic
@ -24,8 +24,8 @@ def get_args():
parser.add_argument('--gamma', type=float, default=0.99)
parser.add_argument('--tau', type=float, default=0.005)
parser.add_argument('--alpha', type=float, default=0.1)
parser.add_argument('--auto_alpha', type=int, default=1)
parser.add_argument('--alpha_lr', type=float, default=3e-4)
parser.add_argument('--auto-alpha', type=int, default=1)
parser.add_argument('--alpha-lr', type=float, default=3e-4)
parser.add_argument('--epoch', type=int, default=100)
parser.add_argument('--step-per-epoch', type=int, default=10000)
parser.add_argument('--collect-per-step', type=int, default=10)
@ -35,54 +35,50 @@ def get_args():
parser.add_argument('--test-num', type=int, default=100)
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.)
parser.add_argument('--rew-norm', type=int, default=0)
parser.add_argument('--ignore-done', type=int, default=0)
parser.add_argument('--n-step', type=int, default=4)
parser.add_argument(
'--device', type=str,
default='cuda' if torch.cuda.is_available() else 'cpu')
parser.add_argument('--resume_path', type=str, default=None)
parser.add_argument('--resume-path', type=str, default=None)
return parser.parse_args()
class EnvWrapper(object):
"""Env wrapper for reward scale, action repeat and action noise"""
class Wrapper(gym.Wrapper):
"""Env wrapper for reward scale, action repeat and removing done penalty"""
def __init__(self, task, action_repeat=3, reward_scale=5, act_noise=0.0):
self._env = gym.make(task)
def __init__(self, env, action_repeat=3, reward_scale=5, rm_done=True):
super().__init__(env)
self.action_repeat = action_repeat
self.reward_scale = reward_scale
self.act_noise = act_noise
def __getattr__(self, name):
return getattr(self._env, name)
self.rm_done = rm_done
def step(self, action):
# add action noise
action += self.act_noise * (-2 * np.random.random(4) + 1)
r = 0.0
for _ in range(self.action_repeat):
obs_, reward_, done_, info_ = self._env.step(action)
obs, reward, done, info = self.env.step(action)
# remove done reward penalty
if done_:
if not done or not self.rm_done:
r = r + reward
if done:
break
r = r + reward_
# scale reward
return obs_, self.reward_scale * r, done_, info_
return obs, self.reward_scale * r, done, info
def test_sac_bipedal(args=get_args()):
env = EnvWrapper(args.task)
env = Wrapper(gym.make(args.task))
args.state_shape = env.observation_space.shape or env.observation_space.n
args.action_shape = env.action_space.shape or env.action_space.n
args.max_action = env.action_space.high[0]
train_envs = SubprocVectorEnv(
[lambda: EnvWrapper(args.task) for _ in range(args.training_num)])
train_envs = SubprocVectorEnv([
lambda: Wrapper(gym.make(args.task))
for _ in range(args.training_num)])
# test_envs = gym.make(args.task)
test_envs = SubprocVectorEnv([lambda: EnvWrapper(args.task, reward_scale=1)
for _ in range(args.test_num)])
test_envs = SubprocVectorEnv([
lambda: Wrapper(gym.make(args.task), reward_scale=1, rm_done=False)
for _ in range(args.test_num)])
# seed
np.random.seed(args.seed)
@ -117,8 +113,6 @@ def test_sac_bipedal(args=get_args()):
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
action_range=[env.action_space.low[0], env.action_space.high[0]],
tau=args.tau, gamma=args.gamma, alpha=args.alpha,
reward_normalization=args.rew_norm,
ignore_done=args.ignore_done,
estimation_step=args.n_step)
# load a previous policy
if args.resume_path:

View File

@ -67,11 +67,13 @@ def test_sac(args=get_args()):
args.max_action, args.device, unbounded=True
).to(args.device)
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
net = Net(args.layer_num, args.state_shape,
args.action_shape, concat=True, device=args.device)
critic1 = Critic(net, args.device).to(args.device)
net_c1 = Net(args.layer_num, args.state_shape,
args.action_shape, concat=True, device=args.device)
critic1 = Critic(net_c1, args.device).to(args.device)
critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr)
critic2 = Critic(net, args.device).to(args.device)
net_c2 = Net(args.layer_num, args.state_shape,
args.action_shape, concat=True, device=args.device)
critic2 = Critic(net_c2, args.device).to(args.device)
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)
if args.auto_alpha:

Binary file not shown.

Before

Width:  |  Height:  |  Size: 46 KiB

After

Width:  |  Height:  |  Size: 42 KiB

View File

@ -1,3 +1,27 @@
Result of Ant-v2:
# Mujoco Result
## SAC (single run)
The best reward computes from 100 episodes returns in the test phase.
SAC on Swimmer-v3 always stops at 47\~48.
| task | 3M best reward | parameters | time cost (3M) |
| -------------- | ----------------- | ------------------------------------------------------- | -------------- |
| HalfCheetah-v3 | 10157.70 ± 171.70 | `python3 mujoco_sac.py --task HalfCheetah-v3` | 2~3h |
| Walker2d-v3 | 5143.04 ± 15.57 | `python3 mujoco_sac.py --task Walker2d-v3` | 2~3h |
| Hopper-v3 | 3604.19 ± 169.55 | `python3 mujoco_sac.py --task Hopper-v3` | 2~3h |
| Humanoid-v3 | 6579.20 ± 1470.57 | `python3 mujoco_sac.py --task Humanoid-v3 --alpha 0.05` | 2~3h |
| Ant-v3 | 6281.65 ± 686.28 | `python3 mujoco_sac.py --task Ant-v3` | 2~3h |
![](results/sac/all.png)
### Which parts are important?
0. DO NOT share the same network with two critic networks.
1. The sigma (of the Gaussian policy) MUST be conditioned on input.
2. The network size should not be less than 256.
3. The deterministic evaluation helps a lot :)
![](/docs/_static/images/Ant-v2.png)

View File

@ -16,27 +16,36 @@ from tianshou.utils.net.continuous import ActorProb, Critic
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='Ant-v2')
parser.add_argument('--task', type=str, default='Ant-v3')
parser.add_argument('--seed', type=int, default=1626)
parser.add_argument('--buffer-size', type=int, default=20000)
parser.add_argument('--buffer-size', type=int, default=1000000)
parser.add_argument('--actor-lr', type=float, default=3e-4)
parser.add_argument('--critic-lr', type=float, default=1e-3)
parser.add_argument('--critic-lr', type=float, default=3e-4)
parser.add_argument('--gamma', type=float, default=0.99)
parser.add_argument('--tau', type=float, default=0.005)
parser.add_argument('--alpha', type=float, default=0.2)
parser.add_argument('--auto-alpha', default=False, action='store_true')
parser.add_argument('--alpha-lr', type=float, default=3e-4)
parser.add_argument('--n-step', type=int, default=2)
parser.add_argument('--epoch', type=int, default=100)
parser.add_argument('--step-per-epoch', type=int, default=2400)
parser.add_argument('--collect-per-step', type=int, default=10)
parser.add_argument('--batch-size', type=int, default=128)
parser.add_argument('--step-per-epoch', type=int, default=10000)
parser.add_argument('--collect-per-step', type=int, default=4)
parser.add_argument('--update-per-step', type=int, default=1)
parser.add_argument('--pre-collect-step', type=int, default=10000)
parser.add_argument('--batch-size', type=int, default=256)
parser.add_argument('--hidden-layer-size', type=int, default=256)
parser.add_argument('--layer-num', type=int, default=1)
parser.add_argument('--training-num', type=int, default=8)
parser.add_argument('--training-num', type=int, default=16)
parser.add_argument('--test-num', type=int, default=100)
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.)
parser.add_argument('--rew-norm', type=bool, default=True)
parser.add_argument('--log-interval', type=int, default=1000)
parser.add_argument(
'--device', type=str,
default='cuda' if torch.cuda.is_available() else 'cpu')
parser.add_argument('--resume-path', type=str, default=None)
parser.add_argument('--watch', default=False, action='store_true',
help='watch the play of pre-trained policy only')
return parser.parse_args()
@ -45,6 +54,10 @@ def test_sac(args=get_args()):
args.state_shape = env.observation_space.shape or env.observation_space.n
args.action_shape = env.action_space.shape or env.action_space.n
args.max_action = env.action_space.high[0]
print("Observations shape:", args.state_shape)
print("Actions shape:", args.action_shape)
print("Action range:", np.min(env.action_space.low),
np.max(env.action_space.high))
# train_envs = gym.make(args.task)
train_envs = SubprocVectorEnv(
[lambda: gym.make(args.task) for _ in range(args.training_num)])
@ -57,53 +70,84 @@ def test_sac(args=get_args()):
train_envs.seed(args.seed)
test_envs.seed(args.seed)
# model
net = Net(args.layer_num, args.state_shape, device=args.device)
net = Net(args.layer_num, args.state_shape, device=args.device,
hidden_layer_size=args.hidden_layer_size)
actor = ActorProb(
net, args.action_shape,
args.max_action, args.device, unbounded=True
net, args.action_shape, args.max_action, args.device, unbounded=True,
hidden_layer_size=args.hidden_layer_size, conditioned_sigma=True,
).to(args.device)
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
net = Net(args.layer_num, args.state_shape,
args.action_shape, concat=True, device=args.device)
critic1 = Critic(net, args.device).to(args.device)
net_c1 = Net(args.layer_num, args.state_shape, args.action_shape,
concat=True, device=args.device,
hidden_layer_size=args.hidden_layer_size)
critic1 = Critic(
net_c1, args.device, hidden_layer_size=args.hidden_layer_size
).to(args.device)
critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr)
critic2 = Critic(net, args.device).to(args.device)
net_c2 = Net(args.layer_num, args.state_shape, args.action_shape,
concat=True, device=args.device,
hidden_layer_size=args.hidden_layer_size)
critic2 = Critic(
net_c2, args.device, hidden_layer_size=args.hidden_layer_size
).to(args.device)
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)
if args.auto_alpha:
target_entropy = -np.prod(env.action_space.shape)
log_alpha = torch.zeros(1, requires_grad=True, device=args.device)
alpha_optim = torch.optim.Adam([log_alpha], lr=args.alpha_lr)
args.alpha = (target_entropy, log_alpha, alpha_optim)
policy = SACPolicy(
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
action_range=[env.action_space.low[0], env.action_space.high[0]],
tau=args.tau, gamma=args.gamma, alpha=args.alpha,
reward_normalization=args.rew_norm, ignore_done=True)
estimation_step=args.n_step)
# load a previous policy
if args.resume_path:
policy.load_state_dict(torch.load(
args.resume_path, map_location=args.device
))
print("Loaded agent from: ", args.resume_path)
# collector
train_collector = Collector(
policy, train_envs, ReplayBuffer(args.buffer_size))
test_collector = Collector(policy, test_envs)
# train_collector.collect(n_step=args.buffer_size)
# log
log_path = os.path.join(args.logdir, args.task, 'sac')
writer = SummaryWriter(log_path)
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
def stop_fn(mean_rewards):
return mean_rewards >= env.spec.reward_threshold
# trainer
result = offpolicy_trainer(
policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.collect_per_step, args.test_num,
args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer)
assert stop_fn(result['best_reward'])
if __name__ == '__main__':
pprint.pprint(result)
# Let's watch its performance!
def watch():
# watch agent's performance
print("Testing agent ...")
policy.eval()
test_envs.seed(args.seed)
test_collector.reset()
result = test_collector.collect(n_episode=[1] * args.test_num,
render=args.render)
print(f'Final reward: {result["rew"]}, length: {result["len"]}')
pprint.pprint(result)
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
def stop_fn(mean_rewards):
return False
if args.watch:
watch()
exit(0)
# trainer
train_collector.collect(n_step=args.pre_collect_step, random=True)
result = offpolicy_trainer(
policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.collect_per_step, args.test_num,
args.batch_size, args.update_per_step,
stop_fn=stop_fn, save_fn=save_fn, writer=writer,
log_interval=args.log_interval)
pprint.pprint(result)
watch()
if __name__ == '__main__':

Binary file not shown.

After

Width:  |  Height:  |  Size: 126 KiB

View File

@ -71,6 +71,8 @@ def test_sac(args=get_args()):
args.action_shape, concat=True, device=args.device)
critic1 = Critic(net, args.device).to(args.device)
critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr)
net = Net(args.layer_num, args.state_shape,
args.action_shape, concat=True, device=args.device)
critic2 = Critic(net, args.device).to(args.device)
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)
policy = SACPolicy(

View File

@ -72,6 +72,8 @@ def test_td3(args=get_args()):
args.action_shape, concat=True, device=args.device)
critic1 = Critic(net, args.device).to(args.device)
critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr)
net = Net(args.layer_num, args.state_shape,
args.action_shape, concat=True, device=args.device)
critic2 = Critic(net, args.device).to(args.device)
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)
policy = TD3Policy(

View File

@ -70,11 +70,13 @@ def test_sac_with_il(args=get_args()):
net, args.action_shape, args.max_action, args.device, unbounded=True
).to(args.device)
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
net = Net(args.layer_num, args.state_shape,
args.action_shape, concat=True, device=args.device)
critic1 = Critic(net, args.device).to(args.device)
net_c1 = Net(args.layer_num, args.state_shape,
args.action_shape, concat=True, device=args.device)
critic1 = Critic(net_c1, args.device).to(args.device)
critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr)
critic2 = Critic(net, args.device).to(args.device)
net_c2 = Net(args.layer_num, args.state_shape,
args.action_shape, concat=True, device=args.device)
critic2 = Critic(net_c2, args.device).to(args.device)
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)
policy = SACPolicy(
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,

View File

@ -74,11 +74,13 @@ def test_td3(args=get_args()):
args.max_action, args.device
).to(args.device)
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
net = Net(args.layer_num, args.state_shape,
args.action_shape, concat=True, device=args.device)
critic1 = Critic(net, args.device).to(args.device)
net_c1 = Net(args.layer_num, args.state_shape,
args.action_shape, concat=True, device=args.device)
critic1 = Critic(net_c1, args.device).to(args.device)
critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr)
critic2 = Critic(net, args.device).to(args.device)
net_c2 = Net(args.layer_num, args.state_shape,
args.action_shape, concat=True, device=args.device)
critic2 = Critic(net_c2, args.device).to(args.device)
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)
policy = TD3Policy(
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,

View File

@ -62,11 +62,11 @@ def test_discrete_sac(args=get_args()):
net = Net(args.layer_num, args.state_shape, device=args.device)
actor = Actor(net, args.action_shape, softmax_output=False).to(args.device)
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
net = Net(args.layer_num, args.state_shape, device=args.device)
critic1 = Critic(net, last_size=args.action_shape).to(args.device)
net_c1 = Net(args.layer_num, args.state_shape, device=args.device)
critic1 = Critic(net_c1, last_size=args.action_shape).to(args.device)
critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr)
net = Net(args.layer_num, args.state_shape, device=args.device)
critic2 = Critic(net, last_size=args.action_shape).to(args.device)
net_c2 = Net(args.layer_num, args.state_shape, device=args.device)
critic2 = Critic(net_c2, last_size=args.action_shape).to(args.device)
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)
# better not to use auto alpha in CartPole

View File

@ -38,6 +38,9 @@ class SACPolicy(DDPGPolicy):
defaults to False.
:param BaseNoise exploration_noise: add a noise to action for exploration,
defaults to None. This is useful when solving hard-exploration problem.
:param bool deterministic_eval: whether to use deterministic action (mean
of Gaussian policy) instead of stochastic action sampled by the policy,
defaults to True.
.. seealso::
@ -63,6 +66,7 @@ class SACPolicy(DDPGPolicy):
ignore_done: bool = False,
estimation_step: int = 1,
exploration_noise: Optional[BaseNoise] = None,
deterministic_eval: bool = True,
**kwargs: Any,
) -> None:
super().__init__(None, None, None, None, action_range, tau, gamma,
@ -86,6 +90,7 @@ class SACPolicy(DDPGPolicy):
else:
self._alpha = alpha
self._deterministic_eval = deterministic_eval
self.__eps = np.finfo(np.float32).eps.item()
def train(self, mode: bool = True) -> "SACPolicy":
@ -116,13 +121,16 @@ class SACPolicy(DDPGPolicy):
logits, h = self.actor(obs, state=state, info=batch.info)
assert isinstance(logits, tuple)
dist = Independent(Normal(*logits), 1)
x = dist.rsample()
if self._deterministic_eval and not self.training:
x = logits[0]
else:
x = dist.rsample()
y = torch.tanh(x)
act = y * self._action_scale + self._action_bias
y = self._action_scale * (1 - y.pow(2)) + self.__eps
log_prob = dist.log_prob(x).unsqueeze(-1)
log_prob = log_prob - torch.log(y).sum(-1, keepdim=True)
if self._noise is not None and not self.updating:
if self._noise is not None and self.training and not self.updating:
act += to_torch_as(self._noise(act.shape), act)
act = act.clamp(self._range[0], self._range[1])
return Batch(

View File

@ -6,6 +6,10 @@ from typing import Any, Dict, Tuple, Union, Optional, Sequence
from tianshou.data import to_torch, to_torch_as
SIGMA_MIN = -20
SIGMA_MAX = 2
class Actor(nn.Module):
"""Simple actor network with MLP.
@ -89,12 +93,17 @@ class ActorProb(nn.Module):
device: Union[str, int, torch.device] = "cpu",
unbounded: bool = False,
hidden_layer_size: int = 128,
conditioned_sigma: bool = False,
) -> None:
super().__init__()
self.preprocess = preprocess_net
self.device = device
self.mu = nn.Linear(hidden_layer_size, np.prod(action_shape))
self.sigma = nn.Parameter(torch.zeros(np.prod(action_shape), 1))
self._c_sigma = conditioned_sigma
if conditioned_sigma:
self.sigma = nn.Linear(hidden_layer_size, np.prod(action_shape))
else:
self.sigma = nn.Parameter(torch.zeros(np.prod(action_shape), 1))
self._max = max_action
self._unbounded = unbounded
@ -109,9 +118,14 @@ class ActorProb(nn.Module):
mu = self.mu(logits)
if not self._unbounded:
mu = self._max * torch.tanh(mu)
shape = [1] * len(mu.shape)
shape[1] = -1
sigma = (self.sigma.view(shape) + torch.zeros_like(mu)).exp()
if self._c_sigma:
sigma = torch.clamp(
self.sigma(logits), min=SIGMA_MIN, max=SIGMA_MAX
).exp()
else:
shape = [1] * len(mu.shape)
shape[1] = -1
sigma = (self.sigma.view(shape) + torch.zeros_like(mu)).exp()
return (mu, sigma), state
@ -131,6 +145,7 @@ class RecurrentActorProb(nn.Module):
device: Union[str, int, torch.device] = "cpu",
unbounded: bool = False,
hidden_layer_size: int = 128,
conditioned_sigma: bool = False,
) -> None:
super().__init__()
self.device = device
@ -141,7 +156,11 @@ class RecurrentActorProb(nn.Module):
batch_first=True,
)
self.mu = nn.Linear(hidden_layer_size, np.prod(action_shape))
self.sigma = nn.Parameter(torch.zeros(np.prod(action_shape), 1))
self._c_sigma = conditioned_sigma
if conditioned_sigma:
self.sigma = nn.Linear(hidden_layer_size, np.prod(action_shape))
else:
self.sigma = nn.Parameter(torch.zeros(np.prod(action_shape), 1))
self._max = max_action
self._unbounded = unbounded
@ -170,9 +189,14 @@ class RecurrentActorProb(nn.Module):
mu = self.mu(logits)
if not self._unbounded:
mu = self._max * torch.tanh(mu)
shape = [1] * len(mu.shape)
shape[1] = -1
sigma = (self.sigma.view(shape) + torch.zeros_like(mu)).exp()
if self._c_sigma:
sigma = torch.clamp(
self.sigma(logits), min=SIGMA_MIN, max=SIGMA_MAX
).exp()
else:
shape = [1] * len(mu.shape)
shape[1] = -1
sigma = (self.sigma.view(shape) + torch.zeros_like(mu)).exp()
# please ensure the first dim is batch size: [bsz, len, ...]
return (mu, sigma), {"h": h.transpose(0, 1).detach(),
"c": c.transpose(0, 1).detach()}