applied formatter to tools

This commit is contained in:
NM512 2023-04-15 15:28:09 +09:00
parent 55ed69bdf7
commit fba87a33e0

228
tools.py
View File

@ -20,14 +20,16 @@ from torch.utils.tensorboard import SummaryWriter
to_np = lambda x: x.detach().cpu().numpy() to_np = lambda x: x.detach().cpu().numpy()
def symlog(x): def symlog(x):
return torch.sign(x) * torch.log(torch.abs(x) + 1.0) return torch.sign(x) * torch.log(torch.abs(x) + 1.0)
def symexp(x): def symexp(x):
return torch.sign(x) * (torch.exp(torch.abs(x)) - 1.0) return torch.sign(x) * (torch.exp(torch.abs(x)) - 1.0)
class RequiresGrad:
class RequiresGrad:
def __init__(self, model): def __init__(self, model):
self._model = model self._model = model
@ -39,7 +41,6 @@ class RequiresGrad:
class TimeRecording: class TimeRecording:
def __init__(self, comment): def __init__(self, comment):
self._comment = comment self._comment = comment
@ -55,7 +56,6 @@ class TimeRecording:
class Logger: class Logger:
def __init__(self, logdir, step): def __init__(self, logdir, step):
self._logdir = logdir self._logdir = logdir
self._writer = SummaryWriter(log_dir=str(logdir), max_queue=1000) self._writer = SummaryWriter(log_dir=str(logdir), max_queue=1000)
@ -80,16 +80,16 @@ class Logger:
step = self.step step = self.step
scalars = list(self._scalars.items()) scalars = list(self._scalars.items())
if fps: if fps:
scalars.append(('fps', self._compute_fps(step))) scalars.append(("fps", self._compute_fps(step)))
print(f'[{step}]', ' / '.join(f'{k} {v:.1f}' for k, v in scalars)) print(f"[{step}]", " / ".join(f"{k} {v:.1f}" for k, v in scalars))
with (self._logdir / 'metrics.jsonl').open('a') as f: with (self._logdir / "metrics.jsonl").open("a") as f:
f.write(json.dumps({'step': step, ** dict(scalars)}) + '\n') f.write(json.dumps({"step": step, **dict(scalars)}) + "\n")
for name, value in scalars: for name, value in scalars:
self._writer.add_scalar('scalars/' + name, value, step) self._writer.add_scalar("scalars/" + name, value, step)
for name, value in self._images.items(): for name, value in self._images.items():
self._writer.add_image(name, value, step) self._writer.add_image(name, value, step)
for name, value in self._videos.items(): for name, value in self._videos.items():
name = name if isinstance(name, str) else name.decode('utf-8') name = name if isinstance(name, str) else name.decode("utf-8")
if np.issubdtype(value.dtype, np.floating): if np.issubdtype(value.dtype, np.floating):
value = np.clip(255 * value, 0, 255).astype(np.uint8) value = np.clip(255 * value, 0, 255).astype(np.uint8)
B, T, H, W, C = value.shape B, T, H, W, C = value.shape
@ -113,7 +113,7 @@ class Logger:
return steps / duration return steps / duration
def offline_scalar(self, name, value, step): def offline_scalar(self, name, value, step):
self._writer.add_scalar('scalars/'+name, value, step) self._writer.add_scalar("scalars/" + name, value, step)
def offline_video(self, name, value, step): def offline_video(self, name, value, step):
if np.issubdtype(value.dtype, np.floating): if np.issubdtype(value.dtype, np.floating):
@ -148,7 +148,8 @@ def simulate(agent, envs, steps=0, episodes=0, state=None):
if isinstance(action, dict): if isinstance(action, dict):
action = [ action = [
{k: np.array(action[k][i].detach().cpu()) for k in action} {k: np.array(action[k][i].detach().cpu()) for k in action}
for i in range(len(envs))] for i in range(len(envs))
]
else: else:
action = np.array(action) action = np.array(action)
assert len(action) == len(envs) assert len(action) == len(envs)
@ -161,7 +162,7 @@ def simulate(agent, envs, steps=0, episodes=0, state=None):
episode += int(done.sum()) episode += int(done.sum())
length += 1 length += 1
step += (done * length).sum() step += (done * length).sum()
length *= (1 - done) length *= 1 - done
return (step - steps, episode - episodes, done, length, obs, agent_state, reward) return (step - steps, episode - episodes, done, length, obs, agent_state, reward)
@ -169,16 +170,16 @@ def simulate(agent, envs, steps=0, episodes=0, state=None):
def save_episodes(directory, episodes): def save_episodes(directory, episodes):
directory = pathlib.Path(directory).expanduser() directory = pathlib.Path(directory).expanduser()
directory.mkdir(parents=True, exist_ok=True) directory.mkdir(parents=True, exist_ok=True)
timestamp = datetime.datetime.now().strftime('%Y%m%dT%H%M%S') timestamp = datetime.datetime.now().strftime("%Y%m%dT%H%M%S")
filenames = [] filenames = []
for episode in episodes: for episode in episodes:
identifier = str(uuid.uuid4().hex) identifier = str(uuid.uuid4().hex)
length = len(episode['reward']) length = len(episode["reward"])
filename = directory / f'{timestamp}-{identifier}-{length}.npz' filename = directory / f"{timestamp}-{identifier}-{length}.npz"
with io.BytesIO() as f1: with io.BytesIO() as f1:
np.savez_compressed(f1, **episode) np.savez_compressed(f1, **episode)
f1.seek(0) f1.seek(0)
with filename.open('wb') as f2: with filename.open("wb") as f2:
f2.write(f1.read()) f2.write(f1.read())
filenames.append(filename) filenames.append(filename)
return filenames return filenames
@ -206,7 +207,7 @@ def sample_episodes(episodes, length=None, balance=False, seed=0):
total = len(next(iter(episode.values()))) total = len(next(iter(episode.values())))
available = total - length available = total - length
if available < 1: if available < 1:
print(f'Skipped short episode of length {available}.') print(f"Skipped short episode of length {available}.")
continue continue
if balance: if balance:
index = min(random.randint(0, total), available) index = min(random.randint(0, total), available)
@ -221,43 +222,42 @@ def load_episodes(directory, limit=None, reverse=True):
episodes = collections.OrderedDict() episodes = collections.OrderedDict()
total = 0 total = 0
if reverse: if reverse:
for filename in reversed(sorted(directory.glob('*.npz'))): for filename in reversed(sorted(directory.glob("*.npz"))):
try: try:
with filename.open('rb') as f: with filename.open("rb") as f:
episode = np.load(f) episode = np.load(f)
episode = {k: episode[k] for k in episode.keys()} episode = {k: episode[k] for k in episode.keys()}
except Exception as e: except Exception as e:
print(f'Could not load episode: {e}') print(f"Could not load episode: {e}")
continue continue
episodes[str(filename)] = episode episodes[str(filename)] = episode
total += len(episode['reward']) - 1 total += len(episode["reward"]) - 1
if limit and total >= limit: if limit and total >= limit:
break break
else: else:
for filename in sorted(directory.glob('*.npz')): for filename in sorted(directory.glob("*.npz")):
try: try:
with filename.open('rb') as f: with filename.open("rb") as f:
episode = np.load(f) episode = np.load(f)
episode = {k: episode[k] for k in episode.keys()} episode = {k: episode[k] for k in episode.keys()}
except Exception as e: except Exception as e:
print(f'Could not load episode: {e}') print(f"Could not load episode: {e}")
continue continue
episodes[str(filename)] = episode episodes[str(filename)] = episode
total += len(episode['reward']) - 1 total += len(episode["reward"]) - 1
if limit and total >= limit: if limit and total >= limit:
break break
return episodes return episodes
class SampleDist: class SampleDist:
def __init__(self, dist, samples=100): def __init__(self, dist, samples=100):
self._dist = dist self._dist = dist
self._samples = samples self._samples = samples
@property @property
def name(self): def name(self):
return 'SampleDist' return "SampleDist"
def __getattr__(self, name): def __getattr__(self, name):
return getattr(self._dist, name) return getattr(self._dist, name)
@ -278,7 +278,6 @@ class SampleDist:
class OneHotDist(torchd.one_hot_categorical.OneHotCategorical): class OneHotDist(torchd.one_hot_categorical.OneHotCategorical):
def __init__(self, logits=None, probs=None, unimix_ratio=0.0): def __init__(self, logits=None, probs=None, unimix_ratio=0.0):
if logits is not None and unimix_ratio > 0.0: if logits is not None and unimix_ratio > 0.0:
probs = F.softmax(logits, dim=-1) probs = F.softmax(logits, dim=-1)
@ -289,12 +288,14 @@ class OneHotDist(torchd.one_hot_categorical.OneHotCategorical):
super().__init__(logits=logits, probs=probs) super().__init__(logits=logits, probs=probs)
def mode(self): def mode(self):
_mode = F.one_hot(torch.argmax(super().logits, axis=-1), super().logits.shape[-1]) _mode = F.one_hot(
torch.argmax(super().logits, axis=-1), super().logits.shape[-1]
)
return _mode.detach() + super().logits - super().logits.detach() return _mode.detach() + super().logits - super().logits.detach()
def sample(self, sample_shape=(), seed=None): def sample(self, sample_shape=(), seed=None):
if seed is not None: if seed is not None:
raise ValueError('need to check') raise ValueError("need to check")
sample = super().sample(sample_shape) sample = super().sample(sample_shape)
probs = super().probs probs = super().probs
while len(probs.shape) < len(sample.shape): while len(probs.shape) < len(sample.shape):
@ -303,9 +304,8 @@ class OneHotDist(torchd.one_hot_categorical.OneHotCategorical):
return sample return sample
class TwoHotDistSymlog(): class TwoHotDistSymlog:
def __init__(self, logits=None, low=-20.0, high=20.0, device="cuda"):
def __init__(self, logits=None, low=-20.0, high=20.0, device='cuda'):
self.logits = logits self.logits = logits
self.probs = torch.softmax(logits, -1) self.probs = torch.softmax(logits, -1)
self.buckets = torch.linspace(low, high, steps=255).to(device) self.buckets = torch.linspace(low, high, steps=255).to(device)
@ -325,10 +325,12 @@ class TwoHotDistSymlog():
x = symlog(x) x = symlog(x)
# x(time, batch, 1) # x(time, batch, 1)
below = torch.sum((self.buckets <= x[..., None]).to(torch.int32), dim=-1) - 1 below = torch.sum((self.buckets <= x[..., None]).to(torch.int32), dim=-1) - 1
above = len(self.buckets) - torch.sum((self.buckets > x[..., None]).to(torch.int32), dim=-1) above = len(self.buckets) - torch.sum(
(self.buckets > x[..., None]).to(torch.int32), dim=-1
)
below = torch.clip(below, 0, len(self.buckets) - 1) below = torch.clip(below, 0, len(self.buckets) - 1)
above = torch.clip(above, 0, len(self.buckets) - 1) above = torch.clip(above, 0, len(self.buckets) - 1)
equal = (below == above) equal = below == above
dist_to_below = torch.where(equal, 1, torch.abs(self.buckets[below] - x)) dist_to_below = torch.where(equal, 1, torch.abs(self.buckets[below] - x))
dist_to_above = torch.where(equal, 1, torch.abs(self.buckets[above] - x)) dist_to_above = torch.where(equal, 1, torch.abs(self.buckets[above] - x))
@ -336,8 +338,9 @@ class TwoHotDistSymlog():
weight_below = dist_to_above / total weight_below = dist_to_above / total
weight_above = dist_to_below / total weight_above = dist_to_below / total
target = ( target = (
F.one_hot(below, num_classes=len(self.buckets)) * weight_below[..., None] + F.one_hot(below, num_classes=len(self.buckets)) * weight_below[..., None]
F.one_hot(above, num_classes=len(self.buckets)) * weight_above[..., None]) + F.one_hot(above, num_classes=len(self.buckets)) * weight_above[..., None]
)
log_pred = self.logits - torch.logsumexp(self.logits, -1, keepdim=True) log_pred = self.logits - torch.logsumexp(self.logits, -1, keepdim=True)
target = target.squeeze(-2) target = target.squeeze(-2)
@ -347,8 +350,11 @@ class TwoHotDistSymlog():
log_pred = super().logits - torch.logsumexp(super().logits, -1, keepdim=True) log_pred = super().logits - torch.logsumexp(super().logits, -1, keepdim=True)
return (target * log_pred).sum(-1) return (target * log_pred).sum(-1)
class SymlogDist():
def __init__(self, mode, dist='mse', agg='sum', tol=1e-8, dim_to_reduce=[-1, -2, -3]): class SymlogDist:
def __init__(
self, mode, dist="mse", agg="sum", tol=1e-8, dim_to_reduce=[-1, -2, -3]
):
self._mode = mode self._mode = mode
self._dist = dist self._dist = dist
self._agg = agg self._agg = agg
@ -363,24 +369,24 @@ class SymlogDist():
def log_prob(self, value): def log_prob(self, value):
assert self._mode.shape == value.shape assert self._mode.shape == value.shape
if self._dist == 'mse': if self._dist == "mse":
distance = (self._mode - symlog(value)) ** 2.0 distance = (self._mode - symlog(value)) ** 2.0
distance = torch.where(distance < self._tol, 0, distance) distance = torch.where(distance < self._tol, 0, distance)
elif self._dist == 'abs': elif self._dist == "abs":
distance = torch.abs(self._mode - symlog(value)) distance = torch.abs(self._mode - symlog(value))
distance = torch.where(distance < self._tol, 0, distance) distance = torch.where(distance < self._tol, 0, distance)
else: else:
raise NotImplementedError(self._dist) raise NotImplementedError(self._dist)
if self._agg == 'mean': if self._agg == "mean":
loss = distance.mean(self._dim_to_reduce) loss = distance.mean(self._dim_to_reduce)
elif self._agg == 'sum': elif self._agg == "sum":
loss = distance.sum(self._dim_to_reduce) loss = distance.sum(self._dim_to_reduce)
else: else:
raise NotImplementedError(self._agg) raise NotImplementedError(self._agg)
return -loss return -loss
class ContDist:
class ContDist:
def __init__(self, dist=None): def __init__(self, dist=None):
super().__init__() super().__init__()
self._dist = dist self._dist = dist
@ -403,7 +409,6 @@ class ContDist:
class Bernoulli: class Bernoulli:
def __init__(self, dist=None): def __init__(self, dist=None):
super().__init__() super().__init__()
self._dist = dist self._dist = dist
@ -431,21 +436,21 @@ class Bernoulli:
class UnnormalizedHuber(torchd.normal.Normal): class UnnormalizedHuber(torchd.normal.Normal):
def __init__(self, loc, scale, threshold=1, **kwargs): def __init__(self, loc, scale, threshold=1, **kwargs):
super().__init__(loc, scale, **kwargs) super().__init__(loc, scale, **kwargs)
self._threshold = threshold self._threshold = threshold
def log_prob(self, event): def log_prob(self, event):
return -(torch.sqrt( return -(
(event - self.mean) ** 2 + self._threshold ** 2) - self._threshold) torch.sqrt((event - self.mean) ** 2 + self._threshold**2)
- self._threshold
)
def mode(self): def mode(self):
return self.mean return self.mean
class SafeTruncatedNormal(torchd.normal.Normal): class SafeTruncatedNormal(torchd.normal.Normal):
def __init__(self, loc, scale, low, high, clip=1e-6, mult=1): def __init__(self, loc, scale, low, high, clip=1e-6, mult=1):
super().__init__(loc, scale) super().__init__(loc, scale)
self._low = low self._low = low
@ -456,8 +461,7 @@ class SafeTruncatedNormal(torchd.normal.Normal):
def sample(self, sample_shape): def sample(self, sample_shape):
event = super().sample(sample_shape) event = super().sample(sample_shape)
if self._clip: if self._clip:
clipped = torch.clip(event, self._low + self._clip, clipped = torch.clip(event, self._low + self._clip, self._high - self._clip)
self._high - self._clip)
event = event - event.detach() + clipped.detach() event = event - event.detach() + clipped.detach()
if self._mult: if self._mult:
event *= self._mult event *= self._mult
@ -465,8 +469,7 @@ class SafeTruncatedNormal(torchd.normal.Normal):
class TanhBijector(torchd.Transform): class TanhBijector(torchd.Transform):
def __init__(self, validate_args=False, name="tanh"):
def __init__(self, validate_args=False, name='tanh'):
super().__init__() super().__init__()
def _forward(self, x): def _forward(self, x):
@ -474,8 +477,8 @@ class TanhBijector(torchd.Transform):
def _inverse(self, y): def _inverse(self, y):
y = torch.where( y = torch.where(
(torch.abs(y) <= 1.), (torch.abs(y) <= 1.0), torch.clamp(y, -0.99999997, 0.99999997), y
torch.clamp(y, -0.99999997, 0.99999997), y) )
y = torch.atanh(y) y = torch.atanh(y)
return y return y
@ -504,8 +507,7 @@ def static_scan_for_lambda_return(fn, inputs, start):
return outputs return outputs
def lambda_return( def lambda_return(reward, value, pcont, bootstrap, lambda_, axis):
reward, value, pcont, bootstrap, lambda_, axis):
# Setting lambda=1 gives a discounted Monte Carlo return. # Setting lambda=1 gives a discounted Monte Carlo return.
# Setting lambda=0 gives a fixed 1-step return. # Setting lambda=0 gives a fixed 1-step return.
# assert reward.shape.ndims == value.shape.ndims, (reward.shape, value.shape) # assert reward.shape.ndims == value.shape.ndims, (reward.shape, value.shape)
@ -527,18 +529,26 @@ def lambda_return(
# (inputs, pcont), bootstrap, reverse=True) # (inputs, pcont), bootstrap, reverse=True)
# reimplement to optimize performance # reimplement to optimize performance
returns = static_scan_for_lambda_return( returns = static_scan_for_lambda_return(
lambda agg, cur0, cur1: cur0 + cur1 * lambda_ * agg, lambda agg, cur0, cur1: cur0 + cur1 * lambda_ * agg, (inputs, pcont), bootstrap
(inputs, pcont), bootstrap) )
if axis != 0: if axis != 0:
returns = returns.permute(dims) returns = returns.permute(dims)
return returns return returns
class Optimizer(): class Optimizer:
def __init__( def __init__(
self, name, parameters, lr, eps=1e-4, clip=None, wd=None, wd_pattern=r'.*', self,
opt='adam', use_amp=False): name,
parameters,
lr,
eps=1e-4,
clip=None,
wd=None,
wd_pattern=r".*",
opt="adam",
use_amp=False,
):
assert 0 <= wd < 1 assert 0 <= wd < 1
assert not clip or 1 <= clip assert not clip or 1 <= clip
self._name = name self._name = name
@ -547,26 +557,18 @@ class Optimizer():
self._wd = wd self._wd = wd
self._wd_pattern = wd_pattern self._wd_pattern = wd_pattern
self._opt = { self._opt = {
'adam': lambda: torch.optim.Adam(parameters, "adam": lambda: torch.optim.Adam(parameters, lr=lr, eps=eps),
lr=lr, "nadam": lambda: NotImplemented(f"{opt} is not implemented"),
eps=eps), "adamax": lambda: torch.optim.Adamax(parameters, lr=lr, eps=eps),
'nadam': lambda: NotImplemented( "sgd": lambda: torch.optim.SGD(parameters, lr=lr),
f'{opt} is not implemented'), "momentum": lambda: torch.optim.SGD(parameters, lr=lr, momentum=0.9),
'adamax': lambda: torch.optim.Adamax(parameters,
lr=lr,
eps=eps),
'sgd': lambda: torch.optim.SGD(parameters,
lr=lr),
'momentum': lambda: torch.optim.SGD(parameters,
lr=lr,
momentum=0.9),
}[opt]() }[opt]()
self._scaler = torch.cuda.amp.GradScaler(enabled=use_amp) self._scaler = torch.cuda.amp.GradScaler(enabled=use_amp)
def __call__(self, loss, params, retain_graph=False): def __call__(self, loss, params, retain_graph=False):
assert len(loss.shape) == 0, loss.shape assert len(loss.shape) == 0, loss.shape
metrics = {} metrics = {}
metrics[f'{self._name}_loss'] = loss.detach().cpu().numpy() metrics[f"{self._name}_loss"] = loss.detach().cpu().numpy()
self._scaler.scale(loss).backward() self._scaler.scale(loss).backward()
self._scaler.unscale_(self._opt) self._scaler.unscale_(self._opt)
# loss.backward(retain_graph=retain_graph) # loss.backward(retain_graph=retain_graph)
@ -577,11 +579,11 @@ class Optimizer():
self._scaler.update() self._scaler.update()
# self._opt.step() # self._opt.step()
self._opt.zero_grad() self._opt.zero_grad()
metrics[f'{self._name}_grad_norm'] = norm.item() metrics[f"{self._name}_grad_norm"] = norm.item()
return metrics return metrics
def _apply_weight_decay(self, varibs): def _apply_weight_decay(self, varibs):
nontrivial = (self._wd_pattern != r'.*') nontrivial = self._wd_pattern != r".*"
if nontrivial: if nontrivial:
raise NotImplementedError raise NotImplementedError
for var in varibs: for var in varibs:
@ -593,16 +595,18 @@ def args_type(default):
if default is None: if default is None:
return x return x
if isinstance(default, bool): if isinstance(default, bool):
return bool(['False', 'True'].index(x)) return bool(["False", "True"].index(x))
if isinstance(default, int): if isinstance(default, int):
return float(x) if ('e' in x or '.' in x) else int(x) return float(x) if ("e" in x or "." in x) else int(x)
if isinstance(default, (list, tuple)): if isinstance(default, (list, tuple)):
return tuple(args_type(default[0])(y) for y in x.split(',')) return tuple(args_type(default[0])(y) for y in x.split(","))
return type(default)(x) return type(default)(x)
def parse_object(x): def parse_object(x):
if isinstance(default, (list, tuple)): if isinstance(default, (list, tuple)):
return tuple(x) return tuple(x)
return x return x
return lambda x: parse_string(x) if isinstance(x, str) else parse_object(x) return lambda x: parse_string(x) if isinstance(x, str) else parse_object(x)
@ -615,27 +619,39 @@ def static_scan(fn, inputs, start):
last = fn(last, *inp(index)) last = fn(last, *inp(index))
if flag: if flag:
if type(last) == type({}): if type(last) == type({}):
outputs = {key: value.clone().unsqueeze(0) for key, value in last.items()} outputs = {
key: value.clone().unsqueeze(0) for key, value in last.items()
}
else: else:
outputs = [] outputs = []
for _last in last: for _last in last:
if type(_last) == type({}): if type(_last) == type({}):
outputs.append({key: value.clone().unsqueeze(0) for key, value in _last.items()}) outputs.append(
{
key: value.clone().unsqueeze(0)
for key, value in _last.items()
}
)
else: else:
outputs.append(_last.clone().unsqueeze(0)) outputs.append(_last.clone().unsqueeze(0))
flag = False flag = False
else: else:
if type(last) == type({}): if type(last) == type({}):
for key in last.keys(): for key in last.keys():
outputs[key] = torch.cat([outputs[key], last[key].unsqueeze(0)], dim=0) outputs[key] = torch.cat(
[outputs[key], last[key].unsqueeze(0)], dim=0
)
else: else:
for j in range(len(outputs)): for j in range(len(outputs)):
if type(last[j]) == type({}): if type(last[j]) == type({}):
for key in last[j].keys(): for key in last[j].keys():
outputs[j][key] = torch.cat([outputs[j][key], outputs[j][key] = torch.cat(
last[j][key].unsqueeze(0)], dim=0) [outputs[j][key], last[j][key].unsqueeze(0)], dim=0
)
else: else:
outputs[j] = torch.cat([outputs[j], last[j].unsqueeze(0)], dim=0) outputs[j] = torch.cat(
[outputs[j], last[j].unsqueeze(0)], dim=0
)
if type(last) == type({}): if type(last) == type({}):
outputs = [outputs] outputs = [outputs]
return outputs return outputs
@ -673,7 +689,6 @@ def static_scan(fn, inputs, start):
class Every: class Every:
def __init__(self, every): def __init__(self, every):
self._every = every self._every = every
self._last = None self._last = None
@ -688,8 +703,8 @@ class Every:
self._last += self._every * count self._last += self._every * count
return count return count
class Once:
class Once:
def __init__(self): def __init__(self):
self._once = True self._once = True
@ -701,7 +716,6 @@ class Once:
class Until: class Until:
def __init__(self, until): def __init__(self, until):
self._until = until self._until = until
@ -715,21 +729,21 @@ def schedule(string, step):
try: try:
return float(string) return float(string)
except ValueError: except ValueError:
match = re.match(r'linear\((.+),(.+),(.+)\)', string) match = re.match(r"linear\((.+),(.+),(.+)\)", string)
if match: if match:
initial, final, duration = [float(group) for group in match.groups()] initial, final, duration = [float(group) for group in match.groups()]
mix = torch.clip(torch.Tensor([step / duration]), 0, 1)[0] mix = torch.clip(torch.Tensor([step / duration]), 0, 1)[0]
return (1 - mix) * initial + mix * final return (1 - mix) * initial + mix * final
match = re.match(r'warmup\((.+),(.+)\)', string) match = re.match(r"warmup\((.+),(.+)\)", string)
if match: if match:
warmup, value = [float(group) for group in match.groups()] warmup, value = [float(group) for group in match.groups()]
scale = torch.clip(step / warmup, 0, 1) scale = torch.clip(step / warmup, 0, 1)
return scale * value return scale * value
match = re.match(r'exp\((.+),(.+),(.+)\)', string) match = re.match(r"exp\((.+),(.+),(.+)\)", string)
if match: if match:
initial, final, halflife = [float(group) for group in match.groups()] initial, final, halflife = [float(group) for group in match.groups()]
return (initial - final) * 0.5 ** (step / halflife) + final return (initial - final) * 0.5 ** (step / halflife) + final
match = re.match(r'horizon\((.+),(.+),(.+)\)', string) match = re.match(r"horizon\((.+),(.+),(.+)\)", string)
if match: if match:
initial, final, duration = [float(group) for group in match.groups()] initial, final, duration = [float(group) for group in match.groups()]
mix = torch.clip(step / duration, 0, 1) mix = torch.clip(step / duration, 0, 1)
@ -737,6 +751,7 @@ def schedule(string, step):
return 1 - 1 / horizon return 1 - 1 / horizon
raise NotImplementedError(string) raise NotImplementedError(string)
def weight_init(m): def weight_init(m):
if isinstance(m, nn.Linear): if isinstance(m, nn.Linear):
in_num = m.in_features in_num = m.in_features
@ -745,7 +760,7 @@ def weight_init(m):
scale = 1.0 / denoms scale = 1.0 / denoms
std = np.sqrt(scale) / 0.87962566103423978 std = np.sqrt(scale) / 0.87962566103423978
nn.init.trunc_normal_(m.weight.data, mean=0.0, std=std, a=-2.0, b=2.0) nn.init.trunc_normal_(m.weight.data, mean=0.0, std=std, a=-2.0, b=2.0)
if hasattr(m.bias, 'data'): if hasattr(m.bias, "data"):
m.bias.data.fill_(0.0) m.bias.data.fill_(0.0)
elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
space = m.kernel_size[0] * m.kernel_size[1] space = m.kernel_size[0] * m.kernel_size[1]
@ -755,13 +770,14 @@ def weight_init(m):
scale = 1.0 / denoms scale = 1.0 / denoms
std = np.sqrt(scale) / 0.87962566103423978 std = np.sqrt(scale) / 0.87962566103423978
nn.init.trunc_normal_(m.weight.data, mean=0.0, std=std, a=-2.0, b=2.0) nn.init.trunc_normal_(m.weight.data, mean=0.0, std=std, a=-2.0, b=2.0)
if hasattr(m.bias, 'data'): if hasattr(m.bias, "data"):
m.bias.data.fill_(0.0) m.bias.data.fill_(0.0)
elif isinstance(m, nn.LayerNorm): elif isinstance(m, nn.LayerNorm):
m.weight.data.fill_(1.0) m.weight.data.fill_(1.0)
if hasattr(m.bias, 'data'): if hasattr(m.bias, "data"):
m.bias.data.fill_(0.0) m.bias.data.fill_(0.0)
def uniform_weight_init(given_scale): def uniform_weight_init(given_scale):
def f(m): def f(m):
if isinstance(m, nn.Linear): if isinstance(m, nn.Linear):
@ -771,21 +787,23 @@ def uniform_weight_init(given_scale):
scale = given_scale / denoms scale = given_scale / denoms
limit = np.sqrt(3 * scale) limit = np.sqrt(3 * scale)
nn.init.uniform_(m.weight.data, a=-limit, b=limit) nn.init.uniform_(m.weight.data, a=-limit, b=limit)
if hasattr(m.bias, 'data'): if hasattr(m.bias, "data"):
m.bias.data.fill_(0.0) m.bias.data.fill_(0.0)
elif isinstance(m, nn.LayerNorm): elif isinstance(m, nn.LayerNorm):
m.weight.data.fill_(1.0) m.weight.data.fill_(1.0)
if hasattr(m.bias, 'data'): if hasattr(m.bias, "data"):
m.bias.data.fill_(0.0) m.bias.data.fill_(0.0)
return f return f
def tensorstats(tensor, prefix=None): def tensorstats(tensor, prefix=None):
metrics = { metrics = {
'mean': to_np(torch.mean(tensor)), "mean": to_np(torch.mean(tensor)),
'std': to_np(torch.std(tensor)), "std": to_np(torch.std(tensor)),
'min': to_np(torch.min(tensor)), "min": to_np(torch.min(tensor)),
'max': to_np(torch.max(tensor)), "max": to_np(torch.max(tensor)),
} }
if prefix: if prefix:
metrics = {f'{prefix}_{k}': v for k, v in metrics.items()} metrics = {f"{prefix}_{k}": v for k, v in metrics.items()}
return metrics return metrics