dreamerv3-torch/exploration.py
2024-01-05 23:23:09 +09:00

135 lines
4.8 KiB
Python

import torch
from torch import nn
from torch import distributions as torchd
import models
import networks
import tools
class Random(nn.Module):
def __init__(self, config, act_space):
super(Random, self).__init__()
self._config = config
self._act_space = act_space
def actor(self, feat):
if self._config.actor["dist"] == "onehot":
return tools.OneHotDist(
torch.zeros(self._config.num_actions)
.repeat(self._config.envs, 1)
.to(self._config.device)
)
else:
return torchd.independent.Independent(
torchd.uniform.Uniform(
torch.Tensor(self._act_space.low)
.repeat(self._config.envs, 1)
.to(self._config.device),
torch.Tensor(self._act_space.high)
.repeat(self._config.envs, 1)
.to(self._config.device),
),
1,
)
def train(self, start, context, data):
return None, {}
class Plan2Explore(nn.Module):
def __init__(self, config, world_model, reward):
super(Plan2Explore, self).__init__()
self._config = config
self._use_amp = True if config.precision == 16 else False
self._reward = reward
self._behavior = models.ImagBehavior(config, world_model)
self.actor = self._behavior.actor
if config.dyn_discrete:
feat_size = config.dyn_stoch * config.dyn_discrete + config.dyn_deter
stoch = config.dyn_stoch * config.dyn_discrete
else:
feat_size = config.dyn_stoch + config.dyn_deter
stoch = config.dyn_stoch
size = {
"embed": world_model.embed_size,
"stoch": stoch,
"deter": config.dyn_deter,
"feat": config.dyn_stoch + config.dyn_deter,
}[self._config.disag_target]
kw = dict(
inp_dim=feat_size
+ (
config.num_actions if config.disag_action_cond else 0
), # pytorch version
shape=size,
layers=config.disag_layers,
units=config.disag_units,
act=config.act,
)
self._networks = nn.ModuleList(
[networks.MLP(**kw) for _ in range(config.disag_models)]
)
kw = dict(wd=config.weight_decay, opt=config.opt, use_amp=self._use_amp)
self._expl_opt = tools.Optimizer(
"explorer",
self._networks.parameters(),
config.model_lr,
config.opt_eps,
config.grad_clip,
**kw
)
def train(self, start, context, data):
with tools.RequiresGrad(self._networks):
metrics = {}
stoch = start["stoch"]
if self._config.dyn_discrete:
stoch = torch.reshape(
stoch, (stoch.shape[:-2] + ((stoch.shape[-2] * stoch.shape[-1]),))
)
target = {
"embed": context["embed"],
"stoch": stoch,
"deter": start["deter"],
"feat": context["feat"],
}[self._config.disag_target]
inputs = context["feat"]
if self._config.disag_action_cond:
inputs = torch.concat(
[inputs, torch.Tensor(data["action"]).to(self._config.device)], -1
)
metrics.update(self._train_ensemble(inputs, target))
metrics.update(self._behavior._train(start, self._intrinsic_reward)[-1])
return None, metrics
def _intrinsic_reward(self, feat, state, action):
inputs = feat
if self._config.disag_action_cond:
inputs = torch.concat([inputs, action], -1)
preds = torch.cat(
[head(inputs, torch.float32).mode()[None] for head in self._networks], 0
)
disag = torch.mean(torch.std(preds, 0), -1)[..., None]
if self._config.disag_log:
disag = torch.log(disag)
reward = self._config.expl_intr_scale * disag
if self._config.expl_extr_scale:
reward += self._config.expl_extr_scale * self._reward(feat, state, action)
return reward
def _train_ensemble(self, inputs, targets):
with torch.cuda.amp.autocast(self._use_amp):
if self._config.disag_offset:
targets = targets[:, self._config.disag_offset :]
inputs = inputs[:, : -self._config.disag_offset]
targets = targets.detach()
inputs = inputs.detach()
preds = [head(inputs) for head in self._networks]
likes = torch.cat(
[torch.mean(pred.log_prob(targets))[None] for pred in preds], 0
)
loss = -torch.mean(likes)
metrics = self._expl_opt(loss, self._networks.parameters())
return metrics