This commit is contained in:
lucidrains 2025-11-04 06:05:11 -08:00
parent c0a6cd56a1
commit 38cba80068

View File

@ -28,9 +28,16 @@ tokenizer = VideoTokenizer(
image_width = 256
)
video = torch.randn(2, 3, 10, 256, 256)
# learn the tokenizer
loss = tokenizer(video)
loss.backward() # ler
# dynamics world model
dynamics = DynamicsWorldModel(
world_model = DynamicsWorldModel(
dim = 512,
dim_latent = 32,
video_tokenizer = tokenizer,
@ -46,7 +53,7 @@ rewards = torch.randn(2, 10)
# learn dynamics / behavior cloned model
loss = dynamics(
loss = world_model(
video = video,
rewards = rewards,
discrete_actions = discrete_actions
@ -58,7 +65,7 @@ loss.backward()
# then generate dreams
dreams = dynamics.generate(
dreams = world_model.generate(
10,
batch_size = 2,
return_decoded_video = True,
@ -67,7 +74,19 @@ dreams = dynamics.generate(
# learn from the dreams
actor_loss, critic_loss = dynamics.learn_from_experience(dreams)
actor_loss, critic_loss = world_model.learn_from_experience(dreams)
(actor_loss + critic_loss).backward()
# learn from environment
from dreamer4.mocks import MockEnv
mock_env = MockEnv((256, 256), vectorized = True, num_envs = 4)
experience = world_model.interact_with_env(mock_env, max_timesteps = 8, env_is_vectorized = True)
actor_loss, critic_loss = world_model.learn_from_experience(experience)
(actor_loss + critic_loss).backward()
```