added state input capability

This commit is contained in:
NM512 2023-05-14 23:38:46 +09:00
parent 3ebb8ad617
commit b984e69b6e
8 changed files with 369 additions and 142 deletions

View File

@ -7,20 +7,24 @@ Get dependencies:
``` ```
pip install -r requirements.txt pip install -r requirements.txt
``` ```
Train the agent on Walker Walk in Vision DMC: Train the agent on Walker Walk in DMC Vision:
``` ```
python3 dreamer.py --configs defaults --task dmc_walker_walk --logdir ~/dreamerv3-torch/logdir/dmc_walker_walk python3 dreamer.py --configs dmc_vision --task dmc_walker_walk --logdir ./logdir/dmc_walker_walk
```
Train the agent on Walker Walk in DMC Proprio:
```
python3 dreamer.py --configs dmc_proprio --task dmc_walker_walk --logdir ./logdir/dmc_walker_walk
``` ```
Train the agent on Alien in Atari 100K: Train the agent on Alien in Atari 100K:
``` ```
python3 dreamer.py --configs defaults atari100k --task atari_alien --logdir ~/dreamerv3-torch/logdir/atari_alien python3 dreamer.py --configs atari100k --task atari_alien --logdir ./logdir/atari_alien
``` ```
Monitor results: Monitor results:
``` ```
tensorboard --logdir ~/dreamerv3-torch/logdir tensorboard --logdir ~/dreamerv3-torch/logdir
``` ```
## Evaluation Results ## Results
More results will be added in the future. More results will be added in the future.
![dmc_vision](https://user-images.githubusercontent.com/70328564/236276650-ae706f29-4c14-4ed3-9b61-1829a1fdedae.png) ![dmc_vision](https://user-images.githubusercontent.com/70328564/236276650-ae706f29-4c14-4ed3-9b61-1829a1fdedae.png)
@ -30,7 +34,7 @@ More results will be added in the future.
- [x] Modify implementation details based on the author's implementation - [x] Modify implementation details based on the author's implementation
- [x] Evaluate on DMC vision - [x] Evaluate on DMC vision
- [x] Evaluate on Atari 100K - [x] Evaluate on Atari 100K
- [ ] Add state input capability - [x] Add state input capability
- [ ] Evaluate on DMC Proprio - [ ] Evaluate on DMC Proprio
- [ ] etc. - [ ] etc.

View File

@ -1,4 +1,3 @@
# defaults is for Vision DMC
defaults: defaults:
logdir: null logdir: null
@ -17,6 +16,7 @@ defaults:
precision: 16 precision: 16
debug: False debug: False
expl_gifs: False expl_gifs: False
video_pred_log: True
# Environment # Environment
task: 'dmc_walker_walk' task: 'dmc_walker_walk'
@ -43,7 +43,7 @@ defaults:
dyn_std_act: 'sigmoid2' dyn_std_act: 'sigmoid2'
dyn_min_std: 0.1 dyn_min_std: 0.1
dyn_temp_post: True dyn_temp_post: True
grad_heads: ['image', 'reward', 'cont'] grad_heads: ['decoder', 'reward', 'cont']
units: 512 units: 512
reward_layers: 2 reward_layers: 2
cont_layers: 2 cont_layers: 2
@ -51,11 +51,12 @@ defaults:
actor_layers: 2 actor_layers: 2
act: 'SiLU' act: 'SiLU'
norm: 'LayerNorm' norm: 'LayerNorm'
cnn_depth: 32 encoder:
encoder_kernels: [4, 4, 4, 4] {mlp_keys: '$^', cnn_keys: 'image', act: 'SiLU', norm: 'LayerNorm', cnn_depth: 32, cnn_kernels: [4, 4, 4, 4], mlp_layers: 2, mlp_units: 512, symlog_inputs: True}
decoder_kernels: [4, 4, 4, 4] decoder:
value_head: 'twohot_symlog' {mlp_keys: '$^', cnn_keys: 'image', act: 'SiLU', norm: 'LayerNorm', cnn_depth: 32, cnn_kernels: [4, 4, 4, 4], mlp_layers: 2, mlp_units: 512, cnn_sigmoid: False, image_dist: mse, vector_dist: symlog_mse,}
reward_head: 'twohot_symlog' value_head: 'symlog_disc'
reward_head: 'symlog_disc'
dyn_scale: '0.5' dyn_scale: '0.5'
rep_scale: '0.1' rep_scale: '0.1'
kl_free: '1.0' kl_free: '1.0'
@ -119,6 +120,20 @@ defaults:
disag_units: 400 disag_units: 400
disag_action_cond: False disag_action_cond: False
dmc_vision:
steps: 1e6
train_ratio: 512
video_pred_log: true
encoder: {mlp_keys: '$^', cnn_keys: 'image'}
decoder: {mlp_keys: '$^', cnn_keys: 'image'}
dmc_proprio:
steps: 5e5
train_ratio: 512
video_pred_log: false
encoder: {mlp_keys: '.*', cnn_keys: '$^'}
decoder: {mlp_keys: '.*', cnn_keys: '$^'}
atari100k: atari100k:
steps: 4e5 steps: 4e5
action_repeat: 4 action_repeat: 4

View File

@ -27,7 +27,7 @@ to_np = lambda x: x.detach().cpu().numpy()
class Dreamer(nn.Module): class Dreamer(nn.Module):
def __init__(self, config, logger, dataset): def __init__(self, obs_space, act_space, config, logger, dataset):
super(Dreamer, self).__init__() super(Dreamer, self).__init__()
self._config = config self._config = config
self._logger = logger self._logger = logger
@ -51,7 +51,7 @@ class Dreamer(nn.Module):
x, self._step x, self._step
) )
self._dataset = dataset self._dataset = dataset
self._wm = models.WorldModel(self._step, config) self._wm = models.WorldModel(obs_space, act_space, self._step, config)
self._task_behavior = models.ImagBehavior( self._task_behavior = models.ImagBehavior(
config, self._wm, config.behavior_stop_grad config, self._wm, config.behavior_stop_grad
) )
@ -90,8 +90,9 @@ class Dreamer(nn.Module):
for name, values in self._metrics.items(): for name, values in self._metrics.items():
self._logger.scalar(name, float(np.mean(values))) self._logger.scalar(name, float(np.mean(values)))
self._metrics[name] = [] self._metrics[name] = []
openl = self._wm.video_pred(next(self._dataset)) if self._config.video_pred_log:
self._logger.video("train_openl", to_np(openl)) openl = self._wm.video_pred(next(self._dataset))
self._logger.video("train_openl", to_np(openl))
self._logger.write(fps=True) self._logger.write(fps=True)
policy_output, state = self._policy(obs, state, training) policy_output, state = self._policy(obs, state, training)
@ -296,8 +297,6 @@ def main(config):
config.eval_every //= config.action_repeat config.eval_every //= config.action_repeat
config.log_every //= config.action_repeat config.log_every //= config.action_repeat
config.time_limit //= config.action_repeat config.time_limit //= config.action_repeat
config.act = getattr(torch.nn, config.act)
config.norm = getattr(torch.nn, config.norm)
print("Logdir", logdir) print("Logdir", logdir)
logdir.mkdir(parents=True, exist_ok=True) logdir.mkdir(parents=True, exist_ok=True)
@ -350,7 +349,13 @@ def main(config):
print("Simulate agent.") print("Simulate agent.")
train_dataset = make_dataset(train_eps, config) train_dataset = make_dataset(train_eps, config)
eval_dataset = make_dataset(eval_eps, config) eval_dataset = make_dataset(eval_eps, config)
agent = Dreamer(config, logger, train_dataset).to(config.device) agent = Dreamer(
train_envs[0].observation_space,
train_envs[0].action_space,
config,
logger,
train_dataset,
).to(config.device)
agent.requires_grad_(requires_grad=False) agent.requires_grad_(requires_grad=False)
if (logdir / "latest_model.pt").exists(): if (logdir / "latest_model.pt").exists():
agent.load_state_dict(torch.load(logdir / "latest_model.pt")) agent.load_state_dict(torch.load(logdir / "latest_model.pt"))
@ -362,8 +367,9 @@ def main(config):
print("Start evaluation.") print("Start evaluation.")
eval_policy = functools.partial(agent, training=False) eval_policy = functools.partial(agent, training=False)
tools.simulate(eval_policy, eval_envs, episodes=config.eval_episode_num) tools.simulate(eval_policy, eval_envs, episodes=config.eval_episode_num)
video_pred = agent._wm.video_pred(next(eval_dataset)) if config.video_pred_log:
logger.video("eval_openl", to_np(video_pred)) video_pred = agent._wm.video_pred(next(eval_dataset))
logger.video("eval_openl", to_np(video_pred))
print("Start training.") print("Start training.")
state = tools.simulate(agent, train_envs, config.eval_every, state=state) state = tools.simulate(agent, train_envs, config.eval_every, state=state)
torch.save(agent.state_dict(), logdir / "latest_model.pt") torch.save(agent.state_dict(), logdir / "latest_model.pt")
@ -376,14 +382,23 @@ def main(config):
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--configs", nargs="+", required=True) parser.add_argument("--configs", nargs="+")
args, remaining = parser.parse_known_args() args, remaining = parser.parse_known_args()
configs = yaml.safe_load( configs = yaml.safe_load(
(pathlib.Path(sys.argv[0]).parent / "configs.yaml").read_text() (pathlib.Path(sys.argv[0]).parent / "configs.yaml").read_text()
) )
def recursive_update(base, update):
for key, value in update.items():
if isinstance(value, dict) and key in base:
recursive_update(base[key], value)
else:
base[key] = value
name_list = ["defaults", *args.configs] if args.configs else ["defaults"]
defaults = {} defaults = {}
for name in args.configs: for name in name_list:
defaults.update(configs[name]) recursive_update(defaults, configs[name])
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
for key, value in sorted(defaults.items(), key=lambda x: x[0]): for key, value in sorted(defaults.items(), key=lambda x: x[0]):
arg_type = tools.args_type(value) arg_type = tools.args_type(value)

View File

@ -24,7 +24,11 @@ class DeepMindControl:
def observation_space(self): def observation_space(self):
spaces = {} spaces = {}
for key, value in self._env.observation_spec().items(): for key, value in self._env.observation_spec().items():
spaces[key] = gym.spaces.Box(-np.inf, np.inf, value.shape, dtype=np.float32) if len(value.shape) == 0:
shape = (1,)
else:
shape = value.shape
spaces[key] = gym.spaces.Box(-np.inf, np.inf, shape, dtype=np.float32)
spaces["image"] = gym.spaces.Box(0, 255, self._size + (3,), dtype=np.uint8) spaces["image"] = gym.spaces.Box(0, 255, self._size + (3,), dtype=np.uint8)
return gym.spaces.Dict(spaces) return gym.spaces.Dict(spaces)
@ -42,6 +46,7 @@ class DeepMindControl:
if time_step.last(): if time_step.last():
break break
obs = dict(time_step.observation) obs = dict(time_step.observation)
obs = {key: [val] if len(val.shape) == 0 else val for key, val in obs.items()}
obs["image"] = self.render() obs["image"] = self.render()
# There is no terminal state in DMC # There is no terminal state in DMC
obs["is_terminal"] = False if time_step.first() else time_step.discount == 0 obs["is_terminal"] = False if time_step.first() else time_step.discount == 0
@ -53,6 +58,7 @@ class DeepMindControl:
def reset(self): def reset(self):
time_step = self._env.reset() time_step = self._env.reset()
obs = dict(time_step.observation) obs = dict(time_step.observation)
obs = {key: [val] if len(val.shape) == 0 else val for key, val in obs.items()}
obs["image"] = self.render() obs["image"] = self.render()
obs["is_terminal"] = False if time_step.first() else time_step.discount == 0 obs["is_terminal"] = False if time_step.first() else time_step.discount == 0
obs["is_first"] = time_step.first() obs["is_first"] = time_step.first()

View File

@ -52,7 +52,7 @@ class Plan2Explore(nn.Module):
act=config.act, act=config.act,
) )
self._networks = nn.ModuleList( self._networks = nn.ModuleList(
[networks.DenseHead(**kw) for _ in range(config.disag_models)] [networks.MLP(**kw) for _ in range(config.disag_models)]
) )
kw = dict(wd=config.weight_decay, opt=config.opt, use_amp=self._use_amp) kw = dict(wd=config.weight_decay, opt=config.opt, use_amp=self._use_amp)
self._model_opt = tools.Optimizer( self._model_opt = tools.Optimizer(

View File

@ -29,26 +29,14 @@ class RewardEMA(object):
class WorldModel(nn.Module): class WorldModel(nn.Module):
def __init__(self, step, config): def __init__(self, obs_space, act_space, step, config):
super(WorldModel, self).__init__() super(WorldModel, self).__init__()
self._step = step self._step = step
self._use_amp = True if config.precision == 16 else False self._use_amp = True if config.precision == 16 else False
self._config = config self._config = config
self.encoder = networks.ConvEncoder( shapes = {k: tuple(v.shape) for k, v in obs_space.spaces.items()}
config.grayscale, self.encoder = networks.MultiEncoder(shapes, **config.encoder)
config.cnn_depth, embed_size = self.encoder.outdim
config.act,
config.norm,
config.encoder_kernels,
)
if config.size[0] == 64 and config.size[1] == 64:
embed_size = (
(64 // 2 ** (len(config.encoder_kernels))) ** 2
* config.cnn_depth
* 2 ** (len(config.encoder_kernels) - 1)
)
else:
raise NotImplemented(f"{config.size} is not applicable now")
self.dynamics = networks.RSSM( self.dynamics = networks.RSSM(
config.dyn_stoch, config.dyn_stoch,
config.dyn_deter, config.dyn_deter,
@ -72,22 +60,15 @@ class WorldModel(nn.Module):
config.device, config.device,
) )
self.heads = nn.ModuleDict() self.heads = nn.ModuleDict()
channels = 1 if config.grayscale else 3
shape = (channels,) + config.size
if config.dyn_discrete: if config.dyn_discrete:
feat_size = config.dyn_stoch * config.dyn_discrete + config.dyn_deter feat_size = config.dyn_stoch * config.dyn_discrete + config.dyn_deter
else: else:
feat_size = config.dyn_stoch + config.dyn_deter feat_size = config.dyn_stoch + config.dyn_deter
self.heads["image"] = networks.ConvDecoder( self.heads["decoder"] = networks.MultiDecoder(
feat_size, # pytorch version feat_size, shapes, **config.decoder
config.cnn_depth,
config.act,
config.norm,
shape,
config.decoder_kernels,
) )
if config.reward_head == "twohot_symlog": if config.reward_head == "symlog_disc":
self.heads["reward"] = networks.DenseHead( self.heads["reward"] = networks.MLP(
feat_size, # pytorch version feat_size, # pytorch version
(255,), (255,),
config.reward_layers, config.reward_layers,
@ -99,7 +80,7 @@ class WorldModel(nn.Module):
device=config.device, device=config.device,
) )
else: else:
self.heads["reward"] = networks.DenseHead( self.heads["reward"] = networks.MLP(
feat_size, # pytorch version feat_size, # pytorch version
[], [],
config.reward_layers, config.reward_layers,
@ -110,7 +91,7 @@ class WorldModel(nn.Module):
outscale=0.0, outscale=0.0,
device=config.device, device=config.device,
) )
self.heads["cont"] = networks.DenseHead( self.heads["cont"] = networks.MLP(
feat_size, # pytorch version feat_size, # pytorch version
[], [],
config.cont_layers, config.cont_layers,
@ -153,15 +134,19 @@ class WorldModel(nn.Module):
kl_loss, kl_value, dyn_loss, rep_loss = self.dynamics.kl_loss( kl_loss, kl_value, dyn_loss, rep_loss = self.dynamics.kl_loss(
post, prior, kl_free, dyn_scale, rep_scale post, prior, kl_free, dyn_scale, rep_scale
) )
losses = {} preds = {}
likes = {}
for name, head in self.heads.items(): for name, head in self.heads.items():
grad_head = name in self._config.grad_heads grad_head = name in self._config.grad_heads
feat = self.dynamics.get_feat(post) feat = self.dynamics.get_feat(post)
feat = feat if grad_head else feat.detach() feat = feat if grad_head else feat.detach()
pred = head(feat) pred = head(feat)
if type(pred) is dict:
preds.update(pred)
else:
preds[name] = pred
losses = {}
for name, pred in preds.items():
like = pred.log_prob(data[name]) like = pred.log_prob(data[name])
likes[name] = like
losses[name] = -torch.mean(like) * self._scales.get(name, 1.0) losses[name] = -torch.mean(like) * self._scales.get(name, 1.0)
model_loss = sum(losses.values()) + kl_loss model_loss = sum(losses.values()) + kl_loss
metrics = self._model_opt(model_loss, self.parameters()) metrics = self._model_opt(model_loss, self.parameters())
@ -213,11 +198,13 @@ class WorldModel(nn.Module):
states, _ = self.dynamics.observe( states, _ = self.dynamics.observe(
embed[:6, :5], data["action"][:6, :5], data["is_first"][:6, :5] embed[:6, :5], data["action"][:6, :5], data["is_first"][:6, :5]
) )
recon = self.heads["image"](self.dynamics.get_feat(states)).mode()[:6] recon = self.heads["decoder"](self.dynamics.get_feat(states))["image"].mode()[
:6
]
reward_post = self.heads["reward"](self.dynamics.get_feat(states)).mode()[:6] reward_post = self.heads["reward"](self.dynamics.get_feat(states)).mode()[:6]
init = {k: v[:, -1] for k, v in states.items()} init = {k: v[:, -1] for k, v in states.items()}
prior = self.dynamics.imagine(data["action"][:6, 5:], init) prior = self.dynamics.imagine(data["action"][:6, 5:], init)
openl = self.heads["image"](self.dynamics.get_feat(prior)).mode() openl = self.heads["decoder"](self.dynamics.get_feat(prior))["image"].mode()
reward_prior = self.heads["reward"](self.dynamics.get_feat(prior)).mode() reward_prior = self.heads["reward"](self.dynamics.get_feat(prior)).mode()
# observed image is given until 5 steps # observed image is given until 5 steps
model = torch.cat([recon[:, :5], openl], 1) model = torch.cat([recon[:, :5], openl], 1)
@ -254,9 +241,9 @@ class ImagBehavior(nn.Module):
config.actor_temp, config.actor_temp,
outscale=1.0, outscale=1.0,
unimix_ratio=config.action_unimix_ratio, unimix_ratio=config.action_unimix_ratio,
) # action_dist -> action_disc? )
if config.value_head == "twohot_symlog": if config.value_head == "symlog_disc":
self.value = networks.DenseHead( self.value = networks.MLP(
feat_size, # pytorch version feat_size, # pytorch version
(255,), (255,),
config.value_layers, config.value_layers,
@ -268,7 +255,7 @@ class ImagBehavior(nn.Module):
device=config.device, device=config.device,
) )
else: else:
self.value = networks.DenseHead( self.value = networks.MLP(
feat_size, # pytorch version feat_size, # pytorch version
[], [],
config.value_layers, config.value_layers,

View File

@ -1,5 +1,6 @@
import math import math
import numpy as np import numpy as np
import re
import torch import torch
from torch import nn from torch import nn
@ -20,8 +21,8 @@ class RSSM(nn.Module):
rec_depth=1, rec_depth=1,
shared=False, shared=False,
discrete=False, discrete=False,
act=nn.ELU, act="SiLU",
norm=nn.LayerNorm, norm="LayerNorm",
mean_act="none", mean_act="none",
std_act="softplus", std_act="softplus",
temp_post=True, temp_post=True,
@ -43,8 +44,8 @@ class RSSM(nn.Module):
self._rec_depth = rec_depth self._rec_depth = rec_depth
self._shared = shared self._shared = shared
self._discrete = discrete self._discrete = discrete
self._act = act act = getattr(torch.nn, act)
self._norm = norm norm = getattr(torch.nn, norm)
self._mean_act = mean_act self._mean_act = mean_act
self._std_act = std_act self._std_act = std_act
self._temp_post = temp_post self._temp_post = temp_post
@ -62,8 +63,8 @@ class RSSM(nn.Module):
inp_dim += self._embed inp_dim += self._embed
for i in range(self._layers_input): for i in range(self._layers_input):
inp_layers.append(nn.Linear(inp_dim, self._hidden, bias=False)) inp_layers.append(nn.Linear(inp_dim, self._hidden, bias=False))
inp_layers.append(self._norm(self._hidden, eps=1e-03)) inp_layers.append(norm(self._hidden, eps=1e-03))
inp_layers.append(self._act()) inp_layers.append(act())
if i == 0: if i == 0:
inp_dim = self._hidden inp_dim = self._hidden
self._inp_layers = nn.Sequential(*inp_layers) self._inp_layers = nn.Sequential(*inp_layers)
@ -82,8 +83,8 @@ class RSSM(nn.Module):
inp_dim = self._deter inp_dim = self._deter
for i in range(self._layers_output): for i in range(self._layers_output):
img_out_layers.append(nn.Linear(inp_dim, self._hidden, bias=False)) img_out_layers.append(nn.Linear(inp_dim, self._hidden, bias=False))
img_out_layers.append(self._norm(self._hidden, eps=1e-03)) img_out_layers.append(norm(self._hidden, eps=1e-03))
img_out_layers.append(self._act()) img_out_layers.append(act())
if i == 0: if i == 0:
inp_dim = self._hidden inp_dim = self._hidden
self._img_out_layers = nn.Sequential(*img_out_layers) self._img_out_layers = nn.Sequential(*img_out_layers)
@ -96,8 +97,8 @@ class RSSM(nn.Module):
inp_dim = self._embed inp_dim = self._embed
for i in range(self._layers_output): for i in range(self._layers_output):
obs_out_layers.append(nn.Linear(inp_dim, self._hidden, bias=False)) obs_out_layers.append(nn.Linear(inp_dim, self._hidden, bias=False))
obs_out_layers.append(self._norm(self._hidden, eps=1e-03)) obs_out_layers.append(norm(self._hidden, eps=1e-03))
obs_out_layers.append(self._act()) obs_out_layers.append(act())
if i == 0: if i == 0:
inp_dim = self._hidden inp_dim = self._hidden
self._obs_out_layers = nn.Sequential(*obs_out_layers) self._obs_out_layers = nn.Sequential(*obs_out_layers)
@ -327,28 +328,156 @@ class RSSM(nn.Module):
return loss, value, dyn_loss, rep_loss return loss, value, dyn_loss, rep_loss
class ConvEncoder(nn.Module): class MultiEncoder(nn.Module):
def __init__( def __init__(
self, self,
grayscale=False, shapes,
depth=32, mlp_keys,
act=nn.ELU, cnn_keys,
norm=nn.LayerNorm, act,
kernels=(3, 3, 3, 3), norm,
cnn_depth,
cnn_kernels,
mlp_layers,
mlp_units,
symlog_inputs,
):
super(MultiEncoder, self).__init__()
self.cnn_shapes = {
k: v for k, v in shapes.items() if len(v) == 3 and re.match(cnn_keys, k)
}
self.mlp_shapes = {
k: v
for k, v in shapes.items()
if len(v) in (1, 2) and re.match(mlp_keys, k)
}
print("Encoder CNN shapes:", self.cnn_shapes)
print("Encoder MLP shapes:", self.mlp_shapes)
self.outdim = 0
if self.cnn_shapes:
input_ch = sum([v[-1] for v in self.cnn_shapes.values()])
self._cnn = ConvEncoder(input_ch, cnn_depth, act, norm, cnn_kernels)
self.outdim += self._cnn.outdim
if self.mlp_shapes:
input_size = sum([sum(v) for v in self.mlp_shapes.values()])
self._mlp = MLP(
input_size,
None,
mlp_layers,
mlp_units,
act,
norm,
symlog_inputs=symlog_inputs,
)
self.outdim += mlp_units
def forward(self, obs):
outputs = []
if self.cnn_shapes:
inputs = torch.cat([obs[k] for k in self.cnn_shapes], -1)
outputs.append(self._cnn(inputs))
if self.mlp_shapes:
inputs = torch.cat([obs[k] for k in self.mlp_shapes], -1)
outputs.append(self._mlp(inputs))
outputs = torch.cat(outputs, -1)
return outputs
class MultiDecoder(nn.Module):
def __init__(
self,
feat_size,
shapes,
mlp_keys,
cnn_keys,
act,
norm,
cnn_depth,
cnn_kernels,
mlp_layers,
mlp_units,
cnn_sigmoid,
image_dist,
vector_dist,
):
super(MultiDecoder, self).__init__()
self.cnn_shapes = {
k: v for k, v in shapes.items() if len(v) == 3 and re.match(cnn_keys, k)
}
self.mlp_shapes = {
k: v
for k, v in shapes.items()
if len(v) in (1, 2) and re.match(mlp_keys, k)
}
print("Decoder CNN shapes:", self.cnn_shapes)
print("Decoder MLP shapes:", self.mlp_shapes)
if self.cnn_shapes:
some_shape = list(self.cnn_shapes.values())[0]
shape = (sum(x[-1] for x in self.cnn_shapes.values()),) + some_shape[:-1]
self._cnn = ConvDecoder(
feat_size,
shape,
cnn_depth,
act,
norm,
cnn_kernels,
cnn_sigmoid=cnn_sigmoid,
)
if self.mlp_shapes:
self._mlp = MLP(
feat_size,
self.mlp_shapes,
mlp_layers,
mlp_units,
act,
norm,
vector_dist,
)
self._image_dist = image_dist
def forward(self, features):
dists = {}
if self.cnn_shapes:
feat = features
outputs = self._cnn(feat)
split_sizes = [v[-1] for v in self.cnn_shapes.values()]
outputs = torch.split(outputs, split_sizes, -1)
dists.update(
{
key: self._make_image_dist(output)
for key, output in zip(self.cnn_shapes.keys(), outputs)
}
)
if self.mlp_shapes:
dists.update(self._mlp(features))
return dists
def _make_image_dist(self, mean):
if self._image_dist == "normal":
return tools.ContDist(
torchd.independent.Independent(torchd.normal.Normal(mean, 1), 3)
)
if self._image_dist == "mse":
return tools.MSEDist(mean)
raise NotImplementedError(self._image_dist)
class ConvEncoder(nn.Module):
def __init__(
self, input_ch, depth=32, act="SiLU", norm="LayerNorm", kernels=(3, 3, 3, 3)
): ):
super(ConvEncoder, self).__init__() super(ConvEncoder, self).__init__()
self._act = act act = getattr(torch.nn, act)
self._norm = norm norm = getattr(torch.nn, norm)
self._depth = depth self._depth = depth
self._kernels = kernels self._kernels = kernels
h, w = 64, 64 h, w = 64, 64
layers = [] layers = []
for i, kernel in enumerate(self._kernels): for i, kernel in enumerate(self._kernels):
if i == 0: if i == 0:
if grayscale: inp_dim = input_ch
inp_dim = 1
else:
inp_dim = 3
else: else:
inp_dim = 2 ** (i - 1) * self._depth inp_dim = 2 ** (i - 1) * self._depth
depth = 2**i * self._depth depth = 2**i * self._depth
@ -365,37 +494,42 @@ class ConvEncoder(nn.Module):
layers.append(act()) layers.append(act())
h, w = h // 2, w // 2 h, w = h // 2, w // 2
self.outdim = depth * h * w
self.layers = nn.Sequential(*layers) self.layers = nn.Sequential(*layers)
self.layers.apply(tools.weight_init) self.layers.apply(tools.weight_init)
def __call__(self, obs): def forward(self, obs):
x = obs["image"].reshape((-1,) + tuple(obs["image"].shape[-3:])) # (batch, time, h, w, ch) -> (batch * time, h, w, ch)
x = obs.reshape((-1,) + tuple(obs.shape[-3:]))
# (batch * time, h, w, ch) -> (batch * time, ch, h, w)
x = x.permute(0, 3, 1, 2) x = x.permute(0, 3, 1, 2)
x = self.layers(x) x = self.layers(x)
# prod: product of all elements # (batch * time, ...) -> (batch * time, -1)
x = x.reshape([x.shape[0], np.prod(x.shape[1:])]) x = x.reshape([x.shape[0], np.prod(x.shape[1:])])
shape = list(obs["image"].shape[:-3]) + [x.shape[-1]] # (batch * time, -1) -> (batch, time, -1)
return x.reshape(shape) return x.reshape(list(obs.shape[:-3]) + [x.shape[-1]])
class ConvDecoder(nn.Module): class ConvDecoder(nn.Module):
def __init__( def __init__(
self, self,
inp_depth, inp_depth,
shape=(3, 64, 64),
depth=32, depth=32,
act=nn.ELU, act=nn.ELU,
norm=nn.LayerNorm, norm=nn.LayerNorm,
shape=(3, 64, 64),
kernels=(3, 3, 3, 3), kernels=(3, 3, 3, 3),
outscale=1.0, outscale=1.0,
cnn_sigmoid=False,
): ):
super(ConvDecoder, self).__init__() super(ConvDecoder, self).__init__()
self._inp_depth = inp_depth self._inp_depth = inp_depth
self._act = act act = getattr(torch.nn, act)
self._norm = norm norm = getattr(torch.nn, norm)
self._depth = depth self._depth = depth
self._shape = shape self._shape = shape
self._kernels = kernels self._kernels = kernels
self._cnn_sigmoid = cnn_sigmoid
self._embed_size = ( self._embed_size = (
(64 // 2 ** (len(kernels))) ** 2 * depth * 2 ** (len(kernels) - 1) (64 // 2 ** (len(kernels))) ** 2 * depth * 2 ** (len(kernels) - 1)
) )
@ -407,7 +541,6 @@ class ConvDecoder(nn.Module):
h, w = 4, 4 h, w = 4, 4
for i, kernel in enumerate(self._kernels): for i, kernel in enumerate(self._kernels):
depth = self._embed_size // 16 // (2 ** (i + 1)) depth = self._embed_size // 16 // (2 ** (i + 1))
act = self._act
bias = False bias = False
initializer = tools.weight_init initializer = tools.weight_init
if i == len(self._kernels) - 1: if i == len(self._kernels) - 1:
@ -447,88 +580,125 @@ class ConvDecoder(nn.Module):
outpad = pad * 2 - val outpad = pad * 2 - val
return pad, outpad return pad, outpad
def __call__(self, features, dtype=None): def forward(self, features, dtype=None):
x = self._linear_layer(features) x = self._linear_layer(features)
# (batch, time, -1) -> (batch * time, h, w, ch)
x = x.reshape([-1, 4, 4, self._embed_size // 16]) x = x.reshape([-1, 4, 4, self._embed_size // 16])
# (batch, time, -1) -> (batch * time, ch, h, w)
x = x.permute(0, 3, 1, 2) x = x.permute(0, 3, 1, 2)
x = self.layers(x) x = self.layers(x)
# (batch, time, -1) -> (batch * time, ch, h, w) necessary???
mean = x.reshape(features.shape[:-1] + self._shape) mean = x.reshape(features.shape[:-1] + self._shape)
# (batch * time, ch, h, w) -> (batch * time, h, w, ch)
mean = mean.permute(0, 1, 3, 4, 2) mean = mean.permute(0, 1, 3, 4, 2)
return tools.SymlogDist(mean) if self._cnn_sigmoid:
mean = F.sigmoid(mean) - 0.5
return mean
class DenseHead(nn.Module): class MLP(nn.Module):
def __init__( def __init__(
self, self,
inp_dim, inp_dim,
shape, shape,
layers, layers,
units, units,
act=nn.ELU, act="SiLU",
norm=nn.LayerNorm, norm="LayerNorm",
dist="normal", dist="normal",
std=1.0, std=1.0,
outscale=1.0, outscale=1.0,
symlog_inputs=False,
device="cuda", device="cuda",
): ):
super(DenseHead, self).__init__() super(MLP, self).__init__()
self._shape = (shape,) if isinstance(shape, int) else shape self._shape = (shape,) if isinstance(shape, int) else shape
if len(self._shape) == 0: if self._shape is not None and len(self._shape) == 0:
self._shape = (1,) self._shape = (1,)
self._layers = layers self._layers = layers
self._units = units act = getattr(torch.nn, act)
self._act = act norm = getattr(torch.nn, norm)
self._norm = norm
self._dist = dist self._dist = dist
self._std = std self._std = std
self._symlog_inputs = symlog_inputs
self._device = device self._device = device
layers = [] layers = []
for index in range(self._layers): for index in range(self._layers):
layers.append(nn.Linear(inp_dim, self._units, bias=False)) layers.append(nn.Linear(inp_dim, units, bias=False))
layers.append(norm(self._units, eps=1e-03)) layers.append(norm(units, eps=1e-03))
layers.append(act()) layers.append(act())
if index == 0: if index == 0:
inp_dim = self._units inp_dim = units
self.layers = nn.Sequential(*layers) self.layers = nn.Sequential(*layers)
self.layers.apply(tools.weight_init) self.layers.apply(tools.weight_init)
self.mean_layer = nn.Linear(inp_dim, np.prod(self._shape)) if isinstance(self._shape, dict):
self.mean_layer.apply(tools.uniform_weight_init(outscale)) self.mean_layer = nn.ModuleDict()
for name, shape in self._shape.items():
self.mean_layer[name] = nn.Linear(inp_dim, np.prod(shape))
self.mean_layer.apply(tools.uniform_weight_init(outscale))
if self._std == "learned":
self.std_layer = nn.ModuleDict()
for name, shape in self._shape.items():
self.std_layer[name] = nn.Linear(inp_dim, np.prod(shape))
self.std_layer.apply(tools.uniform_weight_init(outscale))
elif self._shape is not None:
self.mean_layer = nn.Linear(inp_dim, np.prod(self._shape))
self.mean_layer.apply(tools.uniform_weight_init(outscale))
if self._std == "learned":
self.std_layer = nn.Linear(units, np.prod(self._shape))
self.std_layer.apply(tools.uniform_weight_init(outscale))
if self._std == "learned": def forward(self, features, dtype=None):
self.std_layer = nn.Linear(self._units, np.prod(self._shape))
self.std_layer.apply(tools.uniform_weight_init(outscale))
def __call__(self, features, dtype=None):
x = features x = features
if self._symlog_inputs:
x = tools.symlog(x)
out = self.layers(x) out = self.layers(x)
mean = self.mean_layer(out) if self._shape is None:
if self._std == "learned": return out
std = self.std_layer(out) if isinstance(self._shape, dict):
dists = {}
for name, shape in self._shape.items():
mean = self.mean_layer[name](out)
if self._std == "learned":
std = self.std_layer[name](out)
else:
std = self._std
dists.update({name: self.dist(self._dist, mean, std, shape)})
return dists
else: else:
std = self._std mean = self.mean_layer(out)
if self._dist == "normal": if self._std == "learned":
std = self.std_layer(out)
else:
std = self._std
return self.dist(self._dist, mean, std, self._shape)
def dist(self, dist, mean, std, shape):
if dist == "normal":
return tools.ContDist( return tools.ContDist(
torchd.independent.Independent( torchd.independent.Independent(
torchd.normal.Normal(mean, std), len(self._shape) torchd.normal.Normal(mean, std), len(shape)
) )
) )
if self._dist == "huber": if dist == "huber":
return tools.ContDist( return tools.ContDist(
torchd.independent.Independent( torchd.independent.Independent(
tools.UnnormalizedHuber(mean, std, 1.0), len(self._shape) tools.UnnormalizedHuber(mean, std, 1.0), len(shape)
) )
) )
if self._dist == "binary": if dist == "binary":
return tools.Bernoulli( return tools.Bernoulli(
torchd.independent.Independent( torchd.independent.Independent(
torchd.bernoulli.Bernoulli(logits=mean), len(self._shape) torchd.bernoulli.Bernoulli(logits=mean), len(shape)
) )
) )
if self._dist == "twohot_symlog": if dist == "symlog_disc":
return tools.TwoHotDistSymlog(logits=mean, device=self._device) return tools.DiscDist(logits=mean, device=self._device)
raise NotImplementedError(self._dist) if dist == "symlog_mse":
return tools.SymlogDist(mean)
raise NotImplementedError(dist)
class ActionHead(nn.Module): class ActionHead(nn.Module):
@ -553,8 +723,8 @@ class ActionHead(nn.Module):
self._layers = layers self._layers = layers
self._units = units self._units = units
self._dist = dist self._dist = dist
self._act = act act = getattr(torch.nn, act)
self._norm = norm norm = getattr(torch.nn, norm)
self._min_std = min_std self._min_std = min_std
self._max_std = max_std self._max_std = max_std
self._init_std = init_std self._init_std = init_std
@ -579,7 +749,7 @@ class ActionHead(nn.Module):
self._dist_layer = nn.Linear(self._units, self._size) self._dist_layer = nn.Linear(self._units, self._size)
self._dist_layer.apply(tools.uniform_weight_init(outscale)) self._dist_layer.apply(tools.uniform_weight_init(outscale))
def __call__(self, features, dtype=None): def forward(self, features, dtype=None):
x = features x = features
x = self._pre_layers(x) x = self._pre_layers(x)
if self._dist == "tanh_normal": if self._dist == "tanh_normal":

View File

@ -320,24 +320,34 @@ class OneHotDist(torchd.one_hot_categorical.OneHotCategorical):
return sample return sample
class TwoHotDistSymlog: class DiscDist:
def __init__(self, logits=None, low=-20.0, high=20.0, device="cuda"): def __init__(
self,
logits,
low=-20.0,
high=20.0,
transfwd=symlog,
transbwd=symexp,
device="cuda",
):
self.logits = logits self.logits = logits
self.probs = torch.softmax(logits, -1) self.probs = torch.softmax(logits, -1)
self.buckets = torch.linspace(low, high, steps=255).to(device) self.buckets = torch.linspace(low, high, steps=255).to(device)
self.width = (self.buckets[-1] - self.buckets[0]) / 255 self.width = (self.buckets[-1] - self.buckets[0]) / 255
self.transfwd = transfwd
self.transbwd = transbwd
def mean(self): def mean(self):
_mean = self.probs * self.buckets _mean = self.probs * self.buckets
return symexp(torch.sum(_mean, dim=-1, keepdim=True)) return self.transbwd(torch.sum(_mean, dim=-1, keepdim=True))
def mode(self): def mode(self):
_mode = self.probs * self.buckets _mode = self.probs * self.buckets
return symexp(torch.sum(_mode, dim=-1, keepdim=True)) return self.transbwd(torch.sum(_mode, dim=-1, keepdim=True))
# Inside OneHotCategorical, log_prob is calculated using only max element in targets # Inside OneHotCategorical, log_prob is calculated using only max element in targets
def log_prob(self, x): def log_prob(self, x):
x = symlog(x) x = self.transfwd(x)
# x(time, batch, 1) # x(time, batch, 1)
below = torch.sum((self.buckets <= x[..., None]).to(torch.int32), dim=-1) - 1 below = torch.sum((self.buckets <= x[..., None]).to(torch.int32), dim=-1) - 1
above = len(self.buckets) - torch.sum( above = len(self.buckets) - torch.sum(
@ -366,15 +376,35 @@ class TwoHotDistSymlog:
return (target * log_pred).sum(-1) return (target * log_pred).sum(-1)
class MSEDist:
def __init__(self, mode, agg="sum"):
self._mode = mode
self._agg = agg
def mode(self):
return self._mode
def mean(self):
return self._mode
def log_prob(self, value):
assert self._mode.shape == value.shape, (self._mode.shape, value.shape)
distance = (self._mode - value) ** 2
if self._agg == "mean":
loss = distance.mean(list(range(len(distance.shape)))[2:])
elif self._agg == "sum":
loss = distance.sum(list(range(len(distance.shape)))[2:])
else:
raise NotImplementedError(self._agg)
return -loss
class SymlogDist: class SymlogDist:
def __init__( def __init__(self, mode, dist="mse", agg="sum", tol=1e-8):
self, mode, dist="mse", agg="sum", tol=1e-8, dim_to_reduce=[-1, -2, -3]
):
self._mode = mode self._mode = mode
self._dist = dist self._dist = dist
self._agg = agg self._agg = agg
self._tol = tol self._tol = tol
self._dim_to_reduce = dim_to_reduce
def mode(self): def mode(self):
return symexp(self._mode) return symexp(self._mode)
@ -393,9 +423,9 @@ class SymlogDist:
else: else:
raise NotImplementedError(self._dist) raise NotImplementedError(self._dist)
if self._agg == "mean": if self._agg == "mean":
loss = distance.mean(self._dim_to_reduce) loss = distance.mean(list(range(len(distance.shape)))[2:])
elif self._agg == "sum": elif self._agg == "sum":
loss = distance.sum(self._dim_to_reduce) loss = distance.sum(list(range(len(distance.shape)))[2:])
else: else:
raise NotImplementedError(self._agg) raise NotImplementedError(self._agg)
return -loss return -loss