import argparse
import os
import random
import time

import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions.normal import Normal
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--exp_name', type=str, default=os.path.basename(__file__).rstrip('.py'))
    parser.add_argument('--seed', type=int, default=1)
    parser.add_argument('--torch_deterministic', type=bool, default=True)
    parser.add_argument('--cuda', type=bool, default=True)
    parser.add_argument('--env_id', type=str, default='Humanoid-v4')
    parser.add_argument('--total_time_steps', type=int, default=int(1e7))
    parser.add_argument('--learning_rate', type=float, default=3e-4)
    parser.add_argument('--num_envs', type=int, default=8)
    parser.add_argument('--num_steps', type=int, default=256)
    parser.add_argument('--anneal_lr', type=bool, default=True)
    parser.add_argument('--gamma', type=float, default=0.99)
    parser.add_argument('--gae_lambda', type=float, default=0.95)
    parser.add_argument('--num_mini_batches', type=int, default=4)
    parser.add_argument('--update_epochs', type=int, default=10)
    parser.add_argument('--norm_adv', type=bool, default=True)
    parser.add_argument('--clip_value_loss', type=bool, default=True)
    parser.add_argument('--c_1', type=float, default=0.5)
    parser.add_argument('--c_2', type=float, default=0.0)
    parser.add_argument('--max_grad_norm', type=float, default=0.5)
    parser.add_argument('--kld_max', type=float, default=0.02)
    a = parser.parse_args()
    a.batch_size = int(a.num_envs * a.num_steps)
    a.minibatch_size = int(a.batch_size // a.num_mini_batches)
    return a


def make_env(env_id, gamma):
    def thunk():
        env = gym.make(env_id)
        env = gym.wrappers.FlattenObservation(env)
        env = gym.wrappers.RecordEpisodeStatistics(env)
        env = gym.wrappers.ClipAction(env)
        env = gym.wrappers.NormalizeObservation(env)
        env = gym.wrappers.TransformObservation(env, lambda o: np.clip(o, -10, 10))
        env = gym.wrappers.NormalizeReward(env, gamma=gamma)
        env = gym.wrappers.TransformReward(env, lambda r: float(np.clip(r, -10, 10)))
        return env
    return thunk


def layer_init(layer, s=np.sqrt(2), bias_const=0.0):
    torch.nn.init.orthogonal_(layer.weight, s)
    torch.nn.init.constant_(layer.bias, bias_const)
    return layer


class Agent(nn.Module):
    def __init__(self, e):
        super().__init__()
        self.critic = nn.Sequential(
            layer_init(nn.Linear(np.array(e.single_observation_space.shape).prod(), 64)),
            nn.Tanh(),
            layer_init(nn.Linear(64, 64)),
            nn.Tanh(),
            layer_init(nn.Linear(64, 1), s=1.0),
        )
        self.actor_mean = nn.Sequential(
            layer_init(nn.Linear(np.array(e.single_observation_space.shape).prod(), 64)),
            nn.Tanh(),
            layer_init(nn.Linear(64, 64)),
            nn.Tanh(),
            layer_init(nn.Linear(64, np.array(e.single_action_space.shape).prod()), s=0.01),
        )
        self.actor_log_std = nn.Parameter(torch.zeros(1, np.array(e.single_action_space.shape).prod()))

    def get_value(self, x):
        return self.critic(x)

    def get_action_and_value(self, x, a=None, show_all=False):
        action_mean = self.actor_mean(x)
        action_log_std = self.actor_log_std.expand_as(action_mean)
        action_std = torch.exp(action_log_std)
        probs = Normal(action_mean, action_std)
        if a is None:
            a = probs.sample()
        if show_all:
            return a, probs.log_prob(a).sum(1), probs.entropy().sum(1), self.critic(x), probs
        return a, probs.log_prob(a).sum(1), probs.entropy().sum(1), self.critic(x)


def compute_kld(mu_1, sigma_1, mu_2, sigma_2):
    return torch.log(sigma_2 / sigma_1) + ((mu_1 - mu_2) ** 2 + (sigma_1 ** 2 - sigma_2 ** 2)) / (2 * sigma_2 ** 2)


def main(env_id, seed):
    args = get_args()
    args.env_id = env_id
    args.seed = seed
    run_name = (
            'spo' +
            '_epoch_' + str(args.update_epochs) +
            '_seed_' + str(args.seed)
    )

    # Save training logs
    path_string = str(args.env_id) + '/' + run_name
    writer = SummaryWriter(path_string)
    writer.add_text(
        'Hyperparameter',
        '|param|value|\n|-|-|\n%s' % ('\n'.join([f'|{key}|{value}|' for key, value in vars(args).items()])),
    )

    # Random seed
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.backends.cudnn.deterministic = args.torch_deterministic

    device = torch.device('cuda' if torch.cuda.is_available() and args.cuda else 'cpu')

    # Initialize environments
    envs = gym.vector.SyncVectorEnv(
        [make_env(args.env_id, args.gamma) for _ in range(args.num_envs)]
    )
    assert isinstance(envs.single_action_space, gym.spaces.Box), 'only continuous action space is supported'

    agent = Agent(envs).to(device)
    optimizer = optim.Adam(agent.parameters(), lr=args.learning_rate, eps=1e-5)

    # Initialize buffer
    obs = torch.zeros((args.num_steps, args.num_envs) + envs.single_observation_space.shape).to(device)
    actions = torch.zeros((args.num_steps, args.num_envs) + envs.single_action_space.shape).to(device)
    log_probs = torch.zeros((args.num_steps, args.num_envs)).to(device)
    rewards = torch.zeros((args.num_steps, args.num_envs)).to(device)
    dones = torch.zeros((args.num_steps, args.num_envs)).to(device)
    values = torch.zeros((args.num_steps, args.num_envs)).to(device)
    mean = torch.zeros((args.num_steps, args.num_envs) + envs.single_action_space.shape).to(device)
    std = torch.zeros((args.num_steps, args.num_envs) + envs.single_action_space.shape).to(device)

    # Data collection
    global_step = 0
    start_time = time.time()
    next_obs, _ = envs.reset(seed=args.seed)
    next_obs = torch.Tensor(next_obs).to(device)
    next_done = torch.zeros(args.num_envs).to(device)
    num_updates = args.total_time_steps // args.batch_size

    for update in tqdm(range(1, num_updates + 1)):

        # Linear decay of learning rate
        if args.anneal_lr:
            frac = 1.0 - (update - 1.0) / num_updates
            lr_now = frac * args.learning_rate
            optimizer.param_groups[0]['lr'] = lr_now

        for step in range(0, args.num_steps):
            global_step += 1 * args.num_envs
            obs[step] = next_obs
            dones[step] = next_done

            # Compute the logarithm of the action probability output by the old policy network
            with torch.no_grad():
                action, log_prob, _, value, mean_std = agent.get_action_and_value(next_obs, show_all=True)
                values[step] = value.flatten()
            actions[step] = action
            log_probs[step] = log_prob

            # Mean and standard deviation (mini_batch_size, num_envs, action_dim)
            mean[step] = mean_std.loc
            std[step] = mean_std.scale

            # Update the environments
            next_obs, reward, terminations, truncations, info = envs.step(action.cpu().numpy())
            done = np.logical_or(terminations, truncations)
            rewards[step] = torch.tensor(reward).to(device).view(-1)
            next_obs, next_done = torch.Tensor(next_obs).to(device), torch.Tensor(done).to(device)

            if 'final_info' not in info:
                continue

            for item in info['final_info']:
                if item is None:
                    continue
                writer.add_scalar('charts/episodic_return', item['episode']['r'][0], global_step)

        # Use GAE (Generalized Advantage Estimation) technique to estimate the advantage function
        with torch.no_grad():
            next_value = agent.get_value(next_obs).reshape(1, -1)
            advantages = torch.zeros_like(rewards).to(device)
            last_gae_lam = 0
            for t in reversed(range(args.num_steps)):
                if t == args.num_steps - 1:
                    next_non_terminal = 1.0 - next_done
                    next_values = next_value
                else:
                    next_non_terminal = 1.0 - dones[t + 1]
                    next_values = values[t + 1]
                delta = rewards[t] + args.gamma * next_values * next_non_terminal - values[t]
                advantages[t] = last_gae_lam = delta + args.gamma * args.gae_lambda * next_non_terminal * last_gae_lam
            returns = advantages + values

        # ---------------------- We have collected enough data, now let's start training ---------------------- #
        # Flatten each batch
        b_obs = obs.reshape((-1,) + envs.single_observation_space.shape)
        b_log_probs = log_probs.reshape(-1)
        b_actions = actions.reshape((-1,) + envs.single_action_space.shape)
        b_advantages = advantages.reshape(-1)
        b_returns = returns.reshape(-1)
        b_values = values.reshape(-1)

        # Obtain the mean and the standard deviation of a batch
        b_mean = mean.reshape(args.batch_size, -1)
        b_std = std.reshape(args.batch_size, -1)

        # Update the policy network and value network
        b_index = np.arange(args.batch_size)
        for epoch in range(1, args.update_epochs + 1):
            np.random.shuffle(b_index)
            t = 0
            for start in range(0, args.batch_size, args.minibatch_size):
                t += 1
                end = start + args.minibatch_size
                mb_index = b_index[start:end]

                # The latest outputs of the policy network and value network
                _, new_log_prob, entropy, new_value, new_mean_std = agent.get_action_and_value(b_obs[mb_index],
                                                                                               b_actions[mb_index],
                                                                                               show_all=True)
                # Compute KL divergence
                new_mean = new_mean_std.loc.reshape(args.minibatch_size, -1)
                new_std = new_mean_std.scale.reshape(args.minibatch_size, -1)
                d = compute_kld(b_mean[mb_index], b_std[mb_index], new_mean, new_std).sum(1)

                writer.add_scalar('charts/average_kld', d.mean(), global_step)
                writer.add_scalar('others/min_kld', d.min(), global_step)
                writer.add_scalar('others/max_kld', d.max(), global_step)

                log_ratio = new_log_prob - b_log_probs[mb_index]
                ratios = log_ratio.exp()
                mb_advantages = b_advantages[mb_index]

                # Advantage normalization
                if args.norm_adv:
                    mb_advantages = (mb_advantages - mb_advantages.mean()) / (mb_advantages.std() + 1e-12)

                # Policy loss (main code of SPO)
                if epoch == 1 and t == 1:
                    pg_loss = (-mb_advantages * ratios).mean()
                else:
                    # d_clip
                    d_clip = torch.clamp(input=d, min=0, max=args.kld_max)
                    # d_clip / d
                    ratio = d_clip / (d + 1e-12)
                    # sign_a
                    sign_a = torch.sign(mb_advantages)
                    # (d_clip / d + sign_a - 1) * sign_a
                    result = (ratio + sign_a - 1) * sign_a
                    pg_loss = (-mb_advantages * ratios * result).mean()

                # Value loss
                new_value = new_value.view(-1)
                if args.clip_value_loss:
                    v_loss_un_clipped = (new_value - b_returns[mb_index]) ** 2
                    v_clipped = b_values[mb_index] + torch.clamp(
                        new_value - b_values[mb_index],
                        -0.2,
                        0.2,
                    )
                    v_loss_clipped = (v_clipped - b_returns[mb_index]) ** 2
                    v_loss_max = torch.max(v_loss_un_clipped, v_loss_clipped)
                    v_loss = 0.5 * v_loss_max.mean()
                else:
                    v_loss = 0.5 * ((new_value - b_returns[mb_index]) ** 2).mean()

                # Policy entropy
                entropy_loss = entropy.mean()

                # Total loss
                loss = pg_loss + v_loss * args.c_1 - entropy_loss * args.c_2

                # Save the data during the training process
                writer.add_scalar('losses/policy_loss', pg_loss.item(), global_step)
                writer.add_scalar('losses/value_loss', v_loss.item(), global_step)
                writer.add_scalar('losses/entropy', entropy_loss.item(), global_step)

                # Update network parameters
                optimizer.zero_grad()
                loss.backward()
                nn.utils.clip_grad_norm_(agent.parameters(), args.max_grad_norm)
                optimizer.step()

        y_pre, y_true = b_values.cpu().numpy(), b_returns.cpu().numpy()
        var_y = np.var(y_true)
        explained_var = np.nan if var_y == 0 else 1 - np.var(y_true - y_pre) / var_y
        writer.add_scalar('others/explained_variance', explained_var, global_step)

        # Save the data during the training process
        writer.add_scalar('charts/learning_rate', optimizer.param_groups[0]['lr'], global_step)
        writer.add_scalar('charts/SPS', int(global_step / (time.time() - start_time)), global_step)

    envs.close()
    writer.close()


def run():
    for env_id in ['Humanoid-v4']:
        for seed in range(1, 6):
            print(env_id, 'seed:', seed)
            main(env_id, seed)


if __name__ == '__main__':
    run()