for the value head, we will go for symexp encoding as well (following the "stop regressing" paper from Farebrother et al), also use layernormed mlp given recent papers

This commit is contained in:
lucidrains 2025-10-08 07:37:34 -07:00
parent a50e360502
commit c4e0f46528
3 changed files with 18 additions and 7 deletions

View File

@ -1,5 +1,5 @@
from dreamer4.dreamer4 import ( from dreamer4.dreamer4 import (
VideoTokenizer, VideoTokenizer,
DynamicsModel, DynamicsWorldModel,
Dreamer Dreamer
) )

View File

@ -14,7 +14,7 @@ from torch import nn, cat, stack, arange, tensor, Tensor, is_tensor, zeros, ones
import torchvision import torchvision
from torchvision.models import VGG16_Weights from torchvision.models import VGG16_Weights
from x_mlps_pytorch import create_mlp from x_mlps_pytorch.normed_mlp import create_mlp
from x_mlps_pytorch.ensemble import Ensemble from x_mlps_pytorch.ensemble import Ensemble
from assoc_scan import AssocScan from assoc_scan import AssocScan
@ -176,17 +176,18 @@ def ramp_weight(times, slope = 0.9, intercept = 0.1):
class SymExpTwoHot(Module): class SymExpTwoHot(Module):
def __init__( def __init__(
self, self,
range = (-20., 20.), reward_range = (-20., 20.),
num_bins = 255, num_bins = 255,
learned_embedding = False, learned_embedding = False,
dim_embed = None, dim_embed = None,
): ):
super().__init__() super().__init__()
min_value, max_value = range min_value, max_value = reward_range
values = linspace(min_value, max_value, num_bins) values = linspace(min_value, max_value, num_bins)
values = values.sign() * (torch.exp(values.abs()) - 1.) values = values.sign() * (torch.exp(values.abs()) - 1.)
self.reward_range = reward_range
self.num_bins = num_bins self.num_bins = num_bins
self.register_buffer('bin_values', values) self.register_buffer('bin_values', values)
@ -1061,7 +1062,7 @@ class VideoTokenizer(Module):
# dynamics model, axial space-time transformer # dynamics model, axial space-time transformer
class DynamicsModel(Module): class DynamicsWorldModel(Module):
def __init__( def __init__(
self, self,
dim, dim,
@ -1088,6 +1089,7 @@ class DynamicsModel(Module):
add_reward_embed_to_agent_token = False, add_reward_embed_to_agent_token = False,
add_reward_embed_dropout = 0.1, add_reward_embed_dropout = 0.1,
reward_loss_weight = 0.1, reward_loss_weight = 0.1,
value_head_mlp_depth = 3
): ):
super().__init__() super().__init__()
@ -1190,6 +1192,15 @@ class DynamicsModel(Module):
self.reward_loss_weight = reward_loss_weight self.reward_loss_weight = reward_loss_weight
# value head
self.value_head = create_mlp(
dim_in = dim,
dim = dim * 4,
dim_out = self.reward_encoder.num_bins,
depth = value_head_mlp_depth,
)
# attention # attention
self.attn_softclamp_value = attn_softclamp_value self.attn_softclamp_value = attn_softclamp_value

View File

@ -20,7 +20,7 @@ def test_e2e(
signal_and_step_passed_in, signal_and_step_passed_in,
add_reward_embed_to_agent_token add_reward_embed_to_agent_token
): ):
from dreamer4.dreamer4 import VideoTokenizer, DynamicsModel from dreamer4.dreamer4 import VideoTokenizer, DynamicsWorldModel
tokenizer = VideoTokenizer( tokenizer = VideoTokenizer(
16, 16,
@ -45,7 +45,7 @@ def test_e2e(
query_heads, heads = (16, 4) if grouped_query_attn else (8, 8) query_heads, heads = (16, 4) if grouped_query_attn else (8, 8)
dynamics = DynamicsModel( dynamics = DynamicsWorldModel(
dim = 16, dim = 16,
video_tokenizer = tokenizer, video_tokenizer = tokenizer,
dim_latent = 16, dim_latent = 16,