prepare for the learning in dreams
This commit is contained in:
parent
e04f9ffec6
commit
ca700ba8e1
@ -16,6 +16,8 @@ from torchvision.models import VGG16_Weights
|
|||||||
from x_mlps_pytorch import create_mlp
|
from x_mlps_pytorch import create_mlp
|
||||||
from x_mlps_pytorch.ensemble import Ensemble
|
from x_mlps_pytorch.ensemble import Ensemble
|
||||||
|
|
||||||
|
from assoc_scan import AssocScan
|
||||||
|
|
||||||
from accelerate import Accelerator
|
from accelerate import Accelerator
|
||||||
|
|
||||||
# ein related
|
# ein related
|
||||||
@ -224,6 +226,34 @@ class SymExpTwoHot(Module):
|
|||||||
|
|
||||||
return inverse_pack(encoded, '* l')
|
return inverse_pack(encoded, '* l')
|
||||||
|
|
||||||
|
# generalized advantage estimate
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def calc_gae(
|
||||||
|
rewards,
|
||||||
|
values,
|
||||||
|
masks,
|
||||||
|
gamma = 0.99,
|
||||||
|
lam = 0.95,
|
||||||
|
use_accelerated = None
|
||||||
|
):
|
||||||
|
assert values.shape[-1] == rewards.shape[-1]
|
||||||
|
use_accelerated = default(use_accelerated, rewards.is_cuda)
|
||||||
|
|
||||||
|
values = F.pad(values, (0, 1), value = 0.)
|
||||||
|
values, values_next = values[..., :-1], values[..., 1:]
|
||||||
|
|
||||||
|
delta = rewards + gamma * values_next * masks - values
|
||||||
|
gates = gamma * lam * masks
|
||||||
|
|
||||||
|
scan = AssocScan(reverse = True, use_accelerated = use_accelerated)
|
||||||
|
|
||||||
|
gae = scan(gates, delta)
|
||||||
|
|
||||||
|
returns = gae + values
|
||||||
|
|
||||||
|
return returns
|
||||||
|
|
||||||
# golden gate rotary - Jerry Xiong, PhD student at UIUC
|
# golden gate rotary - Jerry Xiong, PhD student at UIUC
|
||||||
# https://jerryxio.ng/posts/nd-rope/
|
# https://jerryxio.ng/posts/nd-rope/
|
||||||
|
|
||||||
@ -1099,6 +1129,7 @@ class Dreamer(Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
video_tokenizer: VideoTokenizer,
|
video_tokenizer: VideoTokenizer,
|
||||||
dynamics_model: DynamicsModel
|
dynamics_model: DynamicsModel,
|
||||||
|
discount_factor = 0.9995
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user