mpd-public/mpd/losses/gaussian_diffusion_loss.py
2023-10-23 15:45:14 +02:00

26 lines
660 B
Python

import torch
from mpd.models import build_context
class GaussianDiffusionLoss:
def __init__(self):
pass
@staticmethod
def loss_fn(diffusion_model, input_dict, dataset, step=None):
"""
Loss function for training diffusion-based generative models.
"""
traj_normalized = input_dict[f'{dataset.field_key_traj}_normalized']
context = build_context(diffusion_model, dataset, input_dict)
hard_conds = input_dict.get('hard_conds', {})
loss, info = diffusion_model.loss(traj_normalized, context, hard_conds)
loss_dict = {'diffusion_loss': loss}
return loss_dict, info