allow for step_sizes to be passed in, log2 is not that intuitive
This commit is contained in:
parent
a8e14f4b7c
commit
36ccb08500
5
.github/workflows/test.yml
vendored
5
.github/workflows/test.yml
vendored
@ -14,8 +14,9 @@ jobs:
|
||||
python-version: "3.10"
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install -e .[test]
|
||||
python -m pip install uv
|
||||
python -m uv pip install --upgrade pip
|
||||
python -m uv pip install -e .[test]
|
||||
- name: Test with pytest
|
||||
run: |
|
||||
python -m pytest tests/
|
||||
|
||||
@ -1264,7 +1264,6 @@ class DynamicsModel(Module):
|
||||
# derive step size
|
||||
|
||||
step_size = self.max_steps // num_steps
|
||||
step_size_log2 = tensor(log2(step_size), dtype = torch.long, device = self.device)
|
||||
|
||||
# denoising
|
||||
# teacher forcing to start with
|
||||
@ -1288,7 +1287,8 @@ class DynamicsModel(Module):
|
||||
pred = self.forward(
|
||||
latents = noised_latent_with_context,
|
||||
signal_levels = signal_levels_with_context,
|
||||
step_sizes_log2 = step_size_log2,
|
||||
step_sizes = step_size,
|
||||
latent_is_noised = True,
|
||||
return_pred_only = True
|
||||
)
|
||||
|
||||
@ -1330,10 +1330,12 @@ class DynamicsModel(Module):
|
||||
video = None,
|
||||
latents = None, # (b t n d) | (b t d)
|
||||
signal_levels = None, # () | (b) | (b t)
|
||||
step_sizes = None, # () | (b)
|
||||
step_sizes_log2 = None, # () | (b)
|
||||
tasks = None, # (b)
|
||||
rewards = None, # (b t)
|
||||
return_pred_only = False,
|
||||
latent_is_noised = False,
|
||||
return_all_losses = False
|
||||
):
|
||||
# handle video or latents
|
||||
@ -1354,24 +1356,50 @@ class DynamicsModel(Module):
|
||||
|
||||
batch, time, device = *latents.shape[:2], latents.device
|
||||
|
||||
# shape related
|
||||
# signal and step size related input conforming
|
||||
|
||||
if exists(signal_levels):
|
||||
if isinstance(signal_levels, int):
|
||||
signal_levels = tensor(signal_levels, device = self.device)
|
||||
|
||||
if signal_levels.ndim == 0:
|
||||
signal_levels = repeat(signal_levels, '-> b', b = batch)
|
||||
|
||||
if signal_levels.ndim == 1:
|
||||
signal_levels = repeat(signal_levels, 'b -> b t', t = time)
|
||||
|
||||
if exists(step_sizes_log2) and step_sizes_log2.ndim == 0:
|
||||
if exists(step_sizes):
|
||||
if isinstance(step_sizes, int):
|
||||
step_sizes = tensor(step_sizes, device = self.device)
|
||||
|
||||
if step_sizes.ndim == 0:
|
||||
step_sizes = repeat(step_sizes, '-> b', b = batch)
|
||||
|
||||
if exists(step_sizes_log2):
|
||||
if isinstance(step_sizes_log2, int):
|
||||
step_sizes_log2 = tensor(step_sizes_log2, device = self.device)
|
||||
|
||||
if step_sizes_log2.ndim == 0:
|
||||
step_sizes_log2 = repeat(step_sizes_log2, '-> b', b = batch)
|
||||
|
||||
# handle step sizes -> step size log2
|
||||
|
||||
assert not (exists(step_sizes) and exists(step_sizes_log2))
|
||||
|
||||
if exists(step_sizes):
|
||||
step_sizes_log2_maybe_float = torch.log2(step_sizes)
|
||||
step_sizes_log2 = step_sizes_log2_maybe_float.long()
|
||||
|
||||
assert (step_sizes_log2 == step_sizes_log2_maybe_float).all(), f'`step_sizes` must be powers of 2'
|
||||
|
||||
# flow related
|
||||
|
||||
assert not (exists(signal_levels) ^ exists(step_sizes_log2))
|
||||
|
||||
is_inference = exists(signal_levels)
|
||||
return_pred_only = is_inference
|
||||
no_shortcut_train = not is_inference
|
||||
|
||||
return_pred_only = return_pred_only or latent_is_noised
|
||||
|
||||
# if neither signal levels or step sizes passed in, assume training
|
||||
# generate them randomly for training
|
||||
@ -1399,7 +1427,7 @@ class DynamicsModel(Module):
|
||||
|
||||
times = self.get_times_from_signal_level(signal_levels, latents)
|
||||
|
||||
if not is_inference:
|
||||
if not latent_is_noised:
|
||||
# get the noise
|
||||
|
||||
noise = torch.randn_like(latents)
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "dreamer4"
|
||||
version = "0.0.2"
|
||||
version = "0.0.3"
|
||||
description = "Dreamer 4"
|
||||
authors = [
|
||||
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
||||
|
||||
@ -67,8 +67,8 @@ def test_e2e(
|
||||
signal_levels = step_sizes_log2 = None
|
||||
|
||||
if signal_and_step_passed_in:
|
||||
signal_levels = torch.randint(0, 64, (2, 4))
|
||||
step_sizes_log2 = torch.randint(1, 6, (2,))
|
||||
signal_levels = torch.randint(0, 32, (2, 4))
|
||||
step_sizes_log2 = torch.randint(1, 5, (2,))
|
||||
|
||||
if dynamics_with_video_input:
|
||||
dynamics_input = dict(video = video)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user