readme
This commit is contained in:
parent
c0a6cd56a1
commit
38cba80068
27
README.md
27
README.md
@ -28,9 +28,16 @@ tokenizer = VideoTokenizer(
|
||||
image_width = 256
|
||||
)
|
||||
|
||||
video = torch.randn(2, 3, 10, 256, 256)
|
||||
|
||||
# learn the tokenizer
|
||||
|
||||
loss = tokenizer(video)
|
||||
loss.backward() # ler
|
||||
|
||||
# dynamics world model
|
||||
|
||||
dynamics = DynamicsWorldModel(
|
||||
world_model = DynamicsWorldModel(
|
||||
dim = 512,
|
||||
dim_latent = 32,
|
||||
video_tokenizer = tokenizer,
|
||||
@ -46,7 +53,7 @@ rewards = torch.randn(2, 10)
|
||||
|
||||
# learn dynamics / behavior cloned model
|
||||
|
||||
loss = dynamics(
|
||||
loss = world_model(
|
||||
video = video,
|
||||
rewards = rewards,
|
||||
discrete_actions = discrete_actions
|
||||
@ -58,7 +65,7 @@ loss.backward()
|
||||
|
||||
# then generate dreams
|
||||
|
||||
dreams = dynamics.generate(
|
||||
dreams = world_model.generate(
|
||||
10,
|
||||
batch_size = 2,
|
||||
return_decoded_video = True,
|
||||
@ -67,7 +74,19 @@ dreams = dynamics.generate(
|
||||
|
||||
# learn from the dreams
|
||||
|
||||
actor_loss, critic_loss = dynamics.learn_from_experience(dreams)
|
||||
actor_loss, critic_loss = world_model.learn_from_experience(dreams)
|
||||
|
||||
(actor_loss + critic_loss).backward()
|
||||
|
||||
# learn from environment
|
||||
|
||||
from dreamer4.mocks import MockEnv
|
||||
|
||||
mock_env = MockEnv((256, 256), vectorized = True, num_envs = 4)
|
||||
|
||||
experience = world_model.interact_with_env(mock_env, max_timesteps = 8, env_is_vectorized = True)
|
||||
|
||||
actor_loss, critic_loss = world_model.learn_from_experience(experience)
|
||||
|
||||
(actor_loss + critic_loss).backward()
|
||||
```
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user