prepare a mock for interacting with online env
This commit is contained in:
parent
ea13d4fcab
commit
7ba3988fb9
49
dreamer4/mocks.py
Normal file
49
dreamer4/mocks.py
Normal file
@ -0,0 +1,49 @@
|
||||
from __future__ import annotations
|
||||
from random import choice
|
||||
|
||||
import torch
|
||||
from torch import tensor, randn, randint
|
||||
from torch.nn import Module
|
||||
|
||||
from einops import repeat
|
||||
|
||||
# mock env
|
||||
|
||||
class MockEnv(Module):
|
||||
def __init__(
|
||||
self,
|
||||
image_shape,
|
||||
reward_range = (-100., 100.),
|
||||
batch_size = 1,
|
||||
vectorized = False
|
||||
):
|
||||
super().__init__()
|
||||
self.image_shape = image_shape
|
||||
self.reward_range = reward_range
|
||||
|
||||
self.batch_size = batch_size
|
||||
self.vectorized = vectorized
|
||||
self.register_buffer('_step', tensor(0))
|
||||
|
||||
def get_random_state(self):
|
||||
return randn(3, *self.image_shape)
|
||||
|
||||
def reset(
|
||||
self,
|
||||
seed = None
|
||||
):
|
||||
self._step.zero_()
|
||||
return self.get_random_state()
|
||||
|
||||
def step(
|
||||
self,
|
||||
actions,
|
||||
):
|
||||
state = self.get_random_state()
|
||||
reward = randint(*self.reward_range, ()).float()
|
||||
|
||||
if self.vectorized:
|
||||
state = repeat(state, '... -> b ...', b = self.batch_size)
|
||||
reward = repeat(rewardstate, ' -> b', b = self.batch_size)
|
||||
|
||||
return state, reward
|
||||
@ -13,6 +13,8 @@ from dreamer4.dreamer4 import (
|
||||
DynamicsWorldModel
|
||||
)
|
||||
|
||||
from ema_pytorch import EMA
|
||||
|
||||
# helpers
|
||||
|
||||
def exists(v):
|
||||
@ -94,7 +96,6 @@ class VideoTokenizerTrainer(Module):
|
||||
def forward(
|
||||
self
|
||||
):
|
||||
|
||||
iter_train_dl = cycle(self.train_dataloader)
|
||||
|
||||
for _ in range(self.num_train_steps):
|
||||
|
||||
@ -31,6 +31,7 @@ dependencies = [
|
||||
"assoc-scan",
|
||||
"einx>=0.3.0",
|
||||
"einops>=0.8.1",
|
||||
"ema-pytorch",
|
||||
"hl-gauss-pytorch",
|
||||
"hyper-connections>=0.2.1",
|
||||
"torch>=2.4",
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user