replaced all tf function to torch

This commit is contained in:
NM512 2023-04-03 08:06:34 +09:00
parent 8bd69bfcd4
commit 57ac1c11d3
3 changed files with 56 additions and 45 deletions

View File

@ -62,7 +62,7 @@ class Dreamer(nn.Module):
greedy=lambda: self._task_behavior, greedy=lambda: self._task_behavior,
random=lambda: expl.Random(config), random=lambda: expl.Random(config),
plan2explore=lambda: expl.Plan2Explore(config, self._wm, reward), plan2explore=lambda: expl.Plan2Explore(config, self._wm, reward),
)[config.expl_behavior]() )[config.expl_behavior]().to(self._config.device)
def __call__(self, obs, reset, state=None, reward=None, training=True): def __call__(self, obs, reset, state=None, reward=None, training=True):
step = self._step step = self._step

View File

@ -26,83 +26,94 @@ class Random(nn.Module):
# class Plan2Explore(tools.Module): # class Plan2Explore(tools.Module):
class Plan2Explore(nn.Module): class Plan2Explore(nn.Module):
def __init__(self, config, world_model, reward=None): def __init__(self, config, world_model, reward=None):
super(Plan2Explore, self).__init__()
self._config = config self._config = config
self._use_amp = True if config.precision == 16 else False
self._reward = reward self._reward = reward
self._behavior = models.ImagBehavior(config, world_model) self._behavior = models.ImagBehavior(config, world_model)
self.actor = self._behavior.actor self.actor = self._behavior.actor
stoch_size = config.dyn_stoch
if config.dyn_discrete: if config.dyn_discrete:
stoch_size *= config.dyn_discrete feat_size = config.dyn_stoch * config.dyn_discrete + config.dyn_deter
stoch = config.dyn_stoch * config.dyn_discrete
else:
feat_size = config.dyn_stoch + config.dyn_deter
stoch = config.dyn_stoch
size = { size = {
"embed": 32 * config.cnn_depth, "embed": 32 * config.cnn_depth,
"stoch": stoch_size, "stoch": stoch,
"deter": config.dyn_deter, "deter": config.dyn_deter,
"feat": config.dyn_stoch + config.dyn_deter, "feat": config.dyn_stoch + config.dyn_deter,
}[self._config.disag_target] }[self._config.disag_target]
kw = dict( kw = dict(
inp_dim=config.dyn_stoch, # pytorch version inp_dim=feat_size, # pytorch version
shape=size, shape=size,
layers=config.disag_layers, layers=config.disag_layers,
units=config.disag_units, units=config.disag_units,
act=config.act, act=config.act,
) )
self._networks = [networks.DenseHead(**kw) for _ in range(config.disag_models)] self._networks = nn.ModuleList(
self._opt = tools.optimizer( [networks.DenseHead(**kw) for _ in range(config.disag_models)]
config.opt, )
kw = dict(wd=config.weight_decay, opt=config.opt, use_amp=self._use_amp)
self._model_opt = tools.Optimizer(
"explorer",
self.parameters(), self.parameters(),
config.model_lr, config.model_lr,
config.opt_eps, config.opt_eps,
config.weight_decay, config.grad_clip,
**kw
) )
# self._opt = tools.Optimizer(
# 'ensemble', config.model_lr, config.opt_eps, config.grad_clip,
# config.weight_decay, opt=config.opt)
def train(self, start, context, data): def train(self, start, context, data):
metrics = {} with tools.RequiresGrad(self):
stoch = start["stoch"] metrics = {}
if self._config.dyn_discrete: stoch = start["stoch"]
stoch = tf.reshape( if self._config.dyn_discrete:
stoch, stoch.shape[:-2] + (stoch.shape[-2] * stoch.shape[-1]) stoch = torch.reshape(
) stoch, (stoch.shape[:-2] + ((stoch.shape[-2] * stoch.shape[-1]),))
target = { )
"embed": context["embed"], target = {
"stoch": stoch, "embed": context["embed"],
"deter": start["deter"], "stoch": stoch,
"feat": context["feat"], "deter": start["deter"],
}[self._config.disag_target] "feat": context["feat"],
inputs = context["feat"] }[self._config.disag_target]
if self._config.disag_action_cond: inputs = context["feat"]
inputs = tf.concat([inputs, data["action"]], -1) if self._config.disag_action_cond:
metrics.update(self._train_ensemble(inputs, target)) inputs = torch.concat([inputs, data["action"]], -1)
metrics.update(self._behavior.train(start, self._intrinsic_reward)[-1]) metrics.update(self._train_ensemble(inputs, target))
metrics.update(self._behavior._train(start, self._intrinsic_reward)[-1])
return None, metrics return None, metrics
def _intrinsic_reward(self, feat, state, action): def _intrinsic_reward(self, feat, state, action):
inputs = feat inputs = feat
if self._config.disag_action_cond: if self._config.disag_action_cond:
inputs = tf.concat([inputs, action], -1) inputs = torch.concat([inputs, action], -1)
preds = [head(inputs, tf.float32).mean() for head in self._networks] preds = torch.cat(
disag = tf.reduce_mean(tf.math.reduce_std(preds, 0), -1) [head(inputs, torch.float32).mode()[None] for head in self._networks], 0
)
disag = torch.mean(torch.std(preds, 0), -1)[..., None]
if self._config.disag_log: if self._config.disag_log:
disag = tf.math.log(disag) disag = torch.log(disag)
reward = self._config.expl_intr_scale * disag reward = self._config.expl_intr_scale * disag
if self._config.expl_extr_scale: if self._config.expl_extr_scale:
reward += tf.cast( reward += torch.cast(
self._config.expl_extr_scale * self._reward(feat, state, action), self._config.expl_extr_scale * self._reward(feat, state, action),
tf.float32, torch.float32,
) )
return reward return reward
def _train_ensemble(self, inputs, targets): def _train_ensemble(self, inputs, targets):
if self._config.disag_offset: with torch.cuda.amp.autocast(self._use_amp):
targets = targets[:, self._config.disag_offset :] if self._config.disag_offset:
inputs = inputs[:, : -self._config.disag_offset] targets = targets[:, self._config.disag_offset :]
targets = tf.stop_gradient(targets) inputs = inputs[:, : -self._config.disag_offset]
inputs = tf.stop_gradient(inputs) targets = targets.detach()
with tf.GradientTape() as tape: inputs = inputs.detach()
preds = [head(inputs) for head in self._networks] preds = [head(inputs) for head in self._networks]
likes = [tf.reduce_mean(pred.log_prob(targets)) for pred in preds] likes = torch.cat(
loss = -tf.cast(tf.reduce_sum(likes), tf.float32) [torch.mean(pred.log_prob(targets))[None] for pred in preds], 0
metrics = self._opt(tape, loss, self._networks) )
loss = -torch.mean(likes)
metrics = self._model_opt(loss, self.parameters())
return metrics return metrics

View File

@ -548,7 +548,7 @@ class Optimizer():
lr=lr, lr=lr,
eps=eps), eps=eps),
'nadam': lambda: NotImplemented( 'nadam': lambda: NotImplemented(
f'{config.opt} is not implemented'), f'{opt} is not implemented'),
'adamax': lambda: torch.optim.Adamax(parameters, 'adamax': lambda: torch.optim.Adamax(parameters,
lr=lr, lr=lr,
eps=eps), eps=eps),