prepare a mock for interacting with online env

This commit is contained in:
lucidrains 2025-10-21 09:03:20 -07:00
parent ea13d4fcab
commit 7ba3988fb9
3 changed files with 52 additions and 1 deletions

49
dreamer4/mocks.py Normal file
View File

@ -0,0 +1,49 @@
from __future__ import annotations
from random import choice
import torch
from torch import tensor, randn, randint
from torch.nn import Module
from einops import repeat
# mock env
class MockEnv(Module):
def __init__(
self,
image_shape,
reward_range = (-100., 100.),
batch_size = 1,
vectorized = False
):
super().__init__()
self.image_shape = image_shape
self.reward_range = reward_range
self.batch_size = batch_size
self.vectorized = vectorized
self.register_buffer('_step', tensor(0))
def get_random_state(self):
return randn(3, *self.image_shape)
def reset(
self,
seed = None
):
self._step.zero_()
return self.get_random_state()
def step(
self,
actions,
):
state = self.get_random_state()
reward = randint(*self.reward_range, ()).float()
if self.vectorized:
state = repeat(state, '... -> b ...', b = self.batch_size)
reward = repeat(rewardstate, ' -> b', b = self.batch_size)
return state, reward

View File

@ -13,6 +13,8 @@ from dreamer4.dreamer4 import (
DynamicsWorldModel
)
from ema_pytorch import EMA
# helpers
def exists(v):
@ -94,7 +96,6 @@ class VideoTokenizerTrainer(Module):
def forward(
self
):
iter_train_dl = cycle(self.train_dataloader)
for _ in range(self.num_train_steps):

View File

@ -31,6 +31,7 @@ dependencies = [
"assoc-scan",
"einx>=0.3.0",
"einops>=0.8.1",
"ema-pytorch",
"hl-gauss-pytorch",
"hyper-connections>=0.2.1",
"torch>=2.4",