dreamerv3-torch/tools.py
2023-02-12 22:35:25 +09:00

700 lines
21 KiB
Python

import datetime
import io
import json
import pathlib
import pickle
import re
import time
import uuid
import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
from torch import distributions as torchd
from torch.utils.data import Dataset
from torch.utils.tensorboard import SummaryWriter
class RequiresGrad:
def __init__(self, model):
self._model = model
def __enter__(self):
self._model.requires_grad_(requires_grad=True)
def __exit__(self, *args):
self._model.requires_grad_(requires_grad=False)
class TimeRecording:
def __init__(self, comment):
self._comment = comment
def __enter__(self):
self._st = torch.cuda.Event(enable_timing=True)
self._nd = torch.cuda.Event(enable_timing=True)
self._st.record()
def __exit__(self, *args):
self._nd.record()
torch.cuda.synchronize()
print(self._comment, self._st.elapsed_time(self._nd)/1000)
class Logger:
def __init__(self, logdir, step):
self._logdir = logdir
self._writer = SummaryWriter(log_dir=str(logdir), max_queue=1000)
self._last_step = None
self._last_time = None
self._scalars = {}
self._images = {}
self._videos = {}
self.step = step
def scalar(self, name, value):
self._scalars[name] = float(value)
def image(self, name, value):
self._images[name] = np.array(value)
def video(self, name, value):
self._videos[name] = np.array(value)
def write(self, fps=False):
scalars = list(self._scalars.items())
if fps:
scalars.append(('fps', self._compute_fps(self.step)))
print(f'[{self.step}]', ' / '.join(f'{k} {v:.1f}' for k, v in scalars))
with (self._logdir / 'metrics.jsonl').open('a') as f:
f.write(json.dumps({'step': self.step, ** dict(scalars)}) + '\n')
for name, value in scalars:
self._writer.add_scalar('scalars/' + name, value, self.step)
for name, value in self._images.items():
self._writer.add_image(name, value, self.step)
for name, value in self._videos.items():
name = name if isinstance(name, str) else name.decode('utf-8')
if np.issubdtype(value.dtype, np.floating):
value = np.clip(255 * value, 0, 255).astype(np.uint8)
B, T, H, W, C = value.shape
value = value.transpose(1, 4, 2, 0, 3).reshape((1, T, C, H, B*W))
self._writer.add_video(name, value, self.step, 16)
self._writer.flush()
self._scalars = {}
self._images = {}
self._videos = {}
def _compute_fps(self, step):
if self._last_step is None:
self._last_time = time.time()
self._last_step = step
return 0
steps = step - self._last_step
duration = time.time() - self._last_time
self._last_time += duration
self._last_step = step
return steps / duration
def offline_scalar(self, name, value, step):
self._writer.add_scalar('scalars/'+name, value, step)
def offline_video(self, name, value, step):
if np.issubdtype(value.dtype, np.floating):
value = np.clip(255 * value, 0, 255).astype(np.uint8)
B, T, H, W, C = value.shape
value = value.transpose(1, 4, 2, 0, 3).reshape((1, T, C, H, B*W))
self._writer.add_video(name, value, step, 16)
def simulate(agent, envs, steps=0, episodes=0, state=None):
# Initialize or unpack simulation state.
if state is None:
step, episode = 0, 0
done = np.ones(len(envs), np.bool)
length = np.zeros(len(envs), np.int32)
obs = [None] * len(envs)
agent_state = None
reward = [0]*len(envs)
else:
step, episode, done, length, obs, agent_state, reward = state
while (steps and step < steps) or (episodes and episode < episodes):
# Reset envs if necessary.
if done.any():
indices = [index for index, d in enumerate(done) if d]
results = [envs[i].reset() for i in indices]
for index, result in zip(indices, results):
obs[index] = result
reward = [reward[i]*(1-done[i]) for i in range(len(envs))]
# Step agents.
obs = {k: np.stack([o[k] for o in obs]) for k in obs[0]}
action, agent_state = agent(obs, done, agent_state, reward)
if isinstance(action, dict):
action = [
{k: np.array(action[k][i].detach().cpu()) for k in action}
for i in range(len(envs))]
else:
action = np.array(action)
assert len(action) == len(envs)
# Step envs.
results = [e.step(a) for e, a in zip(envs, action)]
obs, reward, done = zip(*[p[:3] for p in results])
obs = list(obs)
reward = list(reward)
done = np.stack(done)
episode += int(done.sum())
length += 1
step += (done * length).sum()
length *= (1 - done)
return (step - steps, episode - episodes, done, length, obs, agent_state, reward)
def save_episodes(directory, episodes):
directory = pathlib.Path(directory).expanduser()
directory.mkdir(parents=True, exist_ok=True)
timestamp = datetime.datetime.now().strftime('%Y%m%dT%H%M%S')
filenames = []
for episode in episodes:
identifier = str(uuid.uuid4().hex)
length = len(episode['reward'])
filename = directory / f'{timestamp}-{identifier}-{length}.npz'
with io.BytesIO() as f1:
np.savez_compressed(f1, **episode)
f1.seek(0)
with filename.open('wb') as f2:
f2.write(f1.read())
filenames.append(filename)
return filenames
def from_generator(generator, batch_size):
while True:
batch = []
for _ in range(batch_size):
batch.append(next(generator))
data = {}
for key in batch[0].keys():
data[key] = []
for i in range(batch_size):
data[key].append(batch[i][key])
data[key] = np.stack(data[key], 0)
yield data
def sample_episodes(episodes, length=None, balance=False, seed=0):
random = np.random.RandomState(seed)
while True:
episode = random.choice(list(episodes.values()))
if length:
total = len(next(iter(episode.values())))
available = total - length
if available < 1:
print(f'Skipped short episode of length {available}.')
continue
if balance:
index = min(random.randint(0, total), available)
else:
index = int(random.randint(0, available + 1))
episode = {k: v[index: index + length] for k, v in episode.items()}
yield episode
def load_episodes(directory, limit=None, reverse=True):
directory = pathlib.Path(directory).expanduser()
episodes = {}
total = 0
if reverse:
for filename in reversed(sorted(directory.glob('*.npz'))):
try:
with filename.open('rb') as f:
episode = np.load(f)
episode = {k: episode[k] for k in episode.keys()}
except Exception as e:
print(f'Could not load episode: {e}')
continue
episodes[str(filename)] = episode
total += len(episode['reward']) - 1
if limit and total >= limit:
break
else:
for filename in sorted(directory.glob('*.npz')):
try:
with filename.open('rb') as f:
episode = np.load(f)
episode = {k: episode[k] for k in episode.keys()}
except Exception as e:
print(f'Could not load episode: {e}')
continue
episodes[str(filename)] = episode
total += len(episode['reward']) - 1
if limit and total >= limit:
break
return episodes
class SampleDist:
def __init__(self, dist, samples=100):
self._dist = dist
self._samples = samples
@property
def name(self):
return 'SampleDist'
def __getattr__(self, name):
return getattr(self._dist, name)
def mean(self):
samples = self._dist.sample(self._samples)
return torch.mean(samples, 0)
def mode(self):
sample = self._dist.sample(self._samples)
logprob = self._dist.log_prob(sample)
return sample[torch.argmax(logprob)][0]
def entropy(self):
sample = self._dist.sample(self._samples)
logprob = self.log_prob(sample)
return -torch.mean(logprob, 0)
class OneHotDist(torchd.one_hot_categorical.OneHotCategorical):
def __init__(self, logits=None, probs=None, unimix_ratio=0.0):
if logits is not None and probs is None and unimix_ratio > 0.0:
probs = F.softmax(logits, dim=-1)
probs = probs * (1.0-unimix_ratio) + unimix_ratio / probs.shape[-1]
logits = None
super().__init__(logits=logits, probs=probs)
def mode(self):
_mode = F.one_hot(torch.argmax(super().logits, axis=-1), super().logits.shape[-1])
return _mode.detach() + super().logits - super().logits.detach()
def sample(self, sample_shape=(), seed=None):
if seed is not None:
raise ValueError('need to check')
sample = super().sample(sample_shape)
probs = super().probs
while len(probs.shape) < len(sample.shape):
probs = probs[None]
sample += probs - probs.detach()
return sample
class TwoHotDist(torchd.one_hot_categorical.OneHotCategorical):
def __init__(self, logits=None, probs=None, unimix_ratio=0.0, device='cuda'):
if logits is not None and probs is None and unimix_ratio > 0.0:
probs = F.softmax(logits, dim=-1)
probs = probs * (1.0-unimix_ratio) + unimix_ratio / probs.shape[-1]
logits = None
super().__init__(logits=logits, probs=probs)
self.buckets = torch.linspace(-20.0, 20.0, steps=255).to(device)
self.width = (self.buckets[-1] - self.buckets[0]) / 255
def mode(self):
_mode = super().probs * self.buckets
return torch.sum(_mode, dim=-1, keepdim=True)
# Inside OneHotCategorical, log_prob is calculated using only max element in targets
def log_prob(self, x):
# x(time, batch, 1)
x = (x - self.buckets[0]) / self.width
lower_indices = (x).to(torch.int64)
# lower_indices is idnside 0 ~ len(buckets)-2
lower_indices = torch.clip(lower_indices, max=len(self.buckets)-2)
# upper_indices is inside 1 ~ len(buckets)-1
upper_indices = lower_indices + 1
lower_weight = torch.abs(x - upper_indices).squeeze(-1)
upper_weight = torch.abs(x - lower_indices).squeeze(-1)
# (time, batch, 1) -> (time, batch, bucket_class)
lower_log_prob = super().log_prob(F.one_hot(lower_indices.squeeze(-1), num_classes=len(self.buckets)))
upper_log_prob = super().log_prob(F.one_hot(upper_indices.squeeze(-1), num_classes=len(self.buckets)))
# label = lower_log_prob * lower_weight + upper_log_prob * upper_weight
# # (time, batch, bucket_class) -> (time, batch)
# cross_entropy = torch.sum(torch.log(super().probs) * label, axis=-1)
return lower_weight * lower_log_prob + upper_weight * upper_log_prob
class ContDist:
def __init__(self, dist=None):
super().__init__()
self._dist = dist
self.mean = dist.mean
def __getattr__(self, name):
return getattr(self._dist, name)
def entropy(self):
return self._dist.entropy()
def mode(self):
return self._dist.mean
def sample(self, sample_shape=()):
return self._dist.rsample(sample_shape)
def log_prob(self, x):
return self._dist.log_prob(x)
class Bernoulli:
def __init__(self, dist=None):
super().__init__()
self._dist = dist
self.mean = dist.mean
def __getattr__(self, name):
return getattr(self._dist, name)
def entropy(self):
return self._dist.entropy()
def mode(self):
_mode = torch.round(self._dist.mean)
return _mode.detach() +self._dist.mean - self._dist.mean.detach()
def sample(self, sample_shape=()):
return self._dist.rsample(sample_shape)
def log_prob(self, x):
_logits = self._dist.base_dist.logits
log_probs0 = -F.softplus(_logits)
log_probs1 = -F.softplus(-_logits)
return log_probs0 * (1-x) + log_probs1 * x
class UnnormalizedHuber(torchd.normal.Normal):
def __init__(self, loc, scale, threshold=1, **kwargs):
super().__init__(loc, scale, **kwargs)
self._threshold = threshold
def log_prob(self, event):
return -(torch.sqrt(
(event - self.mean) ** 2 + self._threshold ** 2) - self._threshold)
def mode(self):
return self.mean
class SafeTruncatedNormal(torchd.normal.Normal):
def __init__(self, loc, scale, low, high, clip=1e-6, mult=1):
super().__init__(loc, scale)
self._low = low
self._high = high
self._clip = clip
self._mult = mult
def sample(self, sample_shape):
event = super().sample(sample_shape)
if self._clip:
clipped = torch.clip(event, self._low + self._clip,
self._high - self._clip)
event = event - event.detach() + clipped.detach()
if self._mult:
event *= self._mult
return event
class TanhBijector(torchd.Transform):
def __init__(self, validate_args=False, name='tanh'):
super().__init__()
def _forward(self, x):
return torch.tanh(x)
def _inverse(self, y):
y = torch.where(
(torch.abs(y) <= 1.),
torch.clamp(y, -0.99999997, 0.99999997), y)
y = torch.atanh(y)
return y
def _forward_log_det_jacobian(self, x):
log2 = torch.math.log(2.0)
return 2.0 * (log2 - x - torch.softplus(-2.0 * x))
def static_scan_for_lambda_return(fn, inputs, start):
last = start
indices = range(inputs[0].shape[0])
indices = reversed(indices)
flag = True
for index in indices:
inp = lambda x: (_input[x] for _input in inputs)
last = fn(last, *inp(index))
if flag:
outputs = last
flag = False
else:
outputs = torch.cat([outputs, last], dim=-1)
outputs = torch.reshape(outputs, [outputs.shape[0], outputs.shape[1], 1])
outputs = torch.unbind(outputs, dim=0)
return outputs
def lambda_return(
reward, value, pcont, bootstrap, lambda_, axis):
# Setting lambda=1 gives a discounted Monte Carlo return.
# Setting lambda=0 gives a fixed 1-step return.
#assert reward.shape.ndims == value.shape.ndims, (reward.shape, value.shape)
assert len(reward.shape) == len(value.shape), (reward.shape, value.shape)
if isinstance(pcont, (int, float)):
pcont = pcont * torch.ones_like(reward)
dims = list(range(len(reward.shape)))
dims = [axis] + dims[1:axis] + [0] + dims[axis + 1:]
if axis != 0:
reward = reward.permute(dims)
value = value.permute(dims)
pcont = pcont.permute(dims)
if bootstrap is None:
bootstrap = torch.zeros_like(value[-1])
next_values = torch.cat([value[1:], bootstrap[None]], 0)
inputs = reward + pcont * next_values * (1 - lambda_)
#returns = static_scan(
# lambda agg, cur0, cur1: cur0 + cur1 * lambda_ * agg,
# (inputs, pcont), bootstrap, reverse=True)
# reimplement to optimize performance
returns = static_scan_for_lambda_return(
lambda agg, cur0, cur1: cur0 + cur1 * lambda_ * agg,
(inputs, pcont), bootstrap)
if axis != 0:
returns = returns.permute(dims)
return returns
class Optimizer():
def __init__(
self, name, parameters, lr, eps=1e-4, clip=None, wd=None, wd_pattern=r'.*',
opt='adam', use_amp=False):
assert 0 <= wd < 1
assert not clip or 1 <= clip
self._name = name
self._parameters = parameters
self._clip = clip
self._wd = wd
self._wd_pattern = wd_pattern
self._opt = {
'adam': lambda: torch.optim.Adam(parameters,
lr=lr,
eps=eps),
'nadam': lambda: NotImplemented(
f'{config.opt} is not implemented'),
'adamax': lambda: torch.optim.Adamax(parameters,
lr=lr,
eps=eps),
'sgd': lambda: torch.optim.SGD(parameters,
lr=lr),
'momentum': lambda: torch.optim.SGD(parameters,
lr=lr,
momentum=0.9),
}[opt]()
self._scaler = torch.cuda.amp.GradScaler(enabled=use_amp)
def __call__(self, loss, params, retain_graph=False):
assert len(loss.shape) == 0, loss.shape
metrics = {}
metrics[f'{self._name}_loss'] = loss.detach().cpu().numpy()
self._scaler.scale(loss).backward()
self._scaler.unscale_(self._opt)
#loss.backward(retain_graph=retain_graph)
norm = torch.nn.utils.clip_grad_norm_(params, self._clip)
if self._wd:
self._apply_weight_decay(params)
self._scaler.step(self._opt)
self._scaler.update()
#self._opt.step()
self._opt.zero_grad()
metrics[f'{self._name}_grad_norm'] = norm.item()
return metrics
def _apply_weight_decay(self, varibs):
nontrivial = (self._wd_pattern != r'.*')
if nontrivial:
raise NotImplementedError
for var in varibs:
var.data = (1 - self._wd) * var.data
def args_type(default):
def parse_string(x):
if default is None:
return x
if isinstance(default, bool):
return bool(['False', 'True'].index(x))
if isinstance(default, int):
return float(x) if ('e' in x or '.' in x) else int(x)
if isinstance(default, (list, tuple)):
return tuple(args_type(default[0])(y) for y in x.split(','))
return type(default)(x)
def parse_object(x):
if isinstance(default, (list, tuple)):
return tuple(x)
return x
return lambda x: parse_string(x) if isinstance(x, str) else parse_object(x)
def static_scan(fn, inputs, start):
last = start
indices = range(inputs[0].shape[0])
flag = True
for index in indices:
inp = lambda x: (_input[x] for _input in inputs)
last = fn(last, *inp(index))
if flag:
if type(last) == type({}):
outputs = {key: value.clone().unsqueeze(0) for key, value in last.items()}
else:
outputs = []
for _last in last:
if type(_last) == type({}):
outputs.append({key: value.clone().unsqueeze(0) for key, value in _last.items()})
else:
outputs.append(_last.clone().unsqueeze(0))
flag = False
else:
if type(last) == type({}):
for key in last.keys():
outputs[key] = torch.cat([outputs[key], last[key].unsqueeze(0)], dim=0)
else:
for j in range(len(outputs)):
if type(last[j]) == type({}):
for key in last[j].keys():
outputs[j][key] = torch.cat([outputs[j][key],
last[j][key].unsqueeze(0)], dim=0)
else:
outputs[j] = torch.cat([outputs[j], last[j].unsqueeze(0)], dim=0)
if type(last) == type({}):
outputs = [outputs]
return outputs
# Original version
#def static_scan2(fn, inputs, start, reverse=False):
# last = start
# outputs = [[] for _ in range(len([start] if type(start)==type({}) else start))]
# indices = range(inputs[0].shape[0])
# if reverse:
# indices = reversed(indices)
# for index in indices:
# inp = lambda x: (_input[x] for _input in inputs)
# last = fn(last, *inp(index))
# [o.append(l) for o, l in zip(outputs, [last] if type(last)==type({}) else last)]
# if reverse:
# outputs = [list(reversed(x)) for x in outputs]
# res = [[]] * len(outputs)
# for i in range(len(outputs)):
# if type(outputs[i][0]) == type({}):
# _res = {}
# for key in outputs[i][0].keys():
# _res[key] = []
# for j in range(len(outputs[i])):
# _res[key].append(outputs[i][j][key])
# #_res[key] = torch.stack(_res[key], 0)
# _res[key] = faster_stack(_res[key], 0)
# else:
# _res = outputs[i]
# #_res = torch.stack(_res, 0)
# _res = faster_stack(_res, 0)
# res[i] = _res
# return res
class Every:
def __init__(self, every):
self._every = every
self._last = None
def __call__(self, step):
if not self._every:
return False
if self._last is None:
self._last = step
return True
if step >= self._last + self._every:
self._last += self._every
return True
return False
class Once:
def __init__(self):
self._once = True
def __call__(self):
if self._once:
self._once = False
return True
return False
class Until:
def __init__(self, until):
self._until = until
def __call__(self, step):
if not self._until:
return True
return step < self._until
def schedule(string, step):
try:
return float(string)
except ValueError:
match = re.match(r'linear\((.+),(.+),(.+)\)', string)
if match:
initial, final, duration = [float(group) for group in match.groups()]
mix = torch.clip(torch.Tensor([step / duration]), 0, 1)[0]
return (1 - mix) * initial + mix * final
match = re.match(r'warmup\((.+),(.+)\)', string)
if match:
warmup, value = [float(group) for group in match.groups()]
scale = torch.clip(step / warmup, 0, 1)
return scale * value
match = re.match(r'exp\((.+),(.+),(.+)\)', string)
if match:
initial, final, halflife = [float(group) for group in match.groups()]
return (initial - final) * 0.5 ** (step / halflife) + final
match = re.match(r'horizon\((.+),(.+),(.+)\)', string)
if match:
initial, final, duration = [float(group) for group in match.groups()]
mix = torch.clip(step / duration, 0, 1)
horizon = (1 - mix) * initial + mix * final
return 1 - 1 / horizon
raise NotImplementedError(string)
def weight_init(m):
if isinstance(m, nn.Linear):
nn.init.orthogonal_(m.weight.data)
if hasattr(m.bias, 'data'):
m.bias.data.fill_(0.0)
elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
gain = nn.init.calculate_gain('relu')
nn.init.orthogonal_(m.weight.data, gain)
if hasattr(m.bias, 'data'):
m.bias.data.fill_(0.0)
elif isinstance(m, nn.LayerNorm):
if hasattr(m.bias, 'data'):
m.bias.data.fill_(0.0)