diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index 067fd66..290c409 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -1,6 +1,7 @@ from __future__ import annotations import math +from collections import namedtuple from functools import partial import torch @@ -49,6 +50,8 @@ except ImportError: LinearNoBias = partial(Linear, bias = False) +TokenizerLosses = namedtuple('TokenizerLosses', ('recon', 'lpips')) + # helpers def exists(v): @@ -780,7 +783,7 @@ class VideoTokenizer(Module): losses = (recon_loss, lpips_loss) - return total_loss, losses + return total_loss, TokenizerLosses(losses) # dynamics model, axial space-time transformer