Fixes for Plan2Explore with actions and for windows.
This commit is contained in:
parent
308eeff09a
commit
b07badeee6
@ -55,10 +55,10 @@ class Dreamer(nn.Module):
|
||||
self._task_behavior = models.ImagBehavior(
|
||||
config, self._wm, config.behavior_stop_grad
|
||||
)
|
||||
if config.compile:
|
||||
if config.compile and os.name != 'nt': # compilation is not supported on windows
|
||||
self._wm = torch.compile(self._wm)
|
||||
self._task_behavior = torch.compile(self._task_behavior)
|
||||
reward = lambda f, s, a: self._wm.heads["reward"](f).mean
|
||||
reward = lambda f, s, a: self._wm.heads["reward"](f).mean()
|
||||
self._expl_behavior = dict(
|
||||
greedy=lambda: self._task_behavior,
|
||||
random=lambda: expl.Random(config, act_space),
|
||||
|
@ -58,7 +58,7 @@ class Plan2Explore(nn.Module):
|
||||
"feat": config.dyn_stoch + config.dyn_deter,
|
||||
}[self._config.disag_target]
|
||||
kw = dict(
|
||||
inp_dim=feat_size, # pytorch version
|
||||
inp_dim=feat_size + config.num_actions if config.disag_action_cond else 0, # pytorch version
|
||||
shape=size,
|
||||
layers=config.disag_layers,
|
||||
units=config.disag_units,
|
||||
@ -93,7 +93,7 @@ class Plan2Explore(nn.Module):
|
||||
}[self._config.disag_target]
|
||||
inputs = context["feat"]
|
||||
if self._config.disag_action_cond:
|
||||
inputs = torch.concat([inputs, data["action"]], -1)
|
||||
inputs = torch.concat([inputs, torch.Tensor(data["action"]).to(self._config.device)], -1)
|
||||
metrics.update(self._train_ensemble(inputs, target))
|
||||
metrics.update(self._behavior._train(start, self._intrinsic_reward)[-1])
|
||||
return None, metrics
|
||||
@ -110,10 +110,7 @@ class Plan2Explore(nn.Module):
|
||||
disag = torch.log(disag)
|
||||
reward = self._config.expl_intr_scale * disag
|
||||
if self._config.expl_extr_scale:
|
||||
reward += torch.cast(
|
||||
self._config.expl_extr_scale * self._reward(feat, state, action),
|
||||
torch.float32,
|
||||
)
|
||||
reward += self._config.expl_extr_scale * self._reward(feat, state, action)
|
||||
return reward
|
||||
|
||||
def _train_ensemble(self, inputs, targets):
|
||||
|
Loading…
x
Reference in New Issue
Block a user