sim trainer needs to take care of agent embedding and old actions
This commit is contained in:
parent
586379f2c8
commit
cb54121ace
@ -396,13 +396,20 @@ class SimTrainer(Module):
|
|||||||
old_values = experience.values
|
old_values = experience.values
|
||||||
rewards = experience.rewards
|
rewards = experience.rewards
|
||||||
|
|
||||||
|
has_agent_embed = exists(experience.agent_embed)
|
||||||
|
agent_embed = experience.agent_embed
|
||||||
|
|
||||||
discrete_actions, continuous_actions = experience.actions
|
discrete_actions, continuous_actions = experience.actions
|
||||||
discrete_log_probs, continuous_log_probs = experience.log_probs
|
discrete_log_probs, continuous_log_probs = experience.log_probs
|
||||||
|
|
||||||
|
discrete_old_action_unembeds, continuous_old_action_unembeds = default(experience.old_action_unembeds, (None, None))
|
||||||
|
|
||||||
# handle empties
|
# handle empties
|
||||||
|
|
||||||
empty_tensor = torch.empty_like(rewards)
|
empty_tensor = torch.empty_like(rewards)
|
||||||
|
|
||||||
|
agent_embed = default(agent_embed, empty_tensor)
|
||||||
|
|
||||||
has_discrete = exists(discrete_actions)
|
has_discrete = exists(discrete_actions)
|
||||||
has_continuous = exists(continuous_actions)
|
has_continuous = exists(continuous_actions)
|
||||||
|
|
||||||
@ -412,6 +419,9 @@ class SimTrainer(Module):
|
|||||||
discrete_log_probs = default(discrete_log_probs, empty_tensor)
|
discrete_log_probs = default(discrete_log_probs, empty_tensor)
|
||||||
continuous_log_probs = default(continuous_log_probs, empty_tensor)
|
continuous_log_probs = default(continuous_log_probs, empty_tensor)
|
||||||
|
|
||||||
|
discrete_old_action_unembeds = default(discrete_old_action_unembeds, empty_tensor)
|
||||||
|
continuous_old_action_unembeds = default(discrete_old_action_unembeds, empty_tensor)
|
||||||
|
|
||||||
# create the dataset and dataloader
|
# create the dataset and dataloader
|
||||||
|
|
||||||
dataset = TensorDataset(
|
dataset = TensorDataset(
|
||||||
@ -420,6 +430,9 @@ class SimTrainer(Module):
|
|||||||
continuous_actions,
|
continuous_actions,
|
||||||
discrete_log_probs,
|
discrete_log_probs,
|
||||||
continuous_log_probs,
|
continuous_log_probs,
|
||||||
|
agent_embed,
|
||||||
|
discrete_old_action_unembeds,
|
||||||
|
continuous_old_action_unembeds,
|
||||||
old_values,
|
old_values,
|
||||||
rewards
|
rewards
|
||||||
)
|
)
|
||||||
@ -434,6 +447,9 @@ class SimTrainer(Module):
|
|||||||
continuous_actions,
|
continuous_actions,
|
||||||
discrete_log_probs,
|
discrete_log_probs,
|
||||||
continuous_log_probs,
|
continuous_log_probs,
|
||||||
|
agent_embed,
|
||||||
|
discrete_old_action_unembeds,
|
||||||
|
continuous_old_action_unembeds,
|
||||||
old_values,
|
old_values,
|
||||||
rewards
|
rewards
|
||||||
) in dataloader:
|
) in dataloader:
|
||||||
@ -448,10 +464,17 @@ class SimTrainer(Module):
|
|||||||
continuous_log_probs if has_continuous else None
|
continuous_log_probs if has_continuous else None
|
||||||
)
|
)
|
||||||
|
|
||||||
|
old_action_unembeds = (
|
||||||
|
discrete_old_action_unembeds if has_discrete else None,
|
||||||
|
continuous_old_action_unembeds if has_continuous else None
|
||||||
|
)
|
||||||
|
|
||||||
batch_experience = Experience(
|
batch_experience = Experience(
|
||||||
latents = latents,
|
latents = latents,
|
||||||
actions = actions,
|
actions = actions,
|
||||||
log_probs = log_probs,
|
log_probs = log_probs,
|
||||||
|
agent_embed = agent_embed if has_agent_embed else None,
|
||||||
|
old_action_unembeds = old_action_unembeds,
|
||||||
values = old_values,
|
values = old_values,
|
||||||
rewards = rewards,
|
rewards = rewards,
|
||||||
step_size = step_size,
|
step_size = step_size,
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "dreamer4"
|
name = "dreamer4"
|
||||||
version = "0.0.95"
|
version = "0.0.96"
|
||||||
description = "Dreamer 4"
|
description = "Dreamer 4"
|
||||||
authors = [
|
authors = [
|
||||||
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user