replaced all tf function to torch
This commit is contained in:
parent
8bd69bfcd4
commit
57ac1c11d3
@ -62,7 +62,7 @@ class Dreamer(nn.Module):
|
|||||||
greedy=lambda: self._task_behavior,
|
greedy=lambda: self._task_behavior,
|
||||||
random=lambda: expl.Random(config),
|
random=lambda: expl.Random(config),
|
||||||
plan2explore=lambda: expl.Plan2Explore(config, self._wm, reward),
|
plan2explore=lambda: expl.Plan2Explore(config, self._wm, reward),
|
||||||
)[config.expl_behavior]()
|
)[config.expl_behavior]().to(self._config.device)
|
||||||
|
|
||||||
def __call__(self, obs, reset, state=None, reward=None, training=True):
|
def __call__(self, obs, reset, state=None, reward=None, training=True):
|
||||||
step = self._step
|
step = self._step
|
||||||
|
@ -26,83 +26,94 @@ class Random(nn.Module):
|
|||||||
# class Plan2Explore(tools.Module):
|
# class Plan2Explore(tools.Module):
|
||||||
class Plan2Explore(nn.Module):
|
class Plan2Explore(nn.Module):
|
||||||
def __init__(self, config, world_model, reward=None):
|
def __init__(self, config, world_model, reward=None):
|
||||||
|
super(Plan2Explore, self).__init__()
|
||||||
self._config = config
|
self._config = config
|
||||||
|
self._use_amp = True if config.precision == 16 else False
|
||||||
self._reward = reward
|
self._reward = reward
|
||||||
self._behavior = models.ImagBehavior(config, world_model)
|
self._behavior = models.ImagBehavior(config, world_model)
|
||||||
self.actor = self._behavior.actor
|
self.actor = self._behavior.actor
|
||||||
stoch_size = config.dyn_stoch
|
|
||||||
if config.dyn_discrete:
|
if config.dyn_discrete:
|
||||||
stoch_size *= config.dyn_discrete
|
feat_size = config.dyn_stoch * config.dyn_discrete + config.dyn_deter
|
||||||
|
stoch = config.dyn_stoch * config.dyn_discrete
|
||||||
|
else:
|
||||||
|
feat_size = config.dyn_stoch + config.dyn_deter
|
||||||
|
stoch = config.dyn_stoch
|
||||||
size = {
|
size = {
|
||||||
"embed": 32 * config.cnn_depth,
|
"embed": 32 * config.cnn_depth,
|
||||||
"stoch": stoch_size,
|
"stoch": stoch,
|
||||||
"deter": config.dyn_deter,
|
"deter": config.dyn_deter,
|
||||||
"feat": config.dyn_stoch + config.dyn_deter,
|
"feat": config.dyn_stoch + config.dyn_deter,
|
||||||
}[self._config.disag_target]
|
}[self._config.disag_target]
|
||||||
kw = dict(
|
kw = dict(
|
||||||
inp_dim=config.dyn_stoch, # pytorch version
|
inp_dim=feat_size, # pytorch version
|
||||||
shape=size,
|
shape=size,
|
||||||
layers=config.disag_layers,
|
layers=config.disag_layers,
|
||||||
units=config.disag_units,
|
units=config.disag_units,
|
||||||
act=config.act,
|
act=config.act,
|
||||||
)
|
)
|
||||||
self._networks = [networks.DenseHead(**kw) for _ in range(config.disag_models)]
|
self._networks = nn.ModuleList(
|
||||||
self._opt = tools.optimizer(
|
[networks.DenseHead(**kw) for _ in range(config.disag_models)]
|
||||||
config.opt,
|
)
|
||||||
|
kw = dict(wd=config.weight_decay, opt=config.opt, use_amp=self._use_amp)
|
||||||
|
self._model_opt = tools.Optimizer(
|
||||||
|
"explorer",
|
||||||
self.parameters(),
|
self.parameters(),
|
||||||
config.model_lr,
|
config.model_lr,
|
||||||
config.opt_eps,
|
config.opt_eps,
|
||||||
config.weight_decay,
|
config.grad_clip,
|
||||||
|
**kw
|
||||||
)
|
)
|
||||||
# self._opt = tools.Optimizer(
|
|
||||||
# 'ensemble', config.model_lr, config.opt_eps, config.grad_clip,
|
|
||||||
# config.weight_decay, opt=config.opt)
|
|
||||||
|
|
||||||
def train(self, start, context, data):
|
def train(self, start, context, data):
|
||||||
metrics = {}
|
with tools.RequiresGrad(self):
|
||||||
stoch = start["stoch"]
|
metrics = {}
|
||||||
if self._config.dyn_discrete:
|
stoch = start["stoch"]
|
||||||
stoch = tf.reshape(
|
if self._config.dyn_discrete:
|
||||||
stoch, stoch.shape[:-2] + (stoch.shape[-2] * stoch.shape[-1])
|
stoch = torch.reshape(
|
||||||
)
|
stoch, (stoch.shape[:-2] + ((stoch.shape[-2] * stoch.shape[-1]),))
|
||||||
target = {
|
)
|
||||||
"embed": context["embed"],
|
target = {
|
||||||
"stoch": stoch,
|
"embed": context["embed"],
|
||||||
"deter": start["deter"],
|
"stoch": stoch,
|
||||||
"feat": context["feat"],
|
"deter": start["deter"],
|
||||||
}[self._config.disag_target]
|
"feat": context["feat"],
|
||||||
inputs = context["feat"]
|
}[self._config.disag_target]
|
||||||
if self._config.disag_action_cond:
|
inputs = context["feat"]
|
||||||
inputs = tf.concat([inputs, data["action"]], -1)
|
if self._config.disag_action_cond:
|
||||||
metrics.update(self._train_ensemble(inputs, target))
|
inputs = torch.concat([inputs, data["action"]], -1)
|
||||||
metrics.update(self._behavior.train(start, self._intrinsic_reward)[-1])
|
metrics.update(self._train_ensemble(inputs, target))
|
||||||
|
metrics.update(self._behavior._train(start, self._intrinsic_reward)[-1])
|
||||||
return None, metrics
|
return None, metrics
|
||||||
|
|
||||||
def _intrinsic_reward(self, feat, state, action):
|
def _intrinsic_reward(self, feat, state, action):
|
||||||
inputs = feat
|
inputs = feat
|
||||||
if self._config.disag_action_cond:
|
if self._config.disag_action_cond:
|
||||||
inputs = tf.concat([inputs, action], -1)
|
inputs = torch.concat([inputs, action], -1)
|
||||||
preds = [head(inputs, tf.float32).mean() for head in self._networks]
|
preds = torch.cat(
|
||||||
disag = tf.reduce_mean(tf.math.reduce_std(preds, 0), -1)
|
[head(inputs, torch.float32).mode()[None] for head in self._networks], 0
|
||||||
|
)
|
||||||
|
disag = torch.mean(torch.std(preds, 0), -1)[..., None]
|
||||||
if self._config.disag_log:
|
if self._config.disag_log:
|
||||||
disag = tf.math.log(disag)
|
disag = torch.log(disag)
|
||||||
reward = self._config.expl_intr_scale * disag
|
reward = self._config.expl_intr_scale * disag
|
||||||
if self._config.expl_extr_scale:
|
if self._config.expl_extr_scale:
|
||||||
reward += tf.cast(
|
reward += torch.cast(
|
||||||
self._config.expl_extr_scale * self._reward(feat, state, action),
|
self._config.expl_extr_scale * self._reward(feat, state, action),
|
||||||
tf.float32,
|
torch.float32,
|
||||||
)
|
)
|
||||||
return reward
|
return reward
|
||||||
|
|
||||||
def _train_ensemble(self, inputs, targets):
|
def _train_ensemble(self, inputs, targets):
|
||||||
if self._config.disag_offset:
|
with torch.cuda.amp.autocast(self._use_amp):
|
||||||
targets = targets[:, self._config.disag_offset :]
|
if self._config.disag_offset:
|
||||||
inputs = inputs[:, : -self._config.disag_offset]
|
targets = targets[:, self._config.disag_offset :]
|
||||||
targets = tf.stop_gradient(targets)
|
inputs = inputs[:, : -self._config.disag_offset]
|
||||||
inputs = tf.stop_gradient(inputs)
|
targets = targets.detach()
|
||||||
with tf.GradientTape() as tape:
|
inputs = inputs.detach()
|
||||||
preds = [head(inputs) for head in self._networks]
|
preds = [head(inputs) for head in self._networks]
|
||||||
likes = [tf.reduce_mean(pred.log_prob(targets)) for pred in preds]
|
likes = torch.cat(
|
||||||
loss = -tf.cast(tf.reduce_sum(likes), tf.float32)
|
[torch.mean(pred.log_prob(targets))[None] for pred in preds], 0
|
||||||
metrics = self._opt(tape, loss, self._networks)
|
)
|
||||||
|
loss = -torch.mean(likes)
|
||||||
|
metrics = self._model_opt(loss, self.parameters())
|
||||||
return metrics
|
return metrics
|
||||||
|
2
tools.py
2
tools.py
@ -548,7 +548,7 @@ class Optimizer():
|
|||||||
lr=lr,
|
lr=lr,
|
||||||
eps=eps),
|
eps=eps),
|
||||||
'nadam': lambda: NotImplemented(
|
'nadam': lambda: NotImplemented(
|
||||||
f'{config.opt} is not implemented'),
|
f'{opt} is not implemented'),
|
||||||
'adamax': lambda: torch.optim.Adamax(parameters,
|
'adamax': lambda: torch.optim.Adamax(parameters,
|
||||||
lr=lr,
|
lr=lr,
|
||||||
eps=eps),
|
eps=eps),
|
||||||
|
Loading…
x
Reference in New Issue
Block a user