for the value head, we will go for symexp encoding as well (following the "stop regressing" paper from Farebrother et al), also use layernormed mlp given recent papers

This commit is contained in:
lucidrains 2025-10-08 07:37:34 -07:00
parent a50e360502
commit c4e0f46528
3 changed files with 18 additions and 7 deletions

View File

@ -1,5 +1,5 @@
from dreamer4.dreamer4 import (
VideoTokenizer,
DynamicsModel,
DynamicsWorldModel,
Dreamer
)

View File

@ -14,7 +14,7 @@ from torch import nn, cat, stack, arange, tensor, Tensor, is_tensor, zeros, ones
import torchvision
from torchvision.models import VGG16_Weights
from x_mlps_pytorch import create_mlp
from x_mlps_pytorch.normed_mlp import create_mlp
from x_mlps_pytorch.ensemble import Ensemble
from assoc_scan import AssocScan
@ -176,17 +176,18 @@ def ramp_weight(times, slope = 0.9, intercept = 0.1):
class SymExpTwoHot(Module):
def __init__(
self,
range = (-20., 20.),
reward_range = (-20., 20.),
num_bins = 255,
learned_embedding = False,
dim_embed = None,
):
super().__init__()
min_value, max_value = range
min_value, max_value = reward_range
values = linspace(min_value, max_value, num_bins)
values = values.sign() * (torch.exp(values.abs()) - 1.)
self.reward_range = reward_range
self.num_bins = num_bins
self.register_buffer('bin_values', values)
@ -1061,7 +1062,7 @@ class VideoTokenizer(Module):
# dynamics model, axial space-time transformer
class DynamicsModel(Module):
class DynamicsWorldModel(Module):
def __init__(
self,
dim,
@ -1088,6 +1089,7 @@ class DynamicsModel(Module):
add_reward_embed_to_agent_token = False,
add_reward_embed_dropout = 0.1,
reward_loss_weight = 0.1,
value_head_mlp_depth = 3
):
super().__init__()
@ -1190,6 +1192,15 @@ class DynamicsModel(Module):
self.reward_loss_weight = reward_loss_weight
# value head
self.value_head = create_mlp(
dim_in = dim,
dim = dim * 4,
dim_out = self.reward_encoder.num_bins,
depth = value_head_mlp_depth,
)
# attention
self.attn_softclamp_value = attn_softclamp_value

View File

@ -20,7 +20,7 @@ def test_e2e(
signal_and_step_passed_in,
add_reward_embed_to_agent_token
):
from dreamer4.dreamer4 import VideoTokenizer, DynamicsModel
from dreamer4.dreamer4 import VideoTokenizer, DynamicsWorldModel
tokenizer = VideoTokenizer(
16,
@ -45,7 +45,7 @@ def test_e2e(
query_heads, heads = (16, 4) if grouped_query_attn else (8, 8)
dynamics = DynamicsModel(
dynamics = DynamicsWorldModel(
dim = 16,
video_tokenizer = tokenizer,
dim_latent = 16,