prepare for the learning in dreams
This commit is contained in:
parent
e04f9ffec6
commit
ca700ba8e1
@ -16,6 +16,8 @@ from torchvision.models import VGG16_Weights
|
||||
from x_mlps_pytorch import create_mlp
|
||||
from x_mlps_pytorch.ensemble import Ensemble
|
||||
|
||||
from assoc_scan import AssocScan
|
||||
|
||||
from accelerate import Accelerator
|
||||
|
||||
# ein related
|
||||
@ -224,6 +226,34 @@ class SymExpTwoHot(Module):
|
||||
|
||||
return inverse_pack(encoded, '* l')
|
||||
|
||||
# generalized advantage estimate
|
||||
|
||||
@torch.no_grad()
|
||||
def calc_gae(
|
||||
rewards,
|
||||
values,
|
||||
masks,
|
||||
gamma = 0.99,
|
||||
lam = 0.95,
|
||||
use_accelerated = None
|
||||
):
|
||||
assert values.shape[-1] == rewards.shape[-1]
|
||||
use_accelerated = default(use_accelerated, rewards.is_cuda)
|
||||
|
||||
values = F.pad(values, (0, 1), value = 0.)
|
||||
values, values_next = values[..., :-1], values[..., 1:]
|
||||
|
||||
delta = rewards + gamma * values_next * masks - values
|
||||
gates = gamma * lam * masks
|
||||
|
||||
scan = AssocScan(reverse = True, use_accelerated = use_accelerated)
|
||||
|
||||
gae = scan(gates, delta)
|
||||
|
||||
returns = gae + values
|
||||
|
||||
return returns
|
||||
|
||||
# golden gate rotary - Jerry Xiong, PhD student at UIUC
|
||||
# https://jerryxio.ng/posts/nd-rope/
|
||||
|
||||
@ -1099,6 +1129,7 @@ class Dreamer(Module):
|
||||
def __init__(
|
||||
self,
|
||||
video_tokenizer: VideoTokenizer,
|
||||
dynamics_model: DynamicsModel
|
||||
dynamics_model: DynamicsModel,
|
||||
discount_factor = 0.9995
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user