diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index ce5bd64..2d82883 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -1700,20 +1700,20 @@ class DynamicsWorldModel(Module): return list(set(params) - set(self.video_tokenizer.parameters())) - def learn_policy_from_generations( + def learn_from_experience( self, - generation: Experience, + experience: Experience, policy_optim: Optimizer | None = None, value_optim: Optimizer | None = None ): - latents = generation.latents - actions = generation.actions - old_log_probs = generation.log_probs - old_values = generation.values - rewards = generation.rewards + latents = experience.latents + actions = experience.actions + old_log_probs = experience.log_probs + old_values = experience.values + rewards = experience.rewards - step_size = generation.step_size - agent_index = generation.agent_index + step_size = experience.step_size + agent_index = experience.agent_index assert all([*map(exists, (old_log_probs, actions, old_values, rewards, step_size))]), 'the generations need to contain the log probs, values, and rewards for policy optimization' diff --git a/pyproject.toml b/pyproject.toml index daa3073..5dd5629 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "dreamer4" -version = "0.0.31" +version = "0.0.32" description = "Dreamer 4" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" } diff --git a/tests/test_dreamer.py b/tests/test_dreamer.py index ac983b2..7a0a0ca 100644 --- a/tests/test_dreamer.py +++ b/tests/test_dreamer.py @@ -238,7 +238,7 @@ def test_action_with_world_model(): # take a reinforcement learning step - actor_loss, critic_loss = dynamics.learn_policy_from_generations(gen) + actor_loss, critic_loss = dynamics.learn_from_experience(gen) actor_loss.backward(retain_graph = True) critic_loss.backward()