modification of expl.
This commit is contained in:
parent
b8ef214efa
commit
02c3d45fcf
@ -61,7 +61,7 @@ class Dreamer(nn.Module):
|
||||
reward = lambda f, s, a: self._wm.heads["reward"](f).mean
|
||||
self._expl_behavior = dict(
|
||||
greedy=lambda: self._task_behavior,
|
||||
random=lambda: expl.Random(config),
|
||||
random=lambda: expl.Random(config, act_space),
|
||||
plan2explore=lambda: expl.Plan2Explore(config, self._wm, reward),
|
||||
)[config.expl_behavior]().to(self._config.device)
|
||||
|
||||
|
@ -8,22 +8,35 @@ import tools
|
||||
|
||||
|
||||
class Random(nn.Module):
|
||||
def __init__(self, config):
|
||||
def __init__(self, config, act_space):
|
||||
super(Random, self).__init__()
|
||||
self._config = config
|
||||
self._act_space = act_space
|
||||
|
||||
def actor(self, feat):
|
||||
shape = feat.shape[:-1] + [self._config.num_actions]
|
||||
if self._config.actor_dist == "onehot":
|
||||
return tools.OneHotDist(torch.zeros(shape))
|
||||
return tools.OneHotDist(
|
||||
torch.zeros(self._config.num_actions)
|
||||
.repeat(self._config.envs, 1)
|
||||
.to(self._config.device)
|
||||
)
|
||||
else:
|
||||
ones = torch.ones(shape)
|
||||
return tools.ContDist(torchd.uniform.Uniform(-ones, ones))
|
||||
return torchd.independent.Independent(
|
||||
torchd.uniform.Uniform(
|
||||
torch.Tensor(self._act_space.low)
|
||||
.repeat(self._config.envs, 1)
|
||||
.to(self._config.device),
|
||||
torch.Tensor(self._act_space.high)
|
||||
.repeat(self._config.envs, 1)
|
||||
.to(self._config.device),
|
||||
),
|
||||
1,
|
||||
)
|
||||
|
||||
def train(self, start, context):
|
||||
def train(self, start, context, data):
|
||||
return None, {}
|
||||
|
||||
|
||||
# class Plan2Explore(tools.Module):
|
||||
class Plan2Explore(nn.Module):
|
||||
def __init__(self, config, world_model, reward=None):
|
||||
super(Plan2Explore, self).__init__()
|
||||
@ -39,7 +52,7 @@ class Plan2Explore(nn.Module):
|
||||
feat_size = config.dyn_stoch + config.dyn_deter
|
||||
stoch = config.dyn_stoch
|
||||
size = {
|
||||
"embed": 32 * config.cnn_depth,
|
||||
"embed": world_model.embed_size,
|
||||
"stoch": stoch,
|
||||
"deter": config.dyn_deter,
|
||||
"feat": config.dyn_stoch + config.dyn_deter,
|
||||
|
13
models.py
13
models.py
@ -36,7 +36,7 @@ class WorldModel(nn.Module):
|
||||
self._config = config
|
||||
shapes = {k: tuple(v.shape) for k, v in obs_space.spaces.items()}
|
||||
self.encoder = networks.MultiEncoder(shapes, **config.encoder)
|
||||
embed_size = self.encoder.outdim
|
||||
self.embed_size = self.encoder.outdim
|
||||
self.dynamics = networks.RSSM(
|
||||
config.dyn_stoch,
|
||||
config.dyn_deter,
|
||||
@ -56,7 +56,7 @@ class WorldModel(nn.Module):
|
||||
config.unimix_ratio,
|
||||
config.initial,
|
||||
config.num_actions,
|
||||
embed_size,
|
||||
self.embed_size,
|
||||
config.device,
|
||||
)
|
||||
self.heads = nn.ModuleDict()
|
||||
@ -228,7 +228,7 @@ class ImagBehavior(nn.Module):
|
||||
else:
|
||||
feat_size = config.dyn_stoch + config.dyn_deter
|
||||
self.actor = networks.ActionHead(
|
||||
feat_size, # pytorch version
|
||||
feat_size,
|
||||
config.num_actions,
|
||||
config.actor_layers,
|
||||
config.units,
|
||||
@ -244,7 +244,7 @@ class ImagBehavior(nn.Module):
|
||||
)
|
||||
if config.value_head == "symlog_disc":
|
||||
self.value = networks.MLP(
|
||||
feat_size, # pytorch version
|
||||
feat_size,
|
||||
(255,),
|
||||
config.value_layers,
|
||||
config.units,
|
||||
@ -256,7 +256,7 @@ class ImagBehavior(nn.Module):
|
||||
)
|
||||
else:
|
||||
self.value = networks.MLP(
|
||||
feat_size, # pytorch version
|
||||
feat_size,
|
||||
[],
|
||||
config.value_layers,
|
||||
config.units,
|
||||
@ -356,7 +356,7 @@ class ImagBehavior(nn.Module):
|
||||
)
|
||||
else:
|
||||
metrics.update(tools.tensorstats(imag_action, "imag_action"))
|
||||
metrics["actor_ent"] = to_np(torch.mean(actor_ent))
|
||||
metrics["actor_entropy"] = to_np(torch.mean(actor_ent))
|
||||
with tools.RequiresGrad(self):
|
||||
metrics.update(self._actor_opt(actor_loss, self.actor.parameters()))
|
||||
metrics.update(self._value_opt(value_loss, self.value.parameters()))
|
||||
@ -462,7 +462,6 @@ class ImagBehavior(nn.Module):
|
||||
if not self._config.future_entropy and (self._config.actor_entropy() > 0):
|
||||
actor_entropy = self._config.actor_entropy() * actor_ent[:-1][:, :, None]
|
||||
actor_target += actor_entropy
|
||||
metrics["actor_entropy"] = to_np(torch.mean(actor_entropy))
|
||||
if not self._config.future_entropy and (self._config.actor_state_entropy() > 0):
|
||||
state_entropy = self._config.actor_state_entropy() * state_ent[:-1]
|
||||
actor_target += state_entropy
|
||||
|
Loading…
x
Reference in New Issue
Block a user