modification of expl.

This commit is contained in:
NM512 2023-05-21 08:17:47 +09:00
parent b8ef214efa
commit 02c3d45fcf
3 changed files with 28 additions and 16 deletions

View File

@ -61,7 +61,7 @@ class Dreamer(nn.Module):
reward = lambda f, s, a: self._wm.heads["reward"](f).mean
self._expl_behavior = dict(
greedy=lambda: self._task_behavior,
random=lambda: expl.Random(config),
random=lambda: expl.Random(config, act_space),
plan2explore=lambda: expl.Plan2Explore(config, self._wm, reward),
)[config.expl_behavior]().to(self._config.device)

View File

@ -8,22 +8,35 @@ import tools
class Random(nn.Module):
def __init__(self, config):
def __init__(self, config, act_space):
super(Random, self).__init__()
self._config = config
self._act_space = act_space
def actor(self, feat):
shape = feat.shape[:-1] + [self._config.num_actions]
if self._config.actor_dist == "onehot":
return tools.OneHotDist(torch.zeros(shape))
return tools.OneHotDist(
torch.zeros(self._config.num_actions)
.repeat(self._config.envs, 1)
.to(self._config.device)
)
else:
ones = torch.ones(shape)
return tools.ContDist(torchd.uniform.Uniform(-ones, ones))
return torchd.independent.Independent(
torchd.uniform.Uniform(
torch.Tensor(self._act_space.low)
.repeat(self._config.envs, 1)
.to(self._config.device),
torch.Tensor(self._act_space.high)
.repeat(self._config.envs, 1)
.to(self._config.device),
),
1,
)
def train(self, start, context):
def train(self, start, context, data):
return None, {}
# class Plan2Explore(tools.Module):
class Plan2Explore(nn.Module):
def __init__(self, config, world_model, reward=None):
super(Plan2Explore, self).__init__()
@ -39,7 +52,7 @@ class Plan2Explore(nn.Module):
feat_size = config.dyn_stoch + config.dyn_deter
stoch = config.dyn_stoch
size = {
"embed": 32 * config.cnn_depth,
"embed": world_model.embed_size,
"stoch": stoch,
"deter": config.dyn_deter,
"feat": config.dyn_stoch + config.dyn_deter,

View File

@ -36,7 +36,7 @@ class WorldModel(nn.Module):
self._config = config
shapes = {k: tuple(v.shape) for k, v in obs_space.spaces.items()}
self.encoder = networks.MultiEncoder(shapes, **config.encoder)
embed_size = self.encoder.outdim
self.embed_size = self.encoder.outdim
self.dynamics = networks.RSSM(
config.dyn_stoch,
config.dyn_deter,
@ -56,7 +56,7 @@ class WorldModel(nn.Module):
config.unimix_ratio,
config.initial,
config.num_actions,
embed_size,
self.embed_size,
config.device,
)
self.heads = nn.ModuleDict()
@ -228,7 +228,7 @@ class ImagBehavior(nn.Module):
else:
feat_size = config.dyn_stoch + config.dyn_deter
self.actor = networks.ActionHead(
feat_size, # pytorch version
feat_size,
config.num_actions,
config.actor_layers,
config.units,
@ -244,7 +244,7 @@ class ImagBehavior(nn.Module):
)
if config.value_head == "symlog_disc":
self.value = networks.MLP(
feat_size, # pytorch version
feat_size,
(255,),
config.value_layers,
config.units,
@ -256,7 +256,7 @@ class ImagBehavior(nn.Module):
)
else:
self.value = networks.MLP(
feat_size, # pytorch version
feat_size,
[],
config.value_layers,
config.units,
@ -356,7 +356,7 @@ class ImagBehavior(nn.Module):
)
else:
metrics.update(tools.tensorstats(imag_action, "imag_action"))
metrics["actor_ent"] = to_np(torch.mean(actor_ent))
metrics["actor_entropy"] = to_np(torch.mean(actor_ent))
with tools.RequiresGrad(self):
metrics.update(self._actor_opt(actor_loss, self.actor.parameters()))
metrics.update(self._value_opt(value_loss, self.value.parameters()))
@ -462,7 +462,6 @@ class ImagBehavior(nn.Module):
if not self._config.future_entropy and (self._config.actor_entropy() > 0):
actor_entropy = self._config.actor_entropy() * actor_ent[:-1][:, :, None]
actor_target += actor_entropy
metrics["actor_entropy"] = to_np(torch.mean(actor_entropy))
if not self._config.future_entropy and (self._config.actor_state_entropy() > 0):
state_entropy = self._config.actor_state_entropy() * state_ent[:-1]
actor_target += state_entropy