This commit is contained in:
lucidrains 2025-10-17 14:44:25 -07:00
parent c967404471
commit 22e13c45fc
3 changed files with 11 additions and 11 deletions

View File

@ -1700,20 +1700,20 @@ class DynamicsWorldModel(Module):
return list(set(params) - set(self.video_tokenizer.parameters()))
def learn_policy_from_generations(
def learn_from_experience(
self,
generation: Experience,
experience: Experience,
policy_optim: Optimizer | None = None,
value_optim: Optimizer | None = None
):
latents = generation.latents
actions = generation.actions
old_log_probs = generation.log_probs
old_values = generation.values
rewards = generation.rewards
latents = experience.latents
actions = experience.actions
old_log_probs = experience.log_probs
old_values = experience.values
rewards = experience.rewards
step_size = generation.step_size
agent_index = generation.agent_index
step_size = experience.step_size
agent_index = experience.agent_index
assert all([*map(exists, (old_log_probs, actions, old_values, rewards, step_size))]), 'the generations need to contain the log probs, values, and rewards for policy optimization'

View File

@ -1,6 +1,6 @@
[project]
name = "dreamer4"
version = "0.0.31"
version = "0.0.32"
description = "Dreamer 4"
authors = [
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

View File

@ -238,7 +238,7 @@ def test_action_with_world_model():
# take a reinforcement learning step
actor_loss, critic_loss = dynamics.learn_policy_from_generations(gen)
actor_loss, critic_loss = dynamics.learn_from_experience(gen)
actor_loss.backward(retain_graph = True)
critic_loss.backward()