readme
This commit is contained in:
parent
c0a6cd56a1
commit
38cba80068
27
README.md
27
README.md
@ -28,9 +28,16 @@ tokenizer = VideoTokenizer(
|
|||||||
image_width = 256
|
image_width = 256
|
||||||
)
|
)
|
||||||
|
|
||||||
|
video = torch.randn(2, 3, 10, 256, 256)
|
||||||
|
|
||||||
|
# learn the tokenizer
|
||||||
|
|
||||||
|
loss = tokenizer(video)
|
||||||
|
loss.backward() # ler
|
||||||
|
|
||||||
# dynamics world model
|
# dynamics world model
|
||||||
|
|
||||||
dynamics = DynamicsWorldModel(
|
world_model = DynamicsWorldModel(
|
||||||
dim = 512,
|
dim = 512,
|
||||||
dim_latent = 32,
|
dim_latent = 32,
|
||||||
video_tokenizer = tokenizer,
|
video_tokenizer = tokenizer,
|
||||||
@ -46,7 +53,7 @@ rewards = torch.randn(2, 10)
|
|||||||
|
|
||||||
# learn dynamics / behavior cloned model
|
# learn dynamics / behavior cloned model
|
||||||
|
|
||||||
loss = dynamics(
|
loss = world_model(
|
||||||
video = video,
|
video = video,
|
||||||
rewards = rewards,
|
rewards = rewards,
|
||||||
discrete_actions = discrete_actions
|
discrete_actions = discrete_actions
|
||||||
@ -58,7 +65,7 @@ loss.backward()
|
|||||||
|
|
||||||
# then generate dreams
|
# then generate dreams
|
||||||
|
|
||||||
dreams = dynamics.generate(
|
dreams = world_model.generate(
|
||||||
10,
|
10,
|
||||||
batch_size = 2,
|
batch_size = 2,
|
||||||
return_decoded_video = True,
|
return_decoded_video = True,
|
||||||
@ -67,7 +74,19 @@ dreams = dynamics.generate(
|
|||||||
|
|
||||||
# learn from the dreams
|
# learn from the dreams
|
||||||
|
|
||||||
actor_loss, critic_loss = dynamics.learn_from_experience(dreams)
|
actor_loss, critic_loss = world_model.learn_from_experience(dreams)
|
||||||
|
|
||||||
|
(actor_loss + critic_loss).backward()
|
||||||
|
|
||||||
|
# learn from environment
|
||||||
|
|
||||||
|
from dreamer4.mocks import MockEnv
|
||||||
|
|
||||||
|
mock_env = MockEnv((256, 256), vectorized = True, num_envs = 4)
|
||||||
|
|
||||||
|
experience = world_model.interact_with_env(mock_env, max_timesteps = 8, env_is_vectorized = True)
|
||||||
|
|
||||||
|
actor_loss, critic_loss = world_model.learn_from_experience(experience)
|
||||||
|
|
||||||
(actor_loss + critic_loss).backward()
|
(actor_loss + critic_loss).backward()
|
||||||
```
|
```
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user