diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index c152ada..972833b 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -1,8 +1,22 @@ +from __future__ import annotations + +import math +from functools import partial + import torch import torch.nn.functional as F -from torch.nn import Module, ModuleList, RMSNorm, Identity +from torch.nn import Module, ModuleList, Linear, RMSNorm, Identity from torch import cat, stack, tensor, Tensor, is_tensor +# ein related + +from einops import einsum, rearrange, repeat, reduce +from einops.layers.torch import Rearrange + +# constants + +LinearNoBias = partial(Linear, bias = False) + # helpers def exists(v):