diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index 31df3d8..7d333d2 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -82,6 +82,7 @@ class Experience: log_probs: tuple[Tensor, Tensor] | None = None values: Tensor | None = None step_size: int | None = None + lens: Tensor | None = None, agent_index: int = 0 is_from_world_model: bool = True @@ -2560,12 +2561,16 @@ class DynamicsWorldModel(Module): # returning agent actions, rewards, and log probs + values for policy optimization + batch, device = latents.shape[0], latents.device + experience_lens = torch.full((batch,), time_steps, device = device) + gen = Experience( latents = latents, video = video, proprio = proprio if has_proprio else None, step_size = step_size, agent_index = agent_index, + lens = experience_lens, is_from_world_model = True ) diff --git a/pyproject.toml b/pyproject.toml index c126938..ab2e923 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "dreamer4" -version = "0.0.72" +version = "0.0.73" description = "Dreamer 4" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" }