From e0487f82062cd080f8123d428fbace22286c42ec Mon Sep 17 00:00:00 2001 From: NM512 Date: Fri, 5 Jan 2024 10:26:48 +0900 Subject: [PATCH] merged action head into MLP and modified configs --- configs.yaml | 63 ++++++++------------- dreamer.py | 4 +- exploration.py | 2 +- models.py | 145 ++++++++++++++++++++++------------------------- networks.py | 150 ++++++++++++++----------------------------------- 5 files changed, 133 insertions(+), 231 deletions(-) diff --git a/configs.yaml b/configs.yaml index 700712e..9879ac6 100644 --- a/configs.yaml +++ b/configs.yaml @@ -47,26 +47,25 @@ defaults: dyn_temp_post: True grad_heads: ['decoder', 'reward', 'cont'] units: 512 - reward_layers: 2 - cont_layers: 2 - value_layers: 2 - actor_layers: 2 act: 'SiLU' norm: 'LayerNorm' encoder: - {mlp_keys: '$^', cnn_keys: 'image', act: 'SiLU', norm: 'LayerNorm', cnn_depth: 32, kernel_size: 4, minres: 4, mlp_layers: 2, mlp_units: 512, symlog_inputs: True} + {mlp_keys: '$^', cnn_keys: 'image', act: 'SiLU', norm: True, cnn_depth: 32, kernel_size: 4, minres: 4, mlp_layers: 2, mlp_units: 512, symlog_inputs: True} decoder: - {mlp_keys: '$^', cnn_keys: 'image', act: 'SiLU', norm: 'LayerNorm', cnn_depth: 32, kernel_size: 4, minres: 4, mlp_layers: 2, mlp_units: 512, cnn_sigmoid: False, image_dist: mse, vector_dist: symlog_mse} - value_head: 'symlog_disc' - reward_head: 'symlog_disc' + {mlp_keys: '$^', cnn_keys: 'image', act: 'SiLU', norm: True, cnn_depth: 32, kernel_size: 4, minres: 4, mlp_layers: 2, mlp_units: 512, cnn_sigmoid: False, image_dist: mse, vector_dist: symlog_mse, outscale: 1.0} + actor: + {layers: 2, dist: 'normal', entropy: 3e-4, unimix_ratio: 0.01, min_std: 0.1, max_std: 1.0, temp: 0.1, lr: 3e-5, eps: 1e-5, grad_clip: 100.0, outscale: 1.0} + critic: + {layers: 2, dist: 'symlog_disc', slow_target: True, slow_target_update: 1, slow_target_fraction: 0.02, lr: 3e-5, eps: 1e-5, grad_clip: 100.0, outscale: 0.0} + reward_head: + {layers: 2, dist: 'symlog_disc', scale: 1.0, outscale: 0.0} + cont_head: + {layers: 2, scale: 1.0, outscale: 1.0} dyn_scale: 0.5 rep_scale: 0.1 kl_free: 1.0 - cont_scale: 1.0 - reward_scale: 1.0 weight_decay: 0.0 unimix_ratio: 0.01 - action_unimix_ratio: 0.01 initial: 'learned' # Training @@ -77,15 +76,7 @@ defaults: model_lr: 1e-4 opt_eps: 1e-8 grad_clip: 1000 - value_lr: 3e-5 - actor_lr: 3e-5 - ac_opt_eps: 1e-5 - value_grad_clip: 100 - actor_grad_clip: 100 dataset_size: 1000000 - slow_value_target: True - slow_target_update: 1 - slow_target_fraction: 0.02 opt: 'adam' # Behavior. @@ -95,18 +86,10 @@ defaults: imag_gradient: 'dynamics' imag_gradient_mix: 0.0 imag_sample: True - actor_dist: 'normal' - actor_entropy: 3e-4 - actor_state_entropy: 0.0 - actor_init_std: 1.0 - actor_min_std: 0.1 - actor_max_std: 1.0 - actor_temp: 0.1 - expl_amount: 0.0 + expl_amount: 0 eval_state_mean: False collect_dyn_sample: True behavior_stop_grad: True - value_decay: 0.0 future_entropy: False # Exploration @@ -150,13 +133,12 @@ crafter: dyn_hidden: 1024 dyn_deter: 4096 units: 1024 - reward_layers: 5 - cont_layers: 5 - value_layers: 5 - actor_layers: 5 encoder: {mlp_keys: '$^', cnn_keys: 'image', cnn_depth: 96, mlp_layers: 5, mlp_units: 1024} decoder: {mlp_keys: '$^', cnn_keys: 'image', cnn_depth: 96, mlp_layers: 5, mlp_units: 1024} - actor_dist: 'onehot' + actor: {layers: 5, dist: 'onehot'} + value: {layers: 5} + reward_head: {layers: 5} + cont_head: {layers: 5} imag_gradient: 'reinforce' atari100k: @@ -166,7 +148,7 @@ atari100k: train_ratio: 1024 video_pred_log: true eval_episode_num: 100 - actor_dist: 'onehot' + actor: {dist: 'onehot'} imag_gradient: 'reinforce' stickey: False lives: unused @@ -189,13 +171,12 @@ minecraft: dyn_hidden: 1024 dyn_deter: 4096 units: 1024 - reward_layers: 5 - cont_layers: 5 - value_layers: 5 - actor_layers: 5 - encoder: {mlp_keys: 'inventory|inventory_max|equipped|health|hunger|breath|reward', cnn_keys: 'image', cnn_depth: 96, mlp_layers: 5, mlp_units: 1024} + encoder: {mlp_keys: 'inventory|inventory_max|equipped|health|hunger|breath|obs_reward', cnn_keys: 'image', cnn_depth: 96, mlp_layers: 5, mlp_units: 1024} decoder: {mlp_keys: 'inventory|inventory_max|equipped|health|hunger|breath', cnn_keys: 'image', cnn_depth: 96, mlp_layers: 5, mlp_units: 1024} - actor_dist: 'onehot' + actor: {layers: 5, dist: 'onehot'} + value: {layers: 5} + reward_head: {layers: 5} + cont_head: {layers: 5} imag_gradient: 'reinforce' break_speed: 100.0 time_limit: 36000 @@ -203,7 +184,7 @@ minecraft: memorymaze: steps: 1e8 action_repeat: 2 - actor_dist: 'onehot' + actor: {dist: 'onehot'} imag_gradient: 'reinforce' task: 'memorymaze_9x9' diff --git a/dreamer.py b/dreamer.py index 218f6fd..3c70d88 100644 --- a/dreamer.py +++ b/dreamer.py @@ -110,7 +110,7 @@ class Dreamer(nn.Module): logprob = actor.log_prob(action) latent = {k: v.detach() for k, v in latent.items()} action = action.detach() - if self._config.actor_dist == "onehot_gumble": + if self._config.actor["dist"] == "onehot_gumble": action = torch.one_hot( torch.argmax(action, dim=-1), self._config.num_actions ) @@ -123,7 +123,7 @@ class Dreamer(nn.Module): amount = self._config.expl_amount if training else self._config.eval_noise if amount == 0: return action - if "onehot" in self._config.actor_dist: + if "onehot" in self._config.actor["dist"]: probs = amount / self._config.num_actions + (1 - amount) * action return tools.OneHotDist(probs=probs).sample() else: diff --git a/exploration.py b/exploration.py index 4938fba..1121c10 100644 --- a/exploration.py +++ b/exploration.py @@ -14,7 +14,7 @@ class Random(nn.Module): self._act_space = act_space def actor(self, feat): - if self._config.actor_dist == "onehot": + if self._config.actor["dist"] == "onehot": return tools.OneHotDist( torch.zeros(self._config.num_actions) .repeat(self._config.envs, 1) diff --git a/models.py b/models.py index 417fa6a..0ab27f1 100644 --- a/models.py +++ b/models.py @@ -67,39 +67,29 @@ class WorldModel(nn.Module): self.heads["decoder"] = networks.MultiDecoder( feat_size, shapes, **config.decoder ) - if config.reward_head == "symlog_disc": - self.heads["reward"] = networks.MLP( - feat_size, # pytorch version - (255,), - config.reward_layers, - config.units, - config.act, - config.norm, - dist=config.reward_head, - outscale=0.0, - device=config.device, - ) - else: - self.heads["reward"] = networks.MLP( - feat_size, # pytorch version - [], - config.reward_layers, - config.units, - config.act, - config.norm, - dist=config.reward_head, - outscale=0.0, - device=config.device, - ) + self.heads["reward"] = networks.MLP( + feat_size, + (255,) if config.reward_head["dist"] == "symlog_disc" else (), + config.reward_head["layers"], + config.units, + config.act, + config.norm, + dist=config.reward_head["dist"], + outscale=config.reward_head["outscale"], + device=config.device, + name="Reward", + ) self.heads["cont"] = networks.MLP( - feat_size, # pytorch version - [], - config.cont_layers, + feat_size, + (), + config.cont_head["layers"], config.units, config.act, config.norm, dist="binary", + outscale=config.cont_head["outscale"], device=config.device, + name="Cont", ) for name in config.grad_heads: assert name in self.heads, name @@ -113,7 +103,14 @@ class WorldModel(nn.Module): opt=config.opt, use_amp=self._use_amp, ) - self._scales = dict(reward=config.reward_scale, cont=config.cont_scale) + print( + f"Optimizer model_opt has {sum(param.numel() for param in self.parameters())} variables." + ) + self._scales = dict( + reward=config.reward_head["scale"], + cont=config.cont_head["scale"], + image=1.0, + ) def _train(self, data): # action (batch_size, batch_length, act_dim) @@ -134,6 +131,7 @@ class WorldModel(nn.Module): kl_loss, kl_value, dyn_loss, rep_loss = self.dynamics.kl_loss( post, prior, kl_free, dyn_scale, rep_scale ) + assert kl_loss.shape == embed.shape[:2], kl_loss.shape preds = {} for name, head in self.heads.items(): grad_head = name in self._config.grad_heads @@ -226,65 +224,60 @@ class ImagBehavior(nn.Module): feat_size = config.dyn_stoch * config.dyn_discrete + config.dyn_deter else: feat_size = config.dyn_stoch + config.dyn_deter - self.actor = networks.ActionHead( + self.actor = networks.MLP( feat_size, - config.num_actions, - config.actor_layers, + (config.num_actions,), + config.actor["layers"], config.units, config.act, config.norm, - config.actor_dist, - config.actor_init_std, - config.actor_min_std, - config.actor_max_std, - config.actor_temp, + config.actor["dist"], + "learned", + config.actor["min_std"], + config.actor["max_std"], + config.actor["temp"], + unimix_ratio=config.actor["unimix_ratio"], outscale=1.0, - unimix_ratio=config.action_unimix_ratio, + name="Actor", ) - if config.value_head == "symlog_disc": - self.value = networks.MLP( - feat_size, - (255,), - config.value_layers, - config.units, - config.act, - config.norm, - config.value_head, - outscale=0.0, - device=config.device, - ) - else: - self.value = networks.MLP( - feat_size, - [], - config.value_layers, - config.units, - config.act, - config.norm, - config.value_head, - outscale=0.0, - device=config.device, - ) - if config.slow_value_target: + self.value = networks.MLP( + feat_size, + (255,) if config.critic["dist"] == "symlog_disc" else (), + config.critic["layers"], + config.units, + config.act, + config.norm, + config.critic["dist"], + outscale=config.critic["outscale"], + device=config.device, + name="Value", + ) + if config.critic["slow_target"]: self._slow_value = copy.deepcopy(self.value) self._updates = 0 kw = dict(wd=config.weight_decay, opt=config.opt, use_amp=self._use_amp) self._actor_opt = tools.Optimizer( "actor", self.actor.parameters(), - config.actor_lr, - config.ac_opt_eps, - config.actor_grad_clip, + config.actor["lr"], + config.actor["eps"], + config.actor["grad_clip"], **kw, ) + print( + f"Optimizer actor_opt has {sum(param.numel() for param in self.actor.parameters())} variables." + ) self._value_opt = tools.Optimizer( "value", self.value.parameters(), - config.value_lr, - config.ac_opt_eps, - config.value_grad_clip, + config.critic["lr"], + config.critic["eps"], + config.critic["grad_clip"], **kw, ) + print( + f"Optimizer value_opt has {sum(param.numel() for param in self.value.parameters())} variables." + ) if self._config.reward_EMA: self.reward_ema = RewardEMA(device=self._config.device) @@ -335,19 +328,15 @@ class ImagBehavior(nn.Module): # (time, batch, 1), (time, batch, 1) -> (time, batch) value_loss = -value.log_prob(target.detach()) slow_target = self._slow_value(value_input[:-1].detach()) - if self._config.slow_value_target: - value_loss = value_loss - value.log_prob( - slow_target.mode().detach() - ) - if self._config.value_decay: - value_loss += self._config.value_decay * value.mode() + if self._config.critic["slow_target"]: + value_loss -= value.log_prob(slow_target.mode().detach()) # (time, batch, 1), (time, batch, 1) -> (1,) value_loss = torch.mean(weights[:-1] * value_loss[:, :, None]) metrics.update(tools.tensorstats(value.mode(), "value")) metrics.update(tools.tensorstats(target, "target")) metrics.update(tools.tensorstats(reward, "imag_reward")) - if self._config.actor_dist in ["onehot"]: + if self._config.actor["dist"] in ["onehot"]: metrics.update( tools.tensorstats( torch.argmax(imag_action, dim=-1).float(), "imag_action" @@ -466,9 +455,9 @@ class ImagBehavior(nn.Module): return actor_loss, metrics def _update_slow_target(self): - if self._config.slow_value_target: - if self._updates % self._config.slow_target_update == 0: - mix = self._config.slow_target_fraction + if self._config.critic["slow_target"]: + if self._updates % self._config.critic["slow_target_update"] == 0: + mix = self._config.critic["slow_target_fraction"] for s, d in zip(self.value.parameters(), self._slow_value.parameters()): d.data = mix * s.data + (1 - mix) * d.data self._updates += 1 diff --git a/networks.py b/networks.py index 38769b7..5118ada 100644 --- a/networks.py +++ b/networks.py @@ -632,9 +632,14 @@ class MLP(nn.Module): norm="LayerNorm", dist="normal", std=1.0, + min_std=0.1, + max_std=1.0, + temp=0.1, + unimix_ratio=0.01, outscale=1.0, symlog_inputs=False, device="cuda", + name="NoName", ): super(MLP, self).__init__() self._shape = (shape,) if isinstance(shape, int) else shape @@ -647,15 +652,20 @@ class MLP(nn.Module): self._std = std self._symlog_inputs = symlog_inputs self._device = device + self._min_std = min_std + self._max_std = max_std + self._temp = temp + self._unimix_ratio = unimix_ratio - layers = [] + self.layers = nn.Sequential() for index in range(self._layers): - layers.append(nn.Linear(inp_dim, units, bias=False)) - layers.append(norm(units, eps=1e-03)) - layers.append(act()) + self.layers.add_module( + f"{name}_linear{index}", nn.Linear(inp_dim, units, bias=False) + ) + self.layers.add_module(f"{name}_norm{index}", norm(units, eps=1e-03)) + self.layers.add_module(f"{name}_act{index}", act()) if index == 0: inp_dim = units - self.layers = nn.Sequential(*layers) self.layers.apply(tools.weight_init) if isinstance(self._shape, dict): @@ -664,6 +674,7 @@ class MLP(nn.Module): self.mean_layer[name] = nn.Linear(inp_dim, np.prod(shape)) self.mean_layer.apply(tools.uniform_weight_init(outscale)) if self._std == "learned": + assert dist in ("tanh_normal", "normal", "trunc_normal", "huber"), dist self.std_layer = nn.ModuleDict() for name, shape in self._shape.items(): self.std_layer[name] = nn.Linear(inp_dim, np.prod(shape)) @@ -672,6 +683,7 @@ class MLP(nn.Module): self.mean_layer = nn.Linear(inp_dim, np.prod(self._shape)) self.mean_layer.apply(tools.uniform_weight_init(outscale)) if self._std == "learned": + assert dist in ("tanh_normal", "normal", "trunc_normal", "huber"), dist self.std_layer = nn.Linear(units, np.prod(self._shape)) self.std_layer.apply(tools.uniform_weight_init(outscale)) @@ -680,6 +692,7 @@ class MLP(nn.Module): if self._symlog_inputs: x = tools.symlog(x) out = self.layers(x) + # Used for encoder output if self._shape is None: return out if isinstance(self._shape, dict): @@ -701,98 +714,9 @@ class MLP(nn.Module): return self.dist(self._dist, mean, std, self._shape) def dist(self, dist, mean, std, shape): - if dist == "normal": - return tools.ContDist( - torchd.independent.Independent( - torchd.normal.Normal(mean, std), len(shape) - ) - ) - if dist == "huber": - return tools.ContDist( - torchd.independent.Independent( - tools.UnnormalizedHuber(mean, std, 1.0), len(shape) - ) - ) - if dist == "binary": - return tools.Bernoulli( - torchd.independent.Independent( - torchd.bernoulli.Bernoulli(logits=mean), len(shape) - ) - ) - if dist == "symlog_disc": - return tools.DiscDist(logits=mean, device=self._device) - if dist == "symlog_mse": - return tools.SymlogDist(mean) - raise NotImplementedError(dist) - - -class ActionHead(nn.Module): - def __init__( - self, - inp_dim, - size, - layers, - units, - act=nn.ELU, - norm=nn.LayerNorm, - dist="trunc_normal", - init_std=0.0, - min_std=0.1, - max_std=1.0, - temp=0.1, - outscale=1.0, - unimix_ratio=0.01, - ): - super(ActionHead, self).__init__() - self._size = size - self._layers = layers - self._units = units - self._dist = dist - act = getattr(torch.nn, act) - norm = getattr(torch.nn, norm) - self._min_std = min_std - self._max_std = max_std - self._init_std = init_std - self._unimix_ratio = unimix_ratio - self._temp = temp() if callable(temp) else temp - - pre_layers = [] - for index in range(self._layers): - pre_layers.append(nn.Linear(inp_dim, self._units, bias=False)) - pre_layers.append(norm(self._units, eps=1e-03)) - pre_layers.append(act()) - if index == 0: - inp_dim = self._units - self._pre_layers = nn.Sequential(*pre_layers) - self._pre_layers.apply(tools.weight_init) - - if self._dist in ["tanh_normal", "tanh_normal_5", "normal", "trunc_normal"]: - self._dist_layer = nn.Linear(self._units, 2 * self._size) - self._dist_layer.apply(tools.uniform_weight_init(outscale)) - - elif self._dist in ["normal_1", "onehot", "onehot_gumbel"]: - self._dist_layer = nn.Linear(self._units, self._size) - self._dist_layer.apply(tools.uniform_weight_init(outscale)) - - def forward(self, features, dtype=None): - x = features - x = self._pre_layers(x) if self._dist == "tanh_normal": - x = self._dist_layer(x) - mean, std = torch.split(x, 2, -1) mean = torch.tanh(mean) - std = F.softplus(std + self._init_std) + self._min_std - dist = torchd.normal.Normal(mean, std) - dist = torchd.transformed_distribution.TransformedDistribution( - dist, tools.TanhBijector() - ) - dist = torchd.independent.Independent(dist, 1) - dist = tools.SampleDist(dist) - elif self._dist == "tanh_normal_5": - x = self._dist_layer(x) - mean, std = torch.split(x, 2, -1) - mean = 5 * torch.tanh(mean / 5) - std = F.softplus(std + 5) + 5 + std = F.softplus(std) + self._min_std dist = torchd.normal.Normal(mean, std) dist = torchd.transformed_distribution.TransformedDistribution( dist, tools.TanhBijector() @@ -800,33 +724,41 @@ class ActionHead(nn.Module): dist = torchd.independent.Independent(dist, 1) dist = tools.SampleDist(dist) elif self._dist == "normal": - x = self._dist_layer(x) - mean, std = torch.split(x, [self._size] * 2, -1) std = (self._max_std - self._min_std) * torch.sigmoid( std + 2.0 ) + self._min_std dist = torchd.normal.Normal(torch.tanh(mean), std) - dist = tools.ContDist(torchd.independent.Independent(dist, 1)) - elif self._dist == "normal_1": - mean = self._dist_layer(x) - dist = torchd.normal.Normal(mean, 1) + dist = tools.ContDist(torchd.independent.Independent(dist, 1), absmax=1.0) + elif self._dist == "normal_std_fixed": + dist = torchd.normal.Normal(mean, self._std) dist = tools.ContDist(torchd.independent.Independent(dist, 1)) elif self._dist == "trunc_normal": - x = self._dist_layer(x) - mean, std = torch.split(x, [self._size] * 2, -1) mean = torch.tanh(mean) std = 2 * torch.sigmoid(std / 2) + self._min_std dist = tools.SafeTruncatedNormal(mean, std, -1, 1) dist = tools.ContDist(torchd.independent.Independent(dist, 1)) elif self._dist == "onehot": - x = self._dist_layer(x) - dist = tools.OneHotDist(x, unimix_ratio=self._unimix_ratio) + dist = tools.OneHotDist(mean, unimix_ratio=self._unimix_ratio) elif self._dist == "onehot_gumble": - x = self._dist_layer(x) - temp = self._temp - dist = tools.ContDist(torchd.gumbel.Gumbel(x, 1 / temp)) + dist = tools.ContDist(torchd.gumbel.Gumbel(mean, 1 / self._temp)) + elif dist == "huber": + dist = tools.ContDist( + torchd.independent.Independent( + tools.UnnormalizedHuber(mean, std, 1.0), len(shape) + ) + ) + elif dist == "binary": + dist = tools.Bernoulli( + torchd.independent.Independent( + torchd.bernoulli.Bernoulli(logits=mean), len(shape) + ) + ) + elif dist == "symlog_disc": + dist = tools.DiscDist(logits=mean, device=self._device) + elif dist == "symlog_mse": + dist = tools.SymlogDist(mean) else: - raise NotImplementedError(self._dist) + raise NotImplementedError(dist) return dist