modified based on author's implementation

This commit is contained in:
NM512 2023-03-18 08:38:23 +09:00
parent a678a509b9
commit 6273444394
6 changed files with 371 additions and 229 deletions

View File

@ -17,8 +17,17 @@ Monitor results:
tensorboard --logdir $ABSOLUTEPATH_TO_SAVE_LOG
```
## ToDo
- [ ] Prototyping
- [ ] Modify implementation details based on the author's implementation
- [ ] Evaluate on visual DMC suite(~10 tasks)
- [ ] Add other tasks and corresponding model sizes implementation
- [ ] Continuous implementation improvement
## Acknowledgments
This code is heavily inspired by the following works:
- danijar's Dreamer-v3 jax implementation: https://github.com/danijar/dreamerv3
- danijar's Dreamer-v2 tensorflow implementation: https://github.com/danijar/dreamerv2
- jsikyoon's Dreamer-v2 pytorch implementation: https://github.com/jsikyoon/dreamer-torch
- RajGhugare19's Dreamer-v2 pytorch implementation: https://github.com/RajGhugare19/dreamerv2

View File

@ -8,9 +8,9 @@ defaults:
seed: 0
steps: 1e6
eval_every: 1e4
eval_episode_num: 10
log_every: 1e4
reset_every: 0
#gpu_growth: True
device: 'cuda:0'
precision: 16
debug: False
@ -25,9 +25,6 @@ defaults:
grayscale: False
prefill: 2500
eval_noise: 0.0
reward_trans: 'symlog'
obs_trans: 'normalize'
critic_trans: 'symlog'
reward_EMA: True
# Model
@ -36,8 +33,8 @@ defaults:
dyn_deter: 512
dyn_stoch: 32
dyn_discrete: 32
dyn_input_layers: 2
dyn_output_layers: 2
dyn_input_layers: 1
dyn_output_layers: 1
dyn_rec_depth: 1
dyn_shared: False
dyn_mean_act: 'none'
@ -53,11 +50,10 @@ defaults:
act: 'SiLU'
norm: 'LayerNorm'
cnn_depth: 32
encoder_kernels: [3, 3, 3, 3]
decoder_kernels: [3, 3, 3, 3]
# changed here
value_head: 'twohot'
reward_head: 'twohot'
encoder_kernels: [4, 4, 4, 4]
decoder_kernels: [4, 4, 4, 4]
value_head: 'twohot_symlog'
reward_head: 'twohot_symlog'
kl_lscale: '0.1'
kl_rscale: '0.5'
kl_free: '1.0'
@ -71,7 +67,7 @@ defaults:
# Training
batch_size: 16
batch_length: 64
train_every: 5
train_ratio: 512
train_steps: 1
pretrain: 100
model_lr: 1e-4
@ -85,9 +81,8 @@ defaults:
dataset_size: 0
oversample_ends: False
slow_value_target: True
slow_actor_target: True
slow_target_update: 100
slow_target_fraction: 0.01
slow_target_update: 1
slow_target_fraction: 0.02
opt: 'adam'
# Behavior.
@ -95,16 +90,15 @@ defaults:
discount_lambda: 0.95
imag_horizon: 15
imag_gradient: 'dynamics'
imag_gradient_mix: '0.1'
imag_gradient_mix: '0.0'
imag_sample: True
actor_dist: 'trunc_normal'
actor_dist: 'normal'
actor_entropy: '3e-4'
actor_state_entropy: 0.0
actor_init_std: 1.0
actor_min_std: 0.1
actor_disc: 5
actor_max_std: 1.0
actor_temp: 0.1
actor_outscale: 0.0
expl_amount: 0.0
eval_state_mean: False
collect_dyn_sample: True
@ -134,3 +128,14 @@ debug:
batch_size: 10
batch_length: 20
cheetah:
task: 'dmc_cheetah_run'
pendulum:
task: 'dmc_pendulum_swingup'
cup:
task: 'dmc_cup_catch'
acrobot:
task: 'dmc_acrobot_swingup'

View File

@ -22,6 +22,7 @@ import torch
from torch import nn
from torch import distributions as torchd
to_np = lambda x: x.detach().cpu().numpy()
@ -31,7 +32,8 @@ class Dreamer(nn.Module):
self._config = config
self._logger = logger
self._should_log = tools.Every(config.log_every)
self._should_train = tools.Every(config.train_every)
batch_steps = config.batch_size * config.batch_length
self._should_train = tools.Every(batch_steps / config.train_ratio)
self._should_pretrain = tools.Once()
self._should_reset = tools.Every(config.reset_every)
self._should_expl = tools.Until(int(config.expl_until / config.action_repeat))
@ -146,16 +148,17 @@ class Dreamer(nn.Module):
post, context, mets = self._wm._train(data)
metrics.update(mets)
start = post
# start['deter'] (16, 64, 512)
if self._config.pred_discount: # Last step could be terminal.
start = {k: v[:, :-1] for k, v in post.items()}
context = {k: v[:, :-1] for k, v in context.items()}
start = {k: v[:-1] for k, v in post.items()}
context = {k: v[:-1] for k, v in context.items()}
reward = lambda f, s, a: self._wm.heads["reward"](
self._wm.dynamics.get_feat(s)
).mode()
metrics.update(self._task_behavior._train(start, reward)[-1])
if self._config.expl_behavior != "greedy":
if self._config.pred_discount:
data = {k: v[:, :-1] for k, v in data.items()}
data = {k: v[:-1] for k, v in data.items()}
mets = self._expl_behavior.train(start, context, data)[-1]
metrics.update({"expl_" + key: value for key, value in mets.items()})
for name, value in metrics.items():
@ -205,7 +208,12 @@ def make_env(config, logger, mode, train_eps, eval_eps):
if (mode == "train") or (mode == "eval"):
callbacks = [
functools.partial(
process_episode, config, logger, mode, train_eps, eval_eps
ProcessEpisodeWrap.process_episode,
config,
logger,
mode,
train_eps,
eval_eps,
)
]
env = wrappers.CollectDataset(env, callbacks)
@ -213,15 +221,33 @@ def make_env(config, logger, mode, train_eps, eval_eps):
return env
def process_episode(config, logger, mode, train_eps, eval_eps, episode):
class ProcessEpisodeWrap:
eval_scores = []
eval_lengths = []
@classmethod
def process_episode(cls, config, logger, mode, train_eps, eval_eps, episode):
directory = dict(train=config.traindir, eval=config.evaldir)[mode]
cache = dict(train=train_eps, eval=eval_eps)[mode]
# this saved episodes is given as train_eps or eval_eps from next call
filename = tools.save_episodes(directory, [episode])[0]
length = len(episode["reward"]) - 1
score = float(episode["reward"].astype(np.float64).sum())
video = episode["image"]
cache[str(filename)] = episode
if mode == "eval":
cache.clear()
cls.eval_scores.append(score)
cls.eval_lengths.append(length)
# save when enought number of episodes are stored
if len(cls.eval_scores) < config.eval_episode_num:
return
else:
score = sum(cls.eval_scores) / len(cls.eval_scores)
length = sum(cls.eval_lengths) / len(cls.eval_lengths)
episode_num = len(cls.eval_scores)
cls.eval_scores = []
cls.eval_lengths = []
if mode == "train" and config.dataset_size:
total = 0
for key, ep in reversed(sorted(cache.items(), key=lambda x: x[0])):
@ -230,12 +256,14 @@ def process_episode(config, logger, mode, train_eps, eval_eps, episode):
else:
del cache[key]
logger.scalar("dataset_size", total + length)
cache[str(filename)] = episode
print(f"{mode.title()} episode has {length} steps and return {score:.1f}.")
logger.scalar(f"{mode}_return", score)
logger.scalar(f"{mode}_length", length)
logger.scalar(f"{mode}_episodes", len(cache))
logger.scalar(
f"{mode}_episodes", len(cache) if mode == "train" else episode_num
)
if mode == "eval" or config.expl_gifs:
# only last video in eval videos is preservad
logger.video(f"{mode}_policy", video[None])
logger.write()
@ -315,7 +343,7 @@ def main(config):
video_pred = agent._wm.video_pred(next(eval_dataset))
logger.video("eval_openl", to_np(video_pred))
eval_policy = functools.partial(agent, training=False)
tools.simulate(eval_policy, eval_envs, episodes=1)
tools.simulate(eval_policy, eval_envs, episodes=config.eval_episode_num)
print("Start training.")
state = tools.simulate(agent, train_envs, config.eval_every, state=state)
torch.save(agent.state_dict(), logdir / "latest_model.pt")

148
models.py
View File

@ -10,30 +10,22 @@ import tools
to_np = lambda x: x.detach().cpu().numpy()
def symlog(x):
return torch.sign(x) * torch.log(torch.abs(x) + 1.0)
def symexp(x):
return torch.sign(x) * (torch.exp(torch.abs(x)) - 1.0)
class RewardEMA(object):
"""running mean and std"""
def __init__(self, device, alpha=1e-2):
self.device = device
self.scale = torch.zeros((1,)).to(device)
self.values = torch.zeros((2,)).to(device)
self.alpha = alpha
self.range = torch.tensor([0.05, 0.95]).to(device)
def __call__(self, x):
flat_x = torch.flatten(x.detach())
x_quantile = torch.quantile(input=flat_x, q=self.range)
scale = x_quantile[1] - x_quantile[0]
new_scale = self.alpha * scale + (1 - self.alpha) * self.scale
self.scale = new_scale
return x / torch.clip(self.scale, min=1.0)
self.values = self.alpha * x_quantile + (1 - self.alpha) * self.values
scale = torch.clip(self.values[1] - self.values[0], min=1.0)
offset = self.values[0]
return offset.detach(), scale.detach()
class WorldModel(nn.Module):
@ -93,7 +85,7 @@ class WorldModel(nn.Module):
shape,
config.decoder_kernels,
)
if config.reward_head == "twohot":
if config.reward_head == "twohot_symlog":
self.heads["reward"] = networks.DenseHead(
feat_size, # pytorch version
(255,),
@ -102,6 +94,7 @@ class WorldModel(nn.Module):
config.act,
config.norm,
dist=config.reward_head,
outscale=0.0,
)
else:
self.heads["reward"] = networks.DenseHead(
@ -112,9 +105,8 @@ class WorldModel(nn.Module):
config.act,
config.norm,
dist=config.reward_head,
outscale=0.0,
)
# added this
self.heads["reward"].apply(tools.weight_init)
if config.pred_discount:
self.heads["discount"] = networks.DenseHead(
feat_size, # pytorch version
@ -163,8 +155,6 @@ class WorldModel(nn.Module):
feat = self.dynamics.get_feat(post)
feat = feat if grad_head else feat.detach()
pred = head(feat)
# if name == 'image':
# losses[name] = torch.nn.functional.mse_loss(pred.mode(), data[name], 'sum')
like = pred.log_prob(data[name])
likes[name] = like
losses[name] = -torch.mean(like) * self._scales.get(name, 1.0)
@ -196,24 +186,9 @@ class WorldModel(nn.Module):
def preprocess(self, obs):
obs = obs.copy()
if self._config.obs_trans == "normalize":
obs["image"] = torch.Tensor(obs["image"]) / 255.0 - 0.5
elif self._config.obs_trans == "identity":
obs["image"] = torch.Tensor(obs["image"])
elif self._config.obs_trans == "symlog":
obs["image"] = symlog(torch.Tensor(obs["image"]))
else:
raise NotImplemented(f"{self._config.reward_trans} is not implemented")
if self._config.reward_trans == "tanh":
# (batch_size, batch_length) -> (batch_size, batch_length, 1)
obs["reward"] = torch.tanh(torch.Tensor(obs["reward"])).unsqueeze(-1)
elif self._config.reward_trans == "identity":
# (batch_size, batch_length) -> (batch_size, batch_length, 1)
obs["reward"] = torch.Tensor(obs["reward"]).unsqueeze(-1)
elif self._config.reward_trans == "symlog":
obs["reward"] = symlog(torch.Tensor(obs["reward"])).unsqueeze(-1)
else:
raise NotImplemented(f"{self._config.reward_trans} is not implemented")
if "discount" in obs:
obs["discount"] *= self._config.discount
# (batch_size, batch_length) -> (batch_size, batch_length, 1)
@ -234,13 +209,9 @@ class WorldModel(nn.Module):
reward_prior = self.heads["reward"](self.dynamics.get_feat(prior)).mode()
# observed image is given until 5 steps
model = torch.cat([recon[:, :5], openl], 1)
if self._config.obs_trans == "normalize":
truth = data["image"][:6] + 0.5
model += 0.5
elif self._config.obs_trans == "symlog":
truth = symexp(data["image"][:6]) / 255.0
model = symexp(model) / 255.0
error = (model - truth + 1) / 2
model = model + 0.5
error = (model - truth + 1.0) / 2.0
return torch.cat([truth, model, error], 2)
@ -267,11 +238,11 @@ class ImagBehavior(nn.Module):
config.actor_dist,
config.actor_init_std,
config.actor_min_std,
config.actor_dist,
config.actor_max_std,
config.actor_temp,
config.actor_outscale,
outscale=1.0,
) # action_dist -> action_disc?
if config.value_head == "twohot":
if config.value_head == "twohot_symlog":
self.value = networks.DenseHead(
feat_size, # pytorch version
(255,),
@ -280,6 +251,7 @@ class ImagBehavior(nn.Module):
config.act,
config.norm,
config.value_head,
outscale=0.0,
)
else:
self.value = networks.DenseHead(
@ -290,9 +262,9 @@ class ImagBehavior(nn.Module):
config.act,
config.norm,
config.value_head,
outscale=0.0,
)
self.value.apply(tools.weight_init)
if config.slow_value_target or config.slow_actor_target:
if config.slow_value_target:
self._slow_value = copy.deepcopy(self.value)
self._updates = 0
kw = dict(wd=config.weight_decay, opt=config.opt, use_amp=self._use_amp)
@ -335,21 +307,12 @@ class ImagBehavior(nn.Module):
start, self.actor, self._config.imag_horizon, repeats
)
reward = objective(imag_feat, imag_state, imag_action)
if self._config.reward_trans == "symlog":
# rescale predicted reward by head['reward']
reward = symexp(reward)
actor_ent = self.actor(imag_feat).entropy()
state_ent = self._world_model.dynamics.get_dist(imag_state).entropy()
# this target is not scaled
# slow is flag to indicate whether slow_target is used for lambda-return
target, weights = self._compute_target(
imag_feat,
imag_state,
imag_action,
reward,
actor_ent,
state_ent,
self._config.slow_actor_target,
target, weights, base = self._compute_target(
imag_feat, imag_state, imag_action, reward, actor_ent, state_ent
)
actor_loss, mets = self._compute_actor_loss(
imag_feat,
@ -359,42 +322,31 @@ class ImagBehavior(nn.Module):
actor_ent,
state_ent,
weights,
base,
)
metrics.update(mets)
if self._config.slow_value_target != self._config.slow_actor_target:
target, weights = self._compute_target(
imag_feat,
imag_state,
imag_action,
reward,
actor_ent,
state_ent,
self._config.slow_value_target,
)
value_input = imag_feat
with tools.RequiresGrad(self.value):
with torch.cuda.amp.autocast(self._use_amp):
value = self.value(value_input[:-1].detach())
target = torch.stack(target, dim=1)
# only critic target is processed using symlog(not actor)
if self._config.critic_trans == "symlog":
metrics["unscaled_target_mean"] = to_np(torch.mean(target))
target = symlog(target)
# (time, batch, 1), (time, batch, 1) -> (time, batch)
value_loss = -value.log_prob(target.detach())
slow_target = self._slow_value(value_input[:-1].detach())
if self._config.slow_value_target:
value_loss = value_loss - value.log_prob(
slow_target.mode().detach()
)
if self._config.value_decay:
value_loss += self._config.value_decay * value.mode()
# (time, batch, 1), (time, batch, 1) -> (1,)
value_loss = torch.mean(weights[:-1] * value_loss[:, :, None])
metrics["value_mean"] = to_np(torch.mean(value.mode()))
metrics["value_max"] = to_np(torch.max(value.mode()))
metrics["value_min"] = to_np(torch.min(value.mode()))
metrics["value_std"] = to_np(torch.std(value.mode()))
metrics["target_mean"] = to_np(torch.mean(target))
metrics["reward_mean"] = to_np(torch.mean(reward))
metrics["reward_std"] = to_np(torch.std(reward))
metrics.update(tools.tensorstats(value.mode(), "value"))
metrics.update(tools.tensorstats(target, "target"))
metrics.update(tools.tensorstats(reward, "imag_reward"))
metrics.update(tools.tensorstats(imag_action, "imag_action"))
metrics["actor_ent"] = to_np(torch.mean(actor_ent))
with tools.RequiresGrad(self):
metrics.update(self._actor_opt(actor_loss, self.actor.parameters()))
@ -402,6 +354,11 @@ class ImagBehavior(nn.Module):
return imag_feat, imag_state, imag_action, weights, metrics
def _imagine(self, start, policy, horizon, repeats=None):
# horizon: 15
# start = dict(stoch, deter, logit)
# start["stoch"] (16, 63, 32, 32)
# start["deter"] (16, 63, 512)
# start["logit"] (16, 63, 32, 32)
dynamics = self._world_model.dynamics
if repeats:
raise NotImplemented("repeats is not implemented in this version")
@ -418,6 +375,8 @@ class ImagBehavior(nn.Module):
feat = 0 * dynamics.get_feat(start)
action = policy(feat).mode()
# Is this action deterministic or stochastic?
# action = policy(feat).sample()
succ, feats, actions = tools.static_scan(
step, [torch.arange(horizon)], (start, feat, action)
)
@ -428,7 +387,7 @@ class ImagBehavior(nn.Module):
return feats, states, actions
def _compute_target(
self, imag_feat, imag_state, imag_action, reward, actor_ent, state_ent, slow
self, imag_feat, imag_state, imag_action, reward, actor_ent, state_ent
):
if "discount" in self._world_model.heads:
inp = self._world_model.dynamics.get_feat(imag_state)
@ -439,13 +398,10 @@ class ImagBehavior(nn.Module):
reward += self._config.actor_entropy() * actor_ent
if self._config.future_entropy and self._config.actor_state_entropy() > 0:
reward += self._config.actor_state_entropy() * state_ent
if slow:
value = self._slow_value(imag_feat).mode()
else:
value = self.value(imag_feat).mode()
if self._config.critic_trans == "symlog":
# After adding this line there is issue
value = symexp(value)
# value(15, 960, ch)
# action(15, 960, ch)
# discount(15, 960, ch)
target = tools.lambda_return(
reward[:-1],
value[:-1],
@ -457,10 +413,18 @@ class ImagBehavior(nn.Module):
weights = torch.cumprod(
torch.cat([torch.ones_like(discount[:1]), discount[:-1]], 0), 0
).detach()
return target, weights
return target, weights, value[:-1]
def _compute_actor_loss(
self, imag_feat, imag_state, imag_action, target, actor_ent, state_ent, weights
self,
imag_feat,
imag_state,
imag_action,
target,
actor_ent,
state_ent,
weights,
base,
):
metrics = {}
inp = imag_feat.detach() if self._stop_grad_actor else imag_feat
@ -469,11 +433,17 @@ class ImagBehavior(nn.Module):
# Q-val for actor is not transformed using symlog
target = torch.stack(target, dim=1)
if self._config.reward_EMA:
target = self.reward_ema(target)
metrics["EMA_scale"] = to_np(self.reward_ema.scale)
offset, scale = self.reward_ema(target)
normed_target = (target - offset) / scale
normed_base = (base - offset) / scale
adv = normed_target - normed_base
metrics.update(tools.tensorstats(normed_target, "normed_target"))
values = self.reward_ema.values
metrics["EMA_005"] = to_np(values[0])
metrics["EMA_095"] = to_np(values[1])
if self._config.imag_gradient == "dynamics":
actor_target = target
actor_target = adv
elif self._config.imag_gradient == "reinforce":
actor_target = (
policy.log_prob(imag_action)[:-1][:, :, None]
@ -501,7 +471,7 @@ class ImagBehavior(nn.Module):
return actor_loss, metrics
def _update_slow_target(self):
if self._config.slow_value_target or self._config.slow_actor_target:
if self._config.slow_value_target:
if self._updates % self._config.slow_target_update == 0:
mix = self._config.slow_target_fraction
for s, d in zip(self.value.parameters(), self._slow_value.parameters()):

View File

@ -59,29 +59,33 @@ class RSSM(nn.Module):
if self._shared:
inp_dim += self._embed
for i in range(self._layers_input):
inp_layers.append(nn.Linear(inp_dim, self._hidden))
inp_layers.append(self._norm(self._hidden))
inp_layers.append(nn.Linear(inp_dim, self._hidden, bias=False))
inp_layers.append(self._norm(self._hidden, eps=1e-03))
inp_layers.append(self._act())
if i == 0:
inp_dim = self._hidden
self._inp_layers = nn.Sequential(*inp_layers)
self._inp_layers.apply(tools.weight_init)
if cell == "gru":
self._cell = GRUCell(self._hidden, self._deter)
self._cell.apply(tools.weight_init)
elif cell == "gru_layer_norm":
self._cell = GRUCell(self._hidden, self._deter, norm=True)
self._cell.apply(tools.weight_init)
else:
raise NotImplementedError(cell)
img_out_layers = []
inp_dim = self._deter
for i in range(self._layers_output):
img_out_layers.append(nn.Linear(inp_dim, self._hidden))
img_out_layers.append(self._norm(self._hidden))
img_out_layers.append(nn.Linear(inp_dim, self._hidden, bias=False))
img_out_layers.append(self._norm(self._hidden, eps=1e-03))
img_out_layers.append(self._act())
if i == 0:
inp_dim = self._hidden
self._img_out_layers = nn.Sequential(*img_out_layers)
self._img_out_layers.apply(tools.weight_init)
obs_out_layers = []
if self._temp_post:
@ -89,19 +93,24 @@ class RSSM(nn.Module):
else:
inp_dim = self._embed
for i in range(self._layers_output):
obs_out_layers.append(nn.Linear(inp_dim, self._hidden))
obs_out_layers.append(self._norm(self._hidden))
obs_out_layers.append(nn.Linear(inp_dim, self._hidden, bias=False))
obs_out_layers.append(self._norm(self._hidden, eps=1e-03))
obs_out_layers.append(self._act())
if i == 0:
inp_dim = self._hidden
self._obs_out_layers = nn.Sequential(*obs_out_layers)
self._obs_out_layers.apply(tools.weight_init)
if self._discrete:
self._ims_stat_layer = nn.Linear(self._hidden, self._stoch * self._discrete)
self._ims_stat_layer.apply(tools.weight_init)
self._obs_stat_layer = nn.Linear(self._hidden, self._stoch * self._discrete)
self._obs_stat_layer.apply(tools.weight_init)
else:
self._ims_stat_layer = nn.Linear(self._hidden, 2 * self._stoch)
self._ims_stat_layer.apply(tools.weight_init)
self._obs_stat_layer = nn.Linear(self._hidden, 2 * self._stoch)
self._obs_stat_layer.apply(tools.weight_init)
def initial(self, batch_size):
deter = torch.zeros(batch_size, self._deter).to(self._device)
@ -178,6 +187,7 @@ class RSSM(nn.Module):
def obs_step(self, prev_state, prev_action, embed, sample=True):
# if shared is True, prior and post both use same networks(inp_layers, _img_out_layers, _ims_stat_layer)
# otherwise, post use different network(_obs_out_layers) with prior[deter] and embed as inputs
prev_action *= (1.0 / torch.clip(torch.abs(prev_action), min=1.0)).detach()
prior = self.img_step(prev_state, prev_action, None, sample)
if self._shared:
post = self.img_step(prev_state, prev_action, embed, sample)
@ -200,6 +210,7 @@ class RSSM(nn.Module):
# this is used for making future image
def img_step(self, prev_state, prev_action, embed=None, sample=True):
# (batch, stoch, discrete_num)
prev_action *= (1.0 / torch.clip(torch.abs(prev_action), min=1.0)).detach()
prev_stoch = prev_state["stoch"]
if self._discrete:
shape = list(prev_stoch.shape[:-2]) + [self._stoch * self._discrete]
@ -317,12 +328,15 @@ class ConvEncoder(nn.Module):
out_channels=depth,
kernel_size=(kernel, kernel),
stride=(2, 2),
bias=False,
)
)
h, w = h // 2, w // 2
# layers.append(norm([depth, h, w]))
layers.append(ChLayerNorm(depth))
layers.append(act())
h, w = h // 2, w // 2
self.layers = nn.Sequential(*layers)
self.layers.apply(tools.weight_init)
def __call__(self, obs):
x = obs["image"].reshape((-1,) + tuple(obs["image"].shape[-3:]))
@ -343,6 +357,7 @@ class ConvDecoder(nn.Module):
norm=nn.LayerNorm,
shape=(3, 64, 64),
kernels=(3, 3, 3, 3),
outscale=1.0,
):
super(ConvDecoder, self).__init__()
self._inp_depth = inp_depth
@ -358,19 +373,25 @@ class ConvDecoder(nn.Module):
self._linear_layer = nn.Linear(inp_depth, self._embed_size)
inp_dim = self._embed_size // 16
cnnt_layers = []
layers = []
h, w = 4, 4
for i, kernel in enumerate(self._kernels):
depth = self._embed_size // 16 // (2 ** (i + 1))
act = self._act
bias = False
initializer = tools.weight_init
if i == len(self._kernels) - 1:
depth = self._shape[0]
act = None
act = False
bias = True
norm = False
initializer = tools.uniform_weight_init(outscale)
if i != 0:
inp_dim = 2 ** (len(self._kernels) - (i - 1) - 2) * self._depth
pad_h, outpad_h = calc_same_pad(k=kernel, s=2, d=1)
pad_w, outpad_w = calc_same_pad(k=kernel, s=2, d=1)
cnnt_layers.append(
pad_h, outpad_h = self.calc_same_pad(k=kernel, s=2, d=1)
pad_w, outpad_w = self.calc_same_pad(k=kernel, s=2, d=1)
layers.append(
nn.ConvTranspose2d(
inp_dim,
depth,
@ -378,26 +399,32 @@ class ConvDecoder(nn.Module):
2,
padding=(pad_h, pad_w),
output_padding=(outpad_h, outpad_w),
bias=bias,
)
)
if norm:
layers.append(ChLayerNorm(depth))
if act:
layers.append(act())
[m.apply(initializer) for m in layers[-3:]]
h, w = h * 2, w * 2
# cnnt_layers.append(norm([depth, h, w]))
if act is not None:
cnnt_layers.append(act())
self._cnnt_layers = nn.Sequential(*cnnt_layers)
self.layers = nn.Sequential(*layers)
def calc_same_pad(self, k, s, d):
val = d * (k - 1) - s + 1
pad = math.ceil(val / 2)
outpad = pad * 2 - val
return pad, outpad
def __call__(self, features, dtype=None):
x = self._linear_layer(features)
x = x.reshape([-1, 4, 4, self._embed_size // 16])
x = x.permute(0, 3, 1, 2)
x = self._cnnt_layers(x)
x = self.layers(x)
mean = x.reshape(features.shape[:-1] + self._shape)
mean = mean.permute(0, 1, 3, 4, 2)
return tools.ContDist(
torchd.independent.Independent(
torchd.normal.Normal(mean, 1), len(self._shape)
)
)
return tools.SymlogDist(mean)
class DenseHead(nn.Module):
@ -411,7 +438,7 @@ class DenseHead(nn.Module):
norm=nn.LayerNorm,
dist="normal",
std=1.0,
unimix_ratio=0.0,
outscale=1.0,
):
super(DenseHead, self).__init__()
self._shape = (shape,) if isinstance(shape, int) else shape
@ -423,27 +450,30 @@ class DenseHead(nn.Module):
self._norm = norm
self._dist = dist
self._std = std
self._unimix_ratio = unimix_ratio
mean_layers = []
layers = []
for index in range(self._layers):
mean_layers.append(nn.Linear(inp_dim, self._units))
mean_layers.append(norm(self._units))
mean_layers.append(act())
layers.append(nn.Linear(inp_dim, self._units, bias=False))
layers.append(norm(self._units, eps=1e-03))
layers.append(act())
if index == 0:
inp_dim = self._units
mean_layers.append(nn.Linear(inp_dim, np.prod(self._shape)))
self._mean_layers = nn.Sequential(*mean_layers)
self.layers = nn.Sequential(*layers)
self.layers.apply(tools.weight_init)
self.mean_layer = nn.Linear(inp_dim, np.prod(self._shape))
self.mean_layer.apply(tools.uniform_weight_init(outscale))
if self._std == "learned":
self._std_layer = nn.Linear(self._units, np.prod(self._shape))
self.std_layer = nn.Linear(self._units, np.prod(self._shape))
self.std_layer.apply(tools.uniform_weight_init(outscale))
def __call__(self, features, dtype=None):
x = features
mean = self._mean_layers(x)
out = self.layers(x)
mean = self.mean_layer(out)
if self._std == "learned":
std = self._std_layer(x)
std = torch.softplus(std) + 0.01
std = self.std_layer(out)
else:
std = self._std
if self._dist == "normal":
@ -464,8 +494,8 @@ class DenseHead(nn.Module):
torchd.bernoulli.Bernoulli(logits=mean), len(self._shape)
)
)
if self._dist == "twohot":
return tools.TwoHotDist(logits=mean, unimix_ratio=self._unimix_ratio)
if self._dist == "twohot_symlog":
return tools.TwoHotDistSymlog(logits=mean)
raise NotImplementedError(self._dist)
@ -481,9 +511,9 @@ class ActionHead(nn.Module):
dist="trunc_normal",
init_std=0.0,
min_std=0.1,
action_disc=5,
max_std=1.0,
temp=0.1,
outscale=0,
outscale=1.0,
):
super(ActionHead, self).__init__()
self._size = size
@ -493,24 +523,27 @@ class ActionHead(nn.Module):
self._act = act
self._norm = norm
self._min_std = min_std
self._max_std = max_std
self._init_std = init_std
self._action_disc = action_disc
self._temp = temp() if callable(temp) else temp
self._outscale = outscale
pre_layers = []
for index in range(self._layers):
pre_layers.append(nn.Linear(inp_dim, self._units))
pre_layers.append(norm(self._units))
pre_layers.append(nn.Linear(inp_dim, self._units, bias=False))
pre_layers.append(norm(self._units, eps=1e-03))
pre_layers.append(act())
if index == 0:
inp_dim = self._units
self._pre_layers = nn.Sequential(*pre_layers)
self._pre_layers.apply(tools.weight_init)
if self._dist in ["tanh_normal", "tanh_normal_5", "normal", "trunc_normal"]:
self._dist_layer = nn.Linear(self._units, 2 * self._size)
self._dist_layer.apply(tools.uniform_weight_init(outscale))
elif self._dist in ["normal_1", "onehot", "onehot_gumbel"]:
self._dist_layer = nn.Linear(self._units, self._size)
self._dist_layer.apply(tools.uniform_weight_init(outscale))
def __call__(self, features, dtype=None):
x = features
@ -539,9 +572,11 @@ class ActionHead(nn.Module):
dist = tools.SampleDist(dist)
elif self._dist == "normal":
x = self._dist_layer(x)
mean, std = torch.split(x, 2, -1)
std = F.softplus(std + self._init_std) + self._min_std
dist = torchd.normal.Normal(mean, std)
mean, std = torch.split(x, [self._size] * 2, -1)
std = (self._max_std - self._min_std) * torch.sigmoid(
std + 2.0
) + self._min_std
dist = torchd.normal.Normal(torch.tanh(mean), std)
dist = tools.ContDist(torchd.independent.Independent(dist, 1))
elif self._dist == "normal_1":
x = self._dist_layer(x)
@ -574,9 +609,9 @@ class GRUCell(nn.Module):
self._act = act
self._norm = norm
self._update_bias = update_bias
self._layer = nn.Linear(inp_size + size, 3 * size, bias=norm is not None)
self._layer = nn.Linear(inp_size + size, 3 * size, bias=False)
if norm:
self._norm = nn.LayerNorm(3 * size)
self._norm = nn.LayerNorm(3 * size, eps=1e-03)
@property
def state_size(self):
@ -625,8 +660,13 @@ class Conv2dSame(torch.nn.Conv2d):
return ret
def calc_same_pad(k, s, d):
val = d * (k - 1) - s + 1
pad = math.ceil(val / 2)
outpad = pad * 2 - val
return pad, outpad
class ChLayerNorm(nn.Module):
def __init__(self, ch, eps=1e-03):
super(ChLayerNorm, self).__init__()
self.norm = torch.nn.LayerNorm(ch, eps=eps)
def forward(self, x):
x = x.permute(0, 2, 3, 1)
x = self.norm(x)
x = x.permute(0, 3, 1, 2)
return x

152
tools.py
View File

@ -17,6 +17,14 @@ from torch.utils.data import Dataset
from torch.utils.tensorboard import SummaryWriter
to_np = lambda x: x.detach().cpu().numpy()
def symlog(x):
return torch.sign(x) * torch.log(torch.abs(x) + 1.0)
def symexp(x):
return torch.sign(x) * (torch.exp(torch.abs(x)) - 1.0)
class RequiresGrad:
def __init__(self, model):
@ -269,10 +277,12 @@ class SampleDist:
class OneHotDist(torchd.one_hot_categorical.OneHotCategorical):
def __init__(self, logits=None, probs=None, unimix_ratio=0.0):
if logits is not None and probs is None and unimix_ratio > 0.0:
if logits is not None and unimix_ratio > 0.0:
probs = F.softmax(logits, dim=-1)
probs = probs * (1.0-unimix_ratio) + unimix_ratio / probs.shape[-1]
logits = None
logits = torch.log(probs)
super().__init__(logits=logits, probs=None)
else:
super().__init__(logits=logits, probs=probs)
def mode(self):
@ -290,42 +300,81 @@ class OneHotDist(torchd.one_hot_categorical.OneHotCategorical):
return sample
class TwoHotDist(torchd.one_hot_categorical.OneHotCategorical):
class TwoHotDistSymlog():
def __init__(self, logits=None, probs=None, unimix_ratio=0.0, device='cuda'):
if logits is not None and probs is None and unimix_ratio > 0.0:
probs = F.softmax(logits, dim=-1)
probs = probs * (1.0-unimix_ratio) + unimix_ratio / probs.shape[-1]
logits = None
super().__init__(logits=logits, probs=probs)
self.buckets = torch.linspace(-20.0, 20.0, steps=255).to(device)
def __init__(self, logits=None, low=-20.0, high=20.0, device='cuda'):
self.logits = logits
self.probs = torch.softmax(logits, -1)
self.buckets = torch.linspace(low, high, steps=255).to(device)
self.width = (self.buckets[-1] - self.buckets[0]) / 255
def mean(self):
print("mean called")
_mode = self.probs * self.buckets
return symexp(torch.sum(_mode, dim=-1, keepdim=True))
def mode(self):
_mode = super().probs * self.buckets
return torch.sum(_mode, dim=-1, keepdim=True)
_mode = self.probs * self.buckets
return symexp(torch.sum(_mode, dim=-1, keepdim=True))
# Inside OneHotCategorical, log_prob is calculated using only max element in targets
def log_prob(self, x):
x = symlog(x)
# x(time, batch, 1)
x = (x - self.buckets[0]) / self.width
lower_indices = (x).to(torch.int64)
# lower_indices is idnside 0 ~ len(buckets)-2
lower_indices = torch.clip(lower_indices, max=len(self.buckets)-2)
# upper_indices is inside 1 ~ len(buckets)-1
upper_indices = lower_indices + 1
lower_weight = torch.abs(x - upper_indices).squeeze(-1)
upper_weight = torch.abs(x - lower_indices).squeeze(-1)
# (time, batch, 1) -> (time, batch, bucket_class)
lower_log_prob = super().log_prob(F.one_hot(lower_indices.squeeze(-1), num_classes=len(self.buckets)))
upper_log_prob = super().log_prob(F.one_hot(upper_indices.squeeze(-1), num_classes=len(self.buckets)))
below = torch.sum((self.buckets <= x[..., None]).to(torch.int32), dim=-1) -1
above = len(self.buckets) - torch.sum((self.buckets > x[..., None]).to(torch.int32), dim=-1)
below = torch.clip(below, 0, len(self.buckets)-1)
above = torch.clip(above, 0, len(self.buckets)-1)
equal = (below == above)
# label = lower_log_prob * lower_weight + upper_log_prob * upper_weight
# # (time, batch, bucket_class) -> (time, batch)
# cross_entropy = torch.sum(torch.log(super().probs) * label, axis=-1)
dist_to_below = torch.where(equal, 1, torch.abs(self.buckets[below] - x))
dist_to_above = torch.where(equal, 1, torch.abs(self.buckets[above] - x))
total = dist_to_below + dist_to_above
weight_below = dist_to_above / total
weight_above = dist_to_below / total
target = (
F.one_hot(below, num_classes=len(self.buckets)) * weight_below[..., None] +
F.one_hot(above, num_classes=len(self.buckets)) * weight_above[..., None])
log_pred = self.logits - torch.logsumexp(self.logits, -1, keepdim=True)
target = target.squeeze(-2)
return lower_weight * lower_log_prob + upper_weight * upper_log_prob
return (target * log_pred).sum(-1)
def log_prob_target(self, target):
log_pred = super().logits - torch.logsumexp(super().logits, -1, keepdim=True)
return (target * log_pred).sum(-1)
class SymlogDist():
def __init__(self, mode, dist='mse', agg='sum', tol=1e-8, dim_to_reduce=[-1, -2, -3]):
self._mode = mode
self._dist = dist
self._agg = agg
self._tol = tol
self._dim_to_reduce = dim_to_reduce
def mode(self):
return symexp(self._mode)
def mean(self):
return symexp(self._mode)
def log_prob(self, value):
assert self._mode.shape == value.shape
if self._dist == 'mse':
distance = (self._mode - symlog(value)) ** 2.0
distance = torch.where(distance < self._tol, 0, distance)
elif self._dist == 'abs':
distance = torch.abs(self._mode - symlog(value))
distance = torch.where(distance < self._tol, 0, distance)
else:
raise NotImplementedError(self._dist)
if self._agg == 'mean':
loss = distance.mean(self._dim_to_reduce)
elif self._agg == 'sum':
loss = distance.sum(self._dim_to_reduce)
else:
raise NotImplementedError(self._agg)
return -loss
class ContDist:
@ -438,6 +487,7 @@ def static_scan_for_lambda_return(fn, inputs, start):
indices = reversed(indices)
flag = True
for index in indices:
# (inputs, pcont) -> (inputs[index], pcont[index])
inp = lambda x: (_input[x] for _input in inputs)
last = fn(last, *inp(index))
if flag:
@ -446,6 +496,7 @@ def static_scan_for_lambda_return(fn, inputs, start):
else:
outputs = torch.cat([outputs, last], dim=-1)
outputs = torch.reshape(outputs, [outputs.shape[0], outputs.shape[1], 1])
outputs = torch.flip(outputs, [1])
outputs = torch.unbind(outputs, dim=0)
return outputs
@ -687,14 +738,53 @@ def schedule(string, step):
def weight_init(m):
if isinstance(m, nn.Linear):
nn.init.orthogonal_(m.weight.data)
in_num = m.in_features
out_num = m.out_features
denoms = (in_num + out_num) / 2.0
scale = 1.0 / denoms
std = np.sqrt(scale) / 0.87962566103423978
nn.init.trunc_normal_(m.weight.data, mean=0.0, std=std, a=- 2.0, b=2.0)
if hasattr(m.bias, 'data'):
m.bias.data.fill_(0.0)
elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
gain = nn.init.calculate_gain('relu')
nn.init.orthogonal_(m.weight.data, gain)
space = m.kernel_size[0] * m.kernel_size[1]
in_num = space * m.in_channels
out_num = space * m.out_channels
denoms = (in_num + out_num) / 2.0
scale = 1.0 / denoms
std = np.sqrt(scale) / 0.87962566103423978
nn.init.trunc_normal_(m.weight.data, mean=0.0, std=std, a=- 2.0, b=2.0)
if hasattr(m.bias, 'data'):
m.bias.data.fill_(0.0)
elif isinstance(m, nn.LayerNorm):
m.weight.data.fill_(1.0)
if hasattr(m.bias, 'data'):
m.bias.data.fill_(0.0)
def uniform_weight_init(given_scale):
def f(m):
if isinstance(m, nn.Linear):
in_num = m.in_features
out_num = m.out_features
denoms = (in_num + out_num) / 2.0
scale = given_scale / denoms
limit = np.sqrt(3 * scale)
nn.init.uniform_(m.weight.data, a=-limit, b=limit)
if hasattr(m.bias, 'data'):
m.bias.data.fill_(0.0)
elif isinstance(m, nn.LayerNorm):
m.weight.data.fill_(1.0)
if hasattr(m.bias, 'data'):
m.bias.data.fill_(0.0)
return f
def tensorstats(tensor, prefix=None):
metrics = {
'mean': to_np(torch.mean(tensor)),
'std': to_np(torch.std(tensor)),
'min': to_np(torch.min(tensor)),
'max': to_np(torch.max(tensor)),
}
if prefix:
metrics = {f'{prefix}_{k}': v for k, v in metrics.items()}
return metrics