1024 lines
34 KiB
Python
Raw Normal View History

2023-02-12 22:35:25 +09:00
import datetime
2023-04-15 15:25:25 +09:00
import collections
2023-02-12 22:35:25 +09:00
import io
import os
2023-02-12 22:35:25 +09:00
import json
import pathlib
import re
import time
import uuid
import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
from torch import distributions as torchd
from torch.utils.tensorboard import SummaryWriter
to_np = lambda x: x.detach().cpu().numpy()
2023-04-15 15:28:09 +09:00
def symlog(x):
2023-04-15 15:28:09 +09:00
return torch.sign(x) * torch.log(torch.abs(x) + 1.0)
def symexp(x):
2023-04-15 15:28:09 +09:00
return torch.sign(x) * (torch.exp(torch.abs(x)) - 1.0)
2023-02-12 22:35:25 +09:00
2023-04-15 15:28:09 +09:00
class RequiresGrad:
def __init__(self, model):
self._model = model
2023-02-12 22:35:25 +09:00
2023-04-15 15:28:09 +09:00
def __enter__(self):
self._model.requires_grad_(requires_grad=True)
2023-02-12 22:35:25 +09:00
2023-04-15 15:28:09 +09:00
def __exit__(self, *args):
self._model.requires_grad_(requires_grad=False)
2023-02-12 22:35:25 +09:00
class TimeRecording:
2023-04-15 15:28:09 +09:00
def __init__(self, comment):
self._comment = comment
2023-02-12 22:35:25 +09:00
2023-04-15 15:28:09 +09:00
def __enter__(self):
self._st = torch.cuda.Event(enable_timing=True)
self._nd = torch.cuda.Event(enable_timing=True)
self._st.record()
2023-02-12 22:35:25 +09:00
2023-04-15 15:28:09 +09:00
def __exit__(self, *args):
self._nd.record()
torch.cuda.synchronize()
print(self._comment, self._st.elapsed_time(self._nd) / 1000)
2023-02-12 22:35:25 +09:00
class Logger:
2023-04-15 15:28:09 +09:00
def __init__(self, logdir, step):
self._logdir = logdir
self._writer = SummaryWriter(log_dir=str(logdir), max_queue=1000)
self._last_step = None
self._last_time = None
self._scalars = {}
self._images = {}
self._videos = {}
self.step = step
def scalar(self, name, value):
self._scalars[name] = float(value)
def image(self, name, value):
self._images[name] = np.array(value)
def video(self, name, value):
self._videos[name] = np.array(value)
def write(self, fps=False, step=False):
if not step:
step = self.step
scalars = list(self._scalars.items())
if fps:
scalars.append(("fps", self._compute_fps(step)))
print(f"[{step}]", " / ".join(f"{k} {v:.1f}" for k, v in scalars))
with (self._logdir / "metrics.jsonl").open("a") as f:
f.write(json.dumps({"step": step, **dict(scalars)}) + "\n")
for name, value in scalars:
self._writer.add_scalar("scalars/" + name, value, step)
for name, value in self._images.items():
self._writer.add_image(name, value, step)
for name, value in self._videos.items():
name = name if isinstance(name, str) else name.decode("utf-8")
if np.issubdtype(value.dtype, np.floating):
value = np.clip(255 * value, 0, 255).astype(np.uint8)
B, T, H, W, C = value.shape
value = value.transpose(1, 4, 2, 0, 3).reshape((1, T, C, H, B * W))
self._writer.add_video(name, value, step, 16)
self._writer.flush()
self._scalars = {}
self._images = {}
self._videos = {}
def _compute_fps(self, step):
if self._last_step is None:
self._last_time = time.time()
self._last_step = step
return 0
steps = step - self._last_step
duration = time.time() - self._last_time
self._last_time += duration
self._last_step = step
return steps / duration
def offline_scalar(self, name, value, step):
self._writer.add_scalar("scalars/" + name, value, step)
def offline_video(self, name, value, step):
if np.issubdtype(value.dtype, np.floating):
value = np.clip(255 * value, 0, 255).astype(np.uint8)
B, T, H, W, C = value.shape
value = value.transpose(1, 4, 2, 0, 3).reshape((1, T, C, H, B * W))
self._writer.add_video(name, value, step, 16)
2023-02-12 22:35:25 +09:00
2023-07-23 22:02:06 +09:00
def simulate(
agent,
envs,
cache,
directory,
logger,
is_eval=False,
limit=None,
steps=0,
episodes=0,
state=None,
):
2023-07-22 21:20:55 +09:00
# initialize or unpack simulation state
2023-04-15 15:28:09 +09:00
if state is None:
step, episode = 0, 0
2023-06-18 17:18:24 +09:00
done = np.ones(len(envs), bool)
2023-04-15 15:28:09 +09:00
length = np.zeros(len(envs), np.int32)
obs = [None] * len(envs)
agent_state = None
reward = [0] * len(envs)
2023-02-12 22:35:25 +09:00
else:
2023-04-15 15:28:09 +09:00
step, episode, done, length, obs, agent_state, reward = state
while (steps and step < steps) or (episodes and episode < episodes):
2023-07-22 21:20:55 +09:00
# reset envs if necessary
2023-04-15 15:28:09 +09:00
if done.any():
indices = [index for index, d in enumerate(done) if d]
results = [envs[i].reset() for i in indices]
results = [r() for r in results]
for i in indices:
t = results[i].copy()
t = {k: convert(v) for k, v in t.items()}
# action will be added to transition in add_to_cache
t["reward"] = 0.0
t["discount"] = 1.0
# initial state should be added to cache
add_to_cache(cache, envs[i].id, t)
2023-04-15 15:28:09 +09:00
for index, result in zip(indices, results):
obs[index] = result
2023-07-22 21:20:55 +09:00
# step agents
2023-04-15 15:28:09 +09:00
obs = {k: np.stack([o[k] for o in obs]) for k in obs[0]}
2023-07-22 20:53:43 +09:00
action, agent_state = agent(obs, done, agent_state)
2023-04-15 15:28:09 +09:00
if isinstance(action, dict):
action = [
{k: np.array(action[k][i].detach().cpu()) for k in action}
for i in range(len(envs))
]
else:
action = np.array(action)
assert len(action) == len(envs)
2023-07-22 21:20:55 +09:00
# step envs
2023-04-15 15:28:09 +09:00
results = [e.step(a) for e, a in zip(envs, action)]
results = [r() for r in results]
2023-04-15 15:28:09 +09:00
obs, reward, done = zip(*[p[:3] for p in results])
obs = list(obs)
reward = list(reward)
done = np.stack(done)
episode += int(done.sum())
length += 1
2023-07-02 11:51:11 +09:00
step += len(envs)
2023-04-15 15:28:09 +09:00
length *= 1 - done
2023-07-22 21:20:55 +09:00
# add to cache
for a, result, env in zip(action, results, envs):
o, r, d, info = result
o = {k: convert(v) for k, v in o.items()}
transition = o.copy()
if isinstance(a, dict):
transition.update(a)
else:
transition["action"] = a
transition["reward"] = r
transition["discount"] = info.get("discount", np.array(1 - float(d)))
add_to_cache(cache, env.id, transition)
2023-04-15 15:28:09 +09:00
if done.any():
indices = [index for index, d in enumerate(done) if d]
# logging for done episode
for i in indices:
save_episodes(directory, {envs[i].id: cache[envs[i].id]})
length = len(cache[envs[i].id]["reward"]) - 1
score = float(np.array(cache[envs[i].id]["reward"]).sum())
video = cache[envs[i].id]["image"]
if not is_eval:
step_in_dataset = erase_over_episodes(cache, limit)
logger.scalar(f"dataset_size", step_in_dataset)
logger.scalar(f"train_return", score)
logger.scalar(f"train_length", length)
logger.scalar(f"train_episodes", len(cache))
logger.write(step=logger.step)
else:
2023-07-23 22:02:06 +09:00
if not "eval_lengths" in locals():
eval_lengths = []
eval_scores = []
eval_done = False
# start counting scores for evaluation
eval_scores.append(score)
eval_lengths.append(length)
score = sum(eval_scores) / len(eval_scores)
length = sum(eval_lengths) / len(eval_lengths)
logger.video(f"eval_policy", np.array(video)[None])
if len(eval_scores) >= episodes and not eval_done:
logger.scalar(f"eval_return", score)
logger.scalar(f"eval_length", length)
logger.scalar(f"eval_episodes", len(eval_scores))
logger.write(step=logger.step)
eval_done = True
if is_eval:
# keep only last item for saving memory. this cache is used for video_pred later
while len(cache) > 1:
# FIFO
cache.popitem(last=False)
2023-04-15 15:28:09 +09:00
return (step - steps, episode - episodes, done, length, obs, agent_state, reward)
2023-02-12 22:35:25 +09:00
class CollectDataset:
def __init__(
self, env, mode, train_eps, eval_eps=dict(), callbacks=None, precision=32
):
self._env = env
self._callbacks = callbacks or ()
self._precision = precision
self._episode = None
self._cache = dict(train=train_eps, eval=eval_eps)[mode]
self._temp_name = str(uuid.uuid4())
def __getattr__(self, name):
return getattr(self._env, name)
def step(self, action):
obs, reward, done, info = self._env.step(action)
obs = {k: self._convert(v) for k, v in obs.items()}
transition = obs.copy()
if isinstance(action, dict):
transition.update(action)
else:
transition["action"] = action
transition["reward"] = reward
transition["discount"] = info.get("discount", np.array(1 - float(done)))
self._episode.append(transition)
self.add_to_cache(transition)
if done:
# detele transitions before whole episode is stored
del self._cache[self._temp_name]
self._temp_name = str(uuid.uuid4())
for key, value in self._episode[1].items():
if key not in self._episode[0]:
self._episode[0][key] = 0 * value
episode = {k: [t[k] for t in self._episode] for k in self._episode[0]}
episode = {k: self._convert(v) for k, v in episode.items()}
info["episode"] = episode
for callback in self._callbacks:
callback(episode)
return obs, reward, done, info
def reset(self):
obs = self._env.reset()
transition = obs.copy()
2023-07-22 21:20:55 +09:00
# missing keys will be filled with a zeroed out version of the first
# transition, because we do not know what action information the agent will
# pass yet.
transition["reward"] = 0.0
transition["discount"] = 1.0
self._episode = [transition]
self.add_to_cache(transition)
return obs
2023-07-23 22:02:06 +09:00
def add_to_cache(cache, id, transition):
if id not in cache:
cache[id] = dict()
for key, val in transition.items():
cache[id][key] = [convert(val)]
else:
for key, val in transition.items():
if key not in cache[id]:
# fill missing data(action, etc.) at second time
cache[id][key] = [convert(0 * val)]
cache[id][key].append(convert(val))
else:
cache[id][key].append(convert(val))
2023-07-23 22:02:06 +09:00
def erase_over_episodes(cache, dataset_size):
step_in_dataset = 0
for key, ep in reversed(sorted(cache.items(), key=lambda x: x[0])):
if (
not dataset_size
or step_in_dataset + (len(ep["reward"]) - 1) <= dataset_size
):
step_in_dataset += len(ep["reward"]) - 1
else:
del cache[key]
return step_in_dataset
2023-07-23 22:02:06 +09:00
def convert(value, precision=32):
value = np.array(value)
if np.issubdtype(value.dtype, np.floating):
dtype = {16: np.float16, 32: np.float32, 64: np.float64}[precision]
elif np.issubdtype(value.dtype, np.signedinteger):
dtype = {16: np.int16, 32: np.int32, 64: np.int64}[precision]
elif np.issubdtype(value.dtype, np.uint8):
dtype = np.uint8
elif np.issubdtype(value.dtype, bool):
dtype = bool
else:
raise NotImplementedError(value.dtype)
return value.astype(dtype)
2023-07-23 22:02:06 +09:00
2023-02-12 22:35:25 +09:00
def save_episodes(directory, episodes):
2023-04-15 15:28:09 +09:00
directory = pathlib.Path(directory).expanduser()
directory.mkdir(parents=True, exist_ok=True)
for filename, episode in episodes.items():
2023-04-15 15:28:09 +09:00
length = len(episode["reward"])
filename = directory / f"{filename}-{length}.npz"
2023-04-15 15:28:09 +09:00
with io.BytesIO() as f1:
np.savez_compressed(f1, **episode)
f1.seek(0)
with filename.open("wb") as f2:
f2.write(f1.read())
return True
2023-02-12 22:35:25 +09:00
def from_generator(generator, batch_size):
2023-04-15 15:28:09 +09:00
while True:
batch = []
for _ in range(batch_size):
batch.append(next(generator))
data = {}
for key in batch[0].keys():
data[key] = []
for i in range(batch_size):
data[key].append(batch[i][key])
data[key] = np.stack(data[key], 0)
yield data
2023-02-12 22:35:25 +09:00
def sample_episodes(episodes, length, seed=0):
2023-04-15 15:28:09 +09:00
random = np.random.RandomState(seed)
while True:
size = 0
ret = None
p = np.array(
[len(next(iter(episode.values()))) for episode in episodes.values()]
)
p = p / np.sum(p)
while size < length:
episode = random.choice(list(episodes.values()), p=p)
2023-04-15 15:28:09 +09:00
total = len(next(iter(episode.values())))
# make sure at least one transition included
if total < 2:
2023-04-15 15:28:09 +09:00
continue
if not ret:
index = int(random.randint(0, total - 1))
ret = {
k: v[index : min(index + length, total)] for k, v in episode.items()
}
if "is_first" in ret:
ret["is_first"][0] = True
2023-04-15 15:28:09 +09:00
else:
# 'is_first' comes after 'is_last'
index = 0
possible = length - size
ret = {
k: np.append(
ret[k], v[index : min(index + possible, total)], axis=0
)
for k, v in episode.items()
}
if "is_first" in ret:
ret["is_first"][size] = True
size = len(next(iter(ret.values())))
yield ret
2023-02-12 22:35:25 +09:00
def load_episodes(directory, limit=None, reverse=True):
2023-04-15 15:28:09 +09:00
directory = pathlib.Path(directory).expanduser()
episodes = collections.OrderedDict()
total = 0
if reverse:
for filename in reversed(sorted(directory.glob("*.npz"))):
try:
with filename.open("rb") as f:
episode = np.load(f)
episode = {k: episode[k] for k in episode.keys()}
except Exception as e:
print(f"Could not load episode: {e}")
continue
# extract only filename without extension
episodes[str(os.path.splitext(os.path.basename(filename))[0])] = episode
2023-04-15 15:28:09 +09:00
total += len(episode["reward"]) - 1
if limit and total >= limit:
break
else:
for filename in sorted(directory.glob("*.npz")):
try:
with filename.open("rb") as f:
episode = np.load(f)
episode = {k: episode[k] for k in episode.keys()}
except Exception as e:
print(f"Could not load episode: {e}")
continue
episodes[str(filename)] = episode
total += len(episode["reward"]) - 1
if limit and total >= limit:
break
return episodes
2023-02-12 22:35:25 +09:00
class SampleDist:
2023-04-15 15:28:09 +09:00
def __init__(self, dist, samples=100):
self._dist = dist
self._samples = samples
2023-02-12 22:35:25 +09:00
2023-04-15 15:28:09 +09:00
@property
def name(self):
return "SampleDist"
2023-02-12 22:35:25 +09:00
2023-04-15 15:28:09 +09:00
def __getattr__(self, name):
return getattr(self._dist, name)
2023-02-12 22:35:25 +09:00
2023-04-15 15:28:09 +09:00
def mean(self):
samples = self._dist.sample(self._samples)
return torch.mean(samples, 0)
2023-02-12 22:35:25 +09:00
2023-04-15 15:28:09 +09:00
def mode(self):
sample = self._dist.sample(self._samples)
logprob = self._dist.log_prob(sample)
return sample[torch.argmax(logprob)][0]
2023-02-12 22:35:25 +09:00
2023-04-15 15:28:09 +09:00
def entropy(self):
sample = self._dist.sample(self._samples)
logprob = self.log_prob(sample)
return -torch.mean(logprob, 0)
2023-02-12 22:35:25 +09:00
class OneHotDist(torchd.one_hot_categorical.OneHotCategorical):
2023-04-15 15:28:09 +09:00
def __init__(self, logits=None, probs=None, unimix_ratio=0.0):
if logits is not None and unimix_ratio > 0.0:
probs = F.softmax(logits, dim=-1)
probs = probs * (1.0 - unimix_ratio) + unimix_ratio / probs.shape[-1]
logits = torch.log(probs)
super().__init__(logits=logits, probs=None)
else:
super().__init__(logits=logits, probs=probs)
def mode(self):
_mode = F.one_hot(
torch.argmax(super().logits, axis=-1), super().logits.shape[-1]
)
return _mode.detach() + super().logits - super().logits.detach()
def sample(self, sample_shape=(), seed=None):
if seed is not None:
raise ValueError("need to check")
sample = super().sample(sample_shape)
probs = super().probs
while len(probs.shape) < len(sample.shape):
probs = probs[None]
sample += probs - probs.detach()
return sample
2023-05-14 23:38:46 +09:00
class DiscDist:
def __init__(
self,
logits,
low=-20.0,
high=20.0,
transfwd=symlog,
transbwd=symexp,
device="cuda",
):
2023-04-15 15:28:09 +09:00
self.logits = logits
self.probs = torch.softmax(logits, -1)
self.buckets = torch.linspace(low, high, steps=255).to(device)
self.width = (self.buckets[-1] - self.buckets[0]) / 255
2023-05-14 23:38:46 +09:00
self.transfwd = transfwd
self.transbwd = transbwd
2023-04-15 15:28:09 +09:00
def mean(self):
2023-04-29 07:57:05 +09:00
_mean = self.probs * self.buckets
2023-05-14 23:38:46 +09:00
return self.transbwd(torch.sum(_mean, dim=-1, keepdim=True))
2023-04-15 15:28:09 +09:00
def mode(self):
_mode = self.probs * self.buckets
2023-05-14 23:38:46 +09:00
return self.transbwd(torch.sum(_mode, dim=-1, keepdim=True))
2023-04-15 15:28:09 +09:00
# Inside OneHotCategorical, log_prob is calculated using only max element in targets
def log_prob(self, x):
2023-05-14 23:38:46 +09:00
x = self.transfwd(x)
2023-04-15 15:28:09 +09:00
# x(time, batch, 1)
below = torch.sum((self.buckets <= x[..., None]).to(torch.int32), dim=-1) - 1
above = len(self.buckets) - torch.sum(
(self.buckets > x[..., None]).to(torch.int32), dim=-1
)
below = torch.clip(below, 0, len(self.buckets) - 1)
above = torch.clip(above, 0, len(self.buckets) - 1)
equal = below == above
dist_to_below = torch.where(equal, 1, torch.abs(self.buckets[below] - x))
dist_to_above = torch.where(equal, 1, torch.abs(self.buckets[above] - x))
total = dist_to_below + dist_to_above
weight_below = dist_to_above / total
weight_above = dist_to_below / total
target = (
F.one_hot(below, num_classes=len(self.buckets)) * weight_below[..., None]
+ F.one_hot(above, num_classes=len(self.buckets)) * weight_above[..., None]
)
log_pred = self.logits - torch.logsumexp(self.logits, -1, keepdim=True)
target = target.squeeze(-2)
return (target * log_pred).sum(-1)
def log_prob_target(self, target):
log_pred = super().logits - torch.logsumexp(super().logits, -1, keepdim=True)
return (target * log_pred).sum(-1)
2023-05-14 23:38:46 +09:00
class MSEDist:
def __init__(self, mode, agg="sum"):
self._mode = mode
self._agg = agg
def mode(self):
return self._mode
def mean(self):
return self._mode
def log_prob(self, value):
assert self._mode.shape == value.shape, (self._mode.shape, value.shape)
distance = (self._mode - value) ** 2
if self._agg == "mean":
loss = distance.mean(list(range(len(distance.shape)))[2:])
elif self._agg == "sum":
loss = distance.sum(list(range(len(distance.shape)))[2:])
else:
raise NotImplementedError(self._agg)
return -loss
2023-04-15 15:28:09 +09:00
class SymlogDist:
2023-05-14 23:38:46 +09:00
def __init__(self, mode, dist="mse", agg="sum", tol=1e-8):
2023-04-15 15:28:09 +09:00
self._mode = mode
self._dist = dist
self._agg = agg
self._tol = tol
def mode(self):
return symexp(self._mode)
def mean(self):
return symexp(self._mode)
def log_prob(self, value):
assert self._mode.shape == value.shape
if self._dist == "mse":
distance = (self._mode - symlog(value)) ** 2.0
distance = torch.where(distance < self._tol, 0, distance)
elif self._dist == "abs":
distance = torch.abs(self._mode - symlog(value))
distance = torch.where(distance < self._tol, 0, distance)
else:
raise NotImplementedError(self._dist)
if self._agg == "mean":
2023-05-14 23:38:46 +09:00
loss = distance.mean(list(range(len(distance.shape)))[2:])
2023-04-15 15:28:09 +09:00
elif self._agg == "sum":
2023-05-14 23:38:46 +09:00
loss = distance.sum(list(range(len(distance.shape)))[2:])
2023-04-15 15:28:09 +09:00
else:
raise NotImplementedError(self._agg)
return -loss
2023-02-12 22:35:25 +09:00
class ContDist:
2023-04-15 15:28:09 +09:00
def __init__(self, dist=None):
super().__init__()
self._dist = dist
self.mean = dist.mean
2023-02-12 22:35:25 +09:00
2023-04-15 15:28:09 +09:00
def __getattr__(self, name):
return getattr(self._dist, name)
2023-02-12 22:35:25 +09:00
2023-04-15 15:28:09 +09:00
def entropy(self):
return self._dist.entropy()
2023-02-12 22:35:25 +09:00
2023-04-15 15:28:09 +09:00
def mode(self):
return self._dist.mean
2023-02-12 22:35:25 +09:00
2023-04-15 15:28:09 +09:00
def sample(self, sample_shape=()):
return self._dist.rsample(sample_shape)
2023-02-12 22:35:25 +09:00
2023-04-15 15:28:09 +09:00
def log_prob(self, x):
return self._dist.log_prob(x)
2023-02-12 22:35:25 +09:00
class Bernoulli:
2023-04-15 15:28:09 +09:00
def __init__(self, dist=None):
super().__init__()
self._dist = dist
self.mean = dist.mean
2023-02-12 22:35:25 +09:00
2023-04-15 15:28:09 +09:00
def __getattr__(self, name):
return getattr(self._dist, name)
2023-02-12 22:35:25 +09:00
2023-04-15 15:28:09 +09:00
def entropy(self):
return self._dist.entropy()
2023-02-12 22:35:25 +09:00
2023-04-15 15:28:09 +09:00
def mode(self):
_mode = torch.round(self._dist.mean)
return _mode.detach() + self._dist.mean - self._dist.mean.detach()
2023-02-12 22:35:25 +09:00
2023-04-15 15:28:09 +09:00
def sample(self, sample_shape=()):
return self._dist.rsample(sample_shape)
2023-02-12 22:35:25 +09:00
2023-04-15 15:28:09 +09:00
def log_prob(self, x):
_logits = self._dist.base_dist.logits
log_probs0 = -F.softplus(_logits)
log_probs1 = -F.softplus(-_logits)
2023-02-12 22:35:25 +09:00
2023-04-15 15:28:09 +09:00
return log_probs0 * (1 - x) + log_probs1 * x
2023-02-12 22:35:25 +09:00
class UnnormalizedHuber(torchd.normal.Normal):
2023-04-15 15:28:09 +09:00
def __init__(self, loc, scale, threshold=1, **kwargs):
super().__init__(loc, scale, **kwargs)
self._threshold = threshold
2023-02-12 22:35:25 +09:00
2023-04-15 15:28:09 +09:00
def log_prob(self, event):
return -(
torch.sqrt((event - self.mean) ** 2 + self._threshold**2)
- self._threshold
)
2023-02-12 22:35:25 +09:00
2023-04-15 15:28:09 +09:00
def mode(self):
return self.mean
2023-02-12 22:35:25 +09:00
class SafeTruncatedNormal(torchd.normal.Normal):
2023-04-15 15:28:09 +09:00
def __init__(self, loc, scale, low, high, clip=1e-6, mult=1):
super().__init__(loc, scale)
self._low = low
self._high = high
self._clip = clip
self._mult = mult
def sample(self, sample_shape):
event = super().sample(sample_shape)
if self._clip:
clipped = torch.clip(event, self._low + self._clip, self._high - self._clip)
event = event - event.detach() + clipped.detach()
if self._mult:
event *= self._mult
return event
2023-02-12 22:35:25 +09:00
class TanhBijector(torchd.Transform):
2023-04-15 15:28:09 +09:00
def __init__(self, validate_args=False, name="tanh"):
super().__init__()
2023-02-12 22:35:25 +09:00
2023-04-15 15:28:09 +09:00
def _forward(self, x):
return torch.tanh(x)
2023-02-12 22:35:25 +09:00
2023-04-15 15:28:09 +09:00
def _inverse(self, y):
y = torch.where(
(torch.abs(y) <= 1.0), torch.clamp(y, -0.99999997, 0.99999997), y
)
y = torch.atanh(y)
return y
2023-02-12 22:35:25 +09:00
2023-04-15 15:28:09 +09:00
def _forward_log_det_jacobian(self, x):
log2 = torch.math.log(2.0)
return 2.0 * (log2 - x - torch.softplus(-2.0 * x))
2023-02-12 22:35:25 +09:00
def static_scan_for_lambda_return(fn, inputs, start):
2023-04-15 15:28:09 +09:00
last = start
indices = range(inputs[0].shape[0])
indices = reversed(indices)
flag = True
for index in indices:
# (inputs, pcont) -> (inputs[index], pcont[index])
inp = lambda x: (_input[x] for _input in inputs)
last = fn(last, *inp(index))
if flag:
outputs = last
flag = False
else:
outputs = torch.cat([outputs, last], dim=-1)
outputs = torch.reshape(outputs, [outputs.shape[0], outputs.shape[1], 1])
outputs = torch.flip(outputs, [1])
outputs = torch.unbind(outputs, dim=0)
return outputs
def lambda_return(reward, value, pcont, bootstrap, lambda_, axis):
# Setting lambda=1 gives a discounted Monte Carlo return.
# Setting lambda=0 gives a fixed 1-step return.
# assert reward.shape.ndims == value.shape.ndims, (reward.shape, value.shape)
assert len(reward.shape) == len(value.shape), (reward.shape, value.shape)
if isinstance(pcont, (int, float)):
pcont = pcont * torch.ones_like(reward)
dims = list(range(len(reward.shape)))
dims = [axis] + dims[1:axis] + [0] + dims[axis + 1 :]
if axis != 0:
reward = reward.permute(dims)
value = value.permute(dims)
pcont = pcont.permute(dims)
if bootstrap is None:
bootstrap = torch.zeros_like(value[-1])
next_values = torch.cat([value[1:], bootstrap[None]], 0)
inputs = reward + pcont * next_values * (1 - lambda_)
# returns = static_scan(
# lambda agg, cur0, cur1: cur0 + cur1 * lambda_ * agg,
# (inputs, pcont), bootstrap, reverse=True)
# reimplement to optimize performance
returns = static_scan_for_lambda_return(
lambda agg, cur0, cur1: cur0 + cur1 * lambda_ * agg, (inputs, pcont), bootstrap
)
if axis != 0:
returns = returns.permute(dims)
return returns
class Optimizer:
def __init__(
self,
name,
parameters,
lr,
eps=1e-4,
clip=None,
wd=None,
wd_pattern=r".*",
opt="adam",
use_amp=False,
):
assert 0 <= wd < 1
assert not clip or 1 <= clip
self._name = name
self._parameters = parameters
self._clip = clip
self._wd = wd
self._wd_pattern = wd_pattern
self._opt = {
"adam": lambda: torch.optim.Adam(parameters, lr=lr, eps=eps),
"nadam": lambda: NotImplemented(f"{opt} is not implemented"),
"adamax": lambda: torch.optim.Adamax(parameters, lr=lr, eps=eps),
"sgd": lambda: torch.optim.SGD(parameters, lr=lr),
"momentum": lambda: torch.optim.SGD(parameters, lr=lr, momentum=0.9),
}[opt]()
self._scaler = torch.cuda.amp.GradScaler(enabled=use_amp)
def __call__(self, loss, params, retain_graph=False):
assert len(loss.shape) == 0, loss.shape
metrics = {}
metrics[f"{self._name}_loss"] = loss.detach().cpu().numpy()
self._scaler.scale(loss).backward()
self._scaler.unscale_(self._opt)
# loss.backward(retain_graph=retain_graph)
norm = torch.nn.utils.clip_grad_norm_(params, self._clip)
if self._wd:
self._apply_weight_decay(params)
self._scaler.step(self._opt)
self._scaler.update()
# self._opt.step()
self._opt.zero_grad()
metrics[f"{self._name}_grad_norm"] = norm.item()
return metrics
def _apply_weight_decay(self, varibs):
nontrivial = self._wd_pattern != r".*"
if nontrivial:
raise NotImplementedError
for var in varibs:
var.data = (1 - self._wd) * var.data
2023-02-12 22:35:25 +09:00
def args_type(default):
2023-04-15 15:28:09 +09:00
def parse_string(x):
if default is None:
return x
if isinstance(default, bool):
return bool(["False", "True"].index(x))
if isinstance(default, int):
return float(x) if ("e" in x or "." in x) else int(x)
if isinstance(default, (list, tuple)):
return tuple(args_type(default[0])(y) for y in x.split(","))
return type(default)(x)
def parse_object(x):
if isinstance(default, (list, tuple)):
return tuple(x)
return x
return lambda x: parse_string(x) if isinstance(x, str) else parse_object(x)
2023-02-12 22:35:25 +09:00
def static_scan(fn, inputs, start):
2023-04-15 15:28:09 +09:00
last = start
indices = range(inputs[0].shape[0])
flag = True
for index in indices:
inp = lambda x: (_input[x] for _input in inputs)
last = fn(last, *inp(index))
if flag:
if type(last) == type({}):
outputs = {
key: value.clone().unsqueeze(0) for key, value in last.items()
}
else:
outputs = []
for _last in last:
if type(_last) == type({}):
outputs.append(
{
key: value.clone().unsqueeze(0)
for key, value in _last.items()
}
)
else:
outputs.append(_last.clone().unsqueeze(0))
flag = False
else:
if type(last) == type({}):
for key in last.keys():
outputs[key] = torch.cat(
[outputs[key], last[key].unsqueeze(0)], dim=0
)
else:
for j in range(len(outputs)):
if type(last[j]) == type({}):
for key in last[j].keys():
outputs[j][key] = torch.cat(
[outputs[j][key], last[j][key].unsqueeze(0)], dim=0
)
else:
outputs[j] = torch.cat(
[outputs[j], last[j].unsqueeze(0)], dim=0
)
if type(last) == type({}):
outputs = [outputs]
return outputs
2023-02-12 22:35:25 +09:00
# Original version
2023-04-15 15:28:09 +09:00
# def static_scan2(fn, inputs, start, reverse=False):
2023-02-12 22:35:25 +09:00
# last = start
# outputs = [[] for _ in range(len([start] if type(start)==type({}) else start))]
# indices = range(inputs[0].shape[0])
# if reverse:
# indices = reversed(indices)
# for index in indices:
# inp = lambda x: (_input[x] for _input in inputs)
# last = fn(last, *inp(index))
# [o.append(l) for o, l in zip(outputs, [last] if type(last)==type({}) else last)]
# if reverse:
# outputs = [list(reversed(x)) for x in outputs]
# res = [[]] * len(outputs)
# for i in range(len(outputs)):
# if type(outputs[i][0]) == type({}):
# _res = {}
# for key in outputs[i][0].keys():
# _res[key] = []
# for j in range(len(outputs[i])):
# _res[key].append(outputs[i][j][key])
# #_res[key] = torch.stack(_res[key], 0)
# _res[key] = faster_stack(_res[key], 0)
# else:
# _res = outputs[i]
# #_res = torch.stack(_res, 0)
# _res = faster_stack(_res, 0)
# res[i] = _res
# return res
class Every:
2023-04-15 15:28:09 +09:00
def __init__(self, every):
self._every = every
self._last = None
2023-02-12 22:35:25 +09:00
2023-04-15 15:28:09 +09:00
def __call__(self, step):
if not self._every:
return 0
if self._last is None:
self._last = step
return 1
count = int((step - self._last) / self._every)
self._last += self._every * count
return count
2023-02-12 22:35:25 +09:00
class Once:
2023-04-15 15:28:09 +09:00
def __init__(self):
self._once = True
2023-02-12 22:35:25 +09:00
2023-04-15 15:28:09 +09:00
def __call__(self):
if self._once:
self._once = False
return True
return False
2023-02-12 22:35:25 +09:00
class Until:
2023-04-15 15:28:09 +09:00
def __init__(self, until):
self._until = until
2023-02-12 22:35:25 +09:00
2023-04-15 15:28:09 +09:00
def __call__(self, step):
if not self._until:
return True
return step < self._until
2023-02-12 22:35:25 +09:00
def schedule(string, step):
2023-04-15 15:28:09 +09:00
try:
return float(string)
except ValueError:
match = re.match(r"linear\((.+),(.+),(.+)\)", string)
if match:
initial, final, duration = [float(group) for group in match.groups()]
mix = torch.clip(torch.Tensor([step / duration]), 0, 1)[0]
return (1 - mix) * initial + mix * final
match = re.match(r"warmup\((.+),(.+)\)", string)
if match:
warmup, value = [float(group) for group in match.groups()]
scale = torch.clip(step / warmup, 0, 1)
return scale * value
match = re.match(r"exp\((.+),(.+),(.+)\)", string)
if match:
initial, final, halflife = [float(group) for group in match.groups()]
return (initial - final) * 0.5 ** (step / halflife) + final
match = re.match(r"horizon\((.+),(.+),(.+)\)", string)
if match:
initial, final, duration = [float(group) for group in match.groups()]
mix = torch.clip(step / duration, 0, 1)
horizon = (1 - mix) * initial + mix * final
return 1 - 1 / horizon
raise NotImplementedError(string)
2023-02-12 22:35:25 +09:00
def weight_init(m):
if isinstance(m, nn.Linear):
2023-04-15 15:28:09 +09:00
in_num = m.in_features
out_num = m.out_features
denoms = (in_num + out_num) / 2.0
scale = 1.0 / denoms
std = np.sqrt(scale) / 0.87962566103423978
2023-06-17 15:28:26 +09:00
nn.init.trunc_normal_(
m.weight.data, mean=0.0, std=std, a=-2.0 * std, b=2.0 * std
)
2023-04-15 15:28:09 +09:00
if hasattr(m.bias, "data"):
m.bias.data.fill_(0.0)
2023-02-12 22:35:25 +09:00
elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
2023-04-15 15:28:09 +09:00
space = m.kernel_size[0] * m.kernel_size[1]
in_num = space * m.in_channels
out_num = space * m.out_channels
denoms = (in_num + out_num) / 2.0
scale = 1.0 / denoms
std = np.sqrt(scale) / 0.87962566103423978
nn.init.trunc_normal_(m.weight.data, mean=0.0, std=std, a=-2.0, b=2.0)
if hasattr(m.bias, "data"):
m.bias.data.fill_(0.0)
elif isinstance(m, nn.LayerNorm):
2023-04-15 15:28:09 +09:00
m.weight.data.fill_(1.0)
if hasattr(m.bias, "data"):
m.bias.data.fill_(0.0)
def uniform_weight_init(given_scale):
2023-04-15 15:28:09 +09:00
def f(m):
if isinstance(m, nn.Linear):
in_num = m.in_features
out_num = m.out_features
denoms = (in_num + out_num) / 2.0
scale = given_scale / denoms
limit = np.sqrt(3 * scale)
nn.init.uniform_(m.weight.data, a=-limit, b=limit)
if hasattr(m.bias, "data"):
m.bias.data.fill_(0.0)
elif isinstance(m, nn.LayerNorm):
m.weight.data.fill_(1.0)
if hasattr(m.bias, "data"):
m.bias.data.fill_(0.0)
return f
def tensorstats(tensor, prefix=None):
2023-04-15 15:28:09 +09:00
metrics = {
"mean": to_np(torch.mean(tensor)),
"std": to_np(torch.std(tensor)),
"min": to_np(torch.min(tensor)),
"max": to_np(torch.max(tensor)),
}
if prefix:
metrics = {f"{prefix}_{k}": v for k, v in metrics.items()}
return metrics