52 lines
1.4 KiB
Python
52 lines
1.4 KiB
Python
from typing import Dict
|
|
|
|
import torch
|
|
import torch.nn
|
|
from diffusion_policy.model.common.normalizer import LinearNormalizer
|
|
|
|
class BaseLowdimDataset(torch.utils.data.Dataset):
|
|
def get_validation_dataset(self) -> 'BaseLowdimDataset':
|
|
# return an empty dataset by default
|
|
return BaseLowdimDataset()
|
|
|
|
def get_normalizer(self, **kwargs) -> LinearNormalizer:
|
|
raise NotImplementedError()
|
|
|
|
def get_all_actions(self) -> torch.Tensor:
|
|
raise NotImplementedError()
|
|
|
|
def __len__(self) -> int:
|
|
return 0
|
|
|
|
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
|
|
"""
|
|
output:
|
|
obs: T, Do
|
|
action: T, Da
|
|
"""
|
|
raise NotImplementedError()
|
|
|
|
|
|
class BaseImageDataset(torch.utils.data.Dataset):
|
|
def get_validation_dataset(self) -> 'BaseLowdimDataset':
|
|
# return an empty dataset by default
|
|
return BaseImageDataset()
|
|
|
|
def get_normalizer(self, **kwargs) -> LinearNormalizer:
|
|
raise NotImplementedError()
|
|
|
|
def get_all_actions(self) -> torch.Tensor:
|
|
raise NotImplementedError()
|
|
|
|
def __len__(self) -> int:
|
|
return 0
|
|
|
|
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
|
|
"""
|
|
output:
|
|
obs:
|
|
key: T, *
|
|
action: T, Da
|
|
"""
|
|
raise NotImplementedError()
|