for the value head, we will go for symexp encoding as well (following the "stop regressing" paper from Farebrother et al), also use layernormed mlp given recent papers
This commit is contained in:
parent
a50e360502
commit
c4e0f46528
@ -1,5 +1,5 @@
|
||||
from dreamer4.dreamer4 import (
|
||||
VideoTokenizer,
|
||||
DynamicsModel,
|
||||
DynamicsWorldModel,
|
||||
Dreamer
|
||||
)
|
||||
|
||||
@ -14,7 +14,7 @@ from torch import nn, cat, stack, arange, tensor, Tensor, is_tensor, zeros, ones
|
||||
import torchvision
|
||||
from torchvision.models import VGG16_Weights
|
||||
|
||||
from x_mlps_pytorch import create_mlp
|
||||
from x_mlps_pytorch.normed_mlp import create_mlp
|
||||
from x_mlps_pytorch.ensemble import Ensemble
|
||||
|
||||
from assoc_scan import AssocScan
|
||||
@ -176,17 +176,18 @@ def ramp_weight(times, slope = 0.9, intercept = 0.1):
|
||||
class SymExpTwoHot(Module):
|
||||
def __init__(
|
||||
self,
|
||||
range = (-20., 20.),
|
||||
reward_range = (-20., 20.),
|
||||
num_bins = 255,
|
||||
learned_embedding = False,
|
||||
dim_embed = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
min_value, max_value = range
|
||||
min_value, max_value = reward_range
|
||||
values = linspace(min_value, max_value, num_bins)
|
||||
values = values.sign() * (torch.exp(values.abs()) - 1.)
|
||||
|
||||
self.reward_range = reward_range
|
||||
self.num_bins = num_bins
|
||||
self.register_buffer('bin_values', values)
|
||||
|
||||
@ -1061,7 +1062,7 @@ class VideoTokenizer(Module):
|
||||
|
||||
# dynamics model, axial space-time transformer
|
||||
|
||||
class DynamicsModel(Module):
|
||||
class DynamicsWorldModel(Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
@ -1088,6 +1089,7 @@ class DynamicsModel(Module):
|
||||
add_reward_embed_to_agent_token = False,
|
||||
add_reward_embed_dropout = 0.1,
|
||||
reward_loss_weight = 0.1,
|
||||
value_head_mlp_depth = 3
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@ -1190,6 +1192,15 @@ class DynamicsModel(Module):
|
||||
|
||||
self.reward_loss_weight = reward_loss_weight
|
||||
|
||||
# value head
|
||||
|
||||
self.value_head = create_mlp(
|
||||
dim_in = dim,
|
||||
dim = dim * 4,
|
||||
dim_out = self.reward_encoder.num_bins,
|
||||
depth = value_head_mlp_depth,
|
||||
)
|
||||
|
||||
# attention
|
||||
|
||||
self.attn_softclamp_value = attn_softclamp_value
|
||||
|
||||
@ -20,7 +20,7 @@ def test_e2e(
|
||||
signal_and_step_passed_in,
|
||||
add_reward_embed_to_agent_token
|
||||
):
|
||||
from dreamer4.dreamer4 import VideoTokenizer, DynamicsModel
|
||||
from dreamer4.dreamer4 import VideoTokenizer, DynamicsWorldModel
|
||||
|
||||
tokenizer = VideoTokenizer(
|
||||
16,
|
||||
@ -45,7 +45,7 @@ def test_e2e(
|
||||
|
||||
query_heads, heads = (16, 4) if grouped_query_attn else (8, 8)
|
||||
|
||||
dynamics = DynamicsModel(
|
||||
dynamics = DynamicsWorldModel(
|
||||
dim = 16,
|
||||
video_tokenizer = tokenizer,
|
||||
dim_latent = 16,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user