From 7ba3988fb95735edb2dcf9cc8fb1c76dd438f361 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Tue, 21 Oct 2025 09:03:20 -0700 Subject: [PATCH] prepare a mock for interacting with online env --- dreamer4/mocks.py | 49 ++++++++++++++++++++++++++++++++++++++++++++ dreamer4/trainers.py | 3 ++- pyproject.toml | 1 + 3 files changed, 52 insertions(+), 1 deletion(-) create mode 100644 dreamer4/mocks.py diff --git a/dreamer4/mocks.py b/dreamer4/mocks.py new file mode 100644 index 0000000..bae22ed --- /dev/null +++ b/dreamer4/mocks.py @@ -0,0 +1,49 @@ +from __future__ import annotations +from random import choice + +import torch +from torch import tensor, randn, randint +from torch.nn import Module + +from einops import repeat + +# mock env + +class MockEnv(Module): + def __init__( + self, + image_shape, + reward_range = (-100., 100.), + batch_size = 1, + vectorized = False + ): + super().__init__() + self.image_shape = image_shape + self.reward_range = reward_range + + self.batch_size = batch_size + self.vectorized = vectorized + self.register_buffer('_step', tensor(0)) + + def get_random_state(self): + return randn(3, *self.image_shape) + + def reset( + self, + seed = None + ): + self._step.zero_() + return self.get_random_state() + + def step( + self, + actions, + ): + state = self.get_random_state() + reward = randint(*self.reward_range, ()).float() + + if self.vectorized: + state = repeat(state, '... -> b ...', b = self.batch_size) + reward = repeat(rewardstate, ' -> b', b = self.batch_size) + + return state, reward diff --git a/dreamer4/trainers.py b/dreamer4/trainers.py index b578f44..ce72f92 100644 --- a/dreamer4/trainers.py +++ b/dreamer4/trainers.py @@ -13,6 +13,8 @@ from dreamer4.dreamer4 import ( DynamicsWorldModel ) +from ema_pytorch import EMA + # helpers def exists(v): @@ -94,7 +96,6 @@ class VideoTokenizerTrainer(Module): def forward( self ): - iter_train_dl = cycle(self.train_dataloader) for _ in range(self.num_train_steps): diff --git a/pyproject.toml b/pyproject.toml index 3965716..63c5949 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,6 +31,7 @@ dependencies = [ "assoc-scan", "einx>=0.3.0", "einops>=0.8.1", + "ema-pytorch", "hl-gauss-pytorch", "hyper-connections>=0.2.1", "torch>=2.4",