From ca700ba8e169050cf9c6d6330fac5623e56f8c53 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Sat, 4 Oct 2025 09:44:46 -0700 Subject: [PATCH] prepare for the learning in dreams --- dreamer4/dreamer4.py | 33 ++++++++++++++++++++++++++++++++- 1 file changed, 32 insertions(+), 1 deletion(-) diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index b058962..ef4b058 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -16,6 +16,8 @@ from torchvision.models import VGG16_Weights from x_mlps_pytorch import create_mlp from x_mlps_pytorch.ensemble import Ensemble +from assoc_scan import AssocScan + from accelerate import Accelerator # ein related @@ -224,6 +226,34 @@ class SymExpTwoHot(Module): return inverse_pack(encoded, '* l') +# generalized advantage estimate + +@torch.no_grad() +def calc_gae( + rewards, + values, + masks, + gamma = 0.99, + lam = 0.95, + use_accelerated = None +): + assert values.shape[-1] == rewards.shape[-1] + use_accelerated = default(use_accelerated, rewards.is_cuda) + + values = F.pad(values, (0, 1), value = 0.) + values, values_next = values[..., :-1], values[..., 1:] + + delta = rewards + gamma * values_next * masks - values + gates = gamma * lam * masks + + scan = AssocScan(reverse = True, use_accelerated = use_accelerated) + + gae = scan(gates, delta) + + returns = gae + values + + return returns + # golden gate rotary - Jerry Xiong, PhD student at UIUC # https://jerryxio.ng/posts/nd-rope/ @@ -1099,6 +1129,7 @@ class Dreamer(Module): def __init__( self, video_tokenizer: VideoTokenizer, - dynamics_model: DynamicsModel + dynamics_model: DynamicsModel, + discount_factor = 0.9995 ): super().__init__()