added state input capability
This commit is contained in:
parent
3ebb8ad617
commit
b984e69b6e
14
README.md
14
README.md
@ -7,20 +7,24 @@ Get dependencies:
|
|||||||
```
|
```
|
||||||
pip install -r requirements.txt
|
pip install -r requirements.txt
|
||||||
```
|
```
|
||||||
Train the agent on Walker Walk in Vision DMC:
|
Train the agent on Walker Walk in DMC Vision:
|
||||||
```
|
```
|
||||||
python3 dreamer.py --configs defaults --task dmc_walker_walk --logdir ~/dreamerv3-torch/logdir/dmc_walker_walk
|
python3 dreamer.py --configs dmc_vision --task dmc_walker_walk --logdir ./logdir/dmc_walker_walk
|
||||||
|
```
|
||||||
|
Train the agent on Walker Walk in DMC Proprio:
|
||||||
|
```
|
||||||
|
python3 dreamer.py --configs dmc_proprio --task dmc_walker_walk --logdir ./logdir/dmc_walker_walk
|
||||||
```
|
```
|
||||||
Train the agent on Alien in Atari 100K:
|
Train the agent on Alien in Atari 100K:
|
||||||
```
|
```
|
||||||
python3 dreamer.py --configs defaults atari100k --task atari_alien --logdir ~/dreamerv3-torch/logdir/atari_alien
|
python3 dreamer.py --configs atari100k --task atari_alien --logdir ./logdir/atari_alien
|
||||||
```
|
```
|
||||||
Monitor results:
|
Monitor results:
|
||||||
```
|
```
|
||||||
tensorboard --logdir ~/dreamerv3-torch/logdir
|
tensorboard --logdir ~/dreamerv3-torch/logdir
|
||||||
```
|
```
|
||||||
|
|
||||||
## Evaluation Results
|
## Results
|
||||||
More results will be added in the future.
|
More results will be added in the future.
|
||||||
|
|
||||||

|

|
||||||
@ -30,7 +34,7 @@ More results will be added in the future.
|
|||||||
- [x] Modify implementation details based on the author's implementation
|
- [x] Modify implementation details based on the author's implementation
|
||||||
- [x] Evaluate on DMC vision
|
- [x] Evaluate on DMC vision
|
||||||
- [x] Evaluate on Atari 100K
|
- [x] Evaluate on Atari 100K
|
||||||
- [ ] Add state input capability
|
- [x] Add state input capability
|
||||||
- [ ] Evaluate on DMC Proprio
|
- [ ] Evaluate on DMC Proprio
|
||||||
- [ ] etc.
|
- [ ] etc.
|
||||||
|
|
||||||
|
29
configs.yaml
29
configs.yaml
@ -1,4 +1,3 @@
|
|||||||
# defaults is for Vision DMC
|
|
||||||
defaults:
|
defaults:
|
||||||
|
|
||||||
logdir: null
|
logdir: null
|
||||||
@ -17,6 +16,7 @@ defaults:
|
|||||||
precision: 16
|
precision: 16
|
||||||
debug: False
|
debug: False
|
||||||
expl_gifs: False
|
expl_gifs: False
|
||||||
|
video_pred_log: True
|
||||||
|
|
||||||
# Environment
|
# Environment
|
||||||
task: 'dmc_walker_walk'
|
task: 'dmc_walker_walk'
|
||||||
@ -43,7 +43,7 @@ defaults:
|
|||||||
dyn_std_act: 'sigmoid2'
|
dyn_std_act: 'sigmoid2'
|
||||||
dyn_min_std: 0.1
|
dyn_min_std: 0.1
|
||||||
dyn_temp_post: True
|
dyn_temp_post: True
|
||||||
grad_heads: ['image', 'reward', 'cont']
|
grad_heads: ['decoder', 'reward', 'cont']
|
||||||
units: 512
|
units: 512
|
||||||
reward_layers: 2
|
reward_layers: 2
|
||||||
cont_layers: 2
|
cont_layers: 2
|
||||||
@ -51,11 +51,12 @@ defaults:
|
|||||||
actor_layers: 2
|
actor_layers: 2
|
||||||
act: 'SiLU'
|
act: 'SiLU'
|
||||||
norm: 'LayerNorm'
|
norm: 'LayerNorm'
|
||||||
cnn_depth: 32
|
encoder:
|
||||||
encoder_kernels: [4, 4, 4, 4]
|
{mlp_keys: '$^', cnn_keys: 'image', act: 'SiLU', norm: 'LayerNorm', cnn_depth: 32, cnn_kernels: [4, 4, 4, 4], mlp_layers: 2, mlp_units: 512, symlog_inputs: True}
|
||||||
decoder_kernels: [4, 4, 4, 4]
|
decoder:
|
||||||
value_head: 'twohot_symlog'
|
{mlp_keys: '$^', cnn_keys: 'image', act: 'SiLU', norm: 'LayerNorm', cnn_depth: 32, cnn_kernels: [4, 4, 4, 4], mlp_layers: 2, mlp_units: 512, cnn_sigmoid: False, image_dist: mse, vector_dist: symlog_mse,}
|
||||||
reward_head: 'twohot_symlog'
|
value_head: 'symlog_disc'
|
||||||
|
reward_head: 'symlog_disc'
|
||||||
dyn_scale: '0.5'
|
dyn_scale: '0.5'
|
||||||
rep_scale: '0.1'
|
rep_scale: '0.1'
|
||||||
kl_free: '1.0'
|
kl_free: '1.0'
|
||||||
@ -119,6 +120,20 @@ defaults:
|
|||||||
disag_units: 400
|
disag_units: 400
|
||||||
disag_action_cond: False
|
disag_action_cond: False
|
||||||
|
|
||||||
|
dmc_vision:
|
||||||
|
steps: 1e6
|
||||||
|
train_ratio: 512
|
||||||
|
video_pred_log: true
|
||||||
|
encoder: {mlp_keys: '$^', cnn_keys: 'image'}
|
||||||
|
decoder: {mlp_keys: '$^', cnn_keys: 'image'}
|
||||||
|
|
||||||
|
dmc_proprio:
|
||||||
|
steps: 5e5
|
||||||
|
train_ratio: 512
|
||||||
|
video_pred_log: false
|
||||||
|
encoder: {mlp_keys: '.*', cnn_keys: '$^'}
|
||||||
|
decoder: {mlp_keys: '.*', cnn_keys: '$^'}
|
||||||
|
|
||||||
atari100k:
|
atari100k:
|
||||||
steps: 4e5
|
steps: 4e5
|
||||||
action_repeat: 4
|
action_repeat: 4
|
||||||
|
39
dreamer.py
39
dreamer.py
@ -27,7 +27,7 @@ to_np = lambda x: x.detach().cpu().numpy()
|
|||||||
|
|
||||||
|
|
||||||
class Dreamer(nn.Module):
|
class Dreamer(nn.Module):
|
||||||
def __init__(self, config, logger, dataset):
|
def __init__(self, obs_space, act_space, config, logger, dataset):
|
||||||
super(Dreamer, self).__init__()
|
super(Dreamer, self).__init__()
|
||||||
self._config = config
|
self._config = config
|
||||||
self._logger = logger
|
self._logger = logger
|
||||||
@ -51,7 +51,7 @@ class Dreamer(nn.Module):
|
|||||||
x, self._step
|
x, self._step
|
||||||
)
|
)
|
||||||
self._dataset = dataset
|
self._dataset = dataset
|
||||||
self._wm = models.WorldModel(self._step, config)
|
self._wm = models.WorldModel(obs_space, act_space, self._step, config)
|
||||||
self._task_behavior = models.ImagBehavior(
|
self._task_behavior = models.ImagBehavior(
|
||||||
config, self._wm, config.behavior_stop_grad
|
config, self._wm, config.behavior_stop_grad
|
||||||
)
|
)
|
||||||
@ -90,8 +90,9 @@ class Dreamer(nn.Module):
|
|||||||
for name, values in self._metrics.items():
|
for name, values in self._metrics.items():
|
||||||
self._logger.scalar(name, float(np.mean(values)))
|
self._logger.scalar(name, float(np.mean(values)))
|
||||||
self._metrics[name] = []
|
self._metrics[name] = []
|
||||||
openl = self._wm.video_pred(next(self._dataset))
|
if self._config.video_pred_log:
|
||||||
self._logger.video("train_openl", to_np(openl))
|
openl = self._wm.video_pred(next(self._dataset))
|
||||||
|
self._logger.video("train_openl", to_np(openl))
|
||||||
self._logger.write(fps=True)
|
self._logger.write(fps=True)
|
||||||
|
|
||||||
policy_output, state = self._policy(obs, state, training)
|
policy_output, state = self._policy(obs, state, training)
|
||||||
@ -296,8 +297,6 @@ def main(config):
|
|||||||
config.eval_every //= config.action_repeat
|
config.eval_every //= config.action_repeat
|
||||||
config.log_every //= config.action_repeat
|
config.log_every //= config.action_repeat
|
||||||
config.time_limit //= config.action_repeat
|
config.time_limit //= config.action_repeat
|
||||||
config.act = getattr(torch.nn, config.act)
|
|
||||||
config.norm = getattr(torch.nn, config.norm)
|
|
||||||
|
|
||||||
print("Logdir", logdir)
|
print("Logdir", logdir)
|
||||||
logdir.mkdir(parents=True, exist_ok=True)
|
logdir.mkdir(parents=True, exist_ok=True)
|
||||||
@ -350,7 +349,13 @@ def main(config):
|
|||||||
print("Simulate agent.")
|
print("Simulate agent.")
|
||||||
train_dataset = make_dataset(train_eps, config)
|
train_dataset = make_dataset(train_eps, config)
|
||||||
eval_dataset = make_dataset(eval_eps, config)
|
eval_dataset = make_dataset(eval_eps, config)
|
||||||
agent = Dreamer(config, logger, train_dataset).to(config.device)
|
agent = Dreamer(
|
||||||
|
train_envs[0].observation_space,
|
||||||
|
train_envs[0].action_space,
|
||||||
|
config,
|
||||||
|
logger,
|
||||||
|
train_dataset,
|
||||||
|
).to(config.device)
|
||||||
agent.requires_grad_(requires_grad=False)
|
agent.requires_grad_(requires_grad=False)
|
||||||
if (logdir / "latest_model.pt").exists():
|
if (logdir / "latest_model.pt").exists():
|
||||||
agent.load_state_dict(torch.load(logdir / "latest_model.pt"))
|
agent.load_state_dict(torch.load(logdir / "latest_model.pt"))
|
||||||
@ -362,8 +367,9 @@ def main(config):
|
|||||||
print("Start evaluation.")
|
print("Start evaluation.")
|
||||||
eval_policy = functools.partial(agent, training=False)
|
eval_policy = functools.partial(agent, training=False)
|
||||||
tools.simulate(eval_policy, eval_envs, episodes=config.eval_episode_num)
|
tools.simulate(eval_policy, eval_envs, episodes=config.eval_episode_num)
|
||||||
video_pred = agent._wm.video_pred(next(eval_dataset))
|
if config.video_pred_log:
|
||||||
logger.video("eval_openl", to_np(video_pred))
|
video_pred = agent._wm.video_pred(next(eval_dataset))
|
||||||
|
logger.video("eval_openl", to_np(video_pred))
|
||||||
print("Start training.")
|
print("Start training.")
|
||||||
state = tools.simulate(agent, train_envs, config.eval_every, state=state)
|
state = tools.simulate(agent, train_envs, config.eval_every, state=state)
|
||||||
torch.save(agent.state_dict(), logdir / "latest_model.pt")
|
torch.save(agent.state_dict(), logdir / "latest_model.pt")
|
||||||
@ -376,14 +382,23 @@ def main(config):
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--configs", nargs="+", required=True)
|
parser.add_argument("--configs", nargs="+")
|
||||||
args, remaining = parser.parse_known_args()
|
args, remaining = parser.parse_known_args()
|
||||||
configs = yaml.safe_load(
|
configs = yaml.safe_load(
|
||||||
(pathlib.Path(sys.argv[0]).parent / "configs.yaml").read_text()
|
(pathlib.Path(sys.argv[0]).parent / "configs.yaml").read_text()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def recursive_update(base, update):
|
||||||
|
for key, value in update.items():
|
||||||
|
if isinstance(value, dict) and key in base:
|
||||||
|
recursive_update(base[key], value)
|
||||||
|
else:
|
||||||
|
base[key] = value
|
||||||
|
|
||||||
|
name_list = ["defaults", *args.configs] if args.configs else ["defaults"]
|
||||||
defaults = {}
|
defaults = {}
|
||||||
for name in args.configs:
|
for name in name_list:
|
||||||
defaults.update(configs[name])
|
recursive_update(defaults, configs[name])
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
for key, value in sorted(defaults.items(), key=lambda x: x[0]):
|
for key, value in sorted(defaults.items(), key=lambda x: x[0]):
|
||||||
arg_type = tools.args_type(value)
|
arg_type = tools.args_type(value)
|
||||||
|
@ -24,7 +24,11 @@ class DeepMindControl:
|
|||||||
def observation_space(self):
|
def observation_space(self):
|
||||||
spaces = {}
|
spaces = {}
|
||||||
for key, value in self._env.observation_spec().items():
|
for key, value in self._env.observation_spec().items():
|
||||||
spaces[key] = gym.spaces.Box(-np.inf, np.inf, value.shape, dtype=np.float32)
|
if len(value.shape) == 0:
|
||||||
|
shape = (1,)
|
||||||
|
else:
|
||||||
|
shape = value.shape
|
||||||
|
spaces[key] = gym.spaces.Box(-np.inf, np.inf, shape, dtype=np.float32)
|
||||||
spaces["image"] = gym.spaces.Box(0, 255, self._size + (3,), dtype=np.uint8)
|
spaces["image"] = gym.spaces.Box(0, 255, self._size + (3,), dtype=np.uint8)
|
||||||
return gym.spaces.Dict(spaces)
|
return gym.spaces.Dict(spaces)
|
||||||
|
|
||||||
@ -42,6 +46,7 @@ class DeepMindControl:
|
|||||||
if time_step.last():
|
if time_step.last():
|
||||||
break
|
break
|
||||||
obs = dict(time_step.observation)
|
obs = dict(time_step.observation)
|
||||||
|
obs = {key: [val] if len(val.shape) == 0 else val for key, val in obs.items()}
|
||||||
obs["image"] = self.render()
|
obs["image"] = self.render()
|
||||||
# There is no terminal state in DMC
|
# There is no terminal state in DMC
|
||||||
obs["is_terminal"] = False if time_step.first() else time_step.discount == 0
|
obs["is_terminal"] = False if time_step.first() else time_step.discount == 0
|
||||||
@ -53,6 +58,7 @@ class DeepMindControl:
|
|||||||
def reset(self):
|
def reset(self):
|
||||||
time_step = self._env.reset()
|
time_step = self._env.reset()
|
||||||
obs = dict(time_step.observation)
|
obs = dict(time_step.observation)
|
||||||
|
obs = {key: [val] if len(val.shape) == 0 else val for key, val in obs.items()}
|
||||||
obs["image"] = self.render()
|
obs["image"] = self.render()
|
||||||
obs["is_terminal"] = False if time_step.first() else time_step.discount == 0
|
obs["is_terminal"] = False if time_step.first() else time_step.discount == 0
|
||||||
obs["is_first"] = time_step.first()
|
obs["is_first"] = time_step.first()
|
||||||
|
@ -52,7 +52,7 @@ class Plan2Explore(nn.Module):
|
|||||||
act=config.act,
|
act=config.act,
|
||||||
)
|
)
|
||||||
self._networks = nn.ModuleList(
|
self._networks = nn.ModuleList(
|
||||||
[networks.DenseHead(**kw) for _ in range(config.disag_models)]
|
[networks.MLP(**kw) for _ in range(config.disag_models)]
|
||||||
)
|
)
|
||||||
kw = dict(wd=config.weight_decay, opt=config.opt, use_amp=self._use_amp)
|
kw = dict(wd=config.weight_decay, opt=config.opt, use_amp=self._use_amp)
|
||||||
self._model_opt = tools.Optimizer(
|
self._model_opt = tools.Optimizer(
|
||||||
|
63
models.py
63
models.py
@ -29,26 +29,14 @@ class RewardEMA(object):
|
|||||||
|
|
||||||
|
|
||||||
class WorldModel(nn.Module):
|
class WorldModel(nn.Module):
|
||||||
def __init__(self, step, config):
|
def __init__(self, obs_space, act_space, step, config):
|
||||||
super(WorldModel, self).__init__()
|
super(WorldModel, self).__init__()
|
||||||
self._step = step
|
self._step = step
|
||||||
self._use_amp = True if config.precision == 16 else False
|
self._use_amp = True if config.precision == 16 else False
|
||||||
self._config = config
|
self._config = config
|
||||||
self.encoder = networks.ConvEncoder(
|
shapes = {k: tuple(v.shape) for k, v in obs_space.spaces.items()}
|
||||||
config.grayscale,
|
self.encoder = networks.MultiEncoder(shapes, **config.encoder)
|
||||||
config.cnn_depth,
|
embed_size = self.encoder.outdim
|
||||||
config.act,
|
|
||||||
config.norm,
|
|
||||||
config.encoder_kernels,
|
|
||||||
)
|
|
||||||
if config.size[0] == 64 and config.size[1] == 64:
|
|
||||||
embed_size = (
|
|
||||||
(64 // 2 ** (len(config.encoder_kernels))) ** 2
|
|
||||||
* config.cnn_depth
|
|
||||||
* 2 ** (len(config.encoder_kernels) - 1)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise NotImplemented(f"{config.size} is not applicable now")
|
|
||||||
self.dynamics = networks.RSSM(
|
self.dynamics = networks.RSSM(
|
||||||
config.dyn_stoch,
|
config.dyn_stoch,
|
||||||
config.dyn_deter,
|
config.dyn_deter,
|
||||||
@ -72,22 +60,15 @@ class WorldModel(nn.Module):
|
|||||||
config.device,
|
config.device,
|
||||||
)
|
)
|
||||||
self.heads = nn.ModuleDict()
|
self.heads = nn.ModuleDict()
|
||||||
channels = 1 if config.grayscale else 3
|
|
||||||
shape = (channels,) + config.size
|
|
||||||
if config.dyn_discrete:
|
if config.dyn_discrete:
|
||||||
feat_size = config.dyn_stoch * config.dyn_discrete + config.dyn_deter
|
feat_size = config.dyn_stoch * config.dyn_discrete + config.dyn_deter
|
||||||
else:
|
else:
|
||||||
feat_size = config.dyn_stoch + config.dyn_deter
|
feat_size = config.dyn_stoch + config.dyn_deter
|
||||||
self.heads["image"] = networks.ConvDecoder(
|
self.heads["decoder"] = networks.MultiDecoder(
|
||||||
feat_size, # pytorch version
|
feat_size, shapes, **config.decoder
|
||||||
config.cnn_depth,
|
|
||||||
config.act,
|
|
||||||
config.norm,
|
|
||||||
shape,
|
|
||||||
config.decoder_kernels,
|
|
||||||
)
|
)
|
||||||
if config.reward_head == "twohot_symlog":
|
if config.reward_head == "symlog_disc":
|
||||||
self.heads["reward"] = networks.DenseHead(
|
self.heads["reward"] = networks.MLP(
|
||||||
feat_size, # pytorch version
|
feat_size, # pytorch version
|
||||||
(255,),
|
(255,),
|
||||||
config.reward_layers,
|
config.reward_layers,
|
||||||
@ -99,7 +80,7 @@ class WorldModel(nn.Module):
|
|||||||
device=config.device,
|
device=config.device,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.heads["reward"] = networks.DenseHead(
|
self.heads["reward"] = networks.MLP(
|
||||||
feat_size, # pytorch version
|
feat_size, # pytorch version
|
||||||
[],
|
[],
|
||||||
config.reward_layers,
|
config.reward_layers,
|
||||||
@ -110,7 +91,7 @@ class WorldModel(nn.Module):
|
|||||||
outscale=0.0,
|
outscale=0.0,
|
||||||
device=config.device,
|
device=config.device,
|
||||||
)
|
)
|
||||||
self.heads["cont"] = networks.DenseHead(
|
self.heads["cont"] = networks.MLP(
|
||||||
feat_size, # pytorch version
|
feat_size, # pytorch version
|
||||||
[],
|
[],
|
||||||
config.cont_layers,
|
config.cont_layers,
|
||||||
@ -153,15 +134,19 @@ class WorldModel(nn.Module):
|
|||||||
kl_loss, kl_value, dyn_loss, rep_loss = self.dynamics.kl_loss(
|
kl_loss, kl_value, dyn_loss, rep_loss = self.dynamics.kl_loss(
|
||||||
post, prior, kl_free, dyn_scale, rep_scale
|
post, prior, kl_free, dyn_scale, rep_scale
|
||||||
)
|
)
|
||||||
losses = {}
|
preds = {}
|
||||||
likes = {}
|
|
||||||
for name, head in self.heads.items():
|
for name, head in self.heads.items():
|
||||||
grad_head = name in self._config.grad_heads
|
grad_head = name in self._config.grad_heads
|
||||||
feat = self.dynamics.get_feat(post)
|
feat = self.dynamics.get_feat(post)
|
||||||
feat = feat if grad_head else feat.detach()
|
feat = feat if grad_head else feat.detach()
|
||||||
pred = head(feat)
|
pred = head(feat)
|
||||||
|
if type(pred) is dict:
|
||||||
|
preds.update(pred)
|
||||||
|
else:
|
||||||
|
preds[name] = pred
|
||||||
|
losses = {}
|
||||||
|
for name, pred in preds.items():
|
||||||
like = pred.log_prob(data[name])
|
like = pred.log_prob(data[name])
|
||||||
likes[name] = like
|
|
||||||
losses[name] = -torch.mean(like) * self._scales.get(name, 1.0)
|
losses[name] = -torch.mean(like) * self._scales.get(name, 1.0)
|
||||||
model_loss = sum(losses.values()) + kl_loss
|
model_loss = sum(losses.values()) + kl_loss
|
||||||
metrics = self._model_opt(model_loss, self.parameters())
|
metrics = self._model_opt(model_loss, self.parameters())
|
||||||
@ -213,11 +198,13 @@ class WorldModel(nn.Module):
|
|||||||
states, _ = self.dynamics.observe(
|
states, _ = self.dynamics.observe(
|
||||||
embed[:6, :5], data["action"][:6, :5], data["is_first"][:6, :5]
|
embed[:6, :5], data["action"][:6, :5], data["is_first"][:6, :5]
|
||||||
)
|
)
|
||||||
recon = self.heads["image"](self.dynamics.get_feat(states)).mode()[:6]
|
recon = self.heads["decoder"](self.dynamics.get_feat(states))["image"].mode()[
|
||||||
|
:6
|
||||||
|
]
|
||||||
reward_post = self.heads["reward"](self.dynamics.get_feat(states)).mode()[:6]
|
reward_post = self.heads["reward"](self.dynamics.get_feat(states)).mode()[:6]
|
||||||
init = {k: v[:, -1] for k, v in states.items()}
|
init = {k: v[:, -1] for k, v in states.items()}
|
||||||
prior = self.dynamics.imagine(data["action"][:6, 5:], init)
|
prior = self.dynamics.imagine(data["action"][:6, 5:], init)
|
||||||
openl = self.heads["image"](self.dynamics.get_feat(prior)).mode()
|
openl = self.heads["decoder"](self.dynamics.get_feat(prior))["image"].mode()
|
||||||
reward_prior = self.heads["reward"](self.dynamics.get_feat(prior)).mode()
|
reward_prior = self.heads["reward"](self.dynamics.get_feat(prior)).mode()
|
||||||
# observed image is given until 5 steps
|
# observed image is given until 5 steps
|
||||||
model = torch.cat([recon[:, :5], openl], 1)
|
model = torch.cat([recon[:, :5], openl], 1)
|
||||||
@ -254,9 +241,9 @@ class ImagBehavior(nn.Module):
|
|||||||
config.actor_temp,
|
config.actor_temp,
|
||||||
outscale=1.0,
|
outscale=1.0,
|
||||||
unimix_ratio=config.action_unimix_ratio,
|
unimix_ratio=config.action_unimix_ratio,
|
||||||
) # action_dist -> action_disc?
|
)
|
||||||
if config.value_head == "twohot_symlog":
|
if config.value_head == "symlog_disc":
|
||||||
self.value = networks.DenseHead(
|
self.value = networks.MLP(
|
||||||
feat_size, # pytorch version
|
feat_size, # pytorch version
|
||||||
(255,),
|
(255,),
|
||||||
config.value_layers,
|
config.value_layers,
|
||||||
@ -268,7 +255,7 @@ class ImagBehavior(nn.Module):
|
|||||||
device=config.device,
|
device=config.device,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.value = networks.DenseHead(
|
self.value = networks.MLP(
|
||||||
feat_size, # pytorch version
|
feat_size, # pytorch version
|
||||||
[],
|
[],
|
||||||
config.value_layers,
|
config.value_layers,
|
||||||
|
304
networks.py
304
networks.py
@ -1,5 +1,6 @@
|
|||||||
import math
|
import math
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import re
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
@ -20,8 +21,8 @@ class RSSM(nn.Module):
|
|||||||
rec_depth=1,
|
rec_depth=1,
|
||||||
shared=False,
|
shared=False,
|
||||||
discrete=False,
|
discrete=False,
|
||||||
act=nn.ELU,
|
act="SiLU",
|
||||||
norm=nn.LayerNorm,
|
norm="LayerNorm",
|
||||||
mean_act="none",
|
mean_act="none",
|
||||||
std_act="softplus",
|
std_act="softplus",
|
||||||
temp_post=True,
|
temp_post=True,
|
||||||
@ -43,8 +44,8 @@ class RSSM(nn.Module):
|
|||||||
self._rec_depth = rec_depth
|
self._rec_depth = rec_depth
|
||||||
self._shared = shared
|
self._shared = shared
|
||||||
self._discrete = discrete
|
self._discrete = discrete
|
||||||
self._act = act
|
act = getattr(torch.nn, act)
|
||||||
self._norm = norm
|
norm = getattr(torch.nn, norm)
|
||||||
self._mean_act = mean_act
|
self._mean_act = mean_act
|
||||||
self._std_act = std_act
|
self._std_act = std_act
|
||||||
self._temp_post = temp_post
|
self._temp_post = temp_post
|
||||||
@ -62,8 +63,8 @@ class RSSM(nn.Module):
|
|||||||
inp_dim += self._embed
|
inp_dim += self._embed
|
||||||
for i in range(self._layers_input):
|
for i in range(self._layers_input):
|
||||||
inp_layers.append(nn.Linear(inp_dim, self._hidden, bias=False))
|
inp_layers.append(nn.Linear(inp_dim, self._hidden, bias=False))
|
||||||
inp_layers.append(self._norm(self._hidden, eps=1e-03))
|
inp_layers.append(norm(self._hidden, eps=1e-03))
|
||||||
inp_layers.append(self._act())
|
inp_layers.append(act())
|
||||||
if i == 0:
|
if i == 0:
|
||||||
inp_dim = self._hidden
|
inp_dim = self._hidden
|
||||||
self._inp_layers = nn.Sequential(*inp_layers)
|
self._inp_layers = nn.Sequential(*inp_layers)
|
||||||
@ -82,8 +83,8 @@ class RSSM(nn.Module):
|
|||||||
inp_dim = self._deter
|
inp_dim = self._deter
|
||||||
for i in range(self._layers_output):
|
for i in range(self._layers_output):
|
||||||
img_out_layers.append(nn.Linear(inp_dim, self._hidden, bias=False))
|
img_out_layers.append(nn.Linear(inp_dim, self._hidden, bias=False))
|
||||||
img_out_layers.append(self._norm(self._hidden, eps=1e-03))
|
img_out_layers.append(norm(self._hidden, eps=1e-03))
|
||||||
img_out_layers.append(self._act())
|
img_out_layers.append(act())
|
||||||
if i == 0:
|
if i == 0:
|
||||||
inp_dim = self._hidden
|
inp_dim = self._hidden
|
||||||
self._img_out_layers = nn.Sequential(*img_out_layers)
|
self._img_out_layers = nn.Sequential(*img_out_layers)
|
||||||
@ -96,8 +97,8 @@ class RSSM(nn.Module):
|
|||||||
inp_dim = self._embed
|
inp_dim = self._embed
|
||||||
for i in range(self._layers_output):
|
for i in range(self._layers_output):
|
||||||
obs_out_layers.append(nn.Linear(inp_dim, self._hidden, bias=False))
|
obs_out_layers.append(nn.Linear(inp_dim, self._hidden, bias=False))
|
||||||
obs_out_layers.append(self._norm(self._hidden, eps=1e-03))
|
obs_out_layers.append(norm(self._hidden, eps=1e-03))
|
||||||
obs_out_layers.append(self._act())
|
obs_out_layers.append(act())
|
||||||
if i == 0:
|
if i == 0:
|
||||||
inp_dim = self._hidden
|
inp_dim = self._hidden
|
||||||
self._obs_out_layers = nn.Sequential(*obs_out_layers)
|
self._obs_out_layers = nn.Sequential(*obs_out_layers)
|
||||||
@ -327,28 +328,156 @@ class RSSM(nn.Module):
|
|||||||
return loss, value, dyn_loss, rep_loss
|
return loss, value, dyn_loss, rep_loss
|
||||||
|
|
||||||
|
|
||||||
class ConvEncoder(nn.Module):
|
class MultiEncoder(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
grayscale=False,
|
shapes,
|
||||||
depth=32,
|
mlp_keys,
|
||||||
act=nn.ELU,
|
cnn_keys,
|
||||||
norm=nn.LayerNorm,
|
act,
|
||||||
kernels=(3, 3, 3, 3),
|
norm,
|
||||||
|
cnn_depth,
|
||||||
|
cnn_kernels,
|
||||||
|
mlp_layers,
|
||||||
|
mlp_units,
|
||||||
|
symlog_inputs,
|
||||||
|
):
|
||||||
|
super(MultiEncoder, self).__init__()
|
||||||
|
self.cnn_shapes = {
|
||||||
|
k: v for k, v in shapes.items() if len(v) == 3 and re.match(cnn_keys, k)
|
||||||
|
}
|
||||||
|
self.mlp_shapes = {
|
||||||
|
k: v
|
||||||
|
for k, v in shapes.items()
|
||||||
|
if len(v) in (1, 2) and re.match(mlp_keys, k)
|
||||||
|
}
|
||||||
|
print("Encoder CNN shapes:", self.cnn_shapes)
|
||||||
|
print("Encoder MLP shapes:", self.mlp_shapes)
|
||||||
|
|
||||||
|
self.outdim = 0
|
||||||
|
if self.cnn_shapes:
|
||||||
|
input_ch = sum([v[-1] for v in self.cnn_shapes.values()])
|
||||||
|
self._cnn = ConvEncoder(input_ch, cnn_depth, act, norm, cnn_kernels)
|
||||||
|
self.outdim += self._cnn.outdim
|
||||||
|
if self.mlp_shapes:
|
||||||
|
input_size = sum([sum(v) for v in self.mlp_shapes.values()])
|
||||||
|
self._mlp = MLP(
|
||||||
|
input_size,
|
||||||
|
None,
|
||||||
|
mlp_layers,
|
||||||
|
mlp_units,
|
||||||
|
act,
|
||||||
|
norm,
|
||||||
|
symlog_inputs=symlog_inputs,
|
||||||
|
)
|
||||||
|
self.outdim += mlp_units
|
||||||
|
|
||||||
|
def forward(self, obs):
|
||||||
|
outputs = []
|
||||||
|
if self.cnn_shapes:
|
||||||
|
inputs = torch.cat([obs[k] for k in self.cnn_shapes], -1)
|
||||||
|
outputs.append(self._cnn(inputs))
|
||||||
|
if self.mlp_shapes:
|
||||||
|
inputs = torch.cat([obs[k] for k in self.mlp_shapes], -1)
|
||||||
|
outputs.append(self._mlp(inputs))
|
||||||
|
outputs = torch.cat(outputs, -1)
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
class MultiDecoder(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
feat_size,
|
||||||
|
shapes,
|
||||||
|
mlp_keys,
|
||||||
|
cnn_keys,
|
||||||
|
act,
|
||||||
|
norm,
|
||||||
|
cnn_depth,
|
||||||
|
cnn_kernels,
|
||||||
|
mlp_layers,
|
||||||
|
mlp_units,
|
||||||
|
cnn_sigmoid,
|
||||||
|
image_dist,
|
||||||
|
vector_dist,
|
||||||
|
):
|
||||||
|
super(MultiDecoder, self).__init__()
|
||||||
|
self.cnn_shapes = {
|
||||||
|
k: v for k, v in shapes.items() if len(v) == 3 and re.match(cnn_keys, k)
|
||||||
|
}
|
||||||
|
self.mlp_shapes = {
|
||||||
|
k: v
|
||||||
|
for k, v in shapes.items()
|
||||||
|
if len(v) in (1, 2) and re.match(mlp_keys, k)
|
||||||
|
}
|
||||||
|
print("Decoder CNN shapes:", self.cnn_shapes)
|
||||||
|
print("Decoder MLP shapes:", self.mlp_shapes)
|
||||||
|
|
||||||
|
if self.cnn_shapes:
|
||||||
|
some_shape = list(self.cnn_shapes.values())[0]
|
||||||
|
shape = (sum(x[-1] for x in self.cnn_shapes.values()),) + some_shape[:-1]
|
||||||
|
self._cnn = ConvDecoder(
|
||||||
|
feat_size,
|
||||||
|
shape,
|
||||||
|
cnn_depth,
|
||||||
|
act,
|
||||||
|
norm,
|
||||||
|
cnn_kernels,
|
||||||
|
cnn_sigmoid=cnn_sigmoid,
|
||||||
|
)
|
||||||
|
if self.mlp_shapes:
|
||||||
|
self._mlp = MLP(
|
||||||
|
feat_size,
|
||||||
|
self.mlp_shapes,
|
||||||
|
mlp_layers,
|
||||||
|
mlp_units,
|
||||||
|
act,
|
||||||
|
norm,
|
||||||
|
vector_dist,
|
||||||
|
)
|
||||||
|
self._image_dist = image_dist
|
||||||
|
|
||||||
|
def forward(self, features):
|
||||||
|
dists = {}
|
||||||
|
if self.cnn_shapes:
|
||||||
|
feat = features
|
||||||
|
outputs = self._cnn(feat)
|
||||||
|
split_sizes = [v[-1] for v in self.cnn_shapes.values()]
|
||||||
|
outputs = torch.split(outputs, split_sizes, -1)
|
||||||
|
dists.update(
|
||||||
|
{
|
||||||
|
key: self._make_image_dist(output)
|
||||||
|
for key, output in zip(self.cnn_shapes.keys(), outputs)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
if self.mlp_shapes:
|
||||||
|
dists.update(self._mlp(features))
|
||||||
|
return dists
|
||||||
|
|
||||||
|
def _make_image_dist(self, mean):
|
||||||
|
if self._image_dist == "normal":
|
||||||
|
return tools.ContDist(
|
||||||
|
torchd.independent.Independent(torchd.normal.Normal(mean, 1), 3)
|
||||||
|
)
|
||||||
|
if self._image_dist == "mse":
|
||||||
|
return tools.MSEDist(mean)
|
||||||
|
raise NotImplementedError(self._image_dist)
|
||||||
|
|
||||||
|
|
||||||
|
class ConvEncoder(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self, input_ch, depth=32, act="SiLU", norm="LayerNorm", kernels=(3, 3, 3, 3)
|
||||||
):
|
):
|
||||||
super(ConvEncoder, self).__init__()
|
super(ConvEncoder, self).__init__()
|
||||||
self._act = act
|
act = getattr(torch.nn, act)
|
||||||
self._norm = norm
|
norm = getattr(torch.nn, norm)
|
||||||
self._depth = depth
|
self._depth = depth
|
||||||
self._kernels = kernels
|
self._kernels = kernels
|
||||||
h, w = 64, 64
|
h, w = 64, 64
|
||||||
layers = []
|
layers = []
|
||||||
for i, kernel in enumerate(self._kernels):
|
for i, kernel in enumerate(self._kernels):
|
||||||
if i == 0:
|
if i == 0:
|
||||||
if grayscale:
|
inp_dim = input_ch
|
||||||
inp_dim = 1
|
|
||||||
else:
|
|
||||||
inp_dim = 3
|
|
||||||
else:
|
else:
|
||||||
inp_dim = 2 ** (i - 1) * self._depth
|
inp_dim = 2 ** (i - 1) * self._depth
|
||||||
depth = 2**i * self._depth
|
depth = 2**i * self._depth
|
||||||
@ -365,37 +494,42 @@ class ConvEncoder(nn.Module):
|
|||||||
layers.append(act())
|
layers.append(act())
|
||||||
h, w = h // 2, w // 2
|
h, w = h // 2, w // 2
|
||||||
|
|
||||||
|
self.outdim = depth * h * w
|
||||||
self.layers = nn.Sequential(*layers)
|
self.layers = nn.Sequential(*layers)
|
||||||
self.layers.apply(tools.weight_init)
|
self.layers.apply(tools.weight_init)
|
||||||
|
|
||||||
def __call__(self, obs):
|
def forward(self, obs):
|
||||||
x = obs["image"].reshape((-1,) + tuple(obs["image"].shape[-3:]))
|
# (batch, time, h, w, ch) -> (batch * time, h, w, ch)
|
||||||
|
x = obs.reshape((-1,) + tuple(obs.shape[-3:]))
|
||||||
|
# (batch * time, h, w, ch) -> (batch * time, ch, h, w)
|
||||||
x = x.permute(0, 3, 1, 2)
|
x = x.permute(0, 3, 1, 2)
|
||||||
x = self.layers(x)
|
x = self.layers(x)
|
||||||
# prod: product of all elements
|
# (batch * time, ...) -> (batch * time, -1)
|
||||||
x = x.reshape([x.shape[0], np.prod(x.shape[1:])])
|
x = x.reshape([x.shape[0], np.prod(x.shape[1:])])
|
||||||
shape = list(obs["image"].shape[:-3]) + [x.shape[-1]]
|
# (batch * time, -1) -> (batch, time, -1)
|
||||||
return x.reshape(shape)
|
return x.reshape(list(obs.shape[:-3]) + [x.shape[-1]])
|
||||||
|
|
||||||
|
|
||||||
class ConvDecoder(nn.Module):
|
class ConvDecoder(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
inp_depth,
|
inp_depth,
|
||||||
|
shape=(3, 64, 64),
|
||||||
depth=32,
|
depth=32,
|
||||||
act=nn.ELU,
|
act=nn.ELU,
|
||||||
norm=nn.LayerNorm,
|
norm=nn.LayerNorm,
|
||||||
shape=(3, 64, 64),
|
|
||||||
kernels=(3, 3, 3, 3),
|
kernels=(3, 3, 3, 3),
|
||||||
outscale=1.0,
|
outscale=1.0,
|
||||||
|
cnn_sigmoid=False,
|
||||||
):
|
):
|
||||||
super(ConvDecoder, self).__init__()
|
super(ConvDecoder, self).__init__()
|
||||||
self._inp_depth = inp_depth
|
self._inp_depth = inp_depth
|
||||||
self._act = act
|
act = getattr(torch.nn, act)
|
||||||
self._norm = norm
|
norm = getattr(torch.nn, norm)
|
||||||
self._depth = depth
|
self._depth = depth
|
||||||
self._shape = shape
|
self._shape = shape
|
||||||
self._kernels = kernels
|
self._kernels = kernels
|
||||||
|
self._cnn_sigmoid = cnn_sigmoid
|
||||||
self._embed_size = (
|
self._embed_size = (
|
||||||
(64 // 2 ** (len(kernels))) ** 2 * depth * 2 ** (len(kernels) - 1)
|
(64 // 2 ** (len(kernels))) ** 2 * depth * 2 ** (len(kernels) - 1)
|
||||||
)
|
)
|
||||||
@ -407,7 +541,6 @@ class ConvDecoder(nn.Module):
|
|||||||
h, w = 4, 4
|
h, w = 4, 4
|
||||||
for i, kernel in enumerate(self._kernels):
|
for i, kernel in enumerate(self._kernels):
|
||||||
depth = self._embed_size // 16 // (2 ** (i + 1))
|
depth = self._embed_size // 16 // (2 ** (i + 1))
|
||||||
act = self._act
|
|
||||||
bias = False
|
bias = False
|
||||||
initializer = tools.weight_init
|
initializer = tools.weight_init
|
||||||
if i == len(self._kernels) - 1:
|
if i == len(self._kernels) - 1:
|
||||||
@ -447,88 +580,125 @@ class ConvDecoder(nn.Module):
|
|||||||
outpad = pad * 2 - val
|
outpad = pad * 2 - val
|
||||||
return pad, outpad
|
return pad, outpad
|
||||||
|
|
||||||
def __call__(self, features, dtype=None):
|
def forward(self, features, dtype=None):
|
||||||
x = self._linear_layer(features)
|
x = self._linear_layer(features)
|
||||||
|
# (batch, time, -1) -> (batch * time, h, w, ch)
|
||||||
x = x.reshape([-1, 4, 4, self._embed_size // 16])
|
x = x.reshape([-1, 4, 4, self._embed_size // 16])
|
||||||
|
# (batch, time, -1) -> (batch * time, ch, h, w)
|
||||||
x = x.permute(0, 3, 1, 2)
|
x = x.permute(0, 3, 1, 2)
|
||||||
x = self.layers(x)
|
x = self.layers(x)
|
||||||
|
# (batch, time, -1) -> (batch * time, ch, h, w) necessary???
|
||||||
mean = x.reshape(features.shape[:-1] + self._shape)
|
mean = x.reshape(features.shape[:-1] + self._shape)
|
||||||
|
# (batch * time, ch, h, w) -> (batch * time, h, w, ch)
|
||||||
mean = mean.permute(0, 1, 3, 4, 2)
|
mean = mean.permute(0, 1, 3, 4, 2)
|
||||||
return tools.SymlogDist(mean)
|
if self._cnn_sigmoid:
|
||||||
|
mean = F.sigmoid(mean) - 0.5
|
||||||
|
return mean
|
||||||
|
|
||||||
|
|
||||||
class DenseHead(nn.Module):
|
class MLP(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
inp_dim,
|
inp_dim,
|
||||||
shape,
|
shape,
|
||||||
layers,
|
layers,
|
||||||
units,
|
units,
|
||||||
act=nn.ELU,
|
act="SiLU",
|
||||||
norm=nn.LayerNorm,
|
norm="LayerNorm",
|
||||||
dist="normal",
|
dist="normal",
|
||||||
std=1.0,
|
std=1.0,
|
||||||
outscale=1.0,
|
outscale=1.0,
|
||||||
|
symlog_inputs=False,
|
||||||
device="cuda",
|
device="cuda",
|
||||||
):
|
):
|
||||||
super(DenseHead, self).__init__()
|
super(MLP, self).__init__()
|
||||||
self._shape = (shape,) if isinstance(shape, int) else shape
|
self._shape = (shape,) if isinstance(shape, int) else shape
|
||||||
if len(self._shape) == 0:
|
if self._shape is not None and len(self._shape) == 0:
|
||||||
self._shape = (1,)
|
self._shape = (1,)
|
||||||
self._layers = layers
|
self._layers = layers
|
||||||
self._units = units
|
act = getattr(torch.nn, act)
|
||||||
self._act = act
|
norm = getattr(torch.nn, norm)
|
||||||
self._norm = norm
|
|
||||||
self._dist = dist
|
self._dist = dist
|
||||||
self._std = std
|
self._std = std
|
||||||
|
self._symlog_inputs = symlog_inputs
|
||||||
self._device = device
|
self._device = device
|
||||||
|
|
||||||
layers = []
|
layers = []
|
||||||
for index in range(self._layers):
|
for index in range(self._layers):
|
||||||
layers.append(nn.Linear(inp_dim, self._units, bias=False))
|
layers.append(nn.Linear(inp_dim, units, bias=False))
|
||||||
layers.append(norm(self._units, eps=1e-03))
|
layers.append(norm(units, eps=1e-03))
|
||||||
layers.append(act())
|
layers.append(act())
|
||||||
if index == 0:
|
if index == 0:
|
||||||
inp_dim = self._units
|
inp_dim = units
|
||||||
self.layers = nn.Sequential(*layers)
|
self.layers = nn.Sequential(*layers)
|
||||||
self.layers.apply(tools.weight_init)
|
self.layers.apply(tools.weight_init)
|
||||||
|
|
||||||
self.mean_layer = nn.Linear(inp_dim, np.prod(self._shape))
|
if isinstance(self._shape, dict):
|
||||||
self.mean_layer.apply(tools.uniform_weight_init(outscale))
|
self.mean_layer = nn.ModuleDict()
|
||||||
|
for name, shape in self._shape.items():
|
||||||
|
self.mean_layer[name] = nn.Linear(inp_dim, np.prod(shape))
|
||||||
|
self.mean_layer.apply(tools.uniform_weight_init(outscale))
|
||||||
|
if self._std == "learned":
|
||||||
|
self.std_layer = nn.ModuleDict()
|
||||||
|
for name, shape in self._shape.items():
|
||||||
|
self.std_layer[name] = nn.Linear(inp_dim, np.prod(shape))
|
||||||
|
self.std_layer.apply(tools.uniform_weight_init(outscale))
|
||||||
|
elif self._shape is not None:
|
||||||
|
self.mean_layer = nn.Linear(inp_dim, np.prod(self._shape))
|
||||||
|
self.mean_layer.apply(tools.uniform_weight_init(outscale))
|
||||||
|
if self._std == "learned":
|
||||||
|
self.std_layer = nn.Linear(units, np.prod(self._shape))
|
||||||
|
self.std_layer.apply(tools.uniform_weight_init(outscale))
|
||||||
|
|
||||||
if self._std == "learned":
|
def forward(self, features, dtype=None):
|
||||||
self.std_layer = nn.Linear(self._units, np.prod(self._shape))
|
|
||||||
self.std_layer.apply(tools.uniform_weight_init(outscale))
|
|
||||||
|
|
||||||
def __call__(self, features, dtype=None):
|
|
||||||
x = features
|
x = features
|
||||||
|
if self._symlog_inputs:
|
||||||
|
x = tools.symlog(x)
|
||||||
out = self.layers(x)
|
out = self.layers(x)
|
||||||
mean = self.mean_layer(out)
|
if self._shape is None:
|
||||||
if self._std == "learned":
|
return out
|
||||||
std = self.std_layer(out)
|
if isinstance(self._shape, dict):
|
||||||
|
dists = {}
|
||||||
|
for name, shape in self._shape.items():
|
||||||
|
mean = self.mean_layer[name](out)
|
||||||
|
if self._std == "learned":
|
||||||
|
std = self.std_layer[name](out)
|
||||||
|
else:
|
||||||
|
std = self._std
|
||||||
|
dists.update({name: self.dist(self._dist, mean, std, shape)})
|
||||||
|
return dists
|
||||||
else:
|
else:
|
||||||
std = self._std
|
mean = self.mean_layer(out)
|
||||||
if self._dist == "normal":
|
if self._std == "learned":
|
||||||
|
std = self.std_layer(out)
|
||||||
|
else:
|
||||||
|
std = self._std
|
||||||
|
return self.dist(self._dist, mean, std, self._shape)
|
||||||
|
|
||||||
|
def dist(self, dist, mean, std, shape):
|
||||||
|
if dist == "normal":
|
||||||
return tools.ContDist(
|
return tools.ContDist(
|
||||||
torchd.independent.Independent(
|
torchd.independent.Independent(
|
||||||
torchd.normal.Normal(mean, std), len(self._shape)
|
torchd.normal.Normal(mean, std), len(shape)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
if self._dist == "huber":
|
if dist == "huber":
|
||||||
return tools.ContDist(
|
return tools.ContDist(
|
||||||
torchd.independent.Independent(
|
torchd.independent.Independent(
|
||||||
tools.UnnormalizedHuber(mean, std, 1.0), len(self._shape)
|
tools.UnnormalizedHuber(mean, std, 1.0), len(shape)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
if self._dist == "binary":
|
if dist == "binary":
|
||||||
return tools.Bernoulli(
|
return tools.Bernoulli(
|
||||||
torchd.independent.Independent(
|
torchd.independent.Independent(
|
||||||
torchd.bernoulli.Bernoulli(logits=mean), len(self._shape)
|
torchd.bernoulli.Bernoulli(logits=mean), len(shape)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
if self._dist == "twohot_symlog":
|
if dist == "symlog_disc":
|
||||||
return tools.TwoHotDistSymlog(logits=mean, device=self._device)
|
return tools.DiscDist(logits=mean, device=self._device)
|
||||||
raise NotImplementedError(self._dist)
|
if dist == "symlog_mse":
|
||||||
|
return tools.SymlogDist(mean)
|
||||||
|
raise NotImplementedError(dist)
|
||||||
|
|
||||||
|
|
||||||
class ActionHead(nn.Module):
|
class ActionHead(nn.Module):
|
||||||
@ -553,8 +723,8 @@ class ActionHead(nn.Module):
|
|||||||
self._layers = layers
|
self._layers = layers
|
||||||
self._units = units
|
self._units = units
|
||||||
self._dist = dist
|
self._dist = dist
|
||||||
self._act = act
|
act = getattr(torch.nn, act)
|
||||||
self._norm = norm
|
norm = getattr(torch.nn, norm)
|
||||||
self._min_std = min_std
|
self._min_std = min_std
|
||||||
self._max_std = max_std
|
self._max_std = max_std
|
||||||
self._init_std = init_std
|
self._init_std = init_std
|
||||||
@ -579,7 +749,7 @@ class ActionHead(nn.Module):
|
|||||||
self._dist_layer = nn.Linear(self._units, self._size)
|
self._dist_layer = nn.Linear(self._units, self._size)
|
||||||
self._dist_layer.apply(tools.uniform_weight_init(outscale))
|
self._dist_layer.apply(tools.uniform_weight_init(outscale))
|
||||||
|
|
||||||
def __call__(self, features, dtype=None):
|
def forward(self, features, dtype=None):
|
||||||
x = features
|
x = features
|
||||||
x = self._pre_layers(x)
|
x = self._pre_layers(x)
|
||||||
if self._dist == "tanh_normal":
|
if self._dist == "tanh_normal":
|
||||||
|
52
tools.py
52
tools.py
@ -320,24 +320,34 @@ class OneHotDist(torchd.one_hot_categorical.OneHotCategorical):
|
|||||||
return sample
|
return sample
|
||||||
|
|
||||||
|
|
||||||
class TwoHotDistSymlog:
|
class DiscDist:
|
||||||
def __init__(self, logits=None, low=-20.0, high=20.0, device="cuda"):
|
def __init__(
|
||||||
|
self,
|
||||||
|
logits,
|
||||||
|
low=-20.0,
|
||||||
|
high=20.0,
|
||||||
|
transfwd=symlog,
|
||||||
|
transbwd=symexp,
|
||||||
|
device="cuda",
|
||||||
|
):
|
||||||
self.logits = logits
|
self.logits = logits
|
||||||
self.probs = torch.softmax(logits, -1)
|
self.probs = torch.softmax(logits, -1)
|
||||||
self.buckets = torch.linspace(low, high, steps=255).to(device)
|
self.buckets = torch.linspace(low, high, steps=255).to(device)
|
||||||
self.width = (self.buckets[-1] - self.buckets[0]) / 255
|
self.width = (self.buckets[-1] - self.buckets[0]) / 255
|
||||||
|
self.transfwd = transfwd
|
||||||
|
self.transbwd = transbwd
|
||||||
|
|
||||||
def mean(self):
|
def mean(self):
|
||||||
_mean = self.probs * self.buckets
|
_mean = self.probs * self.buckets
|
||||||
return symexp(torch.sum(_mean, dim=-1, keepdim=True))
|
return self.transbwd(torch.sum(_mean, dim=-1, keepdim=True))
|
||||||
|
|
||||||
def mode(self):
|
def mode(self):
|
||||||
_mode = self.probs * self.buckets
|
_mode = self.probs * self.buckets
|
||||||
return symexp(torch.sum(_mode, dim=-1, keepdim=True))
|
return self.transbwd(torch.sum(_mode, dim=-1, keepdim=True))
|
||||||
|
|
||||||
# Inside OneHotCategorical, log_prob is calculated using only max element in targets
|
# Inside OneHotCategorical, log_prob is calculated using only max element in targets
|
||||||
def log_prob(self, x):
|
def log_prob(self, x):
|
||||||
x = symlog(x)
|
x = self.transfwd(x)
|
||||||
# x(time, batch, 1)
|
# x(time, batch, 1)
|
||||||
below = torch.sum((self.buckets <= x[..., None]).to(torch.int32), dim=-1) - 1
|
below = torch.sum((self.buckets <= x[..., None]).to(torch.int32), dim=-1) - 1
|
||||||
above = len(self.buckets) - torch.sum(
|
above = len(self.buckets) - torch.sum(
|
||||||
@ -366,15 +376,35 @@ class TwoHotDistSymlog:
|
|||||||
return (target * log_pred).sum(-1)
|
return (target * log_pred).sum(-1)
|
||||||
|
|
||||||
|
|
||||||
|
class MSEDist:
|
||||||
|
def __init__(self, mode, agg="sum"):
|
||||||
|
self._mode = mode
|
||||||
|
self._agg = agg
|
||||||
|
|
||||||
|
def mode(self):
|
||||||
|
return self._mode
|
||||||
|
|
||||||
|
def mean(self):
|
||||||
|
return self._mode
|
||||||
|
|
||||||
|
def log_prob(self, value):
|
||||||
|
assert self._mode.shape == value.shape, (self._mode.shape, value.shape)
|
||||||
|
distance = (self._mode - value) ** 2
|
||||||
|
if self._agg == "mean":
|
||||||
|
loss = distance.mean(list(range(len(distance.shape)))[2:])
|
||||||
|
elif self._agg == "sum":
|
||||||
|
loss = distance.sum(list(range(len(distance.shape)))[2:])
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(self._agg)
|
||||||
|
return -loss
|
||||||
|
|
||||||
|
|
||||||
class SymlogDist:
|
class SymlogDist:
|
||||||
def __init__(
|
def __init__(self, mode, dist="mse", agg="sum", tol=1e-8):
|
||||||
self, mode, dist="mse", agg="sum", tol=1e-8, dim_to_reduce=[-1, -2, -3]
|
|
||||||
):
|
|
||||||
self._mode = mode
|
self._mode = mode
|
||||||
self._dist = dist
|
self._dist = dist
|
||||||
self._agg = agg
|
self._agg = agg
|
||||||
self._tol = tol
|
self._tol = tol
|
||||||
self._dim_to_reduce = dim_to_reduce
|
|
||||||
|
|
||||||
def mode(self):
|
def mode(self):
|
||||||
return symexp(self._mode)
|
return symexp(self._mode)
|
||||||
@ -393,9 +423,9 @@ class SymlogDist:
|
|||||||
else:
|
else:
|
||||||
raise NotImplementedError(self._dist)
|
raise NotImplementedError(self._dist)
|
||||||
if self._agg == "mean":
|
if self._agg == "mean":
|
||||||
loss = distance.mean(self._dim_to_reduce)
|
loss = distance.mean(list(range(len(distance.shape)))[2:])
|
||||||
elif self._agg == "sum":
|
elif self._agg == "sum":
|
||||||
loss = distance.sum(self._dim_to_reduce)
|
loss = distance.sum(list(range(len(distance.shape)))[2:])
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(self._agg)
|
raise NotImplementedError(self._agg)
|
||||||
return -loss
|
return -loss
|
||||||
|
Loading…
x
Reference in New Issue
Block a user