From c4e0f46528d0a5ef9e1b58a9a743f387e2886630 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Wed, 8 Oct 2025 07:37:34 -0700 Subject: [PATCH] for the value head, we will go for symexp encoding as well (following the "stop regressing" paper from Farebrother et al), also use layernormed mlp given recent papers --- dreamer4/__init__.py | 2 +- dreamer4/dreamer4.py | 19 +++++++++++++++---- tests/test_dreamer.py | 4 ++-- 3 files changed, 18 insertions(+), 7 deletions(-) diff --git a/dreamer4/__init__.py b/dreamer4/__init__.py index 59a245a..eeb1f2c 100644 --- a/dreamer4/__init__.py +++ b/dreamer4/__init__.py @@ -1,5 +1,5 @@ from dreamer4.dreamer4 import ( VideoTokenizer, - DynamicsModel, + DynamicsWorldModel, Dreamer ) diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index 9443e4b..9023cac 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -14,7 +14,7 @@ from torch import nn, cat, stack, arange, tensor, Tensor, is_tensor, zeros, ones import torchvision from torchvision.models import VGG16_Weights -from x_mlps_pytorch import create_mlp +from x_mlps_pytorch.normed_mlp import create_mlp from x_mlps_pytorch.ensemble import Ensemble from assoc_scan import AssocScan @@ -176,17 +176,18 @@ def ramp_weight(times, slope = 0.9, intercept = 0.1): class SymExpTwoHot(Module): def __init__( self, - range = (-20., 20.), + reward_range = (-20., 20.), num_bins = 255, learned_embedding = False, dim_embed = None, ): super().__init__() - min_value, max_value = range + min_value, max_value = reward_range values = linspace(min_value, max_value, num_bins) values = values.sign() * (torch.exp(values.abs()) - 1.) + self.reward_range = reward_range self.num_bins = num_bins self.register_buffer('bin_values', values) @@ -1061,7 +1062,7 @@ class VideoTokenizer(Module): # dynamics model, axial space-time transformer -class DynamicsModel(Module): +class DynamicsWorldModel(Module): def __init__( self, dim, @@ -1088,6 +1089,7 @@ class DynamicsModel(Module): add_reward_embed_to_agent_token = False, add_reward_embed_dropout = 0.1, reward_loss_weight = 0.1, + value_head_mlp_depth = 3 ): super().__init__() @@ -1190,6 +1192,15 @@ class DynamicsModel(Module): self.reward_loss_weight = reward_loss_weight + # value head + + self.value_head = create_mlp( + dim_in = dim, + dim = dim * 4, + dim_out = self.reward_encoder.num_bins, + depth = value_head_mlp_depth, + ) + # attention self.attn_softclamp_value = attn_softclamp_value diff --git a/tests/test_dreamer.py b/tests/test_dreamer.py index 671f6dc..9e9e01e 100644 --- a/tests/test_dreamer.py +++ b/tests/test_dreamer.py @@ -20,7 +20,7 @@ def test_e2e( signal_and_step_passed_in, add_reward_embed_to_agent_token ): - from dreamer4.dreamer4 import VideoTokenizer, DynamicsModel + from dreamer4.dreamer4 import VideoTokenizer, DynamicsWorldModel tokenizer = VideoTokenizer( 16, @@ -45,7 +45,7 @@ def test_e2e( query_heads, heads = (16, 4) if grouped_query_attn else (8, 8) - dynamics = DynamicsModel( + dynamics = DynamicsWorldModel( dim = 16, video_tokenizer = tokenizer, dim_latent = 16,