separate action and agent embeds

This commit is contained in:
lucidrains 2025-10-13 11:36:21 -07:00
parent 6dbdc3d7d8
commit ff81dd761b
2 changed files with 15 additions and 8 deletions

View File

@ -39,6 +39,7 @@ from assoc_scan import AssocScan
# vh, vw - video height and width # vh, vw - video height and width
import einx import einx
from einx import add, multiply
from einops import einsum, rearrange, repeat, reduce, pack, unpack from einops import einsum, rearrange, repeat, reduce, pack, unpack
from einops.layers.torch import Rearrange, Reduce from einops.layers.torch import Rearrange, Reduce
@ -579,7 +580,7 @@ class ActionEmbedder(Module):
# offset the discrete actions based on the action types passed in (by default all discrete actions) and the calculated offset # offset the discrete actions based on the action types passed in (by default all discrete actions) and the calculated offset
discrete_actions_offsetted = einx.add('... na, na', discrete_actions, offsets) discrete_actions_offsetted = add('... na, na', discrete_actions, offsets)
discrete_embeds = self.discrete_action_embed(discrete_actions_offsetted) discrete_embeds = self.discrete_action_embed(discrete_actions_offsetted)
if exists(continuous_actions): if exists(continuous_actions):
@ -599,7 +600,7 @@ class ActionEmbedder(Module):
# continuous embed is just the selected continuous action type with the scale # continuous embed is just the selected continuous action type with the scale
continuous_embeds = einx.multiply('na d, ... na -> ... na d', continuous_action_embed, continuous_actions) continuous_embeds = multiply('na d, ... na -> ... na d', continuous_action_embed, continuous_actions)
# return not pooled # return not pooled
@ -769,7 +770,7 @@ class MultiHeadRMSNorm(Module):
): ):
normed = l2norm(x) normed = l2norm(x)
scale = (self.gamma + 1.) * self.scale scale = (self.gamma + 1.) * self.scale
return einx.multiply('... h n d, h d', normed, scale) return multiply('... h n d, h d', normed, scale)
# naive attend # naive attend
@ -1559,6 +1560,8 @@ class DynamicsWorldModel(Module):
# they sum all the actions into a single token # they sum all the actions into a single token
self.num_agents = num_agents self.num_agents = num_agents
self.agent_learned_embed = Parameter(randn(self.num_agents, dim) * 1e-2)
self.action_learned_embed = Parameter(randn(self.num_agents, dim) * 1e-2) self.action_learned_embed = Parameter(randn(self.num_agents, dim) * 1e-2)
self.num_tasks = num_tasks self.num_tasks = num_tasks
@ -1927,13 +1930,13 @@ class DynamicsWorldModel(Module):
# reinforcement learning related # reinforcement learning related
agent_tokens = repeat(self.action_learned_embed, '... d -> b ... d', b = batch) agent_tokens = repeat(self.agent_learned_embed, '... d -> b ... d', b = batch)
if exists(tasks): if exists(tasks):
assert self.num_tasks > 0 assert self.num_tasks > 0
task_embeds = self.task_embed(tasks) task_embeds = self.task_embed(tasks)
agent_tokens = einx.add('b ... d, b d', agent_tokens, task_embeds) agent_tokens = add('b ... d, b d', agent_tokens, task_embeds)
# maybe evolution # maybe evolution
@ -1941,7 +1944,7 @@ class DynamicsWorldModel(Module):
assert exists(self.latent_genes) assert exists(self.latent_genes)
latent_genes = self.latent_genes[latent_gene_ids] latent_genes = self.latent_genes[latent_gene_ids]
agent_tokens = einx.add('b ... d, b d', agent_tokens, latent_genes) agent_tokens = add('b ... d, b d', agent_tokens, latent_genes)
# handle agent tokens w/ actions and task embeds # handle agent tokens w/ actions and task embeds
@ -1962,12 +1965,13 @@ class DynamicsWorldModel(Module):
reward_embeds = pad_at_dim(reward_embeds, (1, -pop_last_reward), dim = -2, value = 0.) # shift as each agent token predicts the next reward reward_embeds = pad_at_dim(reward_embeds, (1, -pop_last_reward), dim = -2, value = 0.) # shift as each agent token predicts the next reward
agent_tokens = einx.add('b t ... d, b t d', agent_tokens, reward_embeds) agent_tokens = add('b t ... d, b t d', agent_tokens, reward_embeds)
# maybe create the action tokens # maybe create the action tokens
if exists(discrete_actions) or exists(continuous_actions): if exists(discrete_actions) or exists(continuous_actions):
assert self.action_embedder.has_actions assert self.action_embedder.has_actions
assert self.num_agents == 1, 'only one agent allowed for now'
action_tokens = self.action_embedder( action_tokens = self.action_embedder(
discrete_actions = discrete_actions, discrete_actions = discrete_actions,
@ -1975,6 +1979,9 @@ class DynamicsWorldModel(Module):
continuous_actions = continuous_actions, continuous_actions = continuous_actions,
continuous_action_types = continuous_action_types continuous_action_types = continuous_action_types
) )
action_tokens = add('1 d, b t d', self.action_learned_embed, action_tokens)
else: else:
action_tokens = agent_tokens[:, :, 0:0] # else empty off agent tokens action_tokens = agent_tokens[:, :, 0:0] # else empty off agent tokens

View File

@ -1,6 +1,6 @@
[project] [project]
name = "dreamer4" name = "dreamer4"
version = "0.0.18" version = "0.0.19"
description = "Dreamer 4" description = "Dreamer 4"
authors = [ authors = [
{ name = "Phil Wang", email = "lucidrains@gmail.com" } { name = "Phil Wang", email = "lucidrains@gmail.com" }