prepare a mock for interacting with online env
This commit is contained in:
parent
ea13d4fcab
commit
7ba3988fb9
49
dreamer4/mocks.py
Normal file
49
dreamer4/mocks.py
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
from random import choice
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import tensor, randn, randint
|
||||||
|
from torch.nn import Module
|
||||||
|
|
||||||
|
from einops import repeat
|
||||||
|
|
||||||
|
# mock env
|
||||||
|
|
||||||
|
class MockEnv(Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
image_shape,
|
||||||
|
reward_range = (-100., 100.),
|
||||||
|
batch_size = 1,
|
||||||
|
vectorized = False
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.image_shape = image_shape
|
||||||
|
self.reward_range = reward_range
|
||||||
|
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.vectorized = vectorized
|
||||||
|
self.register_buffer('_step', tensor(0))
|
||||||
|
|
||||||
|
def get_random_state(self):
|
||||||
|
return randn(3, *self.image_shape)
|
||||||
|
|
||||||
|
def reset(
|
||||||
|
self,
|
||||||
|
seed = None
|
||||||
|
):
|
||||||
|
self._step.zero_()
|
||||||
|
return self.get_random_state()
|
||||||
|
|
||||||
|
def step(
|
||||||
|
self,
|
||||||
|
actions,
|
||||||
|
):
|
||||||
|
state = self.get_random_state()
|
||||||
|
reward = randint(*self.reward_range, ()).float()
|
||||||
|
|
||||||
|
if self.vectorized:
|
||||||
|
state = repeat(state, '... -> b ...', b = self.batch_size)
|
||||||
|
reward = repeat(rewardstate, ' -> b', b = self.batch_size)
|
||||||
|
|
||||||
|
return state, reward
|
||||||
@ -13,6 +13,8 @@ from dreamer4.dreamer4 import (
|
|||||||
DynamicsWorldModel
|
DynamicsWorldModel
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from ema_pytorch import EMA
|
||||||
|
|
||||||
# helpers
|
# helpers
|
||||||
|
|
||||||
def exists(v):
|
def exists(v):
|
||||||
@ -94,7 +96,6 @@ class VideoTokenizerTrainer(Module):
|
|||||||
def forward(
|
def forward(
|
||||||
self
|
self
|
||||||
):
|
):
|
||||||
|
|
||||||
iter_train_dl = cycle(self.train_dataloader)
|
iter_train_dl = cycle(self.train_dataloader)
|
||||||
|
|
||||||
for _ in range(self.num_train_steps):
|
for _ in range(self.num_train_steps):
|
||||||
|
|||||||
@ -31,6 +31,7 @@ dependencies = [
|
|||||||
"assoc-scan",
|
"assoc-scan",
|
||||||
"einx>=0.3.0",
|
"einx>=0.3.0",
|
||||||
"einops>=0.8.1",
|
"einops>=0.8.1",
|
||||||
|
"ema-pytorch",
|
||||||
"hl-gauss-pytorch",
|
"hl-gauss-pytorch",
|
||||||
"hyper-connections>=0.2.1",
|
"hyper-connections>=0.2.1",
|
||||||
"torch>=2.4",
|
"torch>=2.4",
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user