prepare for the learning in dreams

This commit is contained in:
lucidrains 2025-10-04 09:44:46 -07:00
parent e04f9ffec6
commit ca700ba8e1

View File

@ -16,6 +16,8 @@ from torchvision.models import VGG16_Weights
from x_mlps_pytorch import create_mlp
from x_mlps_pytorch.ensemble import Ensemble
from assoc_scan import AssocScan
from accelerate import Accelerator
# ein related
@ -224,6 +226,34 @@ class SymExpTwoHot(Module):
return inverse_pack(encoded, '* l')
# generalized advantage estimate
@torch.no_grad()
def calc_gae(
rewards,
values,
masks,
gamma = 0.99,
lam = 0.95,
use_accelerated = None
):
assert values.shape[-1] == rewards.shape[-1]
use_accelerated = default(use_accelerated, rewards.is_cuda)
values = F.pad(values, (0, 1), value = 0.)
values, values_next = values[..., :-1], values[..., 1:]
delta = rewards + gamma * values_next * masks - values
gates = gamma * lam * masks
scan = AssocScan(reverse = True, use_accelerated = use_accelerated)
gae = scan(gates, delta)
returns = gae + values
return returns
# golden gate rotary - Jerry Xiong, PhD student at UIUC
# https://jerryxio.ng/posts/nd-rope/
@ -1099,6 +1129,7 @@ class Dreamer(Module):
def __init__(
self,
video_tokenizer: VideoTokenizer,
dynamics_model: DynamicsModel
dynamics_model: DynamicsModel,
discount_factor = 0.9995
):
super().__init__()