modified based on author's implementation
This commit is contained in:
parent
a678a509b9
commit
6273444394
@ -17,8 +17,17 @@ Monitor results:
|
||||
tensorboard --logdir $ABSOLUTEPATH_TO_SAVE_LOG
|
||||
```
|
||||
|
||||
## ToDo
|
||||
- [ ] Prototyping
|
||||
- [ ] Modify implementation details based on the author's implementation
|
||||
- [ ] Evaluate on visual DMC suite(~10 tasks)
|
||||
- [ ] Add other tasks and corresponding model sizes implementation
|
||||
- [ ] Continuous implementation improvement
|
||||
|
||||
|
||||
## Acknowledgments
|
||||
This code is heavily inspired by the following works:
|
||||
- danijar's Dreamer-v3 jax implementation: https://github.com/danijar/dreamerv3
|
||||
- danijar's Dreamer-v2 tensorflow implementation: https://github.com/danijar/dreamerv2
|
||||
- jsikyoon's Dreamer-v2 pytorch implementation: https://github.com/jsikyoon/dreamer-torch
|
||||
- RajGhugare19's Dreamer-v2 pytorch implementation: https://github.com/RajGhugare19/dreamerv2
|
||||
|
43
configs.yaml
43
configs.yaml
@ -8,9 +8,9 @@ defaults:
|
||||
seed: 0
|
||||
steps: 1e6
|
||||
eval_every: 1e4
|
||||
eval_episode_num: 10
|
||||
log_every: 1e4
|
||||
reset_every: 0
|
||||
#gpu_growth: True
|
||||
device: 'cuda:0'
|
||||
precision: 16
|
||||
debug: False
|
||||
@ -25,9 +25,6 @@ defaults:
|
||||
grayscale: False
|
||||
prefill: 2500
|
||||
eval_noise: 0.0
|
||||
reward_trans: 'symlog'
|
||||
obs_trans: 'normalize'
|
||||
critic_trans: 'symlog'
|
||||
reward_EMA: True
|
||||
|
||||
# Model
|
||||
@ -36,8 +33,8 @@ defaults:
|
||||
dyn_deter: 512
|
||||
dyn_stoch: 32
|
||||
dyn_discrete: 32
|
||||
dyn_input_layers: 2
|
||||
dyn_output_layers: 2
|
||||
dyn_input_layers: 1
|
||||
dyn_output_layers: 1
|
||||
dyn_rec_depth: 1
|
||||
dyn_shared: False
|
||||
dyn_mean_act: 'none'
|
||||
@ -53,11 +50,10 @@ defaults:
|
||||
act: 'SiLU'
|
||||
norm: 'LayerNorm'
|
||||
cnn_depth: 32
|
||||
encoder_kernels: [3, 3, 3, 3]
|
||||
decoder_kernels: [3, 3, 3, 3]
|
||||
# changed here
|
||||
value_head: 'twohot'
|
||||
reward_head: 'twohot'
|
||||
encoder_kernels: [4, 4, 4, 4]
|
||||
decoder_kernels: [4, 4, 4, 4]
|
||||
value_head: 'twohot_symlog'
|
||||
reward_head: 'twohot_symlog'
|
||||
kl_lscale: '0.1'
|
||||
kl_rscale: '0.5'
|
||||
kl_free: '1.0'
|
||||
@ -71,7 +67,7 @@ defaults:
|
||||
# Training
|
||||
batch_size: 16
|
||||
batch_length: 64
|
||||
train_every: 5
|
||||
train_ratio: 512
|
||||
train_steps: 1
|
||||
pretrain: 100
|
||||
model_lr: 1e-4
|
||||
@ -85,9 +81,8 @@ defaults:
|
||||
dataset_size: 0
|
||||
oversample_ends: False
|
||||
slow_value_target: True
|
||||
slow_actor_target: True
|
||||
slow_target_update: 100
|
||||
slow_target_fraction: 0.01
|
||||
slow_target_update: 1
|
||||
slow_target_fraction: 0.02
|
||||
opt: 'adam'
|
||||
|
||||
# Behavior.
|
||||
@ -95,16 +90,15 @@ defaults:
|
||||
discount_lambda: 0.95
|
||||
imag_horizon: 15
|
||||
imag_gradient: 'dynamics'
|
||||
imag_gradient_mix: '0.1'
|
||||
imag_gradient_mix: '0.0'
|
||||
imag_sample: True
|
||||
actor_dist: 'trunc_normal'
|
||||
actor_dist: 'normal'
|
||||
actor_entropy: '3e-4'
|
||||
actor_state_entropy: 0.0
|
||||
actor_init_std: 1.0
|
||||
actor_min_std: 0.1
|
||||
actor_disc: 5
|
||||
actor_max_std: 1.0
|
||||
actor_temp: 0.1
|
||||
actor_outscale: 0.0
|
||||
expl_amount: 0.0
|
||||
eval_state_mean: False
|
||||
collect_dyn_sample: True
|
||||
@ -134,3 +128,14 @@ debug:
|
||||
batch_size: 10
|
||||
batch_length: 20
|
||||
|
||||
cheetah:
|
||||
task: 'dmc_cheetah_run'
|
||||
|
||||
pendulum:
|
||||
task: 'dmc_pendulum_swingup'
|
||||
|
||||
cup:
|
||||
task: 'dmc_cup_catch'
|
||||
|
||||
acrobot:
|
||||
task: 'dmc_acrobot_swingup'
|
||||
|
48
dreamer.py
48
dreamer.py
@ -22,6 +22,7 @@ import torch
|
||||
from torch import nn
|
||||
from torch import distributions as torchd
|
||||
|
||||
|
||||
to_np = lambda x: x.detach().cpu().numpy()
|
||||
|
||||
|
||||
@ -31,7 +32,8 @@ class Dreamer(nn.Module):
|
||||
self._config = config
|
||||
self._logger = logger
|
||||
self._should_log = tools.Every(config.log_every)
|
||||
self._should_train = tools.Every(config.train_every)
|
||||
batch_steps = config.batch_size * config.batch_length
|
||||
self._should_train = tools.Every(batch_steps / config.train_ratio)
|
||||
self._should_pretrain = tools.Once()
|
||||
self._should_reset = tools.Every(config.reset_every)
|
||||
self._should_expl = tools.Until(int(config.expl_until / config.action_repeat))
|
||||
@ -146,16 +148,17 @@ class Dreamer(nn.Module):
|
||||
post, context, mets = self._wm._train(data)
|
||||
metrics.update(mets)
|
||||
start = post
|
||||
# start['deter'] (16, 64, 512)
|
||||
if self._config.pred_discount: # Last step could be terminal.
|
||||
start = {k: v[:, :-1] for k, v in post.items()}
|
||||
context = {k: v[:, :-1] for k, v in context.items()}
|
||||
start = {k: v[:-1] for k, v in post.items()}
|
||||
context = {k: v[:-1] for k, v in context.items()}
|
||||
reward = lambda f, s, a: self._wm.heads["reward"](
|
||||
self._wm.dynamics.get_feat(s)
|
||||
).mode()
|
||||
metrics.update(self._task_behavior._train(start, reward)[-1])
|
||||
if self._config.expl_behavior != "greedy":
|
||||
if self._config.pred_discount:
|
||||
data = {k: v[:, :-1] for k, v in data.items()}
|
||||
data = {k: v[:-1] for k, v in data.items()}
|
||||
mets = self._expl_behavior.train(start, context, data)[-1]
|
||||
metrics.update({"expl_" + key: value for key, value in mets.items()})
|
||||
for name, value in metrics.items():
|
||||
@ -205,7 +208,12 @@ def make_env(config, logger, mode, train_eps, eval_eps):
|
||||
if (mode == "train") or (mode == "eval"):
|
||||
callbacks = [
|
||||
functools.partial(
|
||||
process_episode, config, logger, mode, train_eps, eval_eps
|
||||
ProcessEpisodeWrap.process_episode,
|
||||
config,
|
||||
logger,
|
||||
mode,
|
||||
train_eps,
|
||||
eval_eps,
|
||||
)
|
||||
]
|
||||
env = wrappers.CollectDataset(env, callbacks)
|
||||
@ -213,15 +221,33 @@ def make_env(config, logger, mode, train_eps, eval_eps):
|
||||
return env
|
||||
|
||||
|
||||
def process_episode(config, logger, mode, train_eps, eval_eps, episode):
|
||||
class ProcessEpisodeWrap:
|
||||
eval_scores = []
|
||||
eval_lengths = []
|
||||
|
||||
@classmethod
|
||||
def process_episode(cls, config, logger, mode, train_eps, eval_eps, episode):
|
||||
directory = dict(train=config.traindir, eval=config.evaldir)[mode]
|
||||
cache = dict(train=train_eps, eval=eval_eps)[mode]
|
||||
# this saved episodes is given as train_eps or eval_eps from next call
|
||||
filename = tools.save_episodes(directory, [episode])[0]
|
||||
length = len(episode["reward"]) - 1
|
||||
score = float(episode["reward"].astype(np.float64).sum())
|
||||
video = episode["image"]
|
||||
cache[str(filename)] = episode
|
||||
if mode == "eval":
|
||||
cache.clear()
|
||||
cls.eval_scores.append(score)
|
||||
cls.eval_lengths.append(length)
|
||||
# save when enought number of episodes are stored
|
||||
if len(cls.eval_scores) < config.eval_episode_num:
|
||||
return
|
||||
else:
|
||||
score = sum(cls.eval_scores) / len(cls.eval_scores)
|
||||
length = sum(cls.eval_lengths) / len(cls.eval_lengths)
|
||||
episode_num = len(cls.eval_scores)
|
||||
cls.eval_scores = []
|
||||
cls.eval_lengths = []
|
||||
|
||||
if mode == "train" and config.dataset_size:
|
||||
total = 0
|
||||
for key, ep in reversed(sorted(cache.items(), key=lambda x: x[0])):
|
||||
@ -230,12 +256,14 @@ def process_episode(config, logger, mode, train_eps, eval_eps, episode):
|
||||
else:
|
||||
del cache[key]
|
||||
logger.scalar("dataset_size", total + length)
|
||||
cache[str(filename)] = episode
|
||||
print(f"{mode.title()} episode has {length} steps and return {score:.1f}.")
|
||||
logger.scalar(f"{mode}_return", score)
|
||||
logger.scalar(f"{mode}_length", length)
|
||||
logger.scalar(f"{mode}_episodes", len(cache))
|
||||
logger.scalar(
|
||||
f"{mode}_episodes", len(cache) if mode == "train" else episode_num
|
||||
)
|
||||
if mode == "eval" or config.expl_gifs:
|
||||
# only last video in eval videos is preservad
|
||||
logger.video(f"{mode}_policy", video[None])
|
||||
logger.write()
|
||||
|
||||
@ -315,7 +343,7 @@ def main(config):
|
||||
video_pred = agent._wm.video_pred(next(eval_dataset))
|
||||
logger.video("eval_openl", to_np(video_pred))
|
||||
eval_policy = functools.partial(agent, training=False)
|
||||
tools.simulate(eval_policy, eval_envs, episodes=1)
|
||||
tools.simulate(eval_policy, eval_envs, episodes=config.eval_episode_num)
|
||||
print("Start training.")
|
||||
state = tools.simulate(agent, train_envs, config.eval_every, state=state)
|
||||
torch.save(agent.state_dict(), logdir / "latest_model.pt")
|
||||
|
148
models.py
148
models.py
@ -10,30 +10,22 @@ import tools
|
||||
to_np = lambda x: x.detach().cpu().numpy()
|
||||
|
||||
|
||||
def symlog(x):
|
||||
return torch.sign(x) * torch.log(torch.abs(x) + 1.0)
|
||||
|
||||
|
||||
def symexp(x):
|
||||
return torch.sign(x) * (torch.exp(torch.abs(x)) - 1.0)
|
||||
|
||||
|
||||
class RewardEMA(object):
|
||||
"""running mean and std"""
|
||||
|
||||
def __init__(self, device, alpha=1e-2):
|
||||
self.device = device
|
||||
self.scale = torch.zeros((1,)).to(device)
|
||||
self.values = torch.zeros((2,)).to(device)
|
||||
self.alpha = alpha
|
||||
self.range = torch.tensor([0.05, 0.95]).to(device)
|
||||
|
||||
def __call__(self, x):
|
||||
flat_x = torch.flatten(x.detach())
|
||||
x_quantile = torch.quantile(input=flat_x, q=self.range)
|
||||
scale = x_quantile[1] - x_quantile[0]
|
||||
new_scale = self.alpha * scale + (1 - self.alpha) * self.scale
|
||||
self.scale = new_scale
|
||||
return x / torch.clip(self.scale, min=1.0)
|
||||
self.values = self.alpha * x_quantile + (1 - self.alpha) * self.values
|
||||
scale = torch.clip(self.values[1] - self.values[0], min=1.0)
|
||||
offset = self.values[0]
|
||||
return offset.detach(), scale.detach()
|
||||
|
||||
|
||||
class WorldModel(nn.Module):
|
||||
@ -93,7 +85,7 @@ class WorldModel(nn.Module):
|
||||
shape,
|
||||
config.decoder_kernels,
|
||||
)
|
||||
if config.reward_head == "twohot":
|
||||
if config.reward_head == "twohot_symlog":
|
||||
self.heads["reward"] = networks.DenseHead(
|
||||
feat_size, # pytorch version
|
||||
(255,),
|
||||
@ -102,6 +94,7 @@ class WorldModel(nn.Module):
|
||||
config.act,
|
||||
config.norm,
|
||||
dist=config.reward_head,
|
||||
outscale=0.0,
|
||||
)
|
||||
else:
|
||||
self.heads["reward"] = networks.DenseHead(
|
||||
@ -112,9 +105,8 @@ class WorldModel(nn.Module):
|
||||
config.act,
|
||||
config.norm,
|
||||
dist=config.reward_head,
|
||||
outscale=0.0,
|
||||
)
|
||||
# added this
|
||||
self.heads["reward"].apply(tools.weight_init)
|
||||
if config.pred_discount:
|
||||
self.heads["discount"] = networks.DenseHead(
|
||||
feat_size, # pytorch version
|
||||
@ -163,8 +155,6 @@ class WorldModel(nn.Module):
|
||||
feat = self.dynamics.get_feat(post)
|
||||
feat = feat if grad_head else feat.detach()
|
||||
pred = head(feat)
|
||||
# if name == 'image':
|
||||
# losses[name] = torch.nn.functional.mse_loss(pred.mode(), data[name], 'sum')
|
||||
like = pred.log_prob(data[name])
|
||||
likes[name] = like
|
||||
losses[name] = -torch.mean(like) * self._scales.get(name, 1.0)
|
||||
@ -196,24 +186,9 @@ class WorldModel(nn.Module):
|
||||
|
||||
def preprocess(self, obs):
|
||||
obs = obs.copy()
|
||||
if self._config.obs_trans == "normalize":
|
||||
obs["image"] = torch.Tensor(obs["image"]) / 255.0 - 0.5
|
||||
elif self._config.obs_trans == "identity":
|
||||
obs["image"] = torch.Tensor(obs["image"])
|
||||
elif self._config.obs_trans == "symlog":
|
||||
obs["image"] = symlog(torch.Tensor(obs["image"]))
|
||||
else:
|
||||
raise NotImplemented(f"{self._config.reward_trans} is not implemented")
|
||||
if self._config.reward_trans == "tanh":
|
||||
# (batch_size, batch_length) -> (batch_size, batch_length, 1)
|
||||
obs["reward"] = torch.tanh(torch.Tensor(obs["reward"])).unsqueeze(-1)
|
||||
elif self._config.reward_trans == "identity":
|
||||
# (batch_size, batch_length) -> (batch_size, batch_length, 1)
|
||||
obs["reward"] = torch.Tensor(obs["reward"]).unsqueeze(-1)
|
||||
elif self._config.reward_trans == "symlog":
|
||||
obs["reward"] = symlog(torch.Tensor(obs["reward"])).unsqueeze(-1)
|
||||
else:
|
||||
raise NotImplemented(f"{self._config.reward_trans} is not implemented")
|
||||
if "discount" in obs:
|
||||
obs["discount"] *= self._config.discount
|
||||
# (batch_size, batch_length) -> (batch_size, batch_length, 1)
|
||||
@ -234,13 +209,9 @@ class WorldModel(nn.Module):
|
||||
reward_prior = self.heads["reward"](self.dynamics.get_feat(prior)).mode()
|
||||
# observed image is given until 5 steps
|
||||
model = torch.cat([recon[:, :5], openl], 1)
|
||||
if self._config.obs_trans == "normalize":
|
||||
truth = data["image"][:6] + 0.5
|
||||
model += 0.5
|
||||
elif self._config.obs_trans == "symlog":
|
||||
truth = symexp(data["image"][:6]) / 255.0
|
||||
model = symexp(model) / 255.0
|
||||
error = (model - truth + 1) / 2
|
||||
model = model + 0.5
|
||||
error = (model - truth + 1.0) / 2.0
|
||||
|
||||
return torch.cat([truth, model, error], 2)
|
||||
|
||||
@ -267,11 +238,11 @@ class ImagBehavior(nn.Module):
|
||||
config.actor_dist,
|
||||
config.actor_init_std,
|
||||
config.actor_min_std,
|
||||
config.actor_dist,
|
||||
config.actor_max_std,
|
||||
config.actor_temp,
|
||||
config.actor_outscale,
|
||||
outscale=1.0,
|
||||
) # action_dist -> action_disc?
|
||||
if config.value_head == "twohot":
|
||||
if config.value_head == "twohot_symlog":
|
||||
self.value = networks.DenseHead(
|
||||
feat_size, # pytorch version
|
||||
(255,),
|
||||
@ -280,6 +251,7 @@ class ImagBehavior(nn.Module):
|
||||
config.act,
|
||||
config.norm,
|
||||
config.value_head,
|
||||
outscale=0.0,
|
||||
)
|
||||
else:
|
||||
self.value = networks.DenseHead(
|
||||
@ -290,9 +262,9 @@ class ImagBehavior(nn.Module):
|
||||
config.act,
|
||||
config.norm,
|
||||
config.value_head,
|
||||
outscale=0.0,
|
||||
)
|
||||
self.value.apply(tools.weight_init)
|
||||
if config.slow_value_target or config.slow_actor_target:
|
||||
if config.slow_value_target:
|
||||
self._slow_value = copy.deepcopy(self.value)
|
||||
self._updates = 0
|
||||
kw = dict(wd=config.weight_decay, opt=config.opt, use_amp=self._use_amp)
|
||||
@ -335,21 +307,12 @@ class ImagBehavior(nn.Module):
|
||||
start, self.actor, self._config.imag_horizon, repeats
|
||||
)
|
||||
reward = objective(imag_feat, imag_state, imag_action)
|
||||
if self._config.reward_trans == "symlog":
|
||||
# rescale predicted reward by head['reward']
|
||||
reward = symexp(reward)
|
||||
actor_ent = self.actor(imag_feat).entropy()
|
||||
state_ent = self._world_model.dynamics.get_dist(imag_state).entropy()
|
||||
# this target is not scaled
|
||||
# slow is flag to indicate whether slow_target is used for lambda-return
|
||||
target, weights = self._compute_target(
|
||||
imag_feat,
|
||||
imag_state,
|
||||
imag_action,
|
||||
reward,
|
||||
actor_ent,
|
||||
state_ent,
|
||||
self._config.slow_actor_target,
|
||||
target, weights, base = self._compute_target(
|
||||
imag_feat, imag_state, imag_action, reward, actor_ent, state_ent
|
||||
)
|
||||
actor_loss, mets = self._compute_actor_loss(
|
||||
imag_feat,
|
||||
@ -359,42 +322,31 @@ class ImagBehavior(nn.Module):
|
||||
actor_ent,
|
||||
state_ent,
|
||||
weights,
|
||||
base,
|
||||
)
|
||||
metrics.update(mets)
|
||||
if self._config.slow_value_target != self._config.slow_actor_target:
|
||||
target, weights = self._compute_target(
|
||||
imag_feat,
|
||||
imag_state,
|
||||
imag_action,
|
||||
reward,
|
||||
actor_ent,
|
||||
state_ent,
|
||||
self._config.slow_value_target,
|
||||
)
|
||||
value_input = imag_feat
|
||||
|
||||
with tools.RequiresGrad(self.value):
|
||||
with torch.cuda.amp.autocast(self._use_amp):
|
||||
value = self.value(value_input[:-1].detach())
|
||||
target = torch.stack(target, dim=1)
|
||||
# only critic target is processed using symlog(not actor)
|
||||
if self._config.critic_trans == "symlog":
|
||||
metrics["unscaled_target_mean"] = to_np(torch.mean(target))
|
||||
target = symlog(target)
|
||||
# (time, batch, 1), (time, batch, 1) -> (time, batch)
|
||||
value_loss = -value.log_prob(target.detach())
|
||||
slow_target = self._slow_value(value_input[:-1].detach())
|
||||
if self._config.slow_value_target:
|
||||
value_loss = value_loss - value.log_prob(
|
||||
slow_target.mode().detach()
|
||||
)
|
||||
if self._config.value_decay:
|
||||
value_loss += self._config.value_decay * value.mode()
|
||||
# (time, batch, 1), (time, batch, 1) -> (1,)
|
||||
value_loss = torch.mean(weights[:-1] * value_loss[:, :, None])
|
||||
|
||||
metrics["value_mean"] = to_np(torch.mean(value.mode()))
|
||||
metrics["value_max"] = to_np(torch.max(value.mode()))
|
||||
metrics["value_min"] = to_np(torch.min(value.mode()))
|
||||
metrics["value_std"] = to_np(torch.std(value.mode()))
|
||||
metrics["target_mean"] = to_np(torch.mean(target))
|
||||
metrics["reward_mean"] = to_np(torch.mean(reward))
|
||||
metrics["reward_std"] = to_np(torch.std(reward))
|
||||
metrics.update(tools.tensorstats(value.mode(), "value"))
|
||||
metrics.update(tools.tensorstats(target, "target"))
|
||||
metrics.update(tools.tensorstats(reward, "imag_reward"))
|
||||
metrics.update(tools.tensorstats(imag_action, "imag_action"))
|
||||
metrics["actor_ent"] = to_np(torch.mean(actor_ent))
|
||||
with tools.RequiresGrad(self):
|
||||
metrics.update(self._actor_opt(actor_loss, self.actor.parameters()))
|
||||
@ -402,6 +354,11 @@ class ImagBehavior(nn.Module):
|
||||
return imag_feat, imag_state, imag_action, weights, metrics
|
||||
|
||||
def _imagine(self, start, policy, horizon, repeats=None):
|
||||
# horizon: 15
|
||||
# start = dict(stoch, deter, logit)
|
||||
# start["stoch"] (16, 63, 32, 32)
|
||||
# start["deter"] (16, 63, 512)
|
||||
# start["logit"] (16, 63, 32, 32)
|
||||
dynamics = self._world_model.dynamics
|
||||
if repeats:
|
||||
raise NotImplemented("repeats is not implemented in this version")
|
||||
@ -418,6 +375,8 @@ class ImagBehavior(nn.Module):
|
||||
|
||||
feat = 0 * dynamics.get_feat(start)
|
||||
action = policy(feat).mode()
|
||||
# Is this action deterministic or stochastic?
|
||||
# action = policy(feat).sample()
|
||||
succ, feats, actions = tools.static_scan(
|
||||
step, [torch.arange(horizon)], (start, feat, action)
|
||||
)
|
||||
@ -428,7 +387,7 @@ class ImagBehavior(nn.Module):
|
||||
return feats, states, actions
|
||||
|
||||
def _compute_target(
|
||||
self, imag_feat, imag_state, imag_action, reward, actor_ent, state_ent, slow
|
||||
self, imag_feat, imag_state, imag_action, reward, actor_ent, state_ent
|
||||
):
|
||||
if "discount" in self._world_model.heads:
|
||||
inp = self._world_model.dynamics.get_feat(imag_state)
|
||||
@ -439,13 +398,10 @@ class ImagBehavior(nn.Module):
|
||||
reward += self._config.actor_entropy() * actor_ent
|
||||
if self._config.future_entropy and self._config.actor_state_entropy() > 0:
|
||||
reward += self._config.actor_state_entropy() * state_ent
|
||||
if slow:
|
||||
value = self._slow_value(imag_feat).mode()
|
||||
else:
|
||||
value = self.value(imag_feat).mode()
|
||||
if self._config.critic_trans == "symlog":
|
||||
# After adding this line there is issue
|
||||
value = symexp(value)
|
||||
# value(15, 960, ch)
|
||||
# action(15, 960, ch)
|
||||
# discount(15, 960, ch)
|
||||
target = tools.lambda_return(
|
||||
reward[:-1],
|
||||
value[:-1],
|
||||
@ -457,10 +413,18 @@ class ImagBehavior(nn.Module):
|
||||
weights = torch.cumprod(
|
||||
torch.cat([torch.ones_like(discount[:1]), discount[:-1]], 0), 0
|
||||
).detach()
|
||||
return target, weights
|
||||
return target, weights, value[:-1]
|
||||
|
||||
def _compute_actor_loss(
|
||||
self, imag_feat, imag_state, imag_action, target, actor_ent, state_ent, weights
|
||||
self,
|
||||
imag_feat,
|
||||
imag_state,
|
||||
imag_action,
|
||||
target,
|
||||
actor_ent,
|
||||
state_ent,
|
||||
weights,
|
||||
base,
|
||||
):
|
||||
metrics = {}
|
||||
inp = imag_feat.detach() if self._stop_grad_actor else imag_feat
|
||||
@ -469,11 +433,17 @@ class ImagBehavior(nn.Module):
|
||||
# Q-val for actor is not transformed using symlog
|
||||
target = torch.stack(target, dim=1)
|
||||
if self._config.reward_EMA:
|
||||
target = self.reward_ema(target)
|
||||
metrics["EMA_scale"] = to_np(self.reward_ema.scale)
|
||||
offset, scale = self.reward_ema(target)
|
||||
normed_target = (target - offset) / scale
|
||||
normed_base = (base - offset) / scale
|
||||
adv = normed_target - normed_base
|
||||
metrics.update(tools.tensorstats(normed_target, "normed_target"))
|
||||
values = self.reward_ema.values
|
||||
metrics["EMA_005"] = to_np(values[0])
|
||||
metrics["EMA_095"] = to_np(values[1])
|
||||
|
||||
if self._config.imag_gradient == "dynamics":
|
||||
actor_target = target
|
||||
actor_target = adv
|
||||
elif self._config.imag_gradient == "reinforce":
|
||||
actor_target = (
|
||||
policy.log_prob(imag_action)[:-1][:, :, None]
|
||||
@ -501,7 +471,7 @@ class ImagBehavior(nn.Module):
|
||||
return actor_loss, metrics
|
||||
|
||||
def _update_slow_target(self):
|
||||
if self._config.slow_value_target or self._config.slow_actor_target:
|
||||
if self._config.slow_value_target:
|
||||
if self._updates % self._config.slow_target_update == 0:
|
||||
mix = self._config.slow_target_fraction
|
||||
for s, d in zip(self.value.parameters(), self._slow_value.parameters()):
|
||||
|
146
networks.py
146
networks.py
@ -59,29 +59,33 @@ class RSSM(nn.Module):
|
||||
if self._shared:
|
||||
inp_dim += self._embed
|
||||
for i in range(self._layers_input):
|
||||
inp_layers.append(nn.Linear(inp_dim, self._hidden))
|
||||
inp_layers.append(self._norm(self._hidden))
|
||||
inp_layers.append(nn.Linear(inp_dim, self._hidden, bias=False))
|
||||
inp_layers.append(self._norm(self._hidden, eps=1e-03))
|
||||
inp_layers.append(self._act())
|
||||
if i == 0:
|
||||
inp_dim = self._hidden
|
||||
self._inp_layers = nn.Sequential(*inp_layers)
|
||||
self._inp_layers.apply(tools.weight_init)
|
||||
|
||||
if cell == "gru":
|
||||
self._cell = GRUCell(self._hidden, self._deter)
|
||||
self._cell.apply(tools.weight_init)
|
||||
elif cell == "gru_layer_norm":
|
||||
self._cell = GRUCell(self._hidden, self._deter, norm=True)
|
||||
self._cell.apply(tools.weight_init)
|
||||
else:
|
||||
raise NotImplementedError(cell)
|
||||
|
||||
img_out_layers = []
|
||||
inp_dim = self._deter
|
||||
for i in range(self._layers_output):
|
||||
img_out_layers.append(nn.Linear(inp_dim, self._hidden))
|
||||
img_out_layers.append(self._norm(self._hidden))
|
||||
img_out_layers.append(nn.Linear(inp_dim, self._hidden, bias=False))
|
||||
img_out_layers.append(self._norm(self._hidden, eps=1e-03))
|
||||
img_out_layers.append(self._act())
|
||||
if i == 0:
|
||||
inp_dim = self._hidden
|
||||
self._img_out_layers = nn.Sequential(*img_out_layers)
|
||||
self._img_out_layers.apply(tools.weight_init)
|
||||
|
||||
obs_out_layers = []
|
||||
if self._temp_post:
|
||||
@ -89,19 +93,24 @@ class RSSM(nn.Module):
|
||||
else:
|
||||
inp_dim = self._embed
|
||||
for i in range(self._layers_output):
|
||||
obs_out_layers.append(nn.Linear(inp_dim, self._hidden))
|
||||
obs_out_layers.append(self._norm(self._hidden))
|
||||
obs_out_layers.append(nn.Linear(inp_dim, self._hidden, bias=False))
|
||||
obs_out_layers.append(self._norm(self._hidden, eps=1e-03))
|
||||
obs_out_layers.append(self._act())
|
||||
if i == 0:
|
||||
inp_dim = self._hidden
|
||||
self._obs_out_layers = nn.Sequential(*obs_out_layers)
|
||||
self._obs_out_layers.apply(tools.weight_init)
|
||||
|
||||
if self._discrete:
|
||||
self._ims_stat_layer = nn.Linear(self._hidden, self._stoch * self._discrete)
|
||||
self._ims_stat_layer.apply(tools.weight_init)
|
||||
self._obs_stat_layer = nn.Linear(self._hidden, self._stoch * self._discrete)
|
||||
self._obs_stat_layer.apply(tools.weight_init)
|
||||
else:
|
||||
self._ims_stat_layer = nn.Linear(self._hidden, 2 * self._stoch)
|
||||
self._ims_stat_layer.apply(tools.weight_init)
|
||||
self._obs_stat_layer = nn.Linear(self._hidden, 2 * self._stoch)
|
||||
self._obs_stat_layer.apply(tools.weight_init)
|
||||
|
||||
def initial(self, batch_size):
|
||||
deter = torch.zeros(batch_size, self._deter).to(self._device)
|
||||
@ -178,6 +187,7 @@ class RSSM(nn.Module):
|
||||
def obs_step(self, prev_state, prev_action, embed, sample=True):
|
||||
# if shared is True, prior and post both use same networks(inp_layers, _img_out_layers, _ims_stat_layer)
|
||||
# otherwise, post use different network(_obs_out_layers) with prior[deter] and embed as inputs
|
||||
prev_action *= (1.0 / torch.clip(torch.abs(prev_action), min=1.0)).detach()
|
||||
prior = self.img_step(prev_state, prev_action, None, sample)
|
||||
if self._shared:
|
||||
post = self.img_step(prev_state, prev_action, embed, sample)
|
||||
@ -200,6 +210,7 @@ class RSSM(nn.Module):
|
||||
# this is used for making future image
|
||||
def img_step(self, prev_state, prev_action, embed=None, sample=True):
|
||||
# (batch, stoch, discrete_num)
|
||||
prev_action *= (1.0 / torch.clip(torch.abs(prev_action), min=1.0)).detach()
|
||||
prev_stoch = prev_state["stoch"]
|
||||
if self._discrete:
|
||||
shape = list(prev_stoch.shape[:-2]) + [self._stoch * self._discrete]
|
||||
@ -317,12 +328,15 @@ class ConvEncoder(nn.Module):
|
||||
out_channels=depth,
|
||||
kernel_size=(kernel, kernel),
|
||||
stride=(2, 2),
|
||||
bias=False,
|
||||
)
|
||||
)
|
||||
h, w = h // 2, w // 2
|
||||
# layers.append(norm([depth, h, w]))
|
||||
layers.append(ChLayerNorm(depth))
|
||||
layers.append(act())
|
||||
h, w = h // 2, w // 2
|
||||
|
||||
self.layers = nn.Sequential(*layers)
|
||||
self.layers.apply(tools.weight_init)
|
||||
|
||||
def __call__(self, obs):
|
||||
x = obs["image"].reshape((-1,) + tuple(obs["image"].shape[-3:]))
|
||||
@ -343,6 +357,7 @@ class ConvDecoder(nn.Module):
|
||||
norm=nn.LayerNorm,
|
||||
shape=(3, 64, 64),
|
||||
kernels=(3, 3, 3, 3),
|
||||
outscale=1.0,
|
||||
):
|
||||
super(ConvDecoder, self).__init__()
|
||||
self._inp_depth = inp_depth
|
||||
@ -358,19 +373,25 @@ class ConvDecoder(nn.Module):
|
||||
self._linear_layer = nn.Linear(inp_depth, self._embed_size)
|
||||
inp_dim = self._embed_size // 16
|
||||
|
||||
cnnt_layers = []
|
||||
layers = []
|
||||
h, w = 4, 4
|
||||
for i, kernel in enumerate(self._kernels):
|
||||
depth = self._embed_size // 16 // (2 ** (i + 1))
|
||||
act = self._act
|
||||
bias = False
|
||||
initializer = tools.weight_init
|
||||
if i == len(self._kernels) - 1:
|
||||
depth = self._shape[0]
|
||||
act = None
|
||||
act = False
|
||||
bias = True
|
||||
norm = False
|
||||
initializer = tools.uniform_weight_init(outscale)
|
||||
|
||||
if i != 0:
|
||||
inp_dim = 2 ** (len(self._kernels) - (i - 1) - 2) * self._depth
|
||||
pad_h, outpad_h = calc_same_pad(k=kernel, s=2, d=1)
|
||||
pad_w, outpad_w = calc_same_pad(k=kernel, s=2, d=1)
|
||||
cnnt_layers.append(
|
||||
pad_h, outpad_h = self.calc_same_pad(k=kernel, s=2, d=1)
|
||||
pad_w, outpad_w = self.calc_same_pad(k=kernel, s=2, d=1)
|
||||
layers.append(
|
||||
nn.ConvTranspose2d(
|
||||
inp_dim,
|
||||
depth,
|
||||
@ -378,26 +399,32 @@ class ConvDecoder(nn.Module):
|
||||
2,
|
||||
padding=(pad_h, pad_w),
|
||||
output_padding=(outpad_h, outpad_w),
|
||||
bias=bias,
|
||||
)
|
||||
)
|
||||
if norm:
|
||||
layers.append(ChLayerNorm(depth))
|
||||
if act:
|
||||
layers.append(act())
|
||||
[m.apply(initializer) for m in layers[-3:]]
|
||||
h, w = h * 2, w * 2
|
||||
# cnnt_layers.append(norm([depth, h, w]))
|
||||
if act is not None:
|
||||
cnnt_layers.append(act())
|
||||
self._cnnt_layers = nn.Sequential(*cnnt_layers)
|
||||
|
||||
self.layers = nn.Sequential(*layers)
|
||||
|
||||
def calc_same_pad(self, k, s, d):
|
||||
val = d * (k - 1) - s + 1
|
||||
pad = math.ceil(val / 2)
|
||||
outpad = pad * 2 - val
|
||||
return pad, outpad
|
||||
|
||||
def __call__(self, features, dtype=None):
|
||||
x = self._linear_layer(features)
|
||||
x = x.reshape([-1, 4, 4, self._embed_size // 16])
|
||||
x = x.permute(0, 3, 1, 2)
|
||||
x = self._cnnt_layers(x)
|
||||
x = self.layers(x)
|
||||
mean = x.reshape(features.shape[:-1] + self._shape)
|
||||
mean = mean.permute(0, 1, 3, 4, 2)
|
||||
return tools.ContDist(
|
||||
torchd.independent.Independent(
|
||||
torchd.normal.Normal(mean, 1), len(self._shape)
|
||||
)
|
||||
)
|
||||
return tools.SymlogDist(mean)
|
||||
|
||||
|
||||
class DenseHead(nn.Module):
|
||||
@ -411,7 +438,7 @@ class DenseHead(nn.Module):
|
||||
norm=nn.LayerNorm,
|
||||
dist="normal",
|
||||
std=1.0,
|
||||
unimix_ratio=0.0,
|
||||
outscale=1.0,
|
||||
):
|
||||
super(DenseHead, self).__init__()
|
||||
self._shape = (shape,) if isinstance(shape, int) else shape
|
||||
@ -423,27 +450,30 @@ class DenseHead(nn.Module):
|
||||
self._norm = norm
|
||||
self._dist = dist
|
||||
self._std = std
|
||||
self._unimix_ratio = unimix_ratio
|
||||
|
||||
mean_layers = []
|
||||
layers = []
|
||||
for index in range(self._layers):
|
||||
mean_layers.append(nn.Linear(inp_dim, self._units))
|
||||
mean_layers.append(norm(self._units))
|
||||
mean_layers.append(act())
|
||||
layers.append(nn.Linear(inp_dim, self._units, bias=False))
|
||||
layers.append(norm(self._units, eps=1e-03))
|
||||
layers.append(act())
|
||||
if index == 0:
|
||||
inp_dim = self._units
|
||||
mean_layers.append(nn.Linear(inp_dim, np.prod(self._shape)))
|
||||
self._mean_layers = nn.Sequential(*mean_layers)
|
||||
self.layers = nn.Sequential(*layers)
|
||||
self.layers.apply(tools.weight_init)
|
||||
|
||||
self.mean_layer = nn.Linear(inp_dim, np.prod(self._shape))
|
||||
self.mean_layer.apply(tools.uniform_weight_init(outscale))
|
||||
|
||||
if self._std == "learned":
|
||||
self._std_layer = nn.Linear(self._units, np.prod(self._shape))
|
||||
self.std_layer = nn.Linear(self._units, np.prod(self._shape))
|
||||
self.std_layer.apply(tools.uniform_weight_init(outscale))
|
||||
|
||||
def __call__(self, features, dtype=None):
|
||||
x = features
|
||||
mean = self._mean_layers(x)
|
||||
out = self.layers(x)
|
||||
mean = self.mean_layer(out)
|
||||
if self._std == "learned":
|
||||
std = self._std_layer(x)
|
||||
std = torch.softplus(std) + 0.01
|
||||
std = self.std_layer(out)
|
||||
else:
|
||||
std = self._std
|
||||
if self._dist == "normal":
|
||||
@ -464,8 +494,8 @@ class DenseHead(nn.Module):
|
||||
torchd.bernoulli.Bernoulli(logits=mean), len(self._shape)
|
||||
)
|
||||
)
|
||||
if self._dist == "twohot":
|
||||
return tools.TwoHotDist(logits=mean, unimix_ratio=self._unimix_ratio)
|
||||
if self._dist == "twohot_symlog":
|
||||
return tools.TwoHotDistSymlog(logits=mean)
|
||||
raise NotImplementedError(self._dist)
|
||||
|
||||
|
||||
@ -481,9 +511,9 @@ class ActionHead(nn.Module):
|
||||
dist="trunc_normal",
|
||||
init_std=0.0,
|
||||
min_std=0.1,
|
||||
action_disc=5,
|
||||
max_std=1.0,
|
||||
temp=0.1,
|
||||
outscale=0,
|
||||
outscale=1.0,
|
||||
):
|
||||
super(ActionHead, self).__init__()
|
||||
self._size = size
|
||||
@ -493,24 +523,27 @@ class ActionHead(nn.Module):
|
||||
self._act = act
|
||||
self._norm = norm
|
||||
self._min_std = min_std
|
||||
self._max_std = max_std
|
||||
self._init_std = init_std
|
||||
self._action_disc = action_disc
|
||||
self._temp = temp() if callable(temp) else temp
|
||||
self._outscale = outscale
|
||||
|
||||
pre_layers = []
|
||||
for index in range(self._layers):
|
||||
pre_layers.append(nn.Linear(inp_dim, self._units))
|
||||
pre_layers.append(norm(self._units))
|
||||
pre_layers.append(nn.Linear(inp_dim, self._units, bias=False))
|
||||
pre_layers.append(norm(self._units, eps=1e-03))
|
||||
pre_layers.append(act())
|
||||
if index == 0:
|
||||
inp_dim = self._units
|
||||
self._pre_layers = nn.Sequential(*pre_layers)
|
||||
self._pre_layers.apply(tools.weight_init)
|
||||
|
||||
if self._dist in ["tanh_normal", "tanh_normal_5", "normal", "trunc_normal"]:
|
||||
self._dist_layer = nn.Linear(self._units, 2 * self._size)
|
||||
self._dist_layer.apply(tools.uniform_weight_init(outscale))
|
||||
|
||||
elif self._dist in ["normal_1", "onehot", "onehot_gumbel"]:
|
||||
self._dist_layer = nn.Linear(self._units, self._size)
|
||||
self._dist_layer.apply(tools.uniform_weight_init(outscale))
|
||||
|
||||
def __call__(self, features, dtype=None):
|
||||
x = features
|
||||
@ -539,9 +572,11 @@ class ActionHead(nn.Module):
|
||||
dist = tools.SampleDist(dist)
|
||||
elif self._dist == "normal":
|
||||
x = self._dist_layer(x)
|
||||
mean, std = torch.split(x, 2, -1)
|
||||
std = F.softplus(std + self._init_std) + self._min_std
|
||||
dist = torchd.normal.Normal(mean, std)
|
||||
mean, std = torch.split(x, [self._size] * 2, -1)
|
||||
std = (self._max_std - self._min_std) * torch.sigmoid(
|
||||
std + 2.0
|
||||
) + self._min_std
|
||||
dist = torchd.normal.Normal(torch.tanh(mean), std)
|
||||
dist = tools.ContDist(torchd.independent.Independent(dist, 1))
|
||||
elif self._dist == "normal_1":
|
||||
x = self._dist_layer(x)
|
||||
@ -574,9 +609,9 @@ class GRUCell(nn.Module):
|
||||
self._act = act
|
||||
self._norm = norm
|
||||
self._update_bias = update_bias
|
||||
self._layer = nn.Linear(inp_size + size, 3 * size, bias=norm is not None)
|
||||
self._layer = nn.Linear(inp_size + size, 3 * size, bias=False)
|
||||
if norm:
|
||||
self._norm = nn.LayerNorm(3 * size)
|
||||
self._norm = nn.LayerNorm(3 * size, eps=1e-03)
|
||||
|
||||
@property
|
||||
def state_size(self):
|
||||
@ -625,8 +660,13 @@ class Conv2dSame(torch.nn.Conv2d):
|
||||
return ret
|
||||
|
||||
|
||||
def calc_same_pad(k, s, d):
|
||||
val = d * (k - 1) - s + 1
|
||||
pad = math.ceil(val / 2)
|
||||
outpad = pad * 2 - val
|
||||
return pad, outpad
|
||||
class ChLayerNorm(nn.Module):
|
||||
def __init__(self, ch, eps=1e-03):
|
||||
super(ChLayerNorm, self).__init__()
|
||||
self.norm = torch.nn.LayerNorm(ch, eps=eps)
|
||||
|
||||
def forward(self, x):
|
||||
x = x.permute(0, 2, 3, 1)
|
||||
x = self.norm(x)
|
||||
x = x.permute(0, 3, 1, 2)
|
||||
return x
|
||||
|
152
tools.py
152
tools.py
@ -17,6 +17,14 @@ from torch.utils.data import Dataset
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
|
||||
to_np = lambda x: x.detach().cpu().numpy()
|
||||
|
||||
def symlog(x):
|
||||
return torch.sign(x) * torch.log(torch.abs(x) + 1.0)
|
||||
|
||||
def symexp(x):
|
||||
return torch.sign(x) * (torch.exp(torch.abs(x)) - 1.0)
|
||||
|
||||
class RequiresGrad:
|
||||
|
||||
def __init__(self, model):
|
||||
@ -269,10 +277,12 @@ class SampleDist:
|
||||
class OneHotDist(torchd.one_hot_categorical.OneHotCategorical):
|
||||
|
||||
def __init__(self, logits=None, probs=None, unimix_ratio=0.0):
|
||||
if logits is not None and probs is None and unimix_ratio > 0.0:
|
||||
if logits is not None and unimix_ratio > 0.0:
|
||||
probs = F.softmax(logits, dim=-1)
|
||||
probs = probs * (1.0-unimix_ratio) + unimix_ratio / probs.shape[-1]
|
||||
logits = None
|
||||
logits = torch.log(probs)
|
||||
super().__init__(logits=logits, probs=None)
|
||||
else:
|
||||
super().__init__(logits=logits, probs=probs)
|
||||
|
||||
def mode(self):
|
||||
@ -290,42 +300,81 @@ class OneHotDist(torchd.one_hot_categorical.OneHotCategorical):
|
||||
return sample
|
||||
|
||||
|
||||
class TwoHotDist(torchd.one_hot_categorical.OneHotCategorical):
|
||||
class TwoHotDistSymlog():
|
||||
|
||||
def __init__(self, logits=None, probs=None, unimix_ratio=0.0, device='cuda'):
|
||||
if logits is not None and probs is None and unimix_ratio > 0.0:
|
||||
probs = F.softmax(logits, dim=-1)
|
||||
probs = probs * (1.0-unimix_ratio) + unimix_ratio / probs.shape[-1]
|
||||
logits = None
|
||||
super().__init__(logits=logits, probs=probs)
|
||||
|
||||
self.buckets = torch.linspace(-20.0, 20.0, steps=255).to(device)
|
||||
def __init__(self, logits=None, low=-20.0, high=20.0, device='cuda'):
|
||||
self.logits = logits
|
||||
self.probs = torch.softmax(logits, -1)
|
||||
self.buckets = torch.linspace(low, high, steps=255).to(device)
|
||||
self.width = (self.buckets[-1] - self.buckets[0]) / 255
|
||||
|
||||
def mean(self):
|
||||
print("mean called")
|
||||
_mode = self.probs * self.buckets
|
||||
return symexp(torch.sum(_mode, dim=-1, keepdim=True))
|
||||
|
||||
def mode(self):
|
||||
_mode = super().probs * self.buckets
|
||||
return torch.sum(_mode, dim=-1, keepdim=True)
|
||||
_mode = self.probs * self.buckets
|
||||
return symexp(torch.sum(_mode, dim=-1, keepdim=True))
|
||||
|
||||
# Inside OneHotCategorical, log_prob is calculated using only max element in targets
|
||||
def log_prob(self, x):
|
||||
x = symlog(x)
|
||||
# x(time, batch, 1)
|
||||
x = (x - self.buckets[0]) / self.width
|
||||
lower_indices = (x).to(torch.int64)
|
||||
# lower_indices is idnside 0 ~ len(buckets)-2
|
||||
lower_indices = torch.clip(lower_indices, max=len(self.buckets)-2)
|
||||
# upper_indices is inside 1 ~ len(buckets)-1
|
||||
upper_indices = lower_indices + 1
|
||||
lower_weight = torch.abs(x - upper_indices).squeeze(-1)
|
||||
upper_weight = torch.abs(x - lower_indices).squeeze(-1)
|
||||
# (time, batch, 1) -> (time, batch, bucket_class)
|
||||
lower_log_prob = super().log_prob(F.one_hot(lower_indices.squeeze(-1), num_classes=len(self.buckets)))
|
||||
upper_log_prob = super().log_prob(F.one_hot(upper_indices.squeeze(-1), num_classes=len(self.buckets)))
|
||||
below = torch.sum((self.buckets <= x[..., None]).to(torch.int32), dim=-1) -1
|
||||
above = len(self.buckets) - torch.sum((self.buckets > x[..., None]).to(torch.int32), dim=-1)
|
||||
below = torch.clip(below, 0, len(self.buckets)-1)
|
||||
above = torch.clip(above, 0, len(self.buckets)-1)
|
||||
equal = (below == above)
|
||||
|
||||
# label = lower_log_prob * lower_weight + upper_log_prob * upper_weight
|
||||
# # (time, batch, bucket_class) -> (time, batch)
|
||||
# cross_entropy = torch.sum(torch.log(super().probs) * label, axis=-1)
|
||||
dist_to_below = torch.where(equal, 1, torch.abs(self.buckets[below] - x))
|
||||
dist_to_above = torch.where(equal, 1, torch.abs(self.buckets[above] - x))
|
||||
total = dist_to_below + dist_to_above
|
||||
weight_below = dist_to_above / total
|
||||
weight_above = dist_to_below / total
|
||||
target = (
|
||||
F.one_hot(below, num_classes=len(self.buckets)) * weight_below[..., None] +
|
||||
F.one_hot(above, num_classes=len(self.buckets)) * weight_above[..., None])
|
||||
log_pred = self.logits - torch.logsumexp(self.logits, -1, keepdim=True)
|
||||
target = target.squeeze(-2)
|
||||
|
||||
return lower_weight * lower_log_prob + upper_weight * upper_log_prob
|
||||
return (target * log_pred).sum(-1)
|
||||
|
||||
def log_prob_target(self, target):
|
||||
log_pred = super().logits - torch.logsumexp(super().logits, -1, keepdim=True)
|
||||
return (target * log_pred).sum(-1)
|
||||
|
||||
class SymlogDist():
|
||||
def __init__(self, mode, dist='mse', agg='sum', tol=1e-8, dim_to_reduce=[-1, -2, -3]):
|
||||
self._mode = mode
|
||||
self._dist = dist
|
||||
self._agg = agg
|
||||
self._tol = tol
|
||||
self._dim_to_reduce = dim_to_reduce
|
||||
|
||||
def mode(self):
|
||||
return symexp(self._mode)
|
||||
|
||||
def mean(self):
|
||||
return symexp(self._mode)
|
||||
|
||||
def log_prob(self, value):
|
||||
assert self._mode.shape == value.shape
|
||||
if self._dist == 'mse':
|
||||
distance = (self._mode - symlog(value)) ** 2.0
|
||||
distance = torch.where(distance < self._tol, 0, distance)
|
||||
elif self._dist == 'abs':
|
||||
distance = torch.abs(self._mode - symlog(value))
|
||||
distance = torch.where(distance < self._tol, 0, distance)
|
||||
else:
|
||||
raise NotImplementedError(self._dist)
|
||||
if self._agg == 'mean':
|
||||
loss = distance.mean(self._dim_to_reduce)
|
||||
elif self._agg == 'sum':
|
||||
loss = distance.sum(self._dim_to_reduce)
|
||||
else:
|
||||
raise NotImplementedError(self._agg)
|
||||
return -loss
|
||||
|
||||
class ContDist:
|
||||
|
||||
@ -438,6 +487,7 @@ def static_scan_for_lambda_return(fn, inputs, start):
|
||||
indices = reversed(indices)
|
||||
flag = True
|
||||
for index in indices:
|
||||
# (inputs, pcont) -> (inputs[index], pcont[index])
|
||||
inp = lambda x: (_input[x] for _input in inputs)
|
||||
last = fn(last, *inp(index))
|
||||
if flag:
|
||||
@ -446,6 +496,7 @@ def static_scan_for_lambda_return(fn, inputs, start):
|
||||
else:
|
||||
outputs = torch.cat([outputs, last], dim=-1)
|
||||
outputs = torch.reshape(outputs, [outputs.shape[0], outputs.shape[1], 1])
|
||||
outputs = torch.flip(outputs, [1])
|
||||
outputs = torch.unbind(outputs, dim=0)
|
||||
return outputs
|
||||
|
||||
@ -687,14 +738,53 @@ def schedule(string, step):
|
||||
|
||||
def weight_init(m):
|
||||
if isinstance(m, nn.Linear):
|
||||
nn.init.orthogonal_(m.weight.data)
|
||||
in_num = m.in_features
|
||||
out_num = m.out_features
|
||||
denoms = (in_num + out_num) / 2.0
|
||||
scale = 1.0 / denoms
|
||||
std = np.sqrt(scale) / 0.87962566103423978
|
||||
nn.init.trunc_normal_(m.weight.data, mean=0.0, std=std, a=- 2.0, b=2.0)
|
||||
if hasattr(m.bias, 'data'):
|
||||
m.bias.data.fill_(0.0)
|
||||
elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
|
||||
gain = nn.init.calculate_gain('relu')
|
||||
nn.init.orthogonal_(m.weight.data, gain)
|
||||
space = m.kernel_size[0] * m.kernel_size[1]
|
||||
in_num = space * m.in_channels
|
||||
out_num = space * m.out_channels
|
||||
denoms = (in_num + out_num) / 2.0
|
||||
scale = 1.0 / denoms
|
||||
std = np.sqrt(scale) / 0.87962566103423978
|
||||
nn.init.trunc_normal_(m.weight.data, mean=0.0, std=std, a=- 2.0, b=2.0)
|
||||
if hasattr(m.bias, 'data'):
|
||||
m.bias.data.fill_(0.0)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
m.weight.data.fill_(1.0)
|
||||
if hasattr(m.bias, 'data'):
|
||||
m.bias.data.fill_(0.0)
|
||||
|
||||
def uniform_weight_init(given_scale):
|
||||
def f(m):
|
||||
if isinstance(m, nn.Linear):
|
||||
in_num = m.in_features
|
||||
out_num = m.out_features
|
||||
denoms = (in_num + out_num) / 2.0
|
||||
scale = given_scale / denoms
|
||||
limit = np.sqrt(3 * scale)
|
||||
nn.init.uniform_(m.weight.data, a=-limit, b=limit)
|
||||
if hasattr(m.bias, 'data'):
|
||||
m.bias.data.fill_(0.0)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
m.weight.data.fill_(1.0)
|
||||
if hasattr(m.bias, 'data'):
|
||||
m.bias.data.fill_(0.0)
|
||||
return f
|
||||
|
||||
def tensorstats(tensor, prefix=None):
|
||||
metrics = {
|
||||
'mean': to_np(torch.mean(tensor)),
|
||||
'std': to_np(torch.std(tensor)),
|
||||
'min': to_np(torch.min(tensor)),
|
||||
'max': to_np(torch.max(tensor)),
|
||||
}
|
||||
if prefix:
|
||||
metrics = {f'{prefix}_{k}': v for k, v in metrics.items()}
|
||||
return metrics
|
||||
|
Loading…
x
Reference in New Issue
Block a user