commit
6c861ca7cb
@ -55,10 +55,10 @@ class Dreamer(nn.Module):
|
||||
self._task_behavior = models.ImagBehavior(
|
||||
config, self._wm, config.behavior_stop_grad
|
||||
)
|
||||
if config.compile:
|
||||
if config.compile and os.name != 'nt': # compilation is not supported on windows
|
||||
self._wm = torch.compile(self._wm)
|
||||
self._task_behavior = torch.compile(self._task_behavior)
|
||||
reward = lambda f, s, a: self._wm.heads["reward"](f).mean
|
||||
reward = lambda f, s, a: self._wm.heads["reward"](f).mean()
|
||||
self._expl_behavior = dict(
|
||||
greedy=lambda: self._task_behavior,
|
||||
random=lambda: expl.Random(config, act_space),
|
||||
|
@ -58,7 +58,7 @@ class Plan2Explore(nn.Module):
|
||||
"feat": config.dyn_stoch + config.dyn_deter,
|
||||
}[self._config.disag_target]
|
||||
kw = dict(
|
||||
inp_dim=feat_size, # pytorch version
|
||||
inp_dim=feat_size + config.num_actions if config.disag_action_cond else 0, # pytorch version
|
||||
shape=size,
|
||||
layers=config.disag_layers,
|
||||
units=config.disag_units,
|
||||
@ -93,7 +93,7 @@ class Plan2Explore(nn.Module):
|
||||
}[self._config.disag_target]
|
||||
inputs = context["feat"]
|
||||
if self._config.disag_action_cond:
|
||||
inputs = torch.concat([inputs, data["action"]], -1)
|
||||
inputs = torch.concat([inputs, torch.Tensor(data["action"]).to(self._config.device)], -1)
|
||||
metrics.update(self._train_ensemble(inputs, target))
|
||||
metrics.update(self._behavior._train(start, self._intrinsic_reward)[-1])
|
||||
return None, metrics
|
||||
@ -110,10 +110,7 @@ class Plan2Explore(nn.Module):
|
||||
disag = torch.log(disag)
|
||||
reward = self._config.expl_intr_scale * disag
|
||||
if self._config.expl_extr_scale:
|
||||
reward += torch.cast(
|
||||
self._config.expl_extr_scale * self._reward(feat, state, action),
|
||||
torch.float32,
|
||||
)
|
||||
reward += self._config.expl_extr_scale * self._reward(feat, state, action)
|
||||
return reward
|
||||
|
||||
def _train_ensemble(self, inputs, targets):
|
||||
|
@ -215,7 +215,7 @@ class RSSM(nn.Module):
|
||||
is_first,
|
||||
is_first.shape + (1,) * (len(val.shape) - len(is_first.shape)),
|
||||
)
|
||||
val = val * (1.0 - is_first_r) + init_state[key] * is_first_r
|
||||
prev_state[key] = val * (1.0 - is_first_r) + init_state[key] * is_first_r
|
||||
|
||||
prior = self.img_step(prev_state, prev_action, None, sample)
|
||||
if self._shared:
|
||||
|
Loading…
x
Reference in New Issue
Block a user