diff --git a/tools.py b/tools.py index db7e911..bd73beb 100644 --- a/tools.py +++ b/tools.py @@ -20,629 +20,645 @@ from torch.utils.tensorboard import SummaryWriter to_np = lambda x: x.detach().cpu().numpy() + def symlog(x): - return torch.sign(x) * torch.log(torch.abs(x) + 1.0) + return torch.sign(x) * torch.log(torch.abs(x) + 1.0) + def symexp(x): - return torch.sign(x) * (torch.exp(torch.abs(x)) - 1.0) + return torch.sign(x) * (torch.exp(torch.abs(x)) - 1.0) + class RequiresGrad: + def __init__(self, model): + self._model = model - def __init__(self, model): - self._model = model + def __enter__(self): + self._model.requires_grad_(requires_grad=True) - def __enter__(self): - self._model.requires_grad_(requires_grad=True) - - def __exit__(self, *args): - self._model.requires_grad_(requires_grad=False) + def __exit__(self, *args): + self._model.requires_grad_(requires_grad=False) class TimeRecording: + def __init__(self, comment): + self._comment = comment - def __init__(self, comment): - self._comment = comment + def __enter__(self): + self._st = torch.cuda.Event(enable_timing=True) + self._nd = torch.cuda.Event(enable_timing=True) + self._st.record() - def __enter__(self): - self._st = torch.cuda.Event(enable_timing=True) - self._nd = torch.cuda.Event(enable_timing=True) - self._st.record() - - def __exit__(self, *args): - self._nd.record() - torch.cuda.synchronize() - print(self._comment, self._st.elapsed_time(self._nd)/1000) + def __exit__(self, *args): + self._nd.record() + torch.cuda.synchronize() + print(self._comment, self._st.elapsed_time(self._nd) / 1000) class Logger: + def __init__(self, logdir, step): + self._logdir = logdir + self._writer = SummaryWriter(log_dir=str(logdir), max_queue=1000) + self._last_step = None + self._last_time = None + self._scalars = {} + self._images = {} + self._videos = {} + self.step = step - def __init__(self, logdir, step): - self._logdir = logdir - self._writer = SummaryWriter(log_dir=str(logdir), max_queue=1000) - self._last_step = None - self._last_time = None - self._scalars = {} - self._images = {} - self._videos = {} - self.step = step + def scalar(self, name, value): + self._scalars[name] = float(value) - def scalar(self, name, value): - self._scalars[name] = float(value) + def image(self, name, value): + self._images[name] = np.array(value) - def image(self, name, value): - self._images[name] = np.array(value) + def video(self, name, value): + self._videos[name] = np.array(value) - def video(self, name, value): - self._videos[name] = np.array(value) + def write(self, fps=False, step=False): + if not step: + step = self.step + scalars = list(self._scalars.items()) + if fps: + scalars.append(("fps", self._compute_fps(step))) + print(f"[{step}]", " / ".join(f"{k} {v:.1f}" for k, v in scalars)) + with (self._logdir / "metrics.jsonl").open("a") as f: + f.write(json.dumps({"step": step, **dict(scalars)}) + "\n") + for name, value in scalars: + self._writer.add_scalar("scalars/" + name, value, step) + for name, value in self._images.items(): + self._writer.add_image(name, value, step) + for name, value in self._videos.items(): + name = name if isinstance(name, str) else name.decode("utf-8") + if np.issubdtype(value.dtype, np.floating): + value = np.clip(255 * value, 0, 255).astype(np.uint8) + B, T, H, W, C = value.shape + value = value.transpose(1, 4, 2, 0, 3).reshape((1, T, C, H, B * W)) + self._writer.add_video(name, value, step, 16) - def write(self, fps=False, step=False): - if not step: - step = self.step - scalars = list(self._scalars.items()) - if fps: - scalars.append(('fps', self._compute_fps(step))) - print(f'[{step}]', ' / '.join(f'{k} {v:.1f}' for k, v in scalars)) - with (self._logdir / 'metrics.jsonl').open('a') as f: - f.write(json.dumps({'step': step, ** dict(scalars)}) + '\n') - for name, value in scalars: - self._writer.add_scalar('scalars/' + name, value, step) - for name, value in self._images.items(): - self._writer.add_image(name, value, step) - for name, value in self._videos.items(): - name = name if isinstance(name, str) else name.decode('utf-8') - if np.issubdtype(value.dtype, np.floating): - value = np.clip(255 * value, 0, 255).astype(np.uint8) - B, T, H, W, C = value.shape - value = value.transpose(1, 4, 2, 0, 3).reshape((1, T, C, H, B*W)) - self._writer.add_video(name, value, step, 16) + self._writer.flush() + self._scalars = {} + self._images = {} + self._videos = {} - self._writer.flush() - self._scalars = {} - self._images = {} - self._videos = {} + def _compute_fps(self, step): + if self._last_step is None: + self._last_time = time.time() + self._last_step = step + return 0 + steps = step - self._last_step + duration = time.time() - self._last_time + self._last_time += duration + self._last_step = step + return steps / duration - def _compute_fps(self, step): - if self._last_step is None: - self._last_time = time.time() - self._last_step = step - return 0 - steps = step - self._last_step - duration = time.time() - self._last_time - self._last_time += duration - self._last_step = step - return steps / duration + def offline_scalar(self, name, value, step): + self._writer.add_scalar("scalars/" + name, value, step) - def offline_scalar(self, name, value, step): - self._writer.add_scalar('scalars/'+name, value, step) - - def offline_video(self, name, value, step): - if np.issubdtype(value.dtype, np.floating): - value = np.clip(255 * value, 0, 255).astype(np.uint8) - B, T, H, W, C = value.shape - value = value.transpose(1, 4, 2, 0, 3).reshape((1, T, C, H, B*W)) - self._writer.add_video(name, value, step, 16) + def offline_video(self, name, value, step): + if np.issubdtype(value.dtype, np.floating): + value = np.clip(255 * value, 0, 255).astype(np.uint8) + B, T, H, W, C = value.shape + value = value.transpose(1, 4, 2, 0, 3).reshape((1, T, C, H, B * W)) + self._writer.add_video(name, value, step, 16) def simulate(agent, envs, steps=0, episodes=0, state=None): - # Initialize or unpack simulation state. - if state is None: - step, episode = 0, 0 - done = np.ones(len(envs), np.bool) - length = np.zeros(len(envs), np.int32) - obs = [None] * len(envs) - agent_state = None - reward = [0]*len(envs) - else: - step, episode, done, length, obs, agent_state, reward = state - while (steps and step < steps) or (episodes and episode < episodes): - # Reset envs if necessary. - if done.any(): - indices = [index for index, d in enumerate(done) if d] - results = [envs[i].reset() for i in indices] - for index, result in zip(indices, results): - obs[index] = result - reward = [reward[i]*(1-done[i]) for i in range(len(envs))] - # Step agents. - obs = {k: np.stack([o[k] for o in obs]) for k in obs[0]} - action, agent_state = agent(obs, done, agent_state, reward) - if isinstance(action, dict): - action = [ - {k: np.array(action[k][i].detach().cpu()) for k in action} - for i in range(len(envs))] + # Initialize or unpack simulation state. + if state is None: + step, episode = 0, 0 + done = np.ones(len(envs), np.bool) + length = np.zeros(len(envs), np.int32) + obs = [None] * len(envs) + agent_state = None + reward = [0] * len(envs) else: - action = np.array(action) - assert len(action) == len(envs) - # Step envs. - results = [e.step(a) for e, a in zip(envs, action)] - obs, reward, done = zip(*[p[:3] for p in results]) - obs = list(obs) - reward = list(reward) - done = np.stack(done) - episode += int(done.sum()) - length += 1 - step += (done * length).sum() - length *= (1 - done) + step, episode, done, length, obs, agent_state, reward = state + while (steps and step < steps) or (episodes and episode < episodes): + # Reset envs if necessary. + if done.any(): + indices = [index for index, d in enumerate(done) if d] + results = [envs[i].reset() for i in indices] + for index, result in zip(indices, results): + obs[index] = result + reward = [reward[i] * (1 - done[i]) for i in range(len(envs))] + # Step agents. + obs = {k: np.stack([o[k] for o in obs]) for k in obs[0]} + action, agent_state = agent(obs, done, agent_state, reward) + if isinstance(action, dict): + action = [ + {k: np.array(action[k][i].detach().cpu()) for k in action} + for i in range(len(envs)) + ] + else: + action = np.array(action) + assert len(action) == len(envs) + # Step envs. + results = [e.step(a) for e, a in zip(envs, action)] + obs, reward, done = zip(*[p[:3] for p in results]) + obs = list(obs) + reward = list(reward) + done = np.stack(done) + episode += int(done.sum()) + length += 1 + step += (done * length).sum() + length *= 1 - done - return (step - steps, episode - episodes, done, length, obs, agent_state, reward) + return (step - steps, episode - episodes, done, length, obs, agent_state, reward) def save_episodes(directory, episodes): - directory = pathlib.Path(directory).expanduser() - directory.mkdir(parents=True, exist_ok=True) - timestamp = datetime.datetime.now().strftime('%Y%m%dT%H%M%S') - filenames = [] - for episode in episodes: - identifier = str(uuid.uuid4().hex) - length = len(episode['reward']) - filename = directory / f'{timestamp}-{identifier}-{length}.npz' - with io.BytesIO() as f1: - np.savez_compressed(f1, **episode) - f1.seek(0) - with filename.open('wb') as f2: - f2.write(f1.read()) - filenames.append(filename) - return filenames + directory = pathlib.Path(directory).expanduser() + directory.mkdir(parents=True, exist_ok=True) + timestamp = datetime.datetime.now().strftime("%Y%m%dT%H%M%S") + filenames = [] + for episode in episodes: + identifier = str(uuid.uuid4().hex) + length = len(episode["reward"]) + filename = directory / f"{timestamp}-{identifier}-{length}.npz" + with io.BytesIO() as f1: + np.savez_compressed(f1, **episode) + f1.seek(0) + with filename.open("wb") as f2: + f2.write(f1.read()) + filenames.append(filename) + return filenames def from_generator(generator, batch_size): - while True: - batch = [] - for _ in range(batch_size): - batch.append(next(generator)) - data = {} - for key in batch[0].keys(): - data[key] = [] - for i in range(batch_size): - data[key].append(batch[i][key]) - data[key] = np.stack(data[key], 0) - yield data + while True: + batch = [] + for _ in range(batch_size): + batch.append(next(generator)) + data = {} + for key in batch[0].keys(): + data[key] = [] + for i in range(batch_size): + data[key].append(batch[i][key]) + data[key] = np.stack(data[key], 0) + yield data def sample_episodes(episodes, length=None, balance=False, seed=0): - random = np.random.RandomState(seed) - while True: - episode = random.choice(list(episodes.values())) - if length: - total = len(next(iter(episode.values()))) - available = total - length - if available < 1: - print(f'Skipped short episode of length {available}.') - continue - if balance: - index = min(random.randint(0, total), available) - else: - index = int(random.randint(0, available + 1)) - episode = {k: v[index: index + length] for k, v in episode.items()} - yield episode + random = np.random.RandomState(seed) + while True: + episode = random.choice(list(episodes.values())) + if length: + total = len(next(iter(episode.values()))) + available = total - length + if available < 1: + print(f"Skipped short episode of length {available}.") + continue + if balance: + index = min(random.randint(0, total), available) + else: + index = int(random.randint(0, available + 1)) + episode = {k: v[index : index + length] for k, v in episode.items()} + yield episode def load_episodes(directory, limit=None, reverse=True): - directory = pathlib.Path(directory).expanduser() - episodes = collections.OrderedDict() - total = 0 - if reverse: - for filename in reversed(sorted(directory.glob('*.npz'))): - try: - with filename.open('rb') as f: - episode = np.load(f) - episode = {k: episode[k] for k in episode.keys()} - except Exception as e: - print(f'Could not load episode: {e}') - continue - episodes[str(filename)] = episode - total += len(episode['reward']) - 1 - if limit and total >= limit: - break - else: - for filename in sorted(directory.glob('*.npz')): - try: - with filename.open('rb') as f: - episode = np.load(f) - episode = {k: episode[k] for k in episode.keys()} - except Exception as e: - print(f'Could not load episode: {e}') - continue - episodes[str(filename)] = episode - total += len(episode['reward']) - 1 - if limit and total >= limit: - break - return episodes + directory = pathlib.Path(directory).expanduser() + episodes = collections.OrderedDict() + total = 0 + if reverse: + for filename in reversed(sorted(directory.glob("*.npz"))): + try: + with filename.open("rb") as f: + episode = np.load(f) + episode = {k: episode[k] for k in episode.keys()} + except Exception as e: + print(f"Could not load episode: {e}") + continue + episodes[str(filename)] = episode + total += len(episode["reward"]) - 1 + if limit and total >= limit: + break + else: + for filename in sorted(directory.glob("*.npz")): + try: + with filename.open("rb") as f: + episode = np.load(f) + episode = {k: episode[k] for k in episode.keys()} + except Exception as e: + print(f"Could not load episode: {e}") + continue + episodes[str(filename)] = episode + total += len(episode["reward"]) - 1 + if limit and total >= limit: + break + return episodes class SampleDist: + def __init__(self, dist, samples=100): + self._dist = dist + self._samples = samples - def __init__(self, dist, samples=100): - self._dist = dist - self._samples = samples + @property + def name(self): + return "SampleDist" - @property - def name(self): - return 'SampleDist' + def __getattr__(self, name): + return getattr(self._dist, name) - def __getattr__(self, name): - return getattr(self._dist, name) + def mean(self): + samples = self._dist.sample(self._samples) + return torch.mean(samples, 0) - def mean(self): - samples = self._dist.sample(self._samples) - return torch.mean(samples, 0) + def mode(self): + sample = self._dist.sample(self._samples) + logprob = self._dist.log_prob(sample) + return sample[torch.argmax(logprob)][0] - def mode(self): - sample = self._dist.sample(self._samples) - logprob = self._dist.log_prob(sample) - return sample[torch.argmax(logprob)][0] - - def entropy(self): - sample = self._dist.sample(self._samples) - logprob = self.log_prob(sample) - return -torch.mean(logprob, 0) + def entropy(self): + sample = self._dist.sample(self._samples) + logprob = self.log_prob(sample) + return -torch.mean(logprob, 0) class OneHotDist(torchd.one_hot_categorical.OneHotCategorical): + def __init__(self, logits=None, probs=None, unimix_ratio=0.0): + if logits is not None and unimix_ratio > 0.0: + probs = F.softmax(logits, dim=-1) + probs = probs * (1.0 - unimix_ratio) + unimix_ratio / probs.shape[-1] + logits = torch.log(probs) + super().__init__(logits=logits, probs=None) + else: + super().__init__(logits=logits, probs=probs) - def __init__(self, logits=None, probs=None, unimix_ratio=0.0): - if logits is not None and unimix_ratio > 0.0: - probs = F.softmax(logits, dim=-1) - probs = probs * (1.0-unimix_ratio) + unimix_ratio / probs.shape[-1] - logits = torch.log(probs) - super().__init__(logits=logits, probs=None) - else: - super().__init__(logits=logits, probs=probs) + def mode(self): + _mode = F.one_hot( + torch.argmax(super().logits, axis=-1), super().logits.shape[-1] + ) + return _mode.detach() + super().logits - super().logits.detach() - def mode(self): - _mode = F.one_hot(torch.argmax(super().logits, axis=-1), super().logits.shape[-1]) - return _mode.detach() + super().logits - super().logits.detach() - - def sample(self, sample_shape=(), seed=None): - if seed is not None: - raise ValueError('need to check') - sample = super().sample(sample_shape) - probs = super().probs - while len(probs.shape) < len(sample.shape): - probs = probs[None] - sample += probs - probs.detach() - return sample + def sample(self, sample_shape=(), seed=None): + if seed is not None: + raise ValueError("need to check") + sample = super().sample(sample_shape) + probs = super().probs + while len(probs.shape) < len(sample.shape): + probs = probs[None] + sample += probs - probs.detach() + return sample -class TwoHotDistSymlog(): +class TwoHotDistSymlog: + def __init__(self, logits=None, low=-20.0, high=20.0, device="cuda"): + self.logits = logits + self.probs = torch.softmax(logits, -1) + self.buckets = torch.linspace(low, high, steps=255).to(device) + self.width = (self.buckets[-1] - self.buckets[0]) / 255 - def __init__(self, logits=None, low=-20.0, high=20.0, device='cuda'): - self.logits = logits - self.probs = torch.softmax(logits, -1) - self.buckets = torch.linspace(low, high, steps=255).to(device) - self.width = (self.buckets[-1] - self.buckets[0]) / 255 + def mean(self): + print("mean called") + _mode = self.probs * self.buckets + return symexp(torch.sum(_mode, dim=-1, keepdim=True)) - def mean(self): - print("mean called") - _mode = self.probs * self.buckets - return symexp(torch.sum(_mode, dim=-1, keepdim=True)) + def mode(self): + _mode = self.probs * self.buckets + return symexp(torch.sum(_mode, dim=-1, keepdim=True)) - def mode(self): - _mode = self.probs * self.buckets - return symexp(torch.sum(_mode, dim=-1, keepdim=True)) + # Inside OneHotCategorical, log_prob is calculated using only max element in targets + def log_prob(self, x): + x = symlog(x) + # x(time, batch, 1) + below = torch.sum((self.buckets <= x[..., None]).to(torch.int32), dim=-1) - 1 + above = len(self.buckets) - torch.sum( + (self.buckets > x[..., None]).to(torch.int32), dim=-1 + ) + below = torch.clip(below, 0, len(self.buckets) - 1) + above = torch.clip(above, 0, len(self.buckets) - 1) + equal = below == above - # Inside OneHotCategorical, log_prob is calculated using only max element in targets - def log_prob(self, x): - x = symlog(x) - # x(time, batch, 1) - below = torch.sum((self.buckets <= x[..., None]).to(torch.int32), dim=-1) -1 - above = len(self.buckets) - torch.sum((self.buckets > x[..., None]).to(torch.int32), dim=-1) - below = torch.clip(below, 0, len(self.buckets)-1) - above = torch.clip(above, 0, len(self.buckets)-1) - equal = (below == above) + dist_to_below = torch.where(equal, 1, torch.abs(self.buckets[below] - x)) + dist_to_above = torch.where(equal, 1, torch.abs(self.buckets[above] - x)) + total = dist_to_below + dist_to_above + weight_below = dist_to_above / total + weight_above = dist_to_below / total + target = ( + F.one_hot(below, num_classes=len(self.buckets)) * weight_below[..., None] + + F.one_hot(above, num_classes=len(self.buckets)) * weight_above[..., None] + ) + log_pred = self.logits - torch.logsumexp(self.logits, -1, keepdim=True) + target = target.squeeze(-2) - dist_to_below = torch.where(equal, 1, torch.abs(self.buckets[below] - x)) - dist_to_above = torch.where(equal, 1, torch.abs(self.buckets[above] - x)) - total = dist_to_below + dist_to_above - weight_below = dist_to_above / total - weight_above = dist_to_below / total - target = ( - F.one_hot(below, num_classes=len(self.buckets)) * weight_below[..., None] + - F.one_hot(above, num_classes=len(self.buckets)) * weight_above[..., None]) - log_pred = self.logits - torch.logsumexp(self.logits, -1, keepdim=True) - target = target.squeeze(-2) + return (target * log_pred).sum(-1) - return (target * log_pred).sum(-1) + def log_prob_target(self, target): + log_pred = super().logits - torch.logsumexp(super().logits, -1, keepdim=True) + return (target * log_pred).sum(-1) - def log_prob_target(self, target): - log_pred = super().logits - torch.logsumexp(super().logits, -1, keepdim=True) - return (target * log_pred).sum(-1) -class SymlogDist(): - def __init__(self, mode, dist='mse', agg='sum', tol=1e-8, dim_to_reduce=[-1, -2, -3]): - self._mode = mode - self._dist = dist - self._agg = agg - self._tol = tol - self._dim_to_reduce = dim_to_reduce +class SymlogDist: + def __init__( + self, mode, dist="mse", agg="sum", tol=1e-8, dim_to_reduce=[-1, -2, -3] + ): + self._mode = mode + self._dist = dist + self._agg = agg + self._tol = tol + self._dim_to_reduce = dim_to_reduce - def mode(self): - return symexp(self._mode) + def mode(self): + return symexp(self._mode) - def mean(self): - return symexp(self._mode) + def mean(self): + return symexp(self._mode) + + def log_prob(self, value): + assert self._mode.shape == value.shape + if self._dist == "mse": + distance = (self._mode - symlog(value)) ** 2.0 + distance = torch.where(distance < self._tol, 0, distance) + elif self._dist == "abs": + distance = torch.abs(self._mode - symlog(value)) + distance = torch.where(distance < self._tol, 0, distance) + else: + raise NotImplementedError(self._dist) + if self._agg == "mean": + loss = distance.mean(self._dim_to_reduce) + elif self._agg == "sum": + loss = distance.sum(self._dim_to_reduce) + else: + raise NotImplementedError(self._agg) + return -loss - def log_prob(self, value): - assert self._mode.shape == value.shape - if self._dist == 'mse': - distance = (self._mode - symlog(value)) ** 2.0 - distance = torch.where(distance < self._tol, 0, distance) - elif self._dist == 'abs': - distance = torch.abs(self._mode - symlog(value)) - distance = torch.where(distance < self._tol, 0, distance) - else: - raise NotImplementedError(self._dist) - if self._agg == 'mean': - loss = distance.mean(self._dim_to_reduce) - elif self._agg == 'sum': - loss = distance.sum(self._dim_to_reduce) - else: - raise NotImplementedError(self._agg) - return -loss class ContDist: + def __init__(self, dist=None): + super().__init__() + self._dist = dist + self.mean = dist.mean - def __init__(self, dist=None): - super().__init__() - self._dist = dist - self.mean = dist.mean + def __getattr__(self, name): + return getattr(self._dist, name) - def __getattr__(self, name): - return getattr(self._dist, name) + def entropy(self): + return self._dist.entropy() - def entropy(self): - return self._dist.entropy() + def mode(self): + return self._dist.mean - def mode(self): - return self._dist.mean + def sample(self, sample_shape=()): + return self._dist.rsample(sample_shape) - def sample(self, sample_shape=()): - return self._dist.rsample(sample_shape) - - def log_prob(self, x): - return self._dist.log_prob(x) + def log_prob(self, x): + return self._dist.log_prob(x) class Bernoulli: + def __init__(self, dist=None): + super().__init__() + self._dist = dist + self.mean = dist.mean - def __init__(self, dist=None): - super().__init__() - self._dist = dist - self.mean = dist.mean + def __getattr__(self, name): + return getattr(self._dist, name) - def __getattr__(self, name): - return getattr(self._dist, name) + def entropy(self): + return self._dist.entropy() - def entropy(self): - return self._dist.entropy() + def mode(self): + _mode = torch.round(self._dist.mean) + return _mode.detach() + self._dist.mean - self._dist.mean.detach() - def mode(self): - _mode = torch.round(self._dist.mean) - return _mode.detach() +self._dist.mean - self._dist.mean.detach() + def sample(self, sample_shape=()): + return self._dist.rsample(sample_shape) - def sample(self, sample_shape=()): - return self._dist.rsample(sample_shape) + def log_prob(self, x): + _logits = self._dist.base_dist.logits + log_probs0 = -F.softplus(_logits) + log_probs1 = -F.softplus(-_logits) - def log_prob(self, x): - _logits = self._dist.base_dist.logits - log_probs0 = -F.softplus(_logits) - log_probs1 = -F.softplus(-_logits) - - return log_probs0 * (1-x) + log_probs1 * x + return log_probs0 * (1 - x) + log_probs1 * x class UnnormalizedHuber(torchd.normal.Normal): + def __init__(self, loc, scale, threshold=1, **kwargs): + super().__init__(loc, scale, **kwargs) + self._threshold = threshold - def __init__(self, loc, scale, threshold=1, **kwargs): - super().__init__(loc, scale, **kwargs) - self._threshold = threshold + def log_prob(self, event): + return -( + torch.sqrt((event - self.mean) ** 2 + self._threshold**2) + - self._threshold + ) - def log_prob(self, event): - return -(torch.sqrt( - (event - self.mean) ** 2 + self._threshold ** 2) - self._threshold) - - def mode(self): - return self.mean + def mode(self): + return self.mean class SafeTruncatedNormal(torchd.normal.Normal): + def __init__(self, loc, scale, low, high, clip=1e-6, mult=1): + super().__init__(loc, scale) + self._low = low + self._high = high + self._clip = clip + self._mult = mult - def __init__(self, loc, scale, low, high, clip=1e-6, mult=1): - super().__init__(loc, scale) - self._low = low - self._high = high - self._clip = clip - self._mult = mult - - def sample(self, sample_shape): - event = super().sample(sample_shape) - if self._clip: - clipped = torch.clip(event, self._low + self._clip, - self._high - self._clip) - event = event - event.detach() + clipped.detach() - if self._mult: - event *= self._mult - return event + def sample(self, sample_shape): + event = super().sample(sample_shape) + if self._clip: + clipped = torch.clip(event, self._low + self._clip, self._high - self._clip) + event = event - event.detach() + clipped.detach() + if self._mult: + event *= self._mult + return event class TanhBijector(torchd.Transform): + def __init__(self, validate_args=False, name="tanh"): + super().__init__() - def __init__(self, validate_args=False, name='tanh'): - super().__init__() + def _forward(self, x): + return torch.tanh(x) - def _forward(self, x): - return torch.tanh(x) + def _inverse(self, y): + y = torch.where( + (torch.abs(y) <= 1.0), torch.clamp(y, -0.99999997, 0.99999997), y + ) + y = torch.atanh(y) + return y - def _inverse(self, y): - y = torch.where( - (torch.abs(y) <= 1.), - torch.clamp(y, -0.99999997, 0.99999997), y) - y = torch.atanh(y) - return y - - def _forward_log_det_jacobian(self, x): - log2 = torch.math.log(2.0) - return 2.0 * (log2 - x - torch.softplus(-2.0 * x)) + def _forward_log_det_jacobian(self, x): + log2 = torch.math.log(2.0) + return 2.0 * (log2 - x - torch.softplus(-2.0 * x)) def static_scan_for_lambda_return(fn, inputs, start): - last = start - indices = range(inputs[0].shape[0]) - indices = reversed(indices) - flag = True - for index in indices: - # (inputs, pcont) -> (inputs[index], pcont[index]) - inp = lambda x: (_input[x] for _input in inputs) - last = fn(last, *inp(index)) - if flag: - outputs = last - flag = False - else: - outputs = torch.cat([outputs, last], dim=-1) - outputs = torch.reshape(outputs, [outputs.shape[0], outputs.shape[1], 1]) - outputs = torch.flip(outputs, [1]) - outputs = torch.unbind(outputs, dim=0) - return outputs + last = start + indices = range(inputs[0].shape[0]) + indices = reversed(indices) + flag = True + for index in indices: + # (inputs, pcont) -> (inputs[index], pcont[index]) + inp = lambda x: (_input[x] for _input in inputs) + last = fn(last, *inp(index)) + if flag: + outputs = last + flag = False + else: + outputs = torch.cat([outputs, last], dim=-1) + outputs = torch.reshape(outputs, [outputs.shape[0], outputs.shape[1], 1]) + outputs = torch.flip(outputs, [1]) + outputs = torch.unbind(outputs, dim=0) + return outputs -def lambda_return( - reward, value, pcont, bootstrap, lambda_, axis): - # Setting lambda=1 gives a discounted Monte Carlo return. - # Setting lambda=0 gives a fixed 1-step return. - #assert reward.shape.ndims == value.shape.ndims, (reward.shape, value.shape) - assert len(reward.shape) == len(value.shape), (reward.shape, value.shape) - if isinstance(pcont, (int, float)): - pcont = pcont * torch.ones_like(reward) - dims = list(range(len(reward.shape))) - dims = [axis] + dims[1:axis] + [0] + dims[axis + 1:] - if axis != 0: - reward = reward.permute(dims) - value = value.permute(dims) - pcont = pcont.permute(dims) - if bootstrap is None: - bootstrap = torch.zeros_like(value[-1]) - next_values = torch.cat([value[1:], bootstrap[None]], 0) - inputs = reward + pcont * next_values * (1 - lambda_) - #returns = static_scan( - # lambda agg, cur0, cur1: cur0 + cur1 * lambda_ * agg, - # (inputs, pcont), bootstrap, reverse=True) - # reimplement to optimize performance - returns = static_scan_for_lambda_return( - lambda agg, cur0, cur1: cur0 + cur1 * lambda_ * agg, - (inputs, pcont), bootstrap) - if axis != 0: - returns = returns.permute(dims) - return returns +def lambda_return(reward, value, pcont, bootstrap, lambda_, axis): + # Setting lambda=1 gives a discounted Monte Carlo return. + # Setting lambda=0 gives a fixed 1-step return. + # assert reward.shape.ndims == value.shape.ndims, (reward.shape, value.shape) + assert len(reward.shape) == len(value.shape), (reward.shape, value.shape) + if isinstance(pcont, (int, float)): + pcont = pcont * torch.ones_like(reward) + dims = list(range(len(reward.shape))) + dims = [axis] + dims[1:axis] + [0] + dims[axis + 1 :] + if axis != 0: + reward = reward.permute(dims) + value = value.permute(dims) + pcont = pcont.permute(dims) + if bootstrap is None: + bootstrap = torch.zeros_like(value[-1]) + next_values = torch.cat([value[1:], bootstrap[None]], 0) + inputs = reward + pcont * next_values * (1 - lambda_) + # returns = static_scan( + # lambda agg, cur0, cur1: cur0 + cur1 * lambda_ * agg, + # (inputs, pcont), bootstrap, reverse=True) + # reimplement to optimize performance + returns = static_scan_for_lambda_return( + lambda agg, cur0, cur1: cur0 + cur1 * lambda_ * agg, (inputs, pcont), bootstrap + ) + if axis != 0: + returns = returns.permute(dims) + return returns -class Optimizer(): +class Optimizer: + def __init__( + self, + name, + parameters, + lr, + eps=1e-4, + clip=None, + wd=None, + wd_pattern=r".*", + opt="adam", + use_amp=False, + ): + assert 0 <= wd < 1 + assert not clip or 1 <= clip + self._name = name + self._parameters = parameters + self._clip = clip + self._wd = wd + self._wd_pattern = wd_pattern + self._opt = { + "adam": lambda: torch.optim.Adam(parameters, lr=lr, eps=eps), + "nadam": lambda: NotImplemented(f"{opt} is not implemented"), + "adamax": lambda: torch.optim.Adamax(parameters, lr=lr, eps=eps), + "sgd": lambda: torch.optim.SGD(parameters, lr=lr), + "momentum": lambda: torch.optim.SGD(parameters, lr=lr, momentum=0.9), + }[opt]() + self._scaler = torch.cuda.amp.GradScaler(enabled=use_amp) - def __init__( - self, name, parameters, lr, eps=1e-4, clip=None, wd=None, wd_pattern=r'.*', - opt='adam', use_amp=False): - assert 0 <= wd < 1 - assert not clip or 1 <= clip - self._name = name - self._parameters = parameters - self._clip = clip - self._wd = wd - self._wd_pattern = wd_pattern - self._opt = { - 'adam': lambda: torch.optim.Adam(parameters, - lr=lr, - eps=eps), - 'nadam': lambda: NotImplemented( - f'{opt} is not implemented'), - 'adamax': lambda: torch.optim.Adamax(parameters, - lr=lr, - eps=eps), - 'sgd': lambda: torch.optim.SGD(parameters, - lr=lr), - 'momentum': lambda: torch.optim.SGD(parameters, - lr=lr, - momentum=0.9), - }[opt]() - self._scaler = torch.cuda.amp.GradScaler(enabled=use_amp) + def __call__(self, loss, params, retain_graph=False): + assert len(loss.shape) == 0, loss.shape + metrics = {} + metrics[f"{self._name}_loss"] = loss.detach().cpu().numpy() + self._scaler.scale(loss).backward() + self._scaler.unscale_(self._opt) + # loss.backward(retain_graph=retain_graph) + norm = torch.nn.utils.clip_grad_norm_(params, self._clip) + if self._wd: + self._apply_weight_decay(params) + self._scaler.step(self._opt) + self._scaler.update() + # self._opt.step() + self._opt.zero_grad() + metrics[f"{self._name}_grad_norm"] = norm.item() + return metrics - def __call__(self, loss, params, retain_graph=False): - assert len(loss.shape) == 0, loss.shape - metrics = {} - metrics[f'{self._name}_loss'] = loss.detach().cpu().numpy() - self._scaler.scale(loss).backward() - self._scaler.unscale_(self._opt) - #loss.backward(retain_graph=retain_graph) - norm = torch.nn.utils.clip_grad_norm_(params, self._clip) - if self._wd: - self._apply_weight_decay(params) - self._scaler.step(self._opt) - self._scaler.update() - #self._opt.step() - self._opt.zero_grad() - metrics[f'{self._name}_grad_norm'] = norm.item() - return metrics - - def _apply_weight_decay(self, varibs): - nontrivial = (self._wd_pattern != r'.*') - if nontrivial: - raise NotImplementedError - for var in varibs: - var.data = (1 - self._wd) * var.data + def _apply_weight_decay(self, varibs): + nontrivial = self._wd_pattern != r".*" + if nontrivial: + raise NotImplementedError + for var in varibs: + var.data = (1 - self._wd) * var.data def args_type(default): - def parse_string(x): - if default is None: - return x - if isinstance(default, bool): - return bool(['False', 'True'].index(x)) - if isinstance(default, int): - return float(x) if ('e' in x or '.' in x) else int(x) - if isinstance(default, (list, tuple)): - return tuple(args_type(default[0])(y) for y in x.split(',')) - return type(default)(x) - def parse_object(x): - if isinstance(default, (list, tuple)): - return tuple(x) - return x - return lambda x: parse_string(x) if isinstance(x, str) else parse_object(x) + def parse_string(x): + if default is None: + return x + if isinstance(default, bool): + return bool(["False", "True"].index(x)) + if isinstance(default, int): + return float(x) if ("e" in x or "." in x) else int(x) + if isinstance(default, (list, tuple)): + return tuple(args_type(default[0])(y) for y in x.split(",")) + return type(default)(x) + + def parse_object(x): + if isinstance(default, (list, tuple)): + return tuple(x) + return x + + return lambda x: parse_string(x) if isinstance(x, str) else parse_object(x) def static_scan(fn, inputs, start): - last = start - indices = range(inputs[0].shape[0]) - flag = True - for index in indices: - inp = lambda x: (_input[x] for _input in inputs) - last = fn(last, *inp(index)) - if flag: - if type(last) == type({}): - outputs = {key: value.clone().unsqueeze(0) for key, value in last.items()} - else: - outputs = [] - for _last in last: - if type(_last) == type({}): - outputs.append({key: value.clone().unsqueeze(0) for key, value in _last.items()}) - else: - outputs.append(_last.clone().unsqueeze(0)) - flag = False - else: - if type(last) == type({}): - for key in last.keys(): - outputs[key] = torch.cat([outputs[key], last[key].unsqueeze(0)], dim=0) - else: - for j in range(len(outputs)): - if type(last[j]) == type({}): - for key in last[j].keys(): - outputs[j][key] = torch.cat([outputs[j][key], - last[j][key].unsqueeze(0)], dim=0) - else: - outputs[j] = torch.cat([outputs[j], last[j].unsqueeze(0)], dim=0) - if type(last) == type({}): - outputs = [outputs] - return outputs + last = start + indices = range(inputs[0].shape[0]) + flag = True + for index in indices: + inp = lambda x: (_input[x] for _input in inputs) + last = fn(last, *inp(index)) + if flag: + if type(last) == type({}): + outputs = { + key: value.clone().unsqueeze(0) for key, value in last.items() + } + else: + outputs = [] + for _last in last: + if type(_last) == type({}): + outputs.append( + { + key: value.clone().unsqueeze(0) + for key, value in _last.items() + } + ) + else: + outputs.append(_last.clone().unsqueeze(0)) + flag = False + else: + if type(last) == type({}): + for key in last.keys(): + outputs[key] = torch.cat( + [outputs[key], last[key].unsqueeze(0)], dim=0 + ) + else: + for j in range(len(outputs)): + if type(last[j]) == type({}): + for key in last[j].keys(): + outputs[j][key] = torch.cat( + [outputs[j][key], last[j][key].unsqueeze(0)], dim=0 + ) + else: + outputs[j] = torch.cat( + [outputs[j], last[j].unsqueeze(0)], dim=0 + ) + if type(last) == type({}): + outputs = [outputs] + return outputs # Original version -#def static_scan2(fn, inputs, start, reverse=False): +# def static_scan2(fn, inputs, start, reverse=False): # last = start # outputs = [[] for _ in range(len([start] if type(start)==type({}) else start))] # indices = range(inputs[0].shape[0]) @@ -673,119 +689,121 @@ def static_scan(fn, inputs, start): class Every: + def __init__(self, every): + self._every = every + self._last = None - def __init__(self, every): - self._every = every - self._last = None + def __call__(self, step): + if not self._every: + return 0 + if self._last is None: + self._last = step + return 1 + count = int((step - self._last) / self._every) + self._last += self._every * count + return count - def __call__(self, step): - if not self._every: - return 0 - if self._last is None: - self._last = step - return 1 - count = int((step - self._last) / self._every) - self._last += self._every * count - return count class Once: + def __init__(self): + self._once = True - def __init__(self): - self._once = True - - def __call__(self): - if self._once: - self._once = False - return True - return False + def __call__(self): + if self._once: + self._once = False + return True + return False class Until: + def __init__(self, until): + self._until = until - def __init__(self, until): - self._until = until - - def __call__(self, step): - if not self._until: - return True - return step < self._until + def __call__(self, step): + if not self._until: + return True + return step < self._until def schedule(string, step): - try: - return float(string) - except ValueError: - match = re.match(r'linear\((.+),(.+),(.+)\)', string) - if match: - initial, final, duration = [float(group) for group in match.groups()] - mix = torch.clip(torch.Tensor([step / duration]), 0, 1)[0] - return (1 - mix) * initial + mix * final - match = re.match(r'warmup\((.+),(.+)\)', string) - if match: - warmup, value = [float(group) for group in match.groups()] - scale = torch.clip(step / warmup, 0, 1) - return scale * value - match = re.match(r'exp\((.+),(.+),(.+)\)', string) - if match: - initial, final, halflife = [float(group) for group in match.groups()] - return (initial - final) * 0.5 ** (step / halflife) + final - match = re.match(r'horizon\((.+),(.+),(.+)\)', string) - if match: - initial, final, duration = [float(group) for group in match.groups()] - mix = torch.clip(step / duration, 0, 1) - horizon = (1 - mix) * initial + mix * final - return 1 - 1 / horizon - raise NotImplementedError(string) + try: + return float(string) + except ValueError: + match = re.match(r"linear\((.+),(.+),(.+)\)", string) + if match: + initial, final, duration = [float(group) for group in match.groups()] + mix = torch.clip(torch.Tensor([step / duration]), 0, 1)[0] + return (1 - mix) * initial + mix * final + match = re.match(r"warmup\((.+),(.+)\)", string) + if match: + warmup, value = [float(group) for group in match.groups()] + scale = torch.clip(step / warmup, 0, 1) + return scale * value + match = re.match(r"exp\((.+),(.+),(.+)\)", string) + if match: + initial, final, halflife = [float(group) for group in match.groups()] + return (initial - final) * 0.5 ** (step / halflife) + final + match = re.match(r"horizon\((.+),(.+),(.+)\)", string) + if match: + initial, final, duration = [float(group) for group in match.groups()] + mix = torch.clip(step / duration, 0, 1) + horizon = (1 - mix) * initial + mix * final + return 1 - 1 / horizon + raise NotImplementedError(string) + def weight_init(m): if isinstance(m, nn.Linear): - in_num = m.in_features - out_num = m.out_features - denoms = (in_num + out_num) / 2.0 - scale = 1.0 / denoms - std = np.sqrt(scale) / 0.87962566103423978 - nn.init.trunc_normal_(m.weight.data, mean=0.0, std=std, a=- 2.0, b=2.0) - if hasattr(m.bias, 'data'): - m.bias.data.fill_(0.0) + in_num = m.in_features + out_num = m.out_features + denoms = (in_num + out_num) / 2.0 + scale = 1.0 / denoms + std = np.sqrt(scale) / 0.87962566103423978 + nn.init.trunc_normal_(m.weight.data, mean=0.0, std=std, a=-2.0, b=2.0) + if hasattr(m.bias, "data"): + m.bias.data.fill_(0.0) elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): - space = m.kernel_size[0] * m.kernel_size[1] - in_num = space * m.in_channels - out_num = space * m.out_channels - denoms = (in_num + out_num) / 2.0 - scale = 1.0 / denoms - std = np.sqrt(scale) / 0.87962566103423978 - nn.init.trunc_normal_(m.weight.data, mean=0.0, std=std, a=- 2.0, b=2.0) - if hasattr(m.bias, 'data'): - m.bias.data.fill_(0.0) + space = m.kernel_size[0] * m.kernel_size[1] + in_num = space * m.in_channels + out_num = space * m.out_channels + denoms = (in_num + out_num) / 2.0 + scale = 1.0 / denoms + std = np.sqrt(scale) / 0.87962566103423978 + nn.init.trunc_normal_(m.weight.data, mean=0.0, std=std, a=-2.0, b=2.0) + if hasattr(m.bias, "data"): + m.bias.data.fill_(0.0) elif isinstance(m, nn.LayerNorm): - m.weight.data.fill_(1.0) - if hasattr(m.bias, 'data'): - m.bias.data.fill_(0.0) + m.weight.data.fill_(1.0) + if hasattr(m.bias, "data"): + m.bias.data.fill_(0.0) + def uniform_weight_init(given_scale): - def f(m): - if isinstance(m, nn.Linear): - in_num = m.in_features - out_num = m.out_features - denoms = (in_num + out_num) / 2.0 - scale = given_scale / denoms - limit = np.sqrt(3 * scale) - nn.init.uniform_(m.weight.data, a=-limit, b=limit) - if hasattr(m.bias, 'data'): - m.bias.data.fill_(0.0) - elif isinstance(m, nn.LayerNorm): - m.weight.data.fill_(1.0) - if hasattr(m.bias, 'data'): - m.bias.data.fill_(0.0) - return f + def f(m): + if isinstance(m, nn.Linear): + in_num = m.in_features + out_num = m.out_features + denoms = (in_num + out_num) / 2.0 + scale = given_scale / denoms + limit = np.sqrt(3 * scale) + nn.init.uniform_(m.weight.data, a=-limit, b=limit) + if hasattr(m.bias, "data"): + m.bias.data.fill_(0.0) + elif isinstance(m, nn.LayerNorm): + m.weight.data.fill_(1.0) + if hasattr(m.bias, "data"): + m.bias.data.fill_(0.0) + + return f + def tensorstats(tensor, prefix=None): - metrics = { - 'mean': to_np(torch.mean(tensor)), - 'std': to_np(torch.std(tensor)), - 'min': to_np(torch.min(tensor)), - 'max': to_np(torch.max(tensor)), - } - if prefix: - metrics = {f'{prefix}_{k}': v for k, v in metrics.items()} - return metrics + metrics = { + "mean": to_np(torch.mean(tensor)), + "std": to_np(torch.std(tensor)), + "min": to_np(torch.min(tensor)), + "max": to_np(torch.max(tensor)), + } + if prefix: + metrics = {f"{prefix}_{k}": v for k, v in metrics.items()} + return metrics