diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index 07643b7..be2aa43 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -291,16 +291,29 @@ class SwiGLUFeedforward(Module): class VideoTokenizer(Module): def __init__( - self + self, + dim, + dim_latent ): super().__init__() + self.encoded_to_latents = Sequential( + LinearNoBias(dim, dim_latent), + nn.Tanh(), + ) + + self.latents_to_decoder = LinearNoBias(dim_latent, dim) + +# dynamics model + class DynamicsModel(Module): def __init__( self ): super().__init__() +# dreamer + class Dreamer(Module): def __init__( self,