allow for step_sizes to be passed in, log2 is not that intuitive
This commit is contained in:
parent
a8e14f4b7c
commit
36ccb08500
5
.github/workflows/test.yml
vendored
5
.github/workflows/test.yml
vendored
@ -14,8 +14,9 @@ jobs:
|
|||||||
python-version: "3.10"
|
python-version: "3.10"
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
python -m pip install --upgrade pip
|
python -m pip install uv
|
||||||
python -m pip install -e .[test]
|
python -m uv pip install --upgrade pip
|
||||||
|
python -m uv pip install -e .[test]
|
||||||
- name: Test with pytest
|
- name: Test with pytest
|
||||||
run: |
|
run: |
|
||||||
python -m pytest tests/
|
python -m pytest tests/
|
||||||
|
|||||||
@ -1264,7 +1264,6 @@ class DynamicsModel(Module):
|
|||||||
# derive step size
|
# derive step size
|
||||||
|
|
||||||
step_size = self.max_steps // num_steps
|
step_size = self.max_steps // num_steps
|
||||||
step_size_log2 = tensor(log2(step_size), dtype = torch.long, device = self.device)
|
|
||||||
|
|
||||||
# denoising
|
# denoising
|
||||||
# teacher forcing to start with
|
# teacher forcing to start with
|
||||||
@ -1288,7 +1287,8 @@ class DynamicsModel(Module):
|
|||||||
pred = self.forward(
|
pred = self.forward(
|
||||||
latents = noised_latent_with_context,
|
latents = noised_latent_with_context,
|
||||||
signal_levels = signal_levels_with_context,
|
signal_levels = signal_levels_with_context,
|
||||||
step_sizes_log2 = step_size_log2,
|
step_sizes = step_size,
|
||||||
|
latent_is_noised = True,
|
||||||
return_pred_only = True
|
return_pred_only = True
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1330,10 +1330,12 @@ class DynamicsModel(Module):
|
|||||||
video = None,
|
video = None,
|
||||||
latents = None, # (b t n d) | (b t d)
|
latents = None, # (b t n d) | (b t d)
|
||||||
signal_levels = None, # () | (b) | (b t)
|
signal_levels = None, # () | (b) | (b t)
|
||||||
|
step_sizes = None, # () | (b)
|
||||||
step_sizes_log2 = None, # () | (b)
|
step_sizes_log2 = None, # () | (b)
|
||||||
tasks = None, # (b)
|
tasks = None, # (b)
|
||||||
rewards = None, # (b t)
|
rewards = None, # (b t)
|
||||||
return_pred_only = False,
|
return_pred_only = False,
|
||||||
|
latent_is_noised = False,
|
||||||
return_all_losses = False
|
return_all_losses = False
|
||||||
):
|
):
|
||||||
# handle video or latents
|
# handle video or latents
|
||||||
@ -1354,24 +1356,50 @@ class DynamicsModel(Module):
|
|||||||
|
|
||||||
batch, time, device = *latents.shape[:2], latents.device
|
batch, time, device = *latents.shape[:2], latents.device
|
||||||
|
|
||||||
# shape related
|
# signal and step size related input conforming
|
||||||
|
|
||||||
if exists(signal_levels):
|
if exists(signal_levels):
|
||||||
|
if isinstance(signal_levels, int):
|
||||||
|
signal_levels = tensor(signal_levels, device = self.device)
|
||||||
|
|
||||||
if signal_levels.ndim == 0:
|
if signal_levels.ndim == 0:
|
||||||
signal_levels = repeat(signal_levels, '-> b', b = batch)
|
signal_levels = repeat(signal_levels, '-> b', b = batch)
|
||||||
|
|
||||||
if signal_levels.ndim == 1:
|
if signal_levels.ndim == 1:
|
||||||
signal_levels = repeat(signal_levels, 'b -> b t', t = time)
|
signal_levels = repeat(signal_levels, 'b -> b t', t = time)
|
||||||
|
|
||||||
if exists(step_sizes_log2) and step_sizes_log2.ndim == 0:
|
if exists(step_sizes):
|
||||||
step_sizes_log2 = repeat(step_sizes_log2, '-> b', b = batch)
|
if isinstance(step_sizes, int):
|
||||||
|
step_sizes = tensor(step_sizes, device = self.device)
|
||||||
|
|
||||||
|
if step_sizes.ndim == 0:
|
||||||
|
step_sizes = repeat(step_sizes, '-> b', b = batch)
|
||||||
|
|
||||||
|
if exists(step_sizes_log2):
|
||||||
|
if isinstance(step_sizes_log2, int):
|
||||||
|
step_sizes_log2 = tensor(step_sizes_log2, device = self.device)
|
||||||
|
|
||||||
|
if step_sizes_log2.ndim == 0:
|
||||||
|
step_sizes_log2 = repeat(step_sizes_log2, '-> b', b = batch)
|
||||||
|
|
||||||
|
# handle step sizes -> step size log2
|
||||||
|
|
||||||
|
assert not (exists(step_sizes) and exists(step_sizes_log2))
|
||||||
|
|
||||||
|
if exists(step_sizes):
|
||||||
|
step_sizes_log2_maybe_float = torch.log2(step_sizes)
|
||||||
|
step_sizes_log2 = step_sizes_log2_maybe_float.long()
|
||||||
|
|
||||||
|
assert (step_sizes_log2 == step_sizes_log2_maybe_float).all(), f'`step_sizes` must be powers of 2'
|
||||||
|
|
||||||
# flow related
|
# flow related
|
||||||
|
|
||||||
assert not (exists(signal_levels) ^ exists(step_sizes_log2))
|
assert not (exists(signal_levels) ^ exists(step_sizes_log2))
|
||||||
|
|
||||||
is_inference = exists(signal_levels)
|
is_inference = exists(signal_levels)
|
||||||
return_pred_only = is_inference
|
no_shortcut_train = not is_inference
|
||||||
|
|
||||||
|
return_pred_only = return_pred_only or latent_is_noised
|
||||||
|
|
||||||
# if neither signal levels or step sizes passed in, assume training
|
# if neither signal levels or step sizes passed in, assume training
|
||||||
# generate them randomly for training
|
# generate them randomly for training
|
||||||
@ -1399,7 +1427,7 @@ class DynamicsModel(Module):
|
|||||||
|
|
||||||
times = self.get_times_from_signal_level(signal_levels, latents)
|
times = self.get_times_from_signal_level(signal_levels, latents)
|
||||||
|
|
||||||
if not is_inference:
|
if not latent_is_noised:
|
||||||
# get the noise
|
# get the noise
|
||||||
|
|
||||||
noise = torch.randn_like(latents)
|
noise = torch.randn_like(latents)
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "dreamer4"
|
name = "dreamer4"
|
||||||
version = "0.0.2"
|
version = "0.0.3"
|
||||||
description = "Dreamer 4"
|
description = "Dreamer 4"
|
||||||
authors = [
|
authors = [
|
||||||
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
||||||
|
|||||||
@ -67,8 +67,8 @@ def test_e2e(
|
|||||||
signal_levels = step_sizes_log2 = None
|
signal_levels = step_sizes_log2 = None
|
||||||
|
|
||||||
if signal_and_step_passed_in:
|
if signal_and_step_passed_in:
|
||||||
signal_levels = torch.randint(0, 64, (2, 4))
|
signal_levels = torch.randint(0, 32, (2, 4))
|
||||||
step_sizes_log2 = torch.randint(1, 6, (2,))
|
step_sizes_log2 = torch.randint(1, 5, (2,))
|
||||||
|
|
||||||
if dynamics_with_video_input:
|
if dynamics_with_video_input:
|
||||||
dynamics_input = dict(video = video)
|
dynamics_input = dict(video = video)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user