700 lines
21 KiB
Python
700 lines
21 KiB
Python
import datetime
|
|
import io
|
|
import json
|
|
import pathlib
|
|
import pickle
|
|
import re
|
|
import time
|
|
import uuid
|
|
|
|
import numpy as np
|
|
|
|
import torch
|
|
from torch import nn
|
|
from torch.nn import functional as F
|
|
from torch import distributions as torchd
|
|
from torch.utils.data import Dataset
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
|
|
|
|
class RequiresGrad:
|
|
|
|
def __init__(self, model):
|
|
self._model = model
|
|
|
|
def __enter__(self):
|
|
self._model.requires_grad_(requires_grad=True)
|
|
|
|
def __exit__(self, *args):
|
|
self._model.requires_grad_(requires_grad=False)
|
|
|
|
|
|
class TimeRecording:
|
|
|
|
def __init__(self, comment):
|
|
self._comment = comment
|
|
|
|
def __enter__(self):
|
|
self._st = torch.cuda.Event(enable_timing=True)
|
|
self._nd = torch.cuda.Event(enable_timing=True)
|
|
self._st.record()
|
|
|
|
def __exit__(self, *args):
|
|
self._nd.record()
|
|
torch.cuda.synchronize()
|
|
print(self._comment, self._st.elapsed_time(self._nd)/1000)
|
|
|
|
|
|
class Logger:
|
|
|
|
def __init__(self, logdir, step):
|
|
self._logdir = logdir
|
|
self._writer = SummaryWriter(log_dir=str(logdir), max_queue=1000)
|
|
self._last_step = None
|
|
self._last_time = None
|
|
self._scalars = {}
|
|
self._images = {}
|
|
self._videos = {}
|
|
self.step = step
|
|
|
|
def scalar(self, name, value):
|
|
self._scalars[name] = float(value)
|
|
|
|
def image(self, name, value):
|
|
self._images[name] = np.array(value)
|
|
|
|
def video(self, name, value):
|
|
self._videos[name] = np.array(value)
|
|
|
|
def write(self, fps=False):
|
|
scalars = list(self._scalars.items())
|
|
if fps:
|
|
scalars.append(('fps', self._compute_fps(self.step)))
|
|
print(f'[{self.step}]', ' / '.join(f'{k} {v:.1f}' for k, v in scalars))
|
|
with (self._logdir / 'metrics.jsonl').open('a') as f:
|
|
f.write(json.dumps({'step': self.step, ** dict(scalars)}) + '\n')
|
|
for name, value in scalars:
|
|
self._writer.add_scalar('scalars/' + name, value, self.step)
|
|
for name, value in self._images.items():
|
|
self._writer.add_image(name, value, self.step)
|
|
for name, value in self._videos.items():
|
|
name = name if isinstance(name, str) else name.decode('utf-8')
|
|
if np.issubdtype(value.dtype, np.floating):
|
|
value = np.clip(255 * value, 0, 255).astype(np.uint8)
|
|
B, T, H, W, C = value.shape
|
|
value = value.transpose(1, 4, 2, 0, 3).reshape((1, T, C, H, B*W))
|
|
self._writer.add_video(name, value, self.step, 16)
|
|
|
|
self._writer.flush()
|
|
self._scalars = {}
|
|
self._images = {}
|
|
self._videos = {}
|
|
|
|
def _compute_fps(self, step):
|
|
if self._last_step is None:
|
|
self._last_time = time.time()
|
|
self._last_step = step
|
|
return 0
|
|
steps = step - self._last_step
|
|
duration = time.time() - self._last_time
|
|
self._last_time += duration
|
|
self._last_step = step
|
|
return steps / duration
|
|
|
|
def offline_scalar(self, name, value, step):
|
|
self._writer.add_scalar('scalars/'+name, value, step)
|
|
|
|
def offline_video(self, name, value, step):
|
|
if np.issubdtype(value.dtype, np.floating):
|
|
value = np.clip(255 * value, 0, 255).astype(np.uint8)
|
|
B, T, H, W, C = value.shape
|
|
value = value.transpose(1, 4, 2, 0, 3).reshape((1, T, C, H, B*W))
|
|
self._writer.add_video(name, value, step, 16)
|
|
|
|
|
|
def simulate(agent, envs, steps=0, episodes=0, state=None):
|
|
# Initialize or unpack simulation state.
|
|
if state is None:
|
|
step, episode = 0, 0
|
|
done = np.ones(len(envs), np.bool)
|
|
length = np.zeros(len(envs), np.int32)
|
|
obs = [None] * len(envs)
|
|
agent_state = None
|
|
reward = [0]*len(envs)
|
|
else:
|
|
step, episode, done, length, obs, agent_state, reward = state
|
|
while (steps and step < steps) or (episodes and episode < episodes):
|
|
# Reset envs if necessary.
|
|
if done.any():
|
|
indices = [index for index, d in enumerate(done) if d]
|
|
results = [envs[i].reset() for i in indices]
|
|
for index, result in zip(indices, results):
|
|
obs[index] = result
|
|
reward = [reward[i]*(1-done[i]) for i in range(len(envs))]
|
|
# Step agents.
|
|
obs = {k: np.stack([o[k] for o in obs]) for k in obs[0]}
|
|
action, agent_state = agent(obs, done, agent_state, reward)
|
|
if isinstance(action, dict):
|
|
action = [
|
|
{k: np.array(action[k][i].detach().cpu()) for k in action}
|
|
for i in range(len(envs))]
|
|
else:
|
|
action = np.array(action)
|
|
assert len(action) == len(envs)
|
|
# Step envs.
|
|
results = [e.step(a) for e, a in zip(envs, action)]
|
|
obs, reward, done = zip(*[p[:3] for p in results])
|
|
obs = list(obs)
|
|
reward = list(reward)
|
|
done = np.stack(done)
|
|
episode += int(done.sum())
|
|
length += 1
|
|
step += (done * length).sum()
|
|
length *= (1 - done)
|
|
|
|
return (step - steps, episode - episodes, done, length, obs, agent_state, reward)
|
|
|
|
|
|
def save_episodes(directory, episodes):
|
|
directory = pathlib.Path(directory).expanduser()
|
|
directory.mkdir(parents=True, exist_ok=True)
|
|
timestamp = datetime.datetime.now().strftime('%Y%m%dT%H%M%S')
|
|
filenames = []
|
|
for episode in episodes:
|
|
identifier = str(uuid.uuid4().hex)
|
|
length = len(episode['reward'])
|
|
filename = directory / f'{timestamp}-{identifier}-{length}.npz'
|
|
with io.BytesIO() as f1:
|
|
np.savez_compressed(f1, **episode)
|
|
f1.seek(0)
|
|
with filename.open('wb') as f2:
|
|
f2.write(f1.read())
|
|
filenames.append(filename)
|
|
return filenames
|
|
|
|
|
|
def from_generator(generator, batch_size):
|
|
while True:
|
|
batch = []
|
|
for _ in range(batch_size):
|
|
batch.append(next(generator))
|
|
data = {}
|
|
for key in batch[0].keys():
|
|
data[key] = []
|
|
for i in range(batch_size):
|
|
data[key].append(batch[i][key])
|
|
data[key] = np.stack(data[key], 0)
|
|
yield data
|
|
|
|
|
|
def sample_episodes(episodes, length=None, balance=False, seed=0):
|
|
random = np.random.RandomState(seed)
|
|
while True:
|
|
episode = random.choice(list(episodes.values()))
|
|
if length:
|
|
total = len(next(iter(episode.values())))
|
|
available = total - length
|
|
if available < 1:
|
|
print(f'Skipped short episode of length {available}.')
|
|
continue
|
|
if balance:
|
|
index = min(random.randint(0, total), available)
|
|
else:
|
|
index = int(random.randint(0, available + 1))
|
|
episode = {k: v[index: index + length] for k, v in episode.items()}
|
|
yield episode
|
|
|
|
|
|
def load_episodes(directory, limit=None, reverse=True):
|
|
directory = pathlib.Path(directory).expanduser()
|
|
episodes = {}
|
|
total = 0
|
|
if reverse:
|
|
for filename in reversed(sorted(directory.glob('*.npz'))):
|
|
try:
|
|
with filename.open('rb') as f:
|
|
episode = np.load(f)
|
|
episode = {k: episode[k] for k in episode.keys()}
|
|
except Exception as e:
|
|
print(f'Could not load episode: {e}')
|
|
continue
|
|
episodes[str(filename)] = episode
|
|
total += len(episode['reward']) - 1
|
|
if limit and total >= limit:
|
|
break
|
|
else:
|
|
for filename in sorted(directory.glob('*.npz')):
|
|
try:
|
|
with filename.open('rb') as f:
|
|
episode = np.load(f)
|
|
episode = {k: episode[k] for k in episode.keys()}
|
|
except Exception as e:
|
|
print(f'Could not load episode: {e}')
|
|
continue
|
|
episodes[str(filename)] = episode
|
|
total += len(episode['reward']) - 1
|
|
if limit and total >= limit:
|
|
break
|
|
return episodes
|
|
|
|
|
|
class SampleDist:
|
|
|
|
def __init__(self, dist, samples=100):
|
|
self._dist = dist
|
|
self._samples = samples
|
|
|
|
@property
|
|
def name(self):
|
|
return 'SampleDist'
|
|
|
|
def __getattr__(self, name):
|
|
return getattr(self._dist, name)
|
|
|
|
def mean(self):
|
|
samples = self._dist.sample(self._samples)
|
|
return torch.mean(samples, 0)
|
|
|
|
def mode(self):
|
|
sample = self._dist.sample(self._samples)
|
|
logprob = self._dist.log_prob(sample)
|
|
return sample[torch.argmax(logprob)][0]
|
|
|
|
def entropy(self):
|
|
sample = self._dist.sample(self._samples)
|
|
logprob = self.log_prob(sample)
|
|
return -torch.mean(logprob, 0)
|
|
|
|
|
|
class OneHotDist(torchd.one_hot_categorical.OneHotCategorical):
|
|
|
|
def __init__(self, logits=None, probs=None, unimix_ratio=0.0):
|
|
if logits is not None and probs is None and unimix_ratio > 0.0:
|
|
probs = F.softmax(logits, dim=-1)
|
|
probs = probs * (1.0-unimix_ratio) + unimix_ratio / probs.shape[-1]
|
|
logits = None
|
|
super().__init__(logits=logits, probs=probs)
|
|
|
|
def mode(self):
|
|
_mode = F.one_hot(torch.argmax(super().logits, axis=-1), super().logits.shape[-1])
|
|
return _mode.detach() + super().logits - super().logits.detach()
|
|
|
|
def sample(self, sample_shape=(), seed=None):
|
|
if seed is not None:
|
|
raise ValueError('need to check')
|
|
sample = super().sample(sample_shape)
|
|
probs = super().probs
|
|
while len(probs.shape) < len(sample.shape):
|
|
probs = probs[None]
|
|
sample += probs - probs.detach()
|
|
return sample
|
|
|
|
|
|
class TwoHotDist(torchd.one_hot_categorical.OneHotCategorical):
|
|
|
|
def __init__(self, logits=None, probs=None, unimix_ratio=0.0, device='cuda'):
|
|
if logits is not None and probs is None and unimix_ratio > 0.0:
|
|
probs = F.softmax(logits, dim=-1)
|
|
probs = probs * (1.0-unimix_ratio) + unimix_ratio / probs.shape[-1]
|
|
logits = None
|
|
super().__init__(logits=logits, probs=probs)
|
|
|
|
self.buckets = torch.linspace(-20.0, 20.0, steps=255).to(device)
|
|
self.width = (self.buckets[-1] - self.buckets[0]) / 255
|
|
|
|
def mode(self):
|
|
_mode = super().probs * self.buckets
|
|
return torch.sum(_mode, dim=-1, keepdim=True)
|
|
|
|
# Inside OneHotCategorical, log_prob is calculated using only max element in targets
|
|
def log_prob(self, x):
|
|
# x(time, batch, 1)
|
|
x = (x - self.buckets[0]) / self.width
|
|
lower_indices = (x).to(torch.int64)
|
|
# lower_indices is idnside 0 ~ len(buckets)-2
|
|
lower_indices = torch.clip(lower_indices, max=len(self.buckets)-2)
|
|
# upper_indices is inside 1 ~ len(buckets)-1
|
|
upper_indices = lower_indices + 1
|
|
lower_weight = torch.abs(x - upper_indices).squeeze(-1)
|
|
upper_weight = torch.abs(x - lower_indices).squeeze(-1)
|
|
# (time, batch, 1) -> (time, batch, bucket_class)
|
|
lower_log_prob = super().log_prob(F.one_hot(lower_indices.squeeze(-1), num_classes=len(self.buckets)))
|
|
upper_log_prob = super().log_prob(F.one_hot(upper_indices.squeeze(-1), num_classes=len(self.buckets)))
|
|
|
|
# label = lower_log_prob * lower_weight + upper_log_prob * upper_weight
|
|
# # (time, batch, bucket_class) -> (time, batch)
|
|
# cross_entropy = torch.sum(torch.log(super().probs) * label, axis=-1)
|
|
|
|
return lower_weight * lower_log_prob + upper_weight * upper_log_prob
|
|
|
|
class ContDist:
|
|
|
|
def __init__(self, dist=None):
|
|
super().__init__()
|
|
self._dist = dist
|
|
self.mean = dist.mean
|
|
|
|
def __getattr__(self, name):
|
|
return getattr(self._dist, name)
|
|
|
|
def entropy(self):
|
|
return self._dist.entropy()
|
|
|
|
def mode(self):
|
|
return self._dist.mean
|
|
|
|
def sample(self, sample_shape=()):
|
|
return self._dist.rsample(sample_shape)
|
|
|
|
def log_prob(self, x):
|
|
return self._dist.log_prob(x)
|
|
|
|
|
|
class Bernoulli:
|
|
|
|
def __init__(self, dist=None):
|
|
super().__init__()
|
|
self._dist = dist
|
|
self.mean = dist.mean
|
|
|
|
def __getattr__(self, name):
|
|
return getattr(self._dist, name)
|
|
|
|
def entropy(self):
|
|
return self._dist.entropy()
|
|
|
|
def mode(self):
|
|
_mode = torch.round(self._dist.mean)
|
|
return _mode.detach() +self._dist.mean - self._dist.mean.detach()
|
|
|
|
def sample(self, sample_shape=()):
|
|
return self._dist.rsample(sample_shape)
|
|
|
|
def log_prob(self, x):
|
|
_logits = self._dist.base_dist.logits
|
|
log_probs0 = -F.softplus(_logits)
|
|
log_probs1 = -F.softplus(-_logits)
|
|
|
|
return log_probs0 * (1-x) + log_probs1 * x
|
|
|
|
|
|
class UnnormalizedHuber(torchd.normal.Normal):
|
|
|
|
def __init__(self, loc, scale, threshold=1, **kwargs):
|
|
super().__init__(loc, scale, **kwargs)
|
|
self._threshold = threshold
|
|
|
|
def log_prob(self, event):
|
|
return -(torch.sqrt(
|
|
(event - self.mean) ** 2 + self._threshold ** 2) - self._threshold)
|
|
|
|
def mode(self):
|
|
return self.mean
|
|
|
|
|
|
class SafeTruncatedNormal(torchd.normal.Normal):
|
|
|
|
def __init__(self, loc, scale, low, high, clip=1e-6, mult=1):
|
|
super().__init__(loc, scale)
|
|
self._low = low
|
|
self._high = high
|
|
self._clip = clip
|
|
self._mult = mult
|
|
|
|
def sample(self, sample_shape):
|
|
event = super().sample(sample_shape)
|
|
if self._clip:
|
|
clipped = torch.clip(event, self._low + self._clip,
|
|
self._high - self._clip)
|
|
event = event - event.detach() + clipped.detach()
|
|
if self._mult:
|
|
event *= self._mult
|
|
return event
|
|
|
|
|
|
class TanhBijector(torchd.Transform):
|
|
|
|
def __init__(self, validate_args=False, name='tanh'):
|
|
super().__init__()
|
|
|
|
def _forward(self, x):
|
|
return torch.tanh(x)
|
|
|
|
def _inverse(self, y):
|
|
y = torch.where(
|
|
(torch.abs(y) <= 1.),
|
|
torch.clamp(y, -0.99999997, 0.99999997), y)
|
|
y = torch.atanh(y)
|
|
return y
|
|
|
|
def _forward_log_det_jacobian(self, x):
|
|
log2 = torch.math.log(2.0)
|
|
return 2.0 * (log2 - x - torch.softplus(-2.0 * x))
|
|
|
|
|
|
def static_scan_for_lambda_return(fn, inputs, start):
|
|
last = start
|
|
indices = range(inputs[0].shape[0])
|
|
indices = reversed(indices)
|
|
flag = True
|
|
for index in indices:
|
|
inp = lambda x: (_input[x] for _input in inputs)
|
|
last = fn(last, *inp(index))
|
|
if flag:
|
|
outputs = last
|
|
flag = False
|
|
else:
|
|
outputs = torch.cat([outputs, last], dim=-1)
|
|
outputs = torch.reshape(outputs, [outputs.shape[0], outputs.shape[1], 1])
|
|
outputs = torch.unbind(outputs, dim=0)
|
|
return outputs
|
|
|
|
|
|
def lambda_return(
|
|
reward, value, pcont, bootstrap, lambda_, axis):
|
|
# Setting lambda=1 gives a discounted Monte Carlo return.
|
|
# Setting lambda=0 gives a fixed 1-step return.
|
|
#assert reward.shape.ndims == value.shape.ndims, (reward.shape, value.shape)
|
|
assert len(reward.shape) == len(value.shape), (reward.shape, value.shape)
|
|
if isinstance(pcont, (int, float)):
|
|
pcont = pcont * torch.ones_like(reward)
|
|
dims = list(range(len(reward.shape)))
|
|
dims = [axis] + dims[1:axis] + [0] + dims[axis + 1:]
|
|
if axis != 0:
|
|
reward = reward.permute(dims)
|
|
value = value.permute(dims)
|
|
pcont = pcont.permute(dims)
|
|
if bootstrap is None:
|
|
bootstrap = torch.zeros_like(value[-1])
|
|
next_values = torch.cat([value[1:], bootstrap[None]], 0)
|
|
inputs = reward + pcont * next_values * (1 - lambda_)
|
|
#returns = static_scan(
|
|
# lambda agg, cur0, cur1: cur0 + cur1 * lambda_ * agg,
|
|
# (inputs, pcont), bootstrap, reverse=True)
|
|
# reimplement to optimize performance
|
|
returns = static_scan_for_lambda_return(
|
|
lambda agg, cur0, cur1: cur0 + cur1 * lambda_ * agg,
|
|
(inputs, pcont), bootstrap)
|
|
if axis != 0:
|
|
returns = returns.permute(dims)
|
|
return returns
|
|
|
|
|
|
class Optimizer():
|
|
|
|
def __init__(
|
|
self, name, parameters, lr, eps=1e-4, clip=None, wd=None, wd_pattern=r'.*',
|
|
opt='adam', use_amp=False):
|
|
assert 0 <= wd < 1
|
|
assert not clip or 1 <= clip
|
|
self._name = name
|
|
self._parameters = parameters
|
|
self._clip = clip
|
|
self._wd = wd
|
|
self._wd_pattern = wd_pattern
|
|
self._opt = {
|
|
'adam': lambda: torch.optim.Adam(parameters,
|
|
lr=lr,
|
|
eps=eps),
|
|
'nadam': lambda: NotImplemented(
|
|
f'{config.opt} is not implemented'),
|
|
'adamax': lambda: torch.optim.Adamax(parameters,
|
|
lr=lr,
|
|
eps=eps),
|
|
'sgd': lambda: torch.optim.SGD(parameters,
|
|
lr=lr),
|
|
'momentum': lambda: torch.optim.SGD(parameters,
|
|
lr=lr,
|
|
momentum=0.9),
|
|
}[opt]()
|
|
self._scaler = torch.cuda.amp.GradScaler(enabled=use_amp)
|
|
|
|
def __call__(self, loss, params, retain_graph=False):
|
|
assert len(loss.shape) == 0, loss.shape
|
|
metrics = {}
|
|
metrics[f'{self._name}_loss'] = loss.detach().cpu().numpy()
|
|
self._scaler.scale(loss).backward()
|
|
self._scaler.unscale_(self._opt)
|
|
#loss.backward(retain_graph=retain_graph)
|
|
norm = torch.nn.utils.clip_grad_norm_(params, self._clip)
|
|
if self._wd:
|
|
self._apply_weight_decay(params)
|
|
self._scaler.step(self._opt)
|
|
self._scaler.update()
|
|
#self._opt.step()
|
|
self._opt.zero_grad()
|
|
metrics[f'{self._name}_grad_norm'] = norm.item()
|
|
return metrics
|
|
|
|
def _apply_weight_decay(self, varibs):
|
|
nontrivial = (self._wd_pattern != r'.*')
|
|
if nontrivial:
|
|
raise NotImplementedError
|
|
for var in varibs:
|
|
var.data = (1 - self._wd) * var.data
|
|
|
|
|
|
def args_type(default):
|
|
def parse_string(x):
|
|
if default is None:
|
|
return x
|
|
if isinstance(default, bool):
|
|
return bool(['False', 'True'].index(x))
|
|
if isinstance(default, int):
|
|
return float(x) if ('e' in x or '.' in x) else int(x)
|
|
if isinstance(default, (list, tuple)):
|
|
return tuple(args_type(default[0])(y) for y in x.split(','))
|
|
return type(default)(x)
|
|
def parse_object(x):
|
|
if isinstance(default, (list, tuple)):
|
|
return tuple(x)
|
|
return x
|
|
return lambda x: parse_string(x) if isinstance(x, str) else parse_object(x)
|
|
|
|
|
|
def static_scan(fn, inputs, start):
|
|
last = start
|
|
indices = range(inputs[0].shape[0])
|
|
flag = True
|
|
for index in indices:
|
|
inp = lambda x: (_input[x] for _input in inputs)
|
|
last = fn(last, *inp(index))
|
|
if flag:
|
|
if type(last) == type({}):
|
|
outputs = {key: value.clone().unsqueeze(0) for key, value in last.items()}
|
|
else:
|
|
outputs = []
|
|
for _last in last:
|
|
if type(_last) == type({}):
|
|
outputs.append({key: value.clone().unsqueeze(0) for key, value in _last.items()})
|
|
else:
|
|
outputs.append(_last.clone().unsqueeze(0))
|
|
flag = False
|
|
else:
|
|
if type(last) == type({}):
|
|
for key in last.keys():
|
|
outputs[key] = torch.cat([outputs[key], last[key].unsqueeze(0)], dim=0)
|
|
else:
|
|
for j in range(len(outputs)):
|
|
if type(last[j]) == type({}):
|
|
for key in last[j].keys():
|
|
outputs[j][key] = torch.cat([outputs[j][key],
|
|
last[j][key].unsqueeze(0)], dim=0)
|
|
else:
|
|
outputs[j] = torch.cat([outputs[j], last[j].unsqueeze(0)], dim=0)
|
|
if type(last) == type({}):
|
|
outputs = [outputs]
|
|
return outputs
|
|
|
|
|
|
# Original version
|
|
#def static_scan2(fn, inputs, start, reverse=False):
|
|
# last = start
|
|
# outputs = [[] for _ in range(len([start] if type(start)==type({}) else start))]
|
|
# indices = range(inputs[0].shape[0])
|
|
# if reverse:
|
|
# indices = reversed(indices)
|
|
# for index in indices:
|
|
# inp = lambda x: (_input[x] for _input in inputs)
|
|
# last = fn(last, *inp(index))
|
|
# [o.append(l) for o, l in zip(outputs, [last] if type(last)==type({}) else last)]
|
|
# if reverse:
|
|
# outputs = [list(reversed(x)) for x in outputs]
|
|
# res = [[]] * len(outputs)
|
|
# for i in range(len(outputs)):
|
|
# if type(outputs[i][0]) == type({}):
|
|
# _res = {}
|
|
# for key in outputs[i][0].keys():
|
|
# _res[key] = []
|
|
# for j in range(len(outputs[i])):
|
|
# _res[key].append(outputs[i][j][key])
|
|
# #_res[key] = torch.stack(_res[key], 0)
|
|
# _res[key] = faster_stack(_res[key], 0)
|
|
# else:
|
|
# _res = outputs[i]
|
|
# #_res = torch.stack(_res, 0)
|
|
# _res = faster_stack(_res, 0)
|
|
# res[i] = _res
|
|
# return res
|
|
|
|
|
|
class Every:
|
|
|
|
def __init__(self, every):
|
|
self._every = every
|
|
self._last = None
|
|
|
|
def __call__(self, step):
|
|
if not self._every:
|
|
return False
|
|
if self._last is None:
|
|
self._last = step
|
|
return True
|
|
if step >= self._last + self._every:
|
|
self._last += self._every
|
|
return True
|
|
return False
|
|
|
|
|
|
class Once:
|
|
|
|
def __init__(self):
|
|
self._once = True
|
|
|
|
def __call__(self):
|
|
if self._once:
|
|
self._once = False
|
|
return True
|
|
return False
|
|
|
|
|
|
class Until:
|
|
|
|
def __init__(self, until):
|
|
self._until = until
|
|
|
|
def __call__(self, step):
|
|
if not self._until:
|
|
return True
|
|
return step < self._until
|
|
|
|
|
|
def schedule(string, step):
|
|
try:
|
|
return float(string)
|
|
except ValueError:
|
|
match = re.match(r'linear\((.+),(.+),(.+)\)', string)
|
|
if match:
|
|
initial, final, duration = [float(group) for group in match.groups()]
|
|
mix = torch.clip(torch.Tensor([step / duration]), 0, 1)[0]
|
|
return (1 - mix) * initial + mix * final
|
|
match = re.match(r'warmup\((.+),(.+)\)', string)
|
|
if match:
|
|
warmup, value = [float(group) for group in match.groups()]
|
|
scale = torch.clip(step / warmup, 0, 1)
|
|
return scale * value
|
|
match = re.match(r'exp\((.+),(.+),(.+)\)', string)
|
|
if match:
|
|
initial, final, halflife = [float(group) for group in match.groups()]
|
|
return (initial - final) * 0.5 ** (step / halflife) + final
|
|
match = re.match(r'horizon\((.+),(.+),(.+)\)', string)
|
|
if match:
|
|
initial, final, duration = [float(group) for group in match.groups()]
|
|
mix = torch.clip(step / duration, 0, 1)
|
|
horizon = (1 - mix) * initial + mix * final
|
|
return 1 - 1 / horizon
|
|
raise NotImplementedError(string)
|
|
|
|
def weight_init(m):
|
|
if isinstance(m, nn.Linear):
|
|
nn.init.orthogonal_(m.weight.data)
|
|
if hasattr(m.bias, 'data'):
|
|
m.bias.data.fill_(0.0)
|
|
elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
|
|
gain = nn.init.calculate_gain('relu')
|
|
nn.init.orthogonal_(m.weight.data, gain)
|
|
if hasattr(m.bias, 'data'):
|
|
m.bias.data.fill_(0.0)
|
|
elif isinstance(m, nn.LayerNorm):
|
|
if hasattr(m.bias, 'data'):
|
|
m.bias.data.fill_(0.0) |