separate action and agent embeds
This commit is contained in:
parent
6dbdc3d7d8
commit
ff81dd761b
@ -39,6 +39,7 @@ from assoc_scan import AssocScan
|
||||
# vh, vw - video height and width
|
||||
|
||||
import einx
|
||||
from einx import add, multiply
|
||||
from einops import einsum, rearrange, repeat, reduce, pack, unpack
|
||||
from einops.layers.torch import Rearrange, Reduce
|
||||
|
||||
@ -579,7 +580,7 @@ class ActionEmbedder(Module):
|
||||
|
||||
# offset the discrete actions based on the action types passed in (by default all discrete actions) and the calculated offset
|
||||
|
||||
discrete_actions_offsetted = einx.add('... na, na', discrete_actions, offsets)
|
||||
discrete_actions_offsetted = add('... na, na', discrete_actions, offsets)
|
||||
discrete_embeds = self.discrete_action_embed(discrete_actions_offsetted)
|
||||
|
||||
if exists(continuous_actions):
|
||||
@ -599,7 +600,7 @@ class ActionEmbedder(Module):
|
||||
|
||||
# continuous embed is just the selected continuous action type with the scale
|
||||
|
||||
continuous_embeds = einx.multiply('na d, ... na -> ... na d', continuous_action_embed, continuous_actions)
|
||||
continuous_embeds = multiply('na d, ... na -> ... na d', continuous_action_embed, continuous_actions)
|
||||
|
||||
# return not pooled
|
||||
|
||||
@ -769,7 +770,7 @@ class MultiHeadRMSNorm(Module):
|
||||
):
|
||||
normed = l2norm(x)
|
||||
scale = (self.gamma + 1.) * self.scale
|
||||
return einx.multiply('... h n d, h d', normed, scale)
|
||||
return multiply('... h n d, h d', normed, scale)
|
||||
|
||||
# naive attend
|
||||
|
||||
@ -1559,6 +1560,8 @@ class DynamicsWorldModel(Module):
|
||||
# they sum all the actions into a single token
|
||||
|
||||
self.num_agents = num_agents
|
||||
|
||||
self.agent_learned_embed = Parameter(randn(self.num_agents, dim) * 1e-2)
|
||||
self.action_learned_embed = Parameter(randn(self.num_agents, dim) * 1e-2)
|
||||
|
||||
self.num_tasks = num_tasks
|
||||
@ -1927,13 +1930,13 @@ class DynamicsWorldModel(Module):
|
||||
|
||||
# reinforcement learning related
|
||||
|
||||
agent_tokens = repeat(self.action_learned_embed, '... d -> b ... d', b = batch)
|
||||
agent_tokens = repeat(self.agent_learned_embed, '... d -> b ... d', b = batch)
|
||||
|
||||
if exists(tasks):
|
||||
assert self.num_tasks > 0
|
||||
|
||||
task_embeds = self.task_embed(tasks)
|
||||
agent_tokens = einx.add('b ... d, b d', agent_tokens, task_embeds)
|
||||
agent_tokens = add('b ... d, b d', agent_tokens, task_embeds)
|
||||
|
||||
# maybe evolution
|
||||
|
||||
@ -1941,7 +1944,7 @@ class DynamicsWorldModel(Module):
|
||||
assert exists(self.latent_genes)
|
||||
latent_genes = self.latent_genes[latent_gene_ids]
|
||||
|
||||
agent_tokens = einx.add('b ... d, b d', agent_tokens, latent_genes)
|
||||
agent_tokens = add('b ... d, b d', agent_tokens, latent_genes)
|
||||
|
||||
# handle agent tokens w/ actions and task embeds
|
||||
|
||||
@ -1962,12 +1965,13 @@ class DynamicsWorldModel(Module):
|
||||
|
||||
reward_embeds = pad_at_dim(reward_embeds, (1, -pop_last_reward), dim = -2, value = 0.) # shift as each agent token predicts the next reward
|
||||
|
||||
agent_tokens = einx.add('b t ... d, b t d', agent_tokens, reward_embeds)
|
||||
agent_tokens = add('b t ... d, b t d', agent_tokens, reward_embeds)
|
||||
|
||||
# maybe create the action tokens
|
||||
|
||||
if exists(discrete_actions) or exists(continuous_actions):
|
||||
assert self.action_embedder.has_actions
|
||||
assert self.num_agents == 1, 'only one agent allowed for now'
|
||||
|
||||
action_tokens = self.action_embedder(
|
||||
discrete_actions = discrete_actions,
|
||||
@ -1975,6 +1979,9 @@ class DynamicsWorldModel(Module):
|
||||
continuous_actions = continuous_actions,
|
||||
continuous_action_types = continuous_action_types
|
||||
)
|
||||
|
||||
action_tokens = add('1 d, b t d', self.action_learned_embed, action_tokens)
|
||||
|
||||
else:
|
||||
action_tokens = agent_tokens[:, :, 0:0] # else empty off agent tokens
|
||||
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "dreamer4"
|
||||
version = "0.0.18"
|
||||
version = "0.0.19"
|
||||
description = "Dreamer 4"
|
||||
authors = [
|
||||
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user