import datetime
import io
import json
import pathlib
import pickle
import re
import time
import uuid

import numpy as np

import torch
from torch import nn
from torch.nn import functional as F
from torch import distributions as torchd
from torch.utils.data import Dataset
from torch.utils.tensorboard import SummaryWriter


to_np = lambda x: x.detach().cpu().numpy()

def symlog(x):
  return torch.sign(x) * torch.log(torch.abs(x) + 1.0)

def symexp(x):
  return torch.sign(x) * (torch.exp(torch.abs(x)) - 1.0)

class RequiresGrad:

  def __init__(self, model):
    self._model = model

  def __enter__(self):
    self._model.requires_grad_(requires_grad=True)

  def __exit__(self, *args):
    self._model.requires_grad_(requires_grad=False)


class TimeRecording:

  def __init__(self, comment):
    self._comment = comment

  def __enter__(self):
    self._st = torch.cuda.Event(enable_timing=True)
    self._nd = torch.cuda.Event(enable_timing=True)
    self._st.record()

  def __exit__(self, *args):
    self._nd.record()
    torch.cuda.synchronize()
    print(self._comment, self._st.elapsed_time(self._nd)/1000)


class Logger:

  def __init__(self, logdir, step):
    self._logdir = logdir
    self._writer = SummaryWriter(log_dir=str(logdir), max_queue=1000)
    self._last_step = None
    self._last_time = None
    self._scalars = {}
    self._images = {}
    self._videos = {}
    self.step = step

  def scalar(self, name, value):
    self._scalars[name] = float(value)

  def image(self, name, value):
    self._images[name] = np.array(value)

  def video(self, name, value):
    self._videos[name] = np.array(value)

  def write(self, fps=False):
    scalars = list(self._scalars.items())
    if fps:
      scalars.append(('fps', self._compute_fps(self.step)))
    print(f'[{self.step}]', ' / '.join(f'{k} {v:.1f}' for k, v in scalars))
    with (self._logdir / 'metrics.jsonl').open('a') as f:
      f.write(json.dumps({'step': self.step, ** dict(scalars)}) + '\n')
    for name, value in scalars:
      self._writer.add_scalar('scalars/' + name, value, self.step)
    for name, value in self._images.items():
      self._writer.add_image(name, value, self.step)
    for name, value in self._videos.items():
      name = name if isinstance(name, str) else name.decode('utf-8')
      if np.issubdtype(value.dtype, np.floating):
        value = np.clip(255 * value, 0, 255).astype(np.uint8)
      B, T, H, W, C = value.shape
      value = value.transpose(1, 4, 2, 0, 3).reshape((1, T, C, H, B*W))
      self._writer.add_video(name, value, self.step, 16)

    self._writer.flush()
    self._scalars = {}
    self._images = {}
    self._videos = {}

  def _compute_fps(self, step):
    if self._last_step is None:
      self._last_time = time.time()
      self._last_step = step
      return 0
    steps = step - self._last_step
    duration = time.time() - self._last_time
    self._last_time += duration
    self._last_step = step
    return steps / duration

  def offline_scalar(self, name, value, step):
    self._writer.add_scalar('scalars/'+name, value, step)

  def offline_video(self, name, value, step):
    if np.issubdtype(value.dtype, np.floating):
      value = np.clip(255 * value, 0, 255).astype(np.uint8)
    B, T, H, W, C = value.shape
    value = value.transpose(1, 4, 2, 0, 3).reshape((1, T, C, H, B*W))
    self._writer.add_video(name, value, step, 16)


def simulate(agent, envs, steps=0, episodes=0, state=None):
  # Initialize or unpack simulation state.
  if state is None:
    step, episode = 0, 0
    done = np.ones(len(envs), np.bool)
    length = np.zeros(len(envs), np.int32)
    obs = [None] * len(envs)
    agent_state = None
    reward = [0]*len(envs)
  else:
    step, episode, done, length, obs, agent_state, reward = state
  while (steps and step < steps) or (episodes and episode < episodes):
    # Reset envs if necessary.
    if done.any():
      indices = [index for index, d in enumerate(done) if d]
      results = [envs[i].reset() for i in indices]
      for index, result in zip(indices, results):
        obs[index] = result
      reward = [reward[i]*(1-done[i]) for i in range(len(envs))]
    # Step agents.
    obs = {k: np.stack([o[k] for o in obs]) for k in obs[0]}
    action, agent_state = agent(obs, done, agent_state, reward)
    if isinstance(action, dict):
      action = [
          {k: np.array(action[k][i].detach().cpu()) for k in action}
          for i in range(len(envs))]
    else:
      action = np.array(action)
    assert len(action) == len(envs)
    # Step envs.
    results = [e.step(a) for e, a in zip(envs, action)]
    obs, reward, done = zip(*[p[:3] for p in results])
    obs = list(obs)
    reward = list(reward)
    done = np.stack(done)
    episode += int(done.sum())
    length += 1
    step += (done * length).sum()
    length *= (1 - done)

  return (step - steps, episode - episodes, done, length, obs, agent_state, reward)


def save_episodes(directory, episodes):
  directory = pathlib.Path(directory).expanduser()
  directory.mkdir(parents=True, exist_ok=True)
  timestamp = datetime.datetime.now().strftime('%Y%m%dT%H%M%S')
  filenames = []
  for episode in episodes:
    identifier = str(uuid.uuid4().hex)
    length = len(episode['reward'])
    filename = directory / f'{timestamp}-{identifier}-{length}.npz'
    with io.BytesIO() as f1:
      np.savez_compressed(f1, **episode)
      f1.seek(0)
      with filename.open('wb') as f2:
        f2.write(f1.read())
    filenames.append(filename)
  return filenames


def from_generator(generator, batch_size):
  while True:
    batch = []
    for _ in range(batch_size):
      batch.append(next(generator))
    data = {}
    for key in batch[0].keys():
      data[key] = []
      for i in range(batch_size):
        data[key].append(batch[i][key])
      data[key] = np.stack(data[key], 0)
    yield data


def sample_episodes(episodes, length=None, balance=False, seed=0):
  random = np.random.RandomState(seed)
  while True:
    episode = random.choice(list(episodes.values()))
    if length:
      total = len(next(iter(episode.values())))
      available = total - length
      if available < 1:
        print(f'Skipped short episode of length {available}.')
        continue
      if balance:
        index = min(random.randint(0, total), available)
      else:
        index = int(random.randint(0, available + 1))
      episode = {k: v[index: index + length] for k, v in episode.items()}
    yield episode


def load_episodes(directory, limit=None, reverse=True):
  directory = pathlib.Path(directory).expanduser()
  episodes = {}
  total = 0
  if reverse:
    for filename in reversed(sorted(directory.glob('*.npz'))):
      try:
        with filename.open('rb') as f:
          episode = np.load(f)
          episode = {k: episode[k] for k in episode.keys()}
      except Exception as e:
        print(f'Could not load episode: {e}')
        continue
      episodes[str(filename)] = episode
      total += len(episode['reward']) - 1
      if limit and total >= limit:
        break
  else:
    for filename in sorted(directory.glob('*.npz')):
      try:
        with filename.open('rb') as f:
          episode = np.load(f)
          episode = {k: episode[k] for k in episode.keys()}
      except Exception as e:
        print(f'Could not load episode: {e}')
        continue
      episodes[str(filename)] = episode
      total += len(episode['reward']) - 1
      if limit and total >= limit:
        break
  return episodes


class SampleDist:

  def __init__(self, dist, samples=100):
    self._dist = dist
    self._samples = samples

  @property
  def name(self):
    return 'SampleDist'

  def __getattr__(self, name):
    return getattr(self._dist, name)

  def mean(self):
    samples = self._dist.sample(self._samples)
    return torch.mean(samples, 0)

  def mode(self):
    sample = self._dist.sample(self._samples)
    logprob = self._dist.log_prob(sample)
    return sample[torch.argmax(logprob)][0]

  def entropy(self):
    sample = self._dist.sample(self._samples)
    logprob = self.log_prob(sample)
    return -torch.mean(logprob, 0)


class OneHotDist(torchd.one_hot_categorical.OneHotCategorical):

  def __init__(self, logits=None, probs=None, unimix_ratio=0.0):
    if logits is not None and unimix_ratio > 0.0:
      probs = F.softmax(logits, dim=-1)
      probs = probs * (1.0-unimix_ratio) + unimix_ratio / probs.shape[-1]
      logits = torch.log(probs)
      super().__init__(logits=logits, probs=None)
    else:
      super().__init__(logits=logits, probs=probs)

  def mode(self):
    _mode = F.one_hot(torch.argmax(super().logits, axis=-1), super().logits.shape[-1])
    return _mode.detach() + super().logits - super().logits.detach()

  def sample(self, sample_shape=(), seed=None):
    if seed is not None:
      raise ValueError('need to check')
    sample = super().sample(sample_shape)
    probs = super().probs
    while len(probs.shape) < len(sample.shape):
      probs = probs[None]
    sample += probs - probs.detach()
    return sample


class TwoHotDistSymlog():

  def __init__(self, logits=None, low=-20.0, high=20.0, device='cuda'):
    self.logits = logits
    self.probs = torch.softmax(logits, -1)
    self.buckets = torch.linspace(low, high, steps=255).to(device)
    self.width = (self.buckets[-1] - self.buckets[0]) / 255

  def mean(self):
    print("mean called")
    _mode = self.probs * self.buckets
    return symexp(torch.sum(_mode, dim=-1, keepdim=True))

  def mode(self):
    _mode = self.probs * self.buckets
    return symexp(torch.sum(_mode, dim=-1, keepdim=True))

  # Inside OneHotCategorical, log_prob is calculated using only max element in targets
  def log_prob(self, x):
    x = symlog(x)
    # x(time, batch, 1)
    below = torch.sum((self.buckets <= x[..., None]).to(torch.int32), dim=-1) -1
    above = len(self.buckets) - torch.sum((self.buckets > x[..., None]).to(torch.int32), dim=-1)
    below = torch.clip(below, 0, len(self.buckets)-1)
    above = torch.clip(above, 0, len(self.buckets)-1)
    equal = (below == above)

    dist_to_below = torch.where(equal, 1, torch.abs(self.buckets[below] - x))
    dist_to_above = torch.where(equal, 1, torch.abs(self.buckets[above] - x))
    total = dist_to_below + dist_to_above
    weight_below = dist_to_above / total
    weight_above = dist_to_below / total
    target = (
      F.one_hot(below, num_classes=len(self.buckets)) * weight_below[..., None] +
      F.one_hot(above, num_classes=len(self.buckets)) * weight_above[..., None])
    log_pred = self.logits - torch.logsumexp(self.logits, -1, keepdim=True)
    target = target.squeeze(-2)

    return (target * log_pred).sum(-1)

  def log_prob_target(self, target):
    log_pred = super().logits - torch.logsumexp(super().logits, -1, keepdim=True)
    return (target * log_pred).sum(-1)

class SymlogDist():
  def __init__(self, mode, dist='mse', agg='sum', tol=1e-8, dim_to_reduce=[-1, -2, -3]):
    self._mode = mode
    self._dist = dist
    self._agg = agg
    self._tol = tol
    self._dim_to_reduce = dim_to_reduce

  def mode(self):
    return symexp(self._mode)

  def mean(self):
    return symexp(self._mode)

  def log_prob(self, value):
    assert self._mode.shape == value.shape
    if self._dist == 'mse':
      distance = (self._mode - symlog(value)) ** 2.0
      distance = torch.where(distance < self._tol, 0, distance)
    elif self._dist == 'abs':
      distance = torch.abs(self._mode - symlog(value))
      distance = torch.where(distance < self._tol, 0, distance)
    else:
      raise NotImplementedError(self._dist)
    if self._agg == 'mean':
      loss = distance.mean(self._dim_to_reduce)
    elif self._agg == 'sum':
      loss = distance.sum(self._dim_to_reduce)
    else:
      raise NotImplementedError(self._agg)
    return -loss

class ContDist:

  def __init__(self, dist=None):
    super().__init__()
    self._dist = dist
    self.mean = dist.mean

  def __getattr__(self, name):
    return getattr(self._dist, name)

  def entropy(self):
    return self._dist.entropy()

  def mode(self):
    return self._dist.mean

  def sample(self, sample_shape=()):
    return self._dist.rsample(sample_shape)

  def log_prob(self, x):
    return self._dist.log_prob(x)


class Bernoulli:

  def __init__(self, dist=None):
    super().__init__()
    self._dist = dist
    self.mean = dist.mean

  def __getattr__(self, name):
    return getattr(self._dist, name)

  def entropy(self):
    return self._dist.entropy()

  def mode(self):
    _mode = torch.round(self._dist.mean)
    return _mode.detach() +self._dist.mean - self._dist.mean.detach()

  def sample(self, sample_shape=()):
    return self._dist.rsample(sample_shape)

  def log_prob(self, x):
    _logits = self._dist.base_dist.logits
    log_probs0 = -F.softplus(_logits)
    log_probs1 = -F.softplus(-_logits)

    return log_probs0 * (1-x) + log_probs1 * x


class UnnormalizedHuber(torchd.normal.Normal):

  def __init__(self, loc, scale, threshold=1, **kwargs):
    super().__init__(loc, scale, **kwargs)
    self._threshold = threshold

  def log_prob(self, event):
    return -(torch.sqrt(
        (event - self.mean) ** 2 + self._threshold ** 2) - self._threshold)

  def mode(self):
    return self.mean


class SafeTruncatedNormal(torchd.normal.Normal):

  def __init__(self, loc, scale, low, high, clip=1e-6, mult=1):
    super().__init__(loc, scale)
    self._low = low
    self._high = high
    self._clip = clip
    self._mult = mult

  def sample(self, sample_shape):
    event = super().sample(sample_shape)
    if self._clip:
      clipped = torch.clip(event, self._low + self._clip,
          self._high - self._clip)
      event = event - event.detach() + clipped.detach()
    if self._mult:
      event *= self._mult
    return event


class TanhBijector(torchd.Transform):

  def __init__(self, validate_args=False, name='tanh'):
    super().__init__()

  def _forward(self, x):
    return torch.tanh(x)

  def _inverse(self, y):
    y = torch.where(
        (torch.abs(y) <= 1.),
        torch.clamp(y, -0.99999997, 0.99999997), y)
    y = torch.atanh(y)
    return y

  def _forward_log_det_jacobian(self, x):
    log2 = torch.math.log(2.0)
    return 2.0 * (log2 - x - torch.softplus(-2.0 * x))


def static_scan_for_lambda_return(fn, inputs, start):
  last = start
  indices = range(inputs[0].shape[0])
  indices = reversed(indices)
  flag = True
  for index in indices:
    # (inputs, pcont) -> (inputs[index], pcont[index])
    inp = lambda x: (_input[x] for _input in inputs)
    last = fn(last, *inp(index))
    if flag:
      outputs = last
      flag = False
    else:
      outputs = torch.cat([outputs, last], dim=-1)
  outputs = torch.reshape(outputs, [outputs.shape[0], outputs.shape[1], 1])
  outputs = torch.flip(outputs, [1])
  outputs = torch.unbind(outputs, dim=0)
  return outputs


def lambda_return(
    reward, value, pcont, bootstrap, lambda_, axis):
  # Setting lambda=1 gives a discounted Monte Carlo return.
  # Setting lambda=0 gives a fixed 1-step return.
  #assert reward.shape.ndims == value.shape.ndims, (reward.shape, value.shape)
  assert len(reward.shape) == len(value.shape), (reward.shape, value.shape)
  if isinstance(pcont, (int, float)):
    pcont = pcont * torch.ones_like(reward)
  dims = list(range(len(reward.shape)))
  dims = [axis] + dims[1:axis] + [0] + dims[axis + 1:]
  if axis != 0:
    reward = reward.permute(dims)
    value = value.permute(dims)
    pcont = pcont.permute(dims)
  if bootstrap is None:
    bootstrap = torch.zeros_like(value[-1])
  next_values = torch.cat([value[1:], bootstrap[None]], 0)
  inputs = reward + pcont * next_values * (1 - lambda_)
  #returns = static_scan(
  #    lambda agg, cur0, cur1: cur0 + cur1 * lambda_ * agg,
  #    (inputs, pcont), bootstrap, reverse=True)
  # reimplement to optimize performance
  returns = static_scan_for_lambda_return(
      lambda agg, cur0, cur1: cur0 + cur1 * lambda_ * agg,
      (inputs, pcont), bootstrap)
  if axis != 0:
    returns = returns.permute(dims)
  return returns


class Optimizer():

  def __init__(
      self, name, parameters, lr, eps=1e-4, clip=None, wd=None, wd_pattern=r'.*',
      opt='adam', use_amp=False):
    assert 0 <= wd < 1
    assert not clip or 1 <= clip
    self._name = name
    self._parameters = parameters
    self._clip = clip
    self._wd = wd
    self._wd_pattern = wd_pattern
    self._opt = {
        'adam': lambda: torch.optim.Adam(parameters,
                            lr=lr,
                            eps=eps),
        'nadam': lambda: NotImplemented(
                             f'{config.opt} is not implemented'),
        'adamax': lambda: torch.optim.Adamax(parameters,
                              lr=lr,
                              eps=eps),
        'sgd': lambda: torch.optim.SGD(parameters,
                           lr=lr),
        'momentum': lambda: torch.optim.SGD(parameters,
                                lr=lr,
                                momentum=0.9),
    }[opt]()
    self._scaler = torch.cuda.amp.GradScaler(enabled=use_amp)

  def __call__(self, loss, params, retain_graph=False):
    assert len(loss.shape) == 0, loss.shape
    metrics = {}
    metrics[f'{self._name}_loss'] = loss.detach().cpu().numpy()
    self._scaler.scale(loss).backward()
    self._scaler.unscale_(self._opt)
    #loss.backward(retain_graph=retain_graph)
    norm = torch.nn.utils.clip_grad_norm_(params, self._clip)
    if self._wd:
      self._apply_weight_decay(params)
    self._scaler.step(self._opt)
    self._scaler.update()
    #self._opt.step()
    self._opt.zero_grad()
    metrics[f'{self._name}_grad_norm'] = norm.item()
    return metrics

  def _apply_weight_decay(self, varibs):
    nontrivial = (self._wd_pattern != r'.*')
    if nontrivial:
       raise NotImplementedError
    for var in varibs:
      var.data = (1 - self._wd) * var.data


def args_type(default):
  def parse_string(x):
    if default is None:
      return x
    if isinstance(default, bool):
      return bool(['False', 'True'].index(x))
    if isinstance(default, int):
      return float(x) if ('e' in x or '.' in x) else int(x)
    if isinstance(default, (list, tuple)):
      return tuple(args_type(default[0])(y) for y in x.split(','))
    return type(default)(x)
  def parse_object(x):
    if isinstance(default, (list, tuple)):
      return tuple(x)
    return x
  return lambda x: parse_string(x) if isinstance(x, str) else parse_object(x)


def static_scan(fn, inputs, start):
  last = start
  indices = range(inputs[0].shape[0])
  flag = True
  for index in indices:
    inp = lambda x: (_input[x] for _input in inputs)
    last = fn(last, *inp(index))
    if flag:
      if type(last) == type({}):
        outputs = {key: value.clone().unsqueeze(0) for key, value in last.items()}
      else:
        outputs = []
        for _last in last:
          if type(_last) == type({}):
            outputs.append({key: value.clone().unsqueeze(0) for key, value in _last.items()})
          else:
            outputs.append(_last.clone().unsqueeze(0))
      flag = False
    else:
      if type(last) == type({}):
        for key in last.keys():
          outputs[key] = torch.cat([outputs[key], last[key].unsqueeze(0)], dim=0)
      else:
        for j in range(len(outputs)):
          if type(last[j]) == type({}):
            for key in last[j].keys():
              outputs[j][key] = torch.cat([outputs[j][key],
                  last[j][key].unsqueeze(0)], dim=0)
          else:
            outputs[j] = torch.cat([outputs[j], last[j].unsqueeze(0)], dim=0)
  if type(last) == type({}):
    outputs = [outputs]
  return outputs


# Original version
#def static_scan2(fn, inputs, start, reverse=False):
#  last = start
#  outputs = [[] for _ in range(len([start] if type(start)==type({}) else start))]
#  indices = range(inputs[0].shape[0])
#  if reverse:
#    indices = reversed(indices)
#  for index in indices:
#    inp = lambda x: (_input[x] for _input in inputs)
#    last = fn(last, *inp(index))
#    [o.append(l) for o, l in zip(outputs, [last] if type(last)==type({}) else last)]
#  if reverse:
#    outputs = [list(reversed(x)) for x in outputs]
#  res = [[]] * len(outputs)
#  for i in range(len(outputs)):
#    if type(outputs[i][0]) == type({}):
#      _res = {}
#      for key in outputs[i][0].keys():
#        _res[key] = []
#        for j in range(len(outputs[i])):
#          _res[key].append(outputs[i][j][key])
#        #_res[key] = torch.stack(_res[key], 0)
#        _res[key] = faster_stack(_res[key], 0)
#    else:
#      _res = outputs[i]
#      #_res = torch.stack(_res, 0)
#      _res = faster_stack(_res, 0)
#    res[i] = _res
#  return res


class Every:

  def __init__(self, every):
    self._every = every
    self._last = None

  def __call__(self, step):
    if not self._every:
      return False
    if self._last is None:
      self._last = step
      return True
    if step >= self._last + self._every:
      self._last += self._every
      return True
    return False


class Once:

  def __init__(self):
    self._once = True

  def __call__(self):
    if self._once:
      self._once = False
      return True
    return False


class Until:

  def __init__(self, until):
    self._until = until

  def __call__(self, step):
    if not self._until:
      return True
    return step < self._until


def schedule(string, step):
  try:
    return float(string)
  except ValueError:
    match = re.match(r'linear\((.+),(.+),(.+)\)', string)
    if match:
      initial, final, duration = [float(group) for group in match.groups()]
      mix = torch.clip(torch.Tensor([step / duration]), 0, 1)[0]
      return (1 - mix) * initial + mix * final
    match = re.match(r'warmup\((.+),(.+)\)', string)
    if match:
      warmup, value = [float(group) for group in match.groups()]
      scale = torch.clip(step / warmup, 0, 1)
      return scale * value
    match = re.match(r'exp\((.+),(.+),(.+)\)', string)
    if match:
      initial, final, halflife = [float(group) for group in match.groups()]
      return (initial - final) * 0.5 ** (step / halflife) + final
    match = re.match(r'horizon\((.+),(.+),(.+)\)', string)
    if match:
      initial, final, duration = [float(group) for group in match.groups()]
      mix = torch.clip(step / duration, 0, 1)
      horizon = (1 - mix) * initial + mix * final
      return 1 - 1 / horizon
    raise NotImplementedError(string)

def weight_init(m):
    if isinstance(m, nn.Linear):
      in_num = m.in_features
      out_num = m.out_features
      denoms = (in_num + out_num) / 2.0
      scale = 1.0 / denoms
      std = np.sqrt(scale) / 0.87962566103423978
      nn.init.trunc_normal_(m.weight.data, mean=0.0, std=std, a=- 2.0, b=2.0)
      if hasattr(m.bias, 'data'):
        m.bias.data.fill_(0.0)
    elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
      space = m.kernel_size[0] * m.kernel_size[1]
      in_num = space * m.in_channels
      out_num = space * m.out_channels
      denoms = (in_num + out_num) / 2.0
      scale = 1.0 / denoms
      std = np.sqrt(scale) / 0.87962566103423978
      nn.init.trunc_normal_(m.weight.data, mean=0.0, std=std, a=- 2.0, b=2.0)
      if hasattr(m.bias, 'data'):
        m.bias.data.fill_(0.0)
    elif isinstance(m, nn.LayerNorm):
      m.weight.data.fill_(1.0)
      if hasattr(m.bias, 'data'):
        m.bias.data.fill_(0.0)

def uniform_weight_init(given_scale):
  def f(m):
    if isinstance(m, nn.Linear):
      in_num = m.in_features
      out_num = m.out_features
      denoms = (in_num + out_num) / 2.0
      scale = given_scale / denoms
      limit = np.sqrt(3 * scale)
      nn.init.uniform_(m.weight.data, a=-limit, b=limit)
      if hasattr(m.bias, 'data'):
        m.bias.data.fill_(0.0)
    elif isinstance(m, nn.LayerNorm):
      m.weight.data.fill_(1.0)
      if hasattr(m.bias, 'data'):
        m.bias.data.fill_(0.0)
  return f

def tensorstats(tensor, prefix=None):
  metrics = {
      'mean': to_np(torch.mean(tensor)),
      'std': to_np(torch.std(tensor)),
      'min': to_np(torch.min(tensor)),
      'max': to_np(torch.max(tensor)),
  }
  if prefix:
    metrics = {f'{prefix}_{k}': v for k, v in metrics.items()}
  return metrics