separate action and agent embeds

This commit is contained in:
lucidrains 2025-10-13 11:36:21 -07:00
parent 6dbdc3d7d8
commit ff81dd761b
2 changed files with 15 additions and 8 deletions

View File

@ -39,6 +39,7 @@ from assoc_scan import AssocScan
# vh, vw - video height and width
import einx
from einx import add, multiply
from einops import einsum, rearrange, repeat, reduce, pack, unpack
from einops.layers.torch import Rearrange, Reduce
@ -579,7 +580,7 @@ class ActionEmbedder(Module):
# offset the discrete actions based on the action types passed in (by default all discrete actions) and the calculated offset
discrete_actions_offsetted = einx.add('... na, na', discrete_actions, offsets)
discrete_actions_offsetted = add('... na, na', discrete_actions, offsets)
discrete_embeds = self.discrete_action_embed(discrete_actions_offsetted)
if exists(continuous_actions):
@ -599,7 +600,7 @@ class ActionEmbedder(Module):
# continuous embed is just the selected continuous action type with the scale
continuous_embeds = einx.multiply('na d, ... na -> ... na d', continuous_action_embed, continuous_actions)
continuous_embeds = multiply('na d, ... na -> ... na d', continuous_action_embed, continuous_actions)
# return not pooled
@ -769,7 +770,7 @@ class MultiHeadRMSNorm(Module):
):
normed = l2norm(x)
scale = (self.gamma + 1.) * self.scale
return einx.multiply('... h n d, h d', normed, scale)
return multiply('... h n d, h d', normed, scale)
# naive attend
@ -1559,6 +1560,8 @@ class DynamicsWorldModel(Module):
# they sum all the actions into a single token
self.num_agents = num_agents
self.agent_learned_embed = Parameter(randn(self.num_agents, dim) * 1e-2)
self.action_learned_embed = Parameter(randn(self.num_agents, dim) * 1e-2)
self.num_tasks = num_tasks
@ -1927,13 +1930,13 @@ class DynamicsWorldModel(Module):
# reinforcement learning related
agent_tokens = repeat(self.action_learned_embed, '... d -> b ... d', b = batch)
agent_tokens = repeat(self.agent_learned_embed, '... d -> b ... d', b = batch)
if exists(tasks):
assert self.num_tasks > 0
task_embeds = self.task_embed(tasks)
agent_tokens = einx.add('b ... d, b d', agent_tokens, task_embeds)
agent_tokens = add('b ... d, b d', agent_tokens, task_embeds)
# maybe evolution
@ -1941,7 +1944,7 @@ class DynamicsWorldModel(Module):
assert exists(self.latent_genes)
latent_genes = self.latent_genes[latent_gene_ids]
agent_tokens = einx.add('b ... d, b d', agent_tokens, latent_genes)
agent_tokens = add('b ... d, b d', agent_tokens, latent_genes)
# handle agent tokens w/ actions and task embeds
@ -1962,12 +1965,13 @@ class DynamicsWorldModel(Module):
reward_embeds = pad_at_dim(reward_embeds, (1, -pop_last_reward), dim = -2, value = 0.) # shift as each agent token predicts the next reward
agent_tokens = einx.add('b t ... d, b t d', agent_tokens, reward_embeds)
agent_tokens = add('b t ... d, b t d', agent_tokens, reward_embeds)
# maybe create the action tokens
if exists(discrete_actions) or exists(continuous_actions):
assert self.action_embedder.has_actions
assert self.num_agents == 1, 'only one agent allowed for now'
action_tokens = self.action_embedder(
discrete_actions = discrete_actions,
@ -1975,6 +1979,9 @@ class DynamicsWorldModel(Module):
continuous_actions = continuous_actions,
continuous_action_types = continuous_action_types
)
action_tokens = add('1 d, b t d', self.action_learned_embed, action_tokens)
else:
action_tokens = agent_tokens[:, :, 0:0] # else empty off agent tokens

View File

@ -1,6 +1,6 @@
[project]
name = "dreamer4"
version = "0.0.18"
version = "0.0.19"
description = "Dreamer 4"
authors = [
{ name = "Phil Wang", email = "lucidrains@gmail.com" }