diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index f21027b..fff3b4f 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -896,7 +896,7 @@ class VideoTokenizer(Module): # modality can only attend to itself while latents can attend to everything # similar to agent token in dynamics model - encoder_attend_fn = get_attend_fn(use_flex, seq_len, seq_len, special_attend_only_itself = True) + encoder_attend_fn = get_attend_fn(use_flex, seq_len, seq_len, special_attend_only_itself = False) # encoder @@ -937,7 +937,7 @@ class VideoTokenizer(Module): # decoder attend - decoder_attend_fn = get_attend_fn(use_flex, seq_len, seq_len) + decoder_attend_fn = get_attend_fn(use_flex, seq_len, seq_len, special_attend_only_itself = True) # decoder attention