allow for step_sizes to be passed in, log2 is not that intuitive

This commit is contained in:
lucidrains 2025-10-07 08:36:46 -07:00
parent a8e14f4b7c
commit 36ccb08500
4 changed files with 41 additions and 12 deletions

View File

@ -14,8 +14,9 @@ jobs:
python-version: "3.10"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install -e .[test]
python -m pip install uv
python -m uv pip install --upgrade pip
python -m uv pip install -e .[test]
- name: Test with pytest
run: |
python -m pytest tests/

View File

@ -1264,7 +1264,6 @@ class DynamicsModel(Module):
# derive step size
step_size = self.max_steps // num_steps
step_size_log2 = tensor(log2(step_size), dtype = torch.long, device = self.device)
# denoising
# teacher forcing to start with
@ -1288,7 +1287,8 @@ class DynamicsModel(Module):
pred = self.forward(
latents = noised_latent_with_context,
signal_levels = signal_levels_with_context,
step_sizes_log2 = step_size_log2,
step_sizes = step_size,
latent_is_noised = True,
return_pred_only = True
)
@ -1330,10 +1330,12 @@ class DynamicsModel(Module):
video = None,
latents = None, # (b t n d) | (b t d)
signal_levels = None, # () | (b) | (b t)
step_sizes = None, # () | (b)
step_sizes_log2 = None, # () | (b)
tasks = None, # (b)
rewards = None, # (b t)
return_pred_only = False,
latent_is_noised = False,
return_all_losses = False
):
# handle video or latents
@ -1354,24 +1356,50 @@ class DynamicsModel(Module):
batch, time, device = *latents.shape[:2], latents.device
# shape related
# signal and step size related input conforming
if exists(signal_levels):
if isinstance(signal_levels, int):
signal_levels = tensor(signal_levels, device = self.device)
if signal_levels.ndim == 0:
signal_levels = repeat(signal_levels, '-> b', b = batch)
if signal_levels.ndim == 1:
signal_levels = repeat(signal_levels, 'b -> b t', t = time)
if exists(step_sizes_log2) and step_sizes_log2.ndim == 0:
if exists(step_sizes):
if isinstance(step_sizes, int):
step_sizes = tensor(step_sizes, device = self.device)
if step_sizes.ndim == 0:
step_sizes = repeat(step_sizes, '-> b', b = batch)
if exists(step_sizes_log2):
if isinstance(step_sizes_log2, int):
step_sizes_log2 = tensor(step_sizes_log2, device = self.device)
if step_sizes_log2.ndim == 0:
step_sizes_log2 = repeat(step_sizes_log2, '-> b', b = batch)
# handle step sizes -> step size log2
assert not (exists(step_sizes) and exists(step_sizes_log2))
if exists(step_sizes):
step_sizes_log2_maybe_float = torch.log2(step_sizes)
step_sizes_log2 = step_sizes_log2_maybe_float.long()
assert (step_sizes_log2 == step_sizes_log2_maybe_float).all(), f'`step_sizes` must be powers of 2'
# flow related
assert not (exists(signal_levels) ^ exists(step_sizes_log2))
is_inference = exists(signal_levels)
return_pred_only = is_inference
no_shortcut_train = not is_inference
return_pred_only = return_pred_only or latent_is_noised
# if neither signal levels or step sizes passed in, assume training
# generate them randomly for training
@ -1399,7 +1427,7 @@ class DynamicsModel(Module):
times = self.get_times_from_signal_level(signal_levels, latents)
if not is_inference:
if not latent_is_noised:
# get the noise
noise = torch.randn_like(latents)

View File

@ -1,6 +1,6 @@
[project]
name = "dreamer4"
version = "0.0.2"
version = "0.0.3"
description = "Dreamer 4"
authors = [
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

View File

@ -67,8 +67,8 @@ def test_e2e(
signal_levels = step_sizes_log2 = None
if signal_and_step_passed_in:
signal_levels = torch.randint(0, 64, (2, 4))
step_sizes_log2 = torch.randint(1, 6, (2,))
signal_levels = torch.randint(0, 32, (2, 4))
step_sizes_log2 = torch.randint(1, 5, (2,))
if dynamics_with_video_input:
dynamics_input = dict(video = video)