sim trainer needs to take care of agent embedding and old actions
This commit is contained in:
parent
586379f2c8
commit
cb54121ace
@ -396,13 +396,20 @@ class SimTrainer(Module):
|
||||
old_values = experience.values
|
||||
rewards = experience.rewards
|
||||
|
||||
has_agent_embed = exists(experience.agent_embed)
|
||||
agent_embed = experience.agent_embed
|
||||
|
||||
discrete_actions, continuous_actions = experience.actions
|
||||
discrete_log_probs, continuous_log_probs = experience.log_probs
|
||||
|
||||
discrete_old_action_unembeds, continuous_old_action_unembeds = default(experience.old_action_unembeds, (None, None))
|
||||
|
||||
# handle empties
|
||||
|
||||
empty_tensor = torch.empty_like(rewards)
|
||||
|
||||
agent_embed = default(agent_embed, empty_tensor)
|
||||
|
||||
has_discrete = exists(discrete_actions)
|
||||
has_continuous = exists(continuous_actions)
|
||||
|
||||
@ -412,6 +419,9 @@ class SimTrainer(Module):
|
||||
discrete_log_probs = default(discrete_log_probs, empty_tensor)
|
||||
continuous_log_probs = default(continuous_log_probs, empty_tensor)
|
||||
|
||||
discrete_old_action_unembeds = default(discrete_old_action_unembeds, empty_tensor)
|
||||
continuous_old_action_unembeds = default(discrete_old_action_unembeds, empty_tensor)
|
||||
|
||||
# create the dataset and dataloader
|
||||
|
||||
dataset = TensorDataset(
|
||||
@ -420,6 +430,9 @@ class SimTrainer(Module):
|
||||
continuous_actions,
|
||||
discrete_log_probs,
|
||||
continuous_log_probs,
|
||||
agent_embed,
|
||||
discrete_old_action_unembeds,
|
||||
continuous_old_action_unembeds,
|
||||
old_values,
|
||||
rewards
|
||||
)
|
||||
@ -434,6 +447,9 @@ class SimTrainer(Module):
|
||||
continuous_actions,
|
||||
discrete_log_probs,
|
||||
continuous_log_probs,
|
||||
agent_embed,
|
||||
discrete_old_action_unembeds,
|
||||
continuous_old_action_unembeds,
|
||||
old_values,
|
||||
rewards
|
||||
) in dataloader:
|
||||
@ -448,10 +464,17 @@ class SimTrainer(Module):
|
||||
continuous_log_probs if has_continuous else None
|
||||
)
|
||||
|
||||
old_action_unembeds = (
|
||||
discrete_old_action_unembeds if has_discrete else None,
|
||||
continuous_old_action_unembeds if has_continuous else None
|
||||
)
|
||||
|
||||
batch_experience = Experience(
|
||||
latents = latents,
|
||||
actions = actions,
|
||||
log_probs = log_probs,
|
||||
agent_embed = agent_embed if has_agent_embed else None,
|
||||
old_action_unembeds = old_action_unembeds,
|
||||
values = old_values,
|
||||
rewards = rewards,
|
||||
step_size = step_size,
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "dreamer4"
|
||||
version = "0.0.95"
|
||||
version = "0.0.96"
|
||||
description = "Dreamer 4"
|
||||
authors = [
|
||||
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user