from abc import abstractmethod, ABC
from typing import Sequence

import torch
from torch import nn
import numpy as np

from tianshou.highlevel.env import Environments
from tianshou.utils.net.common import Net
from tianshou.utils.net.continuous import ActorProb, Critic as ContinuousCritic

TDevice = str | int | torch.device


def init_linear_orthogonal(m: torch.nn.Module):
    """
    Applies orthogonal initialization to linear layers of the given module and sets bias weights to 0

    :param m: the module whose submodules are to be processed
    """
    for m in m.modules():
        if isinstance(m, torch.nn.Linear):
            torch.nn.init.orthogonal_(m.weight, gain=np.sqrt(2))
            torch.nn.init.zeros_(m.bias)


class ActorFactory(ABC):
    @abstractmethod
    def create_module(self, envs: Environments, device: TDevice) -> nn.Module:
        pass

    @staticmethod
    def _init_linear(actor: torch.nn.Module):
        """
        Initializes linear layers of an actor module using default mechanisms
        :param module: the actor module
        """
        init_linear_orthogonal(actor)
        if hasattr(actor, "mu"):
            # For continuous action spaces with Gaussian policies
            # do last policy layer scaling, this will make initial actions have (close to)
            # 0 mean and std, and will help boost performances,
            # see https://arxiv.org/abs/2006.05990, Fig.24 for details
            for m in actor.mu.modules():
                if isinstance(m, torch.nn.Linear):
                    m.weight.data.copy_(0.01 * m.weight.data)


class ContinuousActorFactory(ActorFactory, ABC):
    pass


class ContinuousActorProbFactory(ContinuousActorFactory):
    def __init__(self, hidden_sizes: Sequence[int]):
        self.hidden_sizes = hidden_sizes

    def create_module(self, envs: Environments, device: TDevice) -> nn.Module:
        net_a = Net(
            envs.get_state_shape(), hidden_sizes=self.hidden_sizes, activation=nn.Tanh, device=device
        )
        actor = ActorProb(net_a, envs.get_action_shape(), unbounded=True, device=device).to(device)

        # init params
        torch.nn.init.constant_(actor.sigma_param, -0.5)
        self._init_linear(actor)

        return actor


class CriticFactory(ABC):
    @abstractmethod
    def create_module(self, envs: Environments, device: TDevice) -> nn.Module:
        pass


class ContinuousCriticFactory(CriticFactory, ABC):
    pass


class ContinuousNetCriticFactory(ContinuousCriticFactory):
    def __init__(self, hidden_sizes: Sequence[int]):
        self.hidden_sizes = hidden_sizes

    def create_module(self, envs: Environments, device: TDevice) -> nn.Module:
        net_c = Net(
            envs.get_state_shape(), hidden_sizes=self.hidden_sizes, activation=nn.Tanh, device=device
        )
        critic = ContinuousCritic(net_c, device=device).to(device)
        init_linear_orthogonal(critic)
        return critic