sim trainer needs to take care of agent embedding and old actions

This commit is contained in:
lucidrains 2025-10-29 11:15:11 -07:00
parent 586379f2c8
commit cb54121ace
2 changed files with 24 additions and 1 deletions

View File

@ -396,13 +396,20 @@ class SimTrainer(Module):
old_values = experience.values
rewards = experience.rewards
has_agent_embed = exists(experience.agent_embed)
agent_embed = experience.agent_embed
discrete_actions, continuous_actions = experience.actions
discrete_log_probs, continuous_log_probs = experience.log_probs
discrete_old_action_unembeds, continuous_old_action_unembeds = default(experience.old_action_unembeds, (None, None))
# handle empties
empty_tensor = torch.empty_like(rewards)
agent_embed = default(agent_embed, empty_tensor)
has_discrete = exists(discrete_actions)
has_continuous = exists(continuous_actions)
@ -412,6 +419,9 @@ class SimTrainer(Module):
discrete_log_probs = default(discrete_log_probs, empty_tensor)
continuous_log_probs = default(continuous_log_probs, empty_tensor)
discrete_old_action_unembeds = default(discrete_old_action_unembeds, empty_tensor)
continuous_old_action_unembeds = default(discrete_old_action_unembeds, empty_tensor)
# create the dataset and dataloader
dataset = TensorDataset(
@ -420,6 +430,9 @@ class SimTrainer(Module):
continuous_actions,
discrete_log_probs,
continuous_log_probs,
agent_embed,
discrete_old_action_unembeds,
continuous_old_action_unembeds,
old_values,
rewards
)
@ -434,6 +447,9 @@ class SimTrainer(Module):
continuous_actions,
discrete_log_probs,
continuous_log_probs,
agent_embed,
discrete_old_action_unembeds,
continuous_old_action_unembeds,
old_values,
rewards
) in dataloader:
@ -448,10 +464,17 @@ class SimTrainer(Module):
continuous_log_probs if has_continuous else None
)
old_action_unembeds = (
discrete_old_action_unembeds if has_discrete else None,
continuous_old_action_unembeds if has_continuous else None
)
batch_experience = Experience(
latents = latents,
actions = actions,
log_probs = log_probs,
agent_embed = agent_embed if has_agent_embed else None,
old_action_unembeds = old_action_unembeds,
values = old_values,
rewards = rewards,
step_size = step_size,

View File

@ -1,6 +1,6 @@
[project]
name = "dreamer4"
version = "0.0.95"
version = "0.0.96"
description = "Dreamer 4"
authors = [
{ name = "Phil Wang", email = "lucidrains@gmail.com" }