cleanup
This commit is contained in:
parent
0f4783f23c
commit
7cac3d28c5
@ -1,6 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import math
|
import math
|
||||||
|
from collections import namedtuple
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -49,6 +50,8 @@ except ImportError:
|
|||||||
|
|
||||||
LinearNoBias = partial(Linear, bias = False)
|
LinearNoBias = partial(Linear, bias = False)
|
||||||
|
|
||||||
|
TokenizerLosses = namedtuple('TokenizerLosses', ('recon', 'lpips'))
|
||||||
|
|
||||||
# helpers
|
# helpers
|
||||||
|
|
||||||
def exists(v):
|
def exists(v):
|
||||||
@ -780,7 +783,7 @@ class VideoTokenizer(Module):
|
|||||||
|
|
||||||
losses = (recon_loss, lpips_loss)
|
losses = (recon_loss, lpips_loss)
|
||||||
|
|
||||||
return total_loss, losses
|
return total_loss, TokenizerLosses(losses)
|
||||||
|
|
||||||
# dynamics model, axial space-time transformer
|
# dynamics model, axial space-time transformer
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user