Fixes for Plan2Explore with actions and for windows.

This commit is contained in:
ktolnos 2023-06-05 22:50:12 +03:00
parent 308eeff09a
commit b07badeee6
2 changed files with 5 additions and 8 deletions

View File

@ -55,10 +55,10 @@ class Dreamer(nn.Module):
self._task_behavior = models.ImagBehavior( self._task_behavior = models.ImagBehavior(
config, self._wm, config.behavior_stop_grad config, self._wm, config.behavior_stop_grad
) )
if config.compile: if config.compile and os.name != 'nt': # compilation is not supported on windows
self._wm = torch.compile(self._wm) self._wm = torch.compile(self._wm)
self._task_behavior = torch.compile(self._task_behavior) self._task_behavior = torch.compile(self._task_behavior)
reward = lambda f, s, a: self._wm.heads["reward"](f).mean reward = lambda f, s, a: self._wm.heads["reward"](f).mean()
self._expl_behavior = dict( self._expl_behavior = dict(
greedy=lambda: self._task_behavior, greedy=lambda: self._task_behavior,
random=lambda: expl.Random(config, act_space), random=lambda: expl.Random(config, act_space),

View File

@ -58,7 +58,7 @@ class Plan2Explore(nn.Module):
"feat": config.dyn_stoch + config.dyn_deter, "feat": config.dyn_stoch + config.dyn_deter,
}[self._config.disag_target] }[self._config.disag_target]
kw = dict( kw = dict(
inp_dim=feat_size, # pytorch version inp_dim=feat_size + config.num_actions if config.disag_action_cond else 0, # pytorch version
shape=size, shape=size,
layers=config.disag_layers, layers=config.disag_layers,
units=config.disag_units, units=config.disag_units,
@ -93,7 +93,7 @@ class Plan2Explore(nn.Module):
}[self._config.disag_target] }[self._config.disag_target]
inputs = context["feat"] inputs = context["feat"]
if self._config.disag_action_cond: if self._config.disag_action_cond:
inputs = torch.concat([inputs, data["action"]], -1) inputs = torch.concat([inputs, torch.Tensor(data["action"]).to(self._config.device)], -1)
metrics.update(self._train_ensemble(inputs, target)) metrics.update(self._train_ensemble(inputs, target))
metrics.update(self._behavior._train(start, self._intrinsic_reward)[-1]) metrics.update(self._behavior._train(start, self._intrinsic_reward)[-1])
return None, metrics return None, metrics
@ -110,10 +110,7 @@ class Plan2Explore(nn.Module):
disag = torch.log(disag) disag = torch.log(disag)
reward = self._config.expl_intr_scale * disag reward = self._config.expl_intr_scale * disag
if self._config.expl_extr_scale: if self._config.expl_extr_scale:
reward += torch.cast( reward += self._config.expl_extr_scale * self._reward(feat, state, action)
self._config.expl_extr_scale * self._reward(feat, state, action),
torch.float32,
)
return reward return reward
def _train_ensemble(self, inputs, targets): def _train_ensemble(self, inputs, targets):