Fix type annotations of dist_fn

This commit is contained in:
Dominik Jain 2023-10-09 17:48:43 +02:00
parent a161a9cf58
commit 22dfc4ed2e
8 changed files with 16 additions and 24 deletions

View File

@ -1,15 +1,11 @@
from abc import ABC, abstractmethod
from collections.abc import Callable
from typing import TypeAlias
import torch
from tianshou.highlevel.env import Environments, EnvType
from tianshou.policy.modelfree.pg import TDistParams
from tianshou.policy.modelfree.pg import TDistributionFunction
from tianshou.utils.string import ToStringMixin
TDistributionFunction: TypeAlias = Callable[[TDistParams], torch.distributions.Distribution]
class DistributionFunctionFactory(ToStringMixin, ABC):
@abstractmethod
@ -23,7 +19,7 @@ class DistributionFunctionFactoryCategorical(DistributionFunctionFactory):
return self._dist_fn
@staticmethod
def _dist_fn(p: TDistParams) -> torch.distributions.Distribution:
def _dist_fn(p: torch.Tensor) -> torch.distributions.Distribution:
return torch.distributions.Categorical(logits=p)

View File

@ -14,8 +14,8 @@ from tianshou.highlevel.params.alpha import AutoAlphaFactory
from tianshou.highlevel.params.dist_fn import (
DistributionFunctionFactory,
DistributionFunctionFactoryDefault,
TDistributionFunction,
)
from tianshou.policy.modelfree.pg import TDistributionFunction
from tianshou.highlevel.params.env_param import EnvValueFactory, FloatEnvValueFactory
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactory
from tianshou.highlevel.params.noise import NoiseFactory

View File

@ -10,7 +10,7 @@ from tianshou.data import ReplayBuffer, to_numpy, to_torch
from tianshou.data.types import LogpOldProtocol, RolloutBatchProtocol
from tianshou.policy import PPOPolicy
from tianshou.policy.base import TLearningRateScheduler
from tianshou.policy.modelfree.pg import TDistParams
from tianshou.policy.modelfree.pg import TDistributionFunction
class GAILPolicy(PPOPolicy):
@ -62,7 +62,7 @@ class GAILPolicy(PPOPolicy):
actor: torch.nn.Module,
critic: torch.nn.Module,
optim: torch.optim.Optimizer,
dist_fn: Callable[[TDistParams], torch.distributions.Distribution],
dist_fn: TDistributionFunction,
action_space: gym.Space,
expert_buffer: ReplayBuffer,
disc_net: torch.nn.Module,

View File

@ -11,7 +11,7 @@ from tianshou.data import ReplayBuffer, to_torch_as
from tianshou.data.types import BatchWithAdvantagesProtocol, RolloutBatchProtocol
from tianshou.policy import PGPolicy
from tianshou.policy.base import TLearningRateScheduler
from tianshou.policy.modelfree.pg import TDistParams
from tianshou.policy.modelfree.pg import TDistributionFunction
from tianshou.utils.net.common import ActorCritic
@ -50,7 +50,7 @@ class A2CPolicy(PGPolicy):
actor: torch.nn.Module,
critic: torch.nn.Module,
optim: torch.optim.Optimizer,
dist_fn: Callable[[TDistParams], torch.distributions.Distribution],
dist_fn: TDistributionFunction,
action_space: gym.Space,
vf_coef: float = 0.5,
ent_coef: float = 0.01,

View File

@ -1,4 +1,3 @@
from collections.abc import Callable
from typing import Any, Literal
import gymnasium as gym
@ -12,7 +11,7 @@ from tianshou.data import Batch, ReplayBuffer
from tianshou.data.types import BatchWithAdvantagesProtocol, RolloutBatchProtocol
from tianshou.policy import A2CPolicy
from tianshou.policy.base import TLearningRateScheduler
from tianshou.policy.modelfree.pg import TDistParams
from tianshou.policy.modelfree.pg import TDistributionFunction
class NPGPolicy(A2CPolicy):
@ -47,7 +46,7 @@ class NPGPolicy(A2CPolicy):
actor: torch.nn.Module,
critic: torch.nn.Module,
optim: torch.optim.Optimizer,
dist_fn: Callable[[TDistParams], torch.distributions.Distribution],
dist_fn: TDistributionFunction,
action_space: gym.Space,
optim_critic_iters: int = 5,
actor_step_size: float = 0.5,

View File

@ -1,6 +1,5 @@
import warnings
from collections.abc import Callable
from typing import Any, Literal, TypeAlias, cast
from typing import Any, Literal, cast, TypeAlias, Callable
import gymnasium as gym
import numpy as np
@ -17,7 +16,7 @@ from tianshou.policy import BasePolicy
from tianshou.policy.base import TLearningRateScheduler
from tianshou.utils import RunningMeanStd
TDistParams: TypeAlias = torch.Tensor | [torch.Tensor, torch.Tensor]
TDistributionFunction: TypeAlias = Callable[[torch.Tensor, ...], torch.distributions.Distribution]
class PGPolicy(BasePolicy):
@ -56,7 +55,7 @@ class PGPolicy(BasePolicy):
*,
actor: torch.nn.Module,
optim: torch.optim.Optimizer,
dist_fn: Callable[[TDistParams], torch.distributions.Distribution],
dist_fn: TDistributionFunction,
action_space: gym.Space,
discount_factor: float = 0.99,
# TODO: rename to return_normalization?

View File

@ -1,4 +1,3 @@
from collections.abc import Callable
from typing import Any, Literal
import gymnasium as gym
@ -10,7 +9,7 @@ from tianshou.data import ReplayBuffer, to_torch_as
from tianshou.data.types import LogpOldProtocol, RolloutBatchProtocol
from tianshou.policy import A2CPolicy
from tianshou.policy.base import TLearningRateScheduler
from tianshou.policy.modelfree.pg import TDistParams
from tianshou.policy.modelfree.pg import TDistributionFunction
from tianshou.utils.net.common import ActorCritic
@ -58,7 +57,7 @@ class PPOPolicy(A2CPolicy):
actor: torch.nn.Module,
critic: torch.nn.Module,
optim: torch.optim.Optimizer,
dist_fn: Callable[[TDistParams], torch.distributions.Distribution],
dist_fn: TDistributionFunction,
action_space: gym.Space,
eps_clip: float = 0.2,
dual_clip: float | None = None,

View File

@ -1,5 +1,4 @@
import warnings
from collections.abc import Callable
from typing import Any, Literal
import gymnasium as gym
@ -10,7 +9,7 @@ from torch.distributions import kl_divergence
from tianshou.data import Batch
from tianshou.policy import NPGPolicy
from tianshou.policy.base import TLearningRateScheduler
from tianshou.policy.modelfree.pg import TDistParams
from tianshou.policy.modelfree.pg import TDistributionFunction
class TRPOPolicy(NPGPolicy):
@ -47,7 +46,7 @@ class TRPOPolicy(NPGPolicy):
actor: torch.nn.Module,
critic: torch.nn.Module,
optim: torch.optim.Optimizer,
dist_fn: Callable[[TDistParams], torch.distributions.Distribution],
dist_fn: TDistributionFunction,
action_space: gym.Space,
max_kl: float = 0.01,
backtrack_coeff: float = 0.8,