From 61773c8219c13e4b9eef9fc7bd796526ad359448 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Fri, 17 Oct 2025 08:06:24 -0700 Subject: [PATCH] eventually we will need to learn from the outside stream of experience --- dreamer4/dreamer4.py | 10 ++++++---- pyproject.toml | 2 +- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index 8c88680..56ec538 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -65,7 +65,7 @@ TokenizerLosses = namedtuple('TokenizerLosses', ('recon', 'lpips')) WorldModelLosses = namedtuple('WorldModelLosses', ('flow', 'reward', 'behavior_clone')) @dataclass -class WorldModelGenerations: +class Experience: latents: Tensor video: Tensor | None = None rewards: Tensor | None = None @@ -74,6 +74,7 @@ class WorldModelGenerations: values: Tensor | None = None step_size: int | None = None agent_index: int = 0 + is_from_world_model: bool = True # helpers @@ -1665,7 +1666,7 @@ class DynamicsWorldModel(Module): def learn_policy_from_generations( self, - generation: WorldModelGenerations + generation: Experience ): latents = generation.latents actions = generation.actions @@ -1914,11 +1915,12 @@ class DynamicsWorldModel(Module): # returning agent actions, rewards, and log probs + values for policy optimization - gen = WorldModelGenerations( + gen = Experience( latents = latents, video = video, step_size = step_size, - agent_index = agent_index + agent_index = agent_index, + is_from_world_model = True ) if return_rewards_per_frame: diff --git a/pyproject.toml b/pyproject.toml index df0cd54..1aed066 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "dreamer4" -version = "0.0.27" +version = "0.0.29" description = "Dreamer 4" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" }