196 lines
5.5 KiB
Python
196 lines
5.5 KiB
Python
|
"""
|
||
|
From Michael Janner
|
||
|
"""
|
||
|
import einops
|
||
|
import torch
|
||
|
|
||
|
|
||
|
#-----------------------------------------------------------------------------#
|
||
|
#--------------------------- multi-field normalizer --------------------------#
|
||
|
#-----------------------------------------------------------------------------#
|
||
|
|
||
|
class DatasetNormalizer:
|
||
|
|
||
|
def __init__(self, dataset, normalizer):
|
||
|
dataset = flatten(dataset)
|
||
|
|
||
|
if type(normalizer) == str:
|
||
|
normalizer = eval(normalizer)
|
||
|
|
||
|
self.normalizers = {}
|
||
|
for key, val in dataset.items():
|
||
|
self.normalizers[key] = normalizer(val)
|
||
|
# try:
|
||
|
# self.normalizers[key] = normalizer(val)
|
||
|
# except:
|
||
|
# print(f'[ utils/normalization ] Skipping {key} | {normalizer}')
|
||
|
|
||
|
def __repr__(self):
|
||
|
string = ''
|
||
|
for key, normalizer in self.normalizers.items():
|
||
|
string += f'{key}: {normalizer}]\n'
|
||
|
return string
|
||
|
|
||
|
def __call__(self, *args, **kwargs):
|
||
|
return self.normalize(*args, **kwargs)
|
||
|
|
||
|
def normalize(self, x, key):
|
||
|
return self.normalizers[key].normalize(x)
|
||
|
|
||
|
def unnormalize(self, x, key):
|
||
|
return self.normalizers[key].unnormalize(x)
|
||
|
|
||
|
def get_field_normalizers(self):
|
||
|
return self.normalizers
|
||
|
|
||
|
|
||
|
# def flatten(dataset):
|
||
|
# '''
|
||
|
# flattens dataset of { key: [ batch x length x dim ] }
|
||
|
# to { key : [ (batch * length) x dim ]}
|
||
|
# '''
|
||
|
# flattened = {}
|
||
|
# for key, xs in dataset.items():
|
||
|
# flattened[key] = einops.rearrange(xs, 'b h d -> (b h) d')
|
||
|
# return flattened
|
||
|
|
||
|
|
||
|
def flatten(dataset):
|
||
|
'''
|
||
|
flattens dataset of { key: [ ... x dim ] }
|
||
|
to { key : [ (...) x dim ]}
|
||
|
'''
|
||
|
flattened = {}
|
||
|
for key, xs in dataset.items():
|
||
|
xs_new = xs
|
||
|
if xs.ndim == 2:
|
||
|
# environments (e d)
|
||
|
pass
|
||
|
elif xs.ndim == 3:
|
||
|
# trajectories in fixed environments
|
||
|
xs_new = einops.rearrange(xs, 'b h d -> (b h) d')
|
||
|
elif xs.ndim == 4:
|
||
|
# trajectories in variable environments
|
||
|
xs_new = einops.rearrange(xs, 'e b h d -> (e b h) d')
|
||
|
else:
|
||
|
raise NotImplementedError
|
||
|
flattened[key] = xs_new
|
||
|
return flattened
|
||
|
|
||
|
|
||
|
|
||
|
#-----------------------------------------------------------------------------#
|
||
|
#-------------------------- single-field normalizers -------------------------#
|
||
|
#-----------------------------------------------------------------------------#
|
||
|
class Normalizer:
|
||
|
'''
|
||
|
parent class, subclass by defining the `normalize` and `unnormalize` methods
|
||
|
'''
|
||
|
|
||
|
def __init__(self, X):
|
||
|
self.X = X
|
||
|
self.mins = X.min(dim=0).values
|
||
|
self.maxs = X.max(dim=0).values
|
||
|
|
||
|
def __repr__(self):
|
||
|
return (
|
||
|
f'''[ Normalizer ] dim: {self.mins.size}\n -: '''
|
||
|
f'''{torch.round(self.mins, decimals=2)}\n +: {torch.round(self.maxs, decimals=2)}\n'''
|
||
|
)
|
||
|
|
||
|
def __call__(self, x):
|
||
|
return self.normalize(x)
|
||
|
|
||
|
def normalize(self, *args, **kwargs):
|
||
|
raise NotImplementedError()
|
||
|
|
||
|
def unnormalize(self, *args, **kwargs):
|
||
|
raise NotImplementedError()
|
||
|
|
||
|
|
||
|
class Identity(Normalizer):
|
||
|
def normalize(self, x):
|
||
|
return x
|
||
|
|
||
|
def unnormalize(self, x):
|
||
|
return x
|
||
|
|
||
|
|
||
|
class GaussianNormalizer(Normalizer):
|
||
|
'''
|
||
|
normalizes to zero mean and unit variance
|
||
|
'''
|
||
|
|
||
|
def __init__(self, *args, **kwargs):
|
||
|
super().__init__(*args, **kwargs)
|
||
|
self.means = self.X.mean(dim=0)
|
||
|
self.stds = self.X.std(dim=0)
|
||
|
self.z = 1
|
||
|
|
||
|
def __repr__(self):
|
||
|
return (
|
||
|
f'''[ Normalizer ] dim: {self.mins.size}\n '''
|
||
|
f'''means: {torch.round(self.means, decimals=2)}\n '''
|
||
|
f'''stds: {torch.round(self.z * self.stds, decimals=2)}\n'''
|
||
|
)
|
||
|
|
||
|
def normalize(self, x):
|
||
|
return (x - self.means) / self.stds
|
||
|
|
||
|
def unnormalize(self, x):
|
||
|
return x * self.stds + self.means
|
||
|
|
||
|
|
||
|
class LimitsNormalizer(Normalizer):
|
||
|
'''
|
||
|
maps [ xmin, xmax ] to [ -1, 1 ]
|
||
|
'''
|
||
|
|
||
|
def normalize(self, x):
|
||
|
## [ 0, 1 ]
|
||
|
x = (x - self.mins) / (self.maxs - self.mins)
|
||
|
## [ -1, 1 ]
|
||
|
x = 2 * x - 1
|
||
|
return x
|
||
|
|
||
|
def unnormalize(self, x, eps=1e-4):
|
||
|
'''
|
||
|
x : [ -1, 1 ]
|
||
|
'''
|
||
|
if x.max() > 1 + eps or x.min() < -1 - eps:
|
||
|
# print(f'[ datasets/mujoco ] Warning: sample out of range | ({x.min():.4f}, {x.max():.4f})')
|
||
|
x = torch.clip(x, -1, 1)
|
||
|
|
||
|
## [ -1, 1 ] --> [ 0, 1 ]
|
||
|
x = (x + 1) / 2.
|
||
|
|
||
|
return x * (self.maxs - self.mins) + self.mins
|
||
|
|
||
|
|
||
|
class SafeLimitsNormalizer(LimitsNormalizer):
|
||
|
'''
|
||
|
functions like LimitsNormalizer, but can handle data for which a dimension is constant
|
||
|
'''
|
||
|
|
||
|
def __init__(self, *args, eps=1, **kwargs):
|
||
|
super().__init__(*args, **kwargs)
|
||
|
for i in range(len(self.mins)):
|
||
|
if self.mins[i] == self.maxs[i]:
|
||
|
print(f'''
|
||
|
[ utils/normalization ] Constant data in dimension {i} | '''
|
||
|
f'''max = min = {self.maxs[i]}'''
|
||
|
)
|
||
|
self.mins -= eps
|
||
|
self.maxs += eps
|
||
|
|
||
|
|
||
|
class FixedLimitsNormalizer(LimitsNormalizer):
|
||
|
'''
|
||
|
functions like LimitsNormalizer, but with fixed limits not derived from the data
|
||
|
'''
|
||
|
|
||
|
def __init__(self, *args, min=-1, max=1, **kwargs):
|
||
|
super().__init__(*args, **kwargs)
|
||
|
self.mins = torch.ones_like(self.mins) * min
|
||
|
self.maxs = torch.ones_like(self.maxs) * max
|