diff --git a/dreamer4/trainers.py b/dreamer4/trainers.py index 7004bd4..bb4b198 100644 --- a/dreamer4/trainers.py +++ b/dreamer4/trainers.py @@ -396,13 +396,20 @@ class SimTrainer(Module): old_values = experience.values rewards = experience.rewards + has_agent_embed = exists(experience.agent_embed) + agent_embed = experience.agent_embed + discrete_actions, continuous_actions = experience.actions discrete_log_probs, continuous_log_probs = experience.log_probs + discrete_old_action_unembeds, continuous_old_action_unembeds = default(experience.old_action_unembeds, (None, None)) + # handle empties empty_tensor = torch.empty_like(rewards) + agent_embed = default(agent_embed, empty_tensor) + has_discrete = exists(discrete_actions) has_continuous = exists(continuous_actions) @@ -412,6 +419,9 @@ class SimTrainer(Module): discrete_log_probs = default(discrete_log_probs, empty_tensor) continuous_log_probs = default(continuous_log_probs, empty_tensor) + discrete_old_action_unembeds = default(discrete_old_action_unembeds, empty_tensor) + continuous_old_action_unembeds = default(discrete_old_action_unembeds, empty_tensor) + # create the dataset and dataloader dataset = TensorDataset( @@ -420,6 +430,9 @@ class SimTrainer(Module): continuous_actions, discrete_log_probs, continuous_log_probs, + agent_embed, + discrete_old_action_unembeds, + continuous_old_action_unembeds, old_values, rewards ) @@ -434,6 +447,9 @@ class SimTrainer(Module): continuous_actions, discrete_log_probs, continuous_log_probs, + agent_embed, + discrete_old_action_unembeds, + continuous_old_action_unembeds, old_values, rewards ) in dataloader: @@ -448,10 +464,17 @@ class SimTrainer(Module): continuous_log_probs if has_continuous else None ) + old_action_unembeds = ( + discrete_old_action_unembeds if has_discrete else None, + continuous_old_action_unembeds if has_continuous else None + ) + batch_experience = Experience( latents = latents, actions = actions, log_probs = log_probs, + agent_embed = agent_embed if has_agent_embed else None, + old_action_unembeds = old_action_unembeds, values = old_values, rewards = rewards, step_size = step_size, diff --git a/pyproject.toml b/pyproject.toml index e68d8af..f583388 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "dreamer4" -version = "0.0.95" +version = "0.0.96" description = "Dreamer 4" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" }