applied formatter to tools
This commit is contained in:
parent
55ed69bdf7
commit
fba87a33e0
268
tools.py
268
tools.py
@ -20,14 +20,16 @@ from torch.utils.tensorboard import SummaryWriter
|
|||||||
|
|
||||||
to_np = lambda x: x.detach().cpu().numpy()
|
to_np = lambda x: x.detach().cpu().numpy()
|
||||||
|
|
||||||
|
|
||||||
def symlog(x):
|
def symlog(x):
|
||||||
return torch.sign(x) * torch.log(torch.abs(x) + 1.0)
|
return torch.sign(x) * torch.log(torch.abs(x) + 1.0)
|
||||||
|
|
||||||
|
|
||||||
def symexp(x):
|
def symexp(x):
|
||||||
return torch.sign(x) * (torch.exp(torch.abs(x)) - 1.0)
|
return torch.sign(x) * (torch.exp(torch.abs(x)) - 1.0)
|
||||||
|
|
||||||
class RequiresGrad:
|
|
||||||
|
|
||||||
|
class RequiresGrad:
|
||||||
def __init__(self, model):
|
def __init__(self, model):
|
||||||
self._model = model
|
self._model = model
|
||||||
|
|
||||||
@ -39,7 +41,6 @@ class RequiresGrad:
|
|||||||
|
|
||||||
|
|
||||||
class TimeRecording:
|
class TimeRecording:
|
||||||
|
|
||||||
def __init__(self, comment):
|
def __init__(self, comment):
|
||||||
self._comment = comment
|
self._comment = comment
|
||||||
|
|
||||||
@ -51,11 +52,10 @@ class TimeRecording:
|
|||||||
def __exit__(self, *args):
|
def __exit__(self, *args):
|
||||||
self._nd.record()
|
self._nd.record()
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
print(self._comment, self._st.elapsed_time(self._nd)/1000)
|
print(self._comment, self._st.elapsed_time(self._nd) / 1000)
|
||||||
|
|
||||||
|
|
||||||
class Logger:
|
class Logger:
|
||||||
|
|
||||||
def __init__(self, logdir, step):
|
def __init__(self, logdir, step):
|
||||||
self._logdir = logdir
|
self._logdir = logdir
|
||||||
self._writer = SummaryWriter(log_dir=str(logdir), max_queue=1000)
|
self._writer = SummaryWriter(log_dir=str(logdir), max_queue=1000)
|
||||||
@ -80,20 +80,20 @@ class Logger:
|
|||||||
step = self.step
|
step = self.step
|
||||||
scalars = list(self._scalars.items())
|
scalars = list(self._scalars.items())
|
||||||
if fps:
|
if fps:
|
||||||
scalars.append(('fps', self._compute_fps(step)))
|
scalars.append(("fps", self._compute_fps(step)))
|
||||||
print(f'[{step}]', ' / '.join(f'{k} {v:.1f}' for k, v in scalars))
|
print(f"[{step}]", " / ".join(f"{k} {v:.1f}" for k, v in scalars))
|
||||||
with (self._logdir / 'metrics.jsonl').open('a') as f:
|
with (self._logdir / "metrics.jsonl").open("a") as f:
|
||||||
f.write(json.dumps({'step': step, ** dict(scalars)}) + '\n')
|
f.write(json.dumps({"step": step, **dict(scalars)}) + "\n")
|
||||||
for name, value in scalars:
|
for name, value in scalars:
|
||||||
self._writer.add_scalar('scalars/' + name, value, step)
|
self._writer.add_scalar("scalars/" + name, value, step)
|
||||||
for name, value in self._images.items():
|
for name, value in self._images.items():
|
||||||
self._writer.add_image(name, value, step)
|
self._writer.add_image(name, value, step)
|
||||||
for name, value in self._videos.items():
|
for name, value in self._videos.items():
|
||||||
name = name if isinstance(name, str) else name.decode('utf-8')
|
name = name if isinstance(name, str) else name.decode("utf-8")
|
||||||
if np.issubdtype(value.dtype, np.floating):
|
if np.issubdtype(value.dtype, np.floating):
|
||||||
value = np.clip(255 * value, 0, 255).astype(np.uint8)
|
value = np.clip(255 * value, 0, 255).astype(np.uint8)
|
||||||
B, T, H, W, C = value.shape
|
B, T, H, W, C = value.shape
|
||||||
value = value.transpose(1, 4, 2, 0, 3).reshape((1, T, C, H, B*W))
|
value = value.transpose(1, 4, 2, 0, 3).reshape((1, T, C, H, B * W))
|
||||||
self._writer.add_video(name, value, step, 16)
|
self._writer.add_video(name, value, step, 16)
|
||||||
|
|
||||||
self._writer.flush()
|
self._writer.flush()
|
||||||
@ -113,13 +113,13 @@ class Logger:
|
|||||||
return steps / duration
|
return steps / duration
|
||||||
|
|
||||||
def offline_scalar(self, name, value, step):
|
def offline_scalar(self, name, value, step):
|
||||||
self._writer.add_scalar('scalars/'+name, value, step)
|
self._writer.add_scalar("scalars/" + name, value, step)
|
||||||
|
|
||||||
def offline_video(self, name, value, step):
|
def offline_video(self, name, value, step):
|
||||||
if np.issubdtype(value.dtype, np.floating):
|
if np.issubdtype(value.dtype, np.floating):
|
||||||
value = np.clip(255 * value, 0, 255).astype(np.uint8)
|
value = np.clip(255 * value, 0, 255).astype(np.uint8)
|
||||||
B, T, H, W, C = value.shape
|
B, T, H, W, C = value.shape
|
||||||
value = value.transpose(1, 4, 2, 0, 3).reshape((1, T, C, H, B*W))
|
value = value.transpose(1, 4, 2, 0, 3).reshape((1, T, C, H, B * W))
|
||||||
self._writer.add_video(name, value, step, 16)
|
self._writer.add_video(name, value, step, 16)
|
||||||
|
|
||||||
|
|
||||||
@ -131,7 +131,7 @@ def simulate(agent, envs, steps=0, episodes=0, state=None):
|
|||||||
length = np.zeros(len(envs), np.int32)
|
length = np.zeros(len(envs), np.int32)
|
||||||
obs = [None] * len(envs)
|
obs = [None] * len(envs)
|
||||||
agent_state = None
|
agent_state = None
|
||||||
reward = [0]*len(envs)
|
reward = [0] * len(envs)
|
||||||
else:
|
else:
|
||||||
step, episode, done, length, obs, agent_state, reward = state
|
step, episode, done, length, obs, agent_state, reward = state
|
||||||
while (steps and step < steps) or (episodes and episode < episodes):
|
while (steps and step < steps) or (episodes and episode < episodes):
|
||||||
@ -141,14 +141,15 @@ def simulate(agent, envs, steps=0, episodes=0, state=None):
|
|||||||
results = [envs[i].reset() for i in indices]
|
results = [envs[i].reset() for i in indices]
|
||||||
for index, result in zip(indices, results):
|
for index, result in zip(indices, results):
|
||||||
obs[index] = result
|
obs[index] = result
|
||||||
reward = [reward[i]*(1-done[i]) for i in range(len(envs))]
|
reward = [reward[i] * (1 - done[i]) for i in range(len(envs))]
|
||||||
# Step agents.
|
# Step agents.
|
||||||
obs = {k: np.stack([o[k] for o in obs]) for k in obs[0]}
|
obs = {k: np.stack([o[k] for o in obs]) for k in obs[0]}
|
||||||
action, agent_state = agent(obs, done, agent_state, reward)
|
action, agent_state = agent(obs, done, agent_state, reward)
|
||||||
if isinstance(action, dict):
|
if isinstance(action, dict):
|
||||||
action = [
|
action = [
|
||||||
{k: np.array(action[k][i].detach().cpu()) for k in action}
|
{k: np.array(action[k][i].detach().cpu()) for k in action}
|
||||||
for i in range(len(envs))]
|
for i in range(len(envs))
|
||||||
|
]
|
||||||
else:
|
else:
|
||||||
action = np.array(action)
|
action = np.array(action)
|
||||||
assert len(action) == len(envs)
|
assert len(action) == len(envs)
|
||||||
@ -161,7 +162,7 @@ def simulate(agent, envs, steps=0, episodes=0, state=None):
|
|||||||
episode += int(done.sum())
|
episode += int(done.sum())
|
||||||
length += 1
|
length += 1
|
||||||
step += (done * length).sum()
|
step += (done * length).sum()
|
||||||
length *= (1 - done)
|
length *= 1 - done
|
||||||
|
|
||||||
return (step - steps, episode - episodes, done, length, obs, agent_state, reward)
|
return (step - steps, episode - episodes, done, length, obs, agent_state, reward)
|
||||||
|
|
||||||
@ -169,16 +170,16 @@ def simulate(agent, envs, steps=0, episodes=0, state=None):
|
|||||||
def save_episodes(directory, episodes):
|
def save_episodes(directory, episodes):
|
||||||
directory = pathlib.Path(directory).expanduser()
|
directory = pathlib.Path(directory).expanduser()
|
||||||
directory.mkdir(parents=True, exist_ok=True)
|
directory.mkdir(parents=True, exist_ok=True)
|
||||||
timestamp = datetime.datetime.now().strftime('%Y%m%dT%H%M%S')
|
timestamp = datetime.datetime.now().strftime("%Y%m%dT%H%M%S")
|
||||||
filenames = []
|
filenames = []
|
||||||
for episode in episodes:
|
for episode in episodes:
|
||||||
identifier = str(uuid.uuid4().hex)
|
identifier = str(uuid.uuid4().hex)
|
||||||
length = len(episode['reward'])
|
length = len(episode["reward"])
|
||||||
filename = directory / f'{timestamp}-{identifier}-{length}.npz'
|
filename = directory / f"{timestamp}-{identifier}-{length}.npz"
|
||||||
with io.BytesIO() as f1:
|
with io.BytesIO() as f1:
|
||||||
np.savez_compressed(f1, **episode)
|
np.savez_compressed(f1, **episode)
|
||||||
f1.seek(0)
|
f1.seek(0)
|
||||||
with filename.open('wb') as f2:
|
with filename.open("wb") as f2:
|
||||||
f2.write(f1.read())
|
f2.write(f1.read())
|
||||||
filenames.append(filename)
|
filenames.append(filename)
|
||||||
return filenames
|
return filenames
|
||||||
@ -206,13 +207,13 @@ def sample_episodes(episodes, length=None, balance=False, seed=0):
|
|||||||
total = len(next(iter(episode.values())))
|
total = len(next(iter(episode.values())))
|
||||||
available = total - length
|
available = total - length
|
||||||
if available < 1:
|
if available < 1:
|
||||||
print(f'Skipped short episode of length {available}.')
|
print(f"Skipped short episode of length {available}.")
|
||||||
continue
|
continue
|
||||||
if balance:
|
if balance:
|
||||||
index = min(random.randint(0, total), available)
|
index = min(random.randint(0, total), available)
|
||||||
else:
|
else:
|
||||||
index = int(random.randint(0, available + 1))
|
index = int(random.randint(0, available + 1))
|
||||||
episode = {k: v[index: index + length] for k, v in episode.items()}
|
episode = {k: v[index : index + length] for k, v in episode.items()}
|
||||||
yield episode
|
yield episode
|
||||||
|
|
||||||
|
|
||||||
@ -221,43 +222,42 @@ def load_episodes(directory, limit=None, reverse=True):
|
|||||||
episodes = collections.OrderedDict()
|
episodes = collections.OrderedDict()
|
||||||
total = 0
|
total = 0
|
||||||
if reverse:
|
if reverse:
|
||||||
for filename in reversed(sorted(directory.glob('*.npz'))):
|
for filename in reversed(sorted(directory.glob("*.npz"))):
|
||||||
try:
|
try:
|
||||||
with filename.open('rb') as f:
|
with filename.open("rb") as f:
|
||||||
episode = np.load(f)
|
episode = np.load(f)
|
||||||
episode = {k: episode[k] for k in episode.keys()}
|
episode = {k: episode[k] for k in episode.keys()}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f'Could not load episode: {e}')
|
print(f"Could not load episode: {e}")
|
||||||
continue
|
continue
|
||||||
episodes[str(filename)] = episode
|
episodes[str(filename)] = episode
|
||||||
total += len(episode['reward']) - 1
|
total += len(episode["reward"]) - 1
|
||||||
if limit and total >= limit:
|
if limit and total >= limit:
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
for filename in sorted(directory.glob('*.npz')):
|
for filename in sorted(directory.glob("*.npz")):
|
||||||
try:
|
try:
|
||||||
with filename.open('rb') as f:
|
with filename.open("rb") as f:
|
||||||
episode = np.load(f)
|
episode = np.load(f)
|
||||||
episode = {k: episode[k] for k in episode.keys()}
|
episode = {k: episode[k] for k in episode.keys()}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f'Could not load episode: {e}')
|
print(f"Could not load episode: {e}")
|
||||||
continue
|
continue
|
||||||
episodes[str(filename)] = episode
|
episodes[str(filename)] = episode
|
||||||
total += len(episode['reward']) - 1
|
total += len(episode["reward"]) - 1
|
||||||
if limit and total >= limit:
|
if limit and total >= limit:
|
||||||
break
|
break
|
||||||
return episodes
|
return episodes
|
||||||
|
|
||||||
|
|
||||||
class SampleDist:
|
class SampleDist:
|
||||||
|
|
||||||
def __init__(self, dist, samples=100):
|
def __init__(self, dist, samples=100):
|
||||||
self._dist = dist
|
self._dist = dist
|
||||||
self._samples = samples
|
self._samples = samples
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def name(self):
|
def name(self):
|
||||||
return 'SampleDist'
|
return "SampleDist"
|
||||||
|
|
||||||
def __getattr__(self, name):
|
def __getattr__(self, name):
|
||||||
return getattr(self._dist, name)
|
return getattr(self._dist, name)
|
||||||
@ -278,23 +278,24 @@ class SampleDist:
|
|||||||
|
|
||||||
|
|
||||||
class OneHotDist(torchd.one_hot_categorical.OneHotCategorical):
|
class OneHotDist(torchd.one_hot_categorical.OneHotCategorical):
|
||||||
|
|
||||||
def __init__(self, logits=None, probs=None, unimix_ratio=0.0):
|
def __init__(self, logits=None, probs=None, unimix_ratio=0.0):
|
||||||
if logits is not None and unimix_ratio > 0.0:
|
if logits is not None and unimix_ratio > 0.0:
|
||||||
probs = F.softmax(logits, dim=-1)
|
probs = F.softmax(logits, dim=-1)
|
||||||
probs = probs * (1.0-unimix_ratio) + unimix_ratio / probs.shape[-1]
|
probs = probs * (1.0 - unimix_ratio) + unimix_ratio / probs.shape[-1]
|
||||||
logits = torch.log(probs)
|
logits = torch.log(probs)
|
||||||
super().__init__(logits=logits, probs=None)
|
super().__init__(logits=logits, probs=None)
|
||||||
else:
|
else:
|
||||||
super().__init__(logits=logits, probs=probs)
|
super().__init__(logits=logits, probs=probs)
|
||||||
|
|
||||||
def mode(self):
|
def mode(self):
|
||||||
_mode = F.one_hot(torch.argmax(super().logits, axis=-1), super().logits.shape[-1])
|
_mode = F.one_hot(
|
||||||
|
torch.argmax(super().logits, axis=-1), super().logits.shape[-1]
|
||||||
|
)
|
||||||
return _mode.detach() + super().logits - super().logits.detach()
|
return _mode.detach() + super().logits - super().logits.detach()
|
||||||
|
|
||||||
def sample(self, sample_shape=(), seed=None):
|
def sample(self, sample_shape=(), seed=None):
|
||||||
if seed is not None:
|
if seed is not None:
|
||||||
raise ValueError('need to check')
|
raise ValueError("need to check")
|
||||||
sample = super().sample(sample_shape)
|
sample = super().sample(sample_shape)
|
||||||
probs = super().probs
|
probs = super().probs
|
||||||
while len(probs.shape) < len(sample.shape):
|
while len(probs.shape) < len(sample.shape):
|
||||||
@ -303,9 +304,8 @@ class OneHotDist(torchd.one_hot_categorical.OneHotCategorical):
|
|||||||
return sample
|
return sample
|
||||||
|
|
||||||
|
|
||||||
class TwoHotDistSymlog():
|
class TwoHotDistSymlog:
|
||||||
|
def __init__(self, logits=None, low=-20.0, high=20.0, device="cuda"):
|
||||||
def __init__(self, logits=None, low=-20.0, high=20.0, device='cuda'):
|
|
||||||
self.logits = logits
|
self.logits = logits
|
||||||
self.probs = torch.softmax(logits, -1)
|
self.probs = torch.softmax(logits, -1)
|
||||||
self.buckets = torch.linspace(low, high, steps=255).to(device)
|
self.buckets = torch.linspace(low, high, steps=255).to(device)
|
||||||
@ -324,11 +324,13 @@ class TwoHotDistSymlog():
|
|||||||
def log_prob(self, x):
|
def log_prob(self, x):
|
||||||
x = symlog(x)
|
x = symlog(x)
|
||||||
# x(time, batch, 1)
|
# x(time, batch, 1)
|
||||||
below = torch.sum((self.buckets <= x[..., None]).to(torch.int32), dim=-1) -1
|
below = torch.sum((self.buckets <= x[..., None]).to(torch.int32), dim=-1) - 1
|
||||||
above = len(self.buckets) - torch.sum((self.buckets > x[..., None]).to(torch.int32), dim=-1)
|
above = len(self.buckets) - torch.sum(
|
||||||
below = torch.clip(below, 0, len(self.buckets)-1)
|
(self.buckets > x[..., None]).to(torch.int32), dim=-1
|
||||||
above = torch.clip(above, 0, len(self.buckets)-1)
|
)
|
||||||
equal = (below == above)
|
below = torch.clip(below, 0, len(self.buckets) - 1)
|
||||||
|
above = torch.clip(above, 0, len(self.buckets) - 1)
|
||||||
|
equal = below == above
|
||||||
|
|
||||||
dist_to_below = torch.where(equal, 1, torch.abs(self.buckets[below] - x))
|
dist_to_below = torch.where(equal, 1, torch.abs(self.buckets[below] - x))
|
||||||
dist_to_above = torch.where(equal, 1, torch.abs(self.buckets[above] - x))
|
dist_to_above = torch.where(equal, 1, torch.abs(self.buckets[above] - x))
|
||||||
@ -336,8 +338,9 @@ class TwoHotDistSymlog():
|
|||||||
weight_below = dist_to_above / total
|
weight_below = dist_to_above / total
|
||||||
weight_above = dist_to_below / total
|
weight_above = dist_to_below / total
|
||||||
target = (
|
target = (
|
||||||
F.one_hot(below, num_classes=len(self.buckets)) * weight_below[..., None] +
|
F.one_hot(below, num_classes=len(self.buckets)) * weight_below[..., None]
|
||||||
F.one_hot(above, num_classes=len(self.buckets)) * weight_above[..., None])
|
+ F.one_hot(above, num_classes=len(self.buckets)) * weight_above[..., None]
|
||||||
|
)
|
||||||
log_pred = self.logits - torch.logsumexp(self.logits, -1, keepdim=True)
|
log_pred = self.logits - torch.logsumexp(self.logits, -1, keepdim=True)
|
||||||
target = target.squeeze(-2)
|
target = target.squeeze(-2)
|
||||||
|
|
||||||
@ -347,8 +350,11 @@ class TwoHotDistSymlog():
|
|||||||
log_pred = super().logits - torch.logsumexp(super().logits, -1, keepdim=True)
|
log_pred = super().logits - torch.logsumexp(super().logits, -1, keepdim=True)
|
||||||
return (target * log_pred).sum(-1)
|
return (target * log_pred).sum(-1)
|
||||||
|
|
||||||
class SymlogDist():
|
|
||||||
def __init__(self, mode, dist='mse', agg='sum', tol=1e-8, dim_to_reduce=[-1, -2, -3]):
|
class SymlogDist:
|
||||||
|
def __init__(
|
||||||
|
self, mode, dist="mse", agg="sum", tol=1e-8, dim_to_reduce=[-1, -2, -3]
|
||||||
|
):
|
||||||
self._mode = mode
|
self._mode = mode
|
||||||
self._dist = dist
|
self._dist = dist
|
||||||
self._agg = agg
|
self._agg = agg
|
||||||
@ -363,24 +369,24 @@ class SymlogDist():
|
|||||||
|
|
||||||
def log_prob(self, value):
|
def log_prob(self, value):
|
||||||
assert self._mode.shape == value.shape
|
assert self._mode.shape == value.shape
|
||||||
if self._dist == 'mse':
|
if self._dist == "mse":
|
||||||
distance = (self._mode - symlog(value)) ** 2.0
|
distance = (self._mode - symlog(value)) ** 2.0
|
||||||
distance = torch.where(distance < self._tol, 0, distance)
|
distance = torch.where(distance < self._tol, 0, distance)
|
||||||
elif self._dist == 'abs':
|
elif self._dist == "abs":
|
||||||
distance = torch.abs(self._mode - symlog(value))
|
distance = torch.abs(self._mode - symlog(value))
|
||||||
distance = torch.where(distance < self._tol, 0, distance)
|
distance = torch.where(distance < self._tol, 0, distance)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(self._dist)
|
raise NotImplementedError(self._dist)
|
||||||
if self._agg == 'mean':
|
if self._agg == "mean":
|
||||||
loss = distance.mean(self._dim_to_reduce)
|
loss = distance.mean(self._dim_to_reduce)
|
||||||
elif self._agg == 'sum':
|
elif self._agg == "sum":
|
||||||
loss = distance.sum(self._dim_to_reduce)
|
loss = distance.sum(self._dim_to_reduce)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(self._agg)
|
raise NotImplementedError(self._agg)
|
||||||
return -loss
|
return -loss
|
||||||
|
|
||||||
class ContDist:
|
|
||||||
|
|
||||||
|
class ContDist:
|
||||||
def __init__(self, dist=None):
|
def __init__(self, dist=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._dist = dist
|
self._dist = dist
|
||||||
@ -403,7 +409,6 @@ class ContDist:
|
|||||||
|
|
||||||
|
|
||||||
class Bernoulli:
|
class Bernoulli:
|
||||||
|
|
||||||
def __init__(self, dist=None):
|
def __init__(self, dist=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._dist = dist
|
self._dist = dist
|
||||||
@ -417,7 +422,7 @@ class Bernoulli:
|
|||||||
|
|
||||||
def mode(self):
|
def mode(self):
|
||||||
_mode = torch.round(self._dist.mean)
|
_mode = torch.round(self._dist.mean)
|
||||||
return _mode.detach() +self._dist.mean - self._dist.mean.detach()
|
return _mode.detach() + self._dist.mean - self._dist.mean.detach()
|
||||||
|
|
||||||
def sample(self, sample_shape=()):
|
def sample(self, sample_shape=()):
|
||||||
return self._dist.rsample(sample_shape)
|
return self._dist.rsample(sample_shape)
|
||||||
@ -427,25 +432,25 @@ class Bernoulli:
|
|||||||
log_probs0 = -F.softplus(_logits)
|
log_probs0 = -F.softplus(_logits)
|
||||||
log_probs1 = -F.softplus(-_logits)
|
log_probs1 = -F.softplus(-_logits)
|
||||||
|
|
||||||
return log_probs0 * (1-x) + log_probs1 * x
|
return log_probs0 * (1 - x) + log_probs1 * x
|
||||||
|
|
||||||
|
|
||||||
class UnnormalizedHuber(torchd.normal.Normal):
|
class UnnormalizedHuber(torchd.normal.Normal):
|
||||||
|
|
||||||
def __init__(self, loc, scale, threshold=1, **kwargs):
|
def __init__(self, loc, scale, threshold=1, **kwargs):
|
||||||
super().__init__(loc, scale, **kwargs)
|
super().__init__(loc, scale, **kwargs)
|
||||||
self._threshold = threshold
|
self._threshold = threshold
|
||||||
|
|
||||||
def log_prob(self, event):
|
def log_prob(self, event):
|
||||||
return -(torch.sqrt(
|
return -(
|
||||||
(event - self.mean) ** 2 + self._threshold ** 2) - self._threshold)
|
torch.sqrt((event - self.mean) ** 2 + self._threshold**2)
|
||||||
|
- self._threshold
|
||||||
|
)
|
||||||
|
|
||||||
def mode(self):
|
def mode(self):
|
||||||
return self.mean
|
return self.mean
|
||||||
|
|
||||||
|
|
||||||
class SafeTruncatedNormal(torchd.normal.Normal):
|
class SafeTruncatedNormal(torchd.normal.Normal):
|
||||||
|
|
||||||
def __init__(self, loc, scale, low, high, clip=1e-6, mult=1):
|
def __init__(self, loc, scale, low, high, clip=1e-6, mult=1):
|
||||||
super().__init__(loc, scale)
|
super().__init__(loc, scale)
|
||||||
self._low = low
|
self._low = low
|
||||||
@ -456,8 +461,7 @@ class SafeTruncatedNormal(torchd.normal.Normal):
|
|||||||
def sample(self, sample_shape):
|
def sample(self, sample_shape):
|
||||||
event = super().sample(sample_shape)
|
event = super().sample(sample_shape)
|
||||||
if self._clip:
|
if self._clip:
|
||||||
clipped = torch.clip(event, self._low + self._clip,
|
clipped = torch.clip(event, self._low + self._clip, self._high - self._clip)
|
||||||
self._high - self._clip)
|
|
||||||
event = event - event.detach() + clipped.detach()
|
event = event - event.detach() + clipped.detach()
|
||||||
if self._mult:
|
if self._mult:
|
||||||
event *= self._mult
|
event *= self._mult
|
||||||
@ -465,8 +469,7 @@ class SafeTruncatedNormal(torchd.normal.Normal):
|
|||||||
|
|
||||||
|
|
||||||
class TanhBijector(torchd.Transform):
|
class TanhBijector(torchd.Transform):
|
||||||
|
def __init__(self, validate_args=False, name="tanh"):
|
||||||
def __init__(self, validate_args=False, name='tanh'):
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
def _forward(self, x):
|
def _forward(self, x):
|
||||||
@ -474,8 +477,8 @@ class TanhBijector(torchd.Transform):
|
|||||||
|
|
||||||
def _inverse(self, y):
|
def _inverse(self, y):
|
||||||
y = torch.where(
|
y = torch.where(
|
||||||
(torch.abs(y) <= 1.),
|
(torch.abs(y) <= 1.0), torch.clamp(y, -0.99999997, 0.99999997), y
|
||||||
torch.clamp(y, -0.99999997, 0.99999997), y)
|
)
|
||||||
y = torch.atanh(y)
|
y = torch.atanh(y)
|
||||||
return y
|
return y
|
||||||
|
|
||||||
@ -504,16 +507,15 @@ def static_scan_for_lambda_return(fn, inputs, start):
|
|||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
def lambda_return(
|
def lambda_return(reward, value, pcont, bootstrap, lambda_, axis):
|
||||||
reward, value, pcont, bootstrap, lambda_, axis):
|
|
||||||
# Setting lambda=1 gives a discounted Monte Carlo return.
|
# Setting lambda=1 gives a discounted Monte Carlo return.
|
||||||
# Setting lambda=0 gives a fixed 1-step return.
|
# Setting lambda=0 gives a fixed 1-step return.
|
||||||
#assert reward.shape.ndims == value.shape.ndims, (reward.shape, value.shape)
|
# assert reward.shape.ndims == value.shape.ndims, (reward.shape, value.shape)
|
||||||
assert len(reward.shape) == len(value.shape), (reward.shape, value.shape)
|
assert len(reward.shape) == len(value.shape), (reward.shape, value.shape)
|
||||||
if isinstance(pcont, (int, float)):
|
if isinstance(pcont, (int, float)):
|
||||||
pcont = pcont * torch.ones_like(reward)
|
pcont = pcont * torch.ones_like(reward)
|
||||||
dims = list(range(len(reward.shape)))
|
dims = list(range(len(reward.shape)))
|
||||||
dims = [axis] + dims[1:axis] + [0] + dims[axis + 1:]
|
dims = [axis] + dims[1:axis] + [0] + dims[axis + 1 :]
|
||||||
if axis != 0:
|
if axis != 0:
|
||||||
reward = reward.permute(dims)
|
reward = reward.permute(dims)
|
||||||
value = value.permute(dims)
|
value = value.permute(dims)
|
||||||
@ -522,23 +524,31 @@ def lambda_return(
|
|||||||
bootstrap = torch.zeros_like(value[-1])
|
bootstrap = torch.zeros_like(value[-1])
|
||||||
next_values = torch.cat([value[1:], bootstrap[None]], 0)
|
next_values = torch.cat([value[1:], bootstrap[None]], 0)
|
||||||
inputs = reward + pcont * next_values * (1 - lambda_)
|
inputs = reward + pcont * next_values * (1 - lambda_)
|
||||||
#returns = static_scan(
|
# returns = static_scan(
|
||||||
# lambda agg, cur0, cur1: cur0 + cur1 * lambda_ * agg,
|
# lambda agg, cur0, cur1: cur0 + cur1 * lambda_ * agg,
|
||||||
# (inputs, pcont), bootstrap, reverse=True)
|
# (inputs, pcont), bootstrap, reverse=True)
|
||||||
# reimplement to optimize performance
|
# reimplement to optimize performance
|
||||||
returns = static_scan_for_lambda_return(
|
returns = static_scan_for_lambda_return(
|
||||||
lambda agg, cur0, cur1: cur0 + cur1 * lambda_ * agg,
|
lambda agg, cur0, cur1: cur0 + cur1 * lambda_ * agg, (inputs, pcont), bootstrap
|
||||||
(inputs, pcont), bootstrap)
|
)
|
||||||
if axis != 0:
|
if axis != 0:
|
||||||
returns = returns.permute(dims)
|
returns = returns.permute(dims)
|
||||||
return returns
|
return returns
|
||||||
|
|
||||||
|
|
||||||
class Optimizer():
|
class Optimizer:
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, name, parameters, lr, eps=1e-4, clip=None, wd=None, wd_pattern=r'.*',
|
self,
|
||||||
opt='adam', use_amp=False):
|
name,
|
||||||
|
parameters,
|
||||||
|
lr,
|
||||||
|
eps=1e-4,
|
||||||
|
clip=None,
|
||||||
|
wd=None,
|
||||||
|
wd_pattern=r".*",
|
||||||
|
opt="adam",
|
||||||
|
use_amp=False,
|
||||||
|
):
|
||||||
assert 0 <= wd < 1
|
assert 0 <= wd < 1
|
||||||
assert not clip or 1 <= clip
|
assert not clip or 1 <= clip
|
||||||
self._name = name
|
self._name = name
|
||||||
@ -547,41 +557,33 @@ class Optimizer():
|
|||||||
self._wd = wd
|
self._wd = wd
|
||||||
self._wd_pattern = wd_pattern
|
self._wd_pattern = wd_pattern
|
||||||
self._opt = {
|
self._opt = {
|
||||||
'adam': lambda: torch.optim.Adam(parameters,
|
"adam": lambda: torch.optim.Adam(parameters, lr=lr, eps=eps),
|
||||||
lr=lr,
|
"nadam": lambda: NotImplemented(f"{opt} is not implemented"),
|
||||||
eps=eps),
|
"adamax": lambda: torch.optim.Adamax(parameters, lr=lr, eps=eps),
|
||||||
'nadam': lambda: NotImplemented(
|
"sgd": lambda: torch.optim.SGD(parameters, lr=lr),
|
||||||
f'{opt} is not implemented'),
|
"momentum": lambda: torch.optim.SGD(parameters, lr=lr, momentum=0.9),
|
||||||
'adamax': lambda: torch.optim.Adamax(parameters,
|
|
||||||
lr=lr,
|
|
||||||
eps=eps),
|
|
||||||
'sgd': lambda: torch.optim.SGD(parameters,
|
|
||||||
lr=lr),
|
|
||||||
'momentum': lambda: torch.optim.SGD(parameters,
|
|
||||||
lr=lr,
|
|
||||||
momentum=0.9),
|
|
||||||
}[opt]()
|
}[opt]()
|
||||||
self._scaler = torch.cuda.amp.GradScaler(enabled=use_amp)
|
self._scaler = torch.cuda.amp.GradScaler(enabled=use_amp)
|
||||||
|
|
||||||
def __call__(self, loss, params, retain_graph=False):
|
def __call__(self, loss, params, retain_graph=False):
|
||||||
assert len(loss.shape) == 0, loss.shape
|
assert len(loss.shape) == 0, loss.shape
|
||||||
metrics = {}
|
metrics = {}
|
||||||
metrics[f'{self._name}_loss'] = loss.detach().cpu().numpy()
|
metrics[f"{self._name}_loss"] = loss.detach().cpu().numpy()
|
||||||
self._scaler.scale(loss).backward()
|
self._scaler.scale(loss).backward()
|
||||||
self._scaler.unscale_(self._opt)
|
self._scaler.unscale_(self._opt)
|
||||||
#loss.backward(retain_graph=retain_graph)
|
# loss.backward(retain_graph=retain_graph)
|
||||||
norm = torch.nn.utils.clip_grad_norm_(params, self._clip)
|
norm = torch.nn.utils.clip_grad_norm_(params, self._clip)
|
||||||
if self._wd:
|
if self._wd:
|
||||||
self._apply_weight_decay(params)
|
self._apply_weight_decay(params)
|
||||||
self._scaler.step(self._opt)
|
self._scaler.step(self._opt)
|
||||||
self._scaler.update()
|
self._scaler.update()
|
||||||
#self._opt.step()
|
# self._opt.step()
|
||||||
self._opt.zero_grad()
|
self._opt.zero_grad()
|
||||||
metrics[f'{self._name}_grad_norm'] = norm.item()
|
metrics[f"{self._name}_grad_norm"] = norm.item()
|
||||||
return metrics
|
return metrics
|
||||||
|
|
||||||
def _apply_weight_decay(self, varibs):
|
def _apply_weight_decay(self, varibs):
|
||||||
nontrivial = (self._wd_pattern != r'.*')
|
nontrivial = self._wd_pattern != r".*"
|
||||||
if nontrivial:
|
if nontrivial:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
for var in varibs:
|
for var in varibs:
|
||||||
@ -593,16 +595,18 @@ def args_type(default):
|
|||||||
if default is None:
|
if default is None:
|
||||||
return x
|
return x
|
||||||
if isinstance(default, bool):
|
if isinstance(default, bool):
|
||||||
return bool(['False', 'True'].index(x))
|
return bool(["False", "True"].index(x))
|
||||||
if isinstance(default, int):
|
if isinstance(default, int):
|
||||||
return float(x) if ('e' in x or '.' in x) else int(x)
|
return float(x) if ("e" in x or "." in x) else int(x)
|
||||||
if isinstance(default, (list, tuple)):
|
if isinstance(default, (list, tuple)):
|
||||||
return tuple(args_type(default[0])(y) for y in x.split(','))
|
return tuple(args_type(default[0])(y) for y in x.split(","))
|
||||||
return type(default)(x)
|
return type(default)(x)
|
||||||
|
|
||||||
def parse_object(x):
|
def parse_object(x):
|
||||||
if isinstance(default, (list, tuple)):
|
if isinstance(default, (list, tuple)):
|
||||||
return tuple(x)
|
return tuple(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
return lambda x: parse_string(x) if isinstance(x, str) else parse_object(x)
|
return lambda x: parse_string(x) if isinstance(x, str) else parse_object(x)
|
||||||
|
|
||||||
|
|
||||||
@ -615,34 +619,46 @@ def static_scan(fn, inputs, start):
|
|||||||
last = fn(last, *inp(index))
|
last = fn(last, *inp(index))
|
||||||
if flag:
|
if flag:
|
||||||
if type(last) == type({}):
|
if type(last) == type({}):
|
||||||
outputs = {key: value.clone().unsqueeze(0) for key, value in last.items()}
|
outputs = {
|
||||||
|
key: value.clone().unsqueeze(0) for key, value in last.items()
|
||||||
|
}
|
||||||
else:
|
else:
|
||||||
outputs = []
|
outputs = []
|
||||||
for _last in last:
|
for _last in last:
|
||||||
if type(_last) == type({}):
|
if type(_last) == type({}):
|
||||||
outputs.append({key: value.clone().unsqueeze(0) for key, value in _last.items()})
|
outputs.append(
|
||||||
|
{
|
||||||
|
key: value.clone().unsqueeze(0)
|
||||||
|
for key, value in _last.items()
|
||||||
|
}
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
outputs.append(_last.clone().unsqueeze(0))
|
outputs.append(_last.clone().unsqueeze(0))
|
||||||
flag = False
|
flag = False
|
||||||
else:
|
else:
|
||||||
if type(last) == type({}):
|
if type(last) == type({}):
|
||||||
for key in last.keys():
|
for key in last.keys():
|
||||||
outputs[key] = torch.cat([outputs[key], last[key].unsqueeze(0)], dim=0)
|
outputs[key] = torch.cat(
|
||||||
|
[outputs[key], last[key].unsqueeze(0)], dim=0
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
for j in range(len(outputs)):
|
for j in range(len(outputs)):
|
||||||
if type(last[j]) == type({}):
|
if type(last[j]) == type({}):
|
||||||
for key in last[j].keys():
|
for key in last[j].keys():
|
||||||
outputs[j][key] = torch.cat([outputs[j][key],
|
outputs[j][key] = torch.cat(
|
||||||
last[j][key].unsqueeze(0)], dim=0)
|
[outputs[j][key], last[j][key].unsqueeze(0)], dim=0
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
outputs[j] = torch.cat([outputs[j], last[j].unsqueeze(0)], dim=0)
|
outputs[j] = torch.cat(
|
||||||
|
[outputs[j], last[j].unsqueeze(0)], dim=0
|
||||||
|
)
|
||||||
if type(last) == type({}):
|
if type(last) == type({}):
|
||||||
outputs = [outputs]
|
outputs = [outputs]
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
# Original version
|
# Original version
|
||||||
#def static_scan2(fn, inputs, start, reverse=False):
|
# def static_scan2(fn, inputs, start, reverse=False):
|
||||||
# last = start
|
# last = start
|
||||||
# outputs = [[] for _ in range(len([start] if type(start)==type({}) else start))]
|
# outputs = [[] for _ in range(len([start] if type(start)==type({}) else start))]
|
||||||
# indices = range(inputs[0].shape[0])
|
# indices = range(inputs[0].shape[0])
|
||||||
@ -673,7 +689,6 @@ def static_scan(fn, inputs, start):
|
|||||||
|
|
||||||
|
|
||||||
class Every:
|
class Every:
|
||||||
|
|
||||||
def __init__(self, every):
|
def __init__(self, every):
|
||||||
self._every = every
|
self._every = every
|
||||||
self._last = None
|
self._last = None
|
||||||
@ -688,8 +703,8 @@ class Every:
|
|||||||
self._last += self._every * count
|
self._last += self._every * count
|
||||||
return count
|
return count
|
||||||
|
|
||||||
class Once:
|
|
||||||
|
|
||||||
|
class Once:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self._once = True
|
self._once = True
|
||||||
|
|
||||||
@ -701,7 +716,6 @@ class Once:
|
|||||||
|
|
||||||
|
|
||||||
class Until:
|
class Until:
|
||||||
|
|
||||||
def __init__(self, until):
|
def __init__(self, until):
|
||||||
self._until = until
|
self._until = until
|
||||||
|
|
||||||
@ -715,21 +729,21 @@ def schedule(string, step):
|
|||||||
try:
|
try:
|
||||||
return float(string)
|
return float(string)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
match = re.match(r'linear\((.+),(.+),(.+)\)', string)
|
match = re.match(r"linear\((.+),(.+),(.+)\)", string)
|
||||||
if match:
|
if match:
|
||||||
initial, final, duration = [float(group) for group in match.groups()]
|
initial, final, duration = [float(group) for group in match.groups()]
|
||||||
mix = torch.clip(torch.Tensor([step / duration]), 0, 1)[0]
|
mix = torch.clip(torch.Tensor([step / duration]), 0, 1)[0]
|
||||||
return (1 - mix) * initial + mix * final
|
return (1 - mix) * initial + mix * final
|
||||||
match = re.match(r'warmup\((.+),(.+)\)', string)
|
match = re.match(r"warmup\((.+),(.+)\)", string)
|
||||||
if match:
|
if match:
|
||||||
warmup, value = [float(group) for group in match.groups()]
|
warmup, value = [float(group) for group in match.groups()]
|
||||||
scale = torch.clip(step / warmup, 0, 1)
|
scale = torch.clip(step / warmup, 0, 1)
|
||||||
return scale * value
|
return scale * value
|
||||||
match = re.match(r'exp\((.+),(.+),(.+)\)', string)
|
match = re.match(r"exp\((.+),(.+),(.+)\)", string)
|
||||||
if match:
|
if match:
|
||||||
initial, final, halflife = [float(group) for group in match.groups()]
|
initial, final, halflife = [float(group) for group in match.groups()]
|
||||||
return (initial - final) * 0.5 ** (step / halflife) + final
|
return (initial - final) * 0.5 ** (step / halflife) + final
|
||||||
match = re.match(r'horizon\((.+),(.+),(.+)\)', string)
|
match = re.match(r"horizon\((.+),(.+),(.+)\)", string)
|
||||||
if match:
|
if match:
|
||||||
initial, final, duration = [float(group) for group in match.groups()]
|
initial, final, duration = [float(group) for group in match.groups()]
|
||||||
mix = torch.clip(step / duration, 0, 1)
|
mix = torch.clip(step / duration, 0, 1)
|
||||||
@ -737,6 +751,7 @@ def schedule(string, step):
|
|||||||
return 1 - 1 / horizon
|
return 1 - 1 / horizon
|
||||||
raise NotImplementedError(string)
|
raise NotImplementedError(string)
|
||||||
|
|
||||||
|
|
||||||
def weight_init(m):
|
def weight_init(m):
|
||||||
if isinstance(m, nn.Linear):
|
if isinstance(m, nn.Linear):
|
||||||
in_num = m.in_features
|
in_num = m.in_features
|
||||||
@ -744,8 +759,8 @@ def weight_init(m):
|
|||||||
denoms = (in_num + out_num) / 2.0
|
denoms = (in_num + out_num) / 2.0
|
||||||
scale = 1.0 / denoms
|
scale = 1.0 / denoms
|
||||||
std = np.sqrt(scale) / 0.87962566103423978
|
std = np.sqrt(scale) / 0.87962566103423978
|
||||||
nn.init.trunc_normal_(m.weight.data, mean=0.0, std=std, a=- 2.0, b=2.0)
|
nn.init.trunc_normal_(m.weight.data, mean=0.0, std=std, a=-2.0, b=2.0)
|
||||||
if hasattr(m.bias, 'data'):
|
if hasattr(m.bias, "data"):
|
||||||
m.bias.data.fill_(0.0)
|
m.bias.data.fill_(0.0)
|
||||||
elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
|
elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
|
||||||
space = m.kernel_size[0] * m.kernel_size[1]
|
space = m.kernel_size[0] * m.kernel_size[1]
|
||||||
@ -754,14 +769,15 @@ def weight_init(m):
|
|||||||
denoms = (in_num + out_num) / 2.0
|
denoms = (in_num + out_num) / 2.0
|
||||||
scale = 1.0 / denoms
|
scale = 1.0 / denoms
|
||||||
std = np.sqrt(scale) / 0.87962566103423978
|
std = np.sqrt(scale) / 0.87962566103423978
|
||||||
nn.init.trunc_normal_(m.weight.data, mean=0.0, std=std, a=- 2.0, b=2.0)
|
nn.init.trunc_normal_(m.weight.data, mean=0.0, std=std, a=-2.0, b=2.0)
|
||||||
if hasattr(m.bias, 'data'):
|
if hasattr(m.bias, "data"):
|
||||||
m.bias.data.fill_(0.0)
|
m.bias.data.fill_(0.0)
|
||||||
elif isinstance(m, nn.LayerNorm):
|
elif isinstance(m, nn.LayerNorm):
|
||||||
m.weight.data.fill_(1.0)
|
m.weight.data.fill_(1.0)
|
||||||
if hasattr(m.bias, 'data'):
|
if hasattr(m.bias, "data"):
|
||||||
m.bias.data.fill_(0.0)
|
m.bias.data.fill_(0.0)
|
||||||
|
|
||||||
|
|
||||||
def uniform_weight_init(given_scale):
|
def uniform_weight_init(given_scale):
|
||||||
def f(m):
|
def f(m):
|
||||||
if isinstance(m, nn.Linear):
|
if isinstance(m, nn.Linear):
|
||||||
@ -771,21 +787,23 @@ def uniform_weight_init(given_scale):
|
|||||||
scale = given_scale / denoms
|
scale = given_scale / denoms
|
||||||
limit = np.sqrt(3 * scale)
|
limit = np.sqrt(3 * scale)
|
||||||
nn.init.uniform_(m.weight.data, a=-limit, b=limit)
|
nn.init.uniform_(m.weight.data, a=-limit, b=limit)
|
||||||
if hasattr(m.bias, 'data'):
|
if hasattr(m.bias, "data"):
|
||||||
m.bias.data.fill_(0.0)
|
m.bias.data.fill_(0.0)
|
||||||
elif isinstance(m, nn.LayerNorm):
|
elif isinstance(m, nn.LayerNorm):
|
||||||
m.weight.data.fill_(1.0)
|
m.weight.data.fill_(1.0)
|
||||||
if hasattr(m.bias, 'data'):
|
if hasattr(m.bias, "data"):
|
||||||
m.bias.data.fill_(0.0)
|
m.bias.data.fill_(0.0)
|
||||||
|
|
||||||
return f
|
return f
|
||||||
|
|
||||||
|
|
||||||
def tensorstats(tensor, prefix=None):
|
def tensorstats(tensor, prefix=None):
|
||||||
metrics = {
|
metrics = {
|
||||||
'mean': to_np(torch.mean(tensor)),
|
"mean": to_np(torch.mean(tensor)),
|
||||||
'std': to_np(torch.std(tensor)),
|
"std": to_np(torch.std(tensor)),
|
||||||
'min': to_np(torch.min(tensor)),
|
"min": to_np(torch.min(tensor)),
|
||||||
'max': to_np(torch.max(tensor)),
|
"max": to_np(torch.max(tensor)),
|
||||||
}
|
}
|
||||||
if prefix:
|
if prefix:
|
||||||
metrics = {f'{prefix}_{k}': v for k, v in metrics.items()}
|
metrics = {f"{prefix}_{k}": v for k, v in metrics.items()}
|
||||||
return metrics
|
return metrics
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user