Merge pull request #36 from WeberSamuel/main

Fix Plan2Explore
This commit is contained in:
NM512 2023-08-22 22:08:28 +09:00 committed by GitHub
commit 1442ee82ed
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -58,9 +58,9 @@ class Plan2Explore(nn.Module):
"feat": config.dyn_stoch + config.dyn_deter, "feat": config.dyn_stoch + config.dyn_deter,
}[self._config.disag_target] }[self._config.disag_target]
kw = dict( kw = dict(
inp_dim=feat_size + config.num_actions inp_dim=feat_size + (config.num_actions
if config.disag_action_cond if config.disag_action_cond
else 0, # pytorch version else 0), # pytorch version
shape=size, shape=size,
layers=config.disag_layers, layers=config.disag_layers,
units=config.disag_units, units=config.disag_units,
@ -72,7 +72,7 @@ class Plan2Explore(nn.Module):
kw = dict(wd=config.weight_decay, opt=config.opt, use_amp=self._use_amp) kw = dict(wd=config.weight_decay, opt=config.opt, use_amp=self._use_amp)
self._model_opt = tools.Optimizer( self._model_opt = tools.Optimizer(
"explorer", "explorer",
self.parameters(), self._networks.parameters(),
config.model_lr, config.model_lr,
config.opt_eps, config.opt_eps,
config.grad_clip, config.grad_clip,
@ -80,7 +80,7 @@ class Plan2Explore(nn.Module):
) )
def train(self, start, context, data): def train(self, start, context, data):
with tools.RequiresGrad(self): with tools.RequiresGrad(self._networks):
metrics = {} metrics = {}
stoch = start["stoch"] stoch = start["stoch"]
if self._config.dyn_discrete: if self._config.dyn_discrete:
@ -99,7 +99,7 @@ class Plan2Explore(nn.Module):
[inputs, torch.Tensor(data["action"]).to(self._config.device)], -1 [inputs, torch.Tensor(data["action"]).to(self._config.device)], -1
) )
metrics.update(self._train_ensemble(inputs, target)) metrics.update(self._train_ensemble(inputs, target))
metrics.update(self._behavior._train(start, self._intrinsic_reward)[-1]) metrics.update(self._behavior._train(start, self._intrinsic_reward)[-1])
return None, metrics return None, metrics
def _intrinsic_reward(self, feat, state, action): def _intrinsic_reward(self, feat, state, action):
@ -129,5 +129,5 @@ class Plan2Explore(nn.Module):
[torch.mean(pred.log_prob(targets))[None] for pred in preds], 0 [torch.mean(pred.log_prob(targets))[None] for pred in preds], 0
) )
loss = -torch.mean(likes) loss = -torch.mean(likes)
metrics = self._model_opt(loss, self.parameters()) metrics = self._model_opt(loss, self._networks.parameters())
return metrics return metrics