commit
1442ee82ed
@ -58,9 +58,9 @@ class Plan2Explore(nn.Module):
|
|||||||
"feat": config.dyn_stoch + config.dyn_deter,
|
"feat": config.dyn_stoch + config.dyn_deter,
|
||||||
}[self._config.disag_target]
|
}[self._config.disag_target]
|
||||||
kw = dict(
|
kw = dict(
|
||||||
inp_dim=feat_size + config.num_actions
|
inp_dim=feat_size + (config.num_actions
|
||||||
if config.disag_action_cond
|
if config.disag_action_cond
|
||||||
else 0, # pytorch version
|
else 0), # pytorch version
|
||||||
shape=size,
|
shape=size,
|
||||||
layers=config.disag_layers,
|
layers=config.disag_layers,
|
||||||
units=config.disag_units,
|
units=config.disag_units,
|
||||||
@ -72,7 +72,7 @@ class Plan2Explore(nn.Module):
|
|||||||
kw = dict(wd=config.weight_decay, opt=config.opt, use_amp=self._use_amp)
|
kw = dict(wd=config.weight_decay, opt=config.opt, use_amp=self._use_amp)
|
||||||
self._model_opt = tools.Optimizer(
|
self._model_opt = tools.Optimizer(
|
||||||
"explorer",
|
"explorer",
|
||||||
self.parameters(),
|
self._networks.parameters(),
|
||||||
config.model_lr,
|
config.model_lr,
|
||||||
config.opt_eps,
|
config.opt_eps,
|
||||||
config.grad_clip,
|
config.grad_clip,
|
||||||
@ -80,7 +80,7 @@ class Plan2Explore(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def train(self, start, context, data):
|
def train(self, start, context, data):
|
||||||
with tools.RequiresGrad(self):
|
with tools.RequiresGrad(self._networks):
|
||||||
metrics = {}
|
metrics = {}
|
||||||
stoch = start["stoch"]
|
stoch = start["stoch"]
|
||||||
if self._config.dyn_discrete:
|
if self._config.dyn_discrete:
|
||||||
@ -99,7 +99,7 @@ class Plan2Explore(nn.Module):
|
|||||||
[inputs, torch.Tensor(data["action"]).to(self._config.device)], -1
|
[inputs, torch.Tensor(data["action"]).to(self._config.device)], -1
|
||||||
)
|
)
|
||||||
metrics.update(self._train_ensemble(inputs, target))
|
metrics.update(self._train_ensemble(inputs, target))
|
||||||
metrics.update(self._behavior._train(start, self._intrinsic_reward)[-1])
|
metrics.update(self._behavior._train(start, self._intrinsic_reward)[-1])
|
||||||
return None, metrics
|
return None, metrics
|
||||||
|
|
||||||
def _intrinsic_reward(self, feat, state, action):
|
def _intrinsic_reward(self, feat, state, action):
|
||||||
@ -129,5 +129,5 @@ class Plan2Explore(nn.Module):
|
|||||||
[torch.mean(pred.log_prob(targets))[None] for pred in preds], 0
|
[torch.mean(pred.log_prob(targets))[None] for pred in preds], 0
|
||||||
)
|
)
|
||||||
loss = -torch.mean(likes)
|
loss = -torch.mean(likes)
|
||||||
metrics = self._model_opt(loss, self.parameters())
|
metrics = self._model_opt(loss, self._networks.parameters())
|
||||||
return metrics
|
return metrics
|
||||||
|
Loading…
x
Reference in New Issue
Block a user