Fix type annotations of dist_fn
This commit is contained in:
parent
a161a9cf58
commit
22dfc4ed2e
@ -1,15 +1,11 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable
|
||||
from typing import TypeAlias
|
||||
|
||||
import torch
|
||||
|
||||
from tianshou.highlevel.env import Environments, EnvType
|
||||
from tianshou.policy.modelfree.pg import TDistParams
|
||||
from tianshou.policy.modelfree.pg import TDistributionFunction
|
||||
from tianshou.utils.string import ToStringMixin
|
||||
|
||||
TDistributionFunction: TypeAlias = Callable[[TDistParams], torch.distributions.Distribution]
|
||||
|
||||
|
||||
class DistributionFunctionFactory(ToStringMixin, ABC):
|
||||
@abstractmethod
|
||||
@ -23,7 +19,7 @@ class DistributionFunctionFactoryCategorical(DistributionFunctionFactory):
|
||||
return self._dist_fn
|
||||
|
||||
@staticmethod
|
||||
def _dist_fn(p: TDistParams) -> torch.distributions.Distribution:
|
||||
def _dist_fn(p: torch.Tensor) -> torch.distributions.Distribution:
|
||||
return torch.distributions.Categorical(logits=p)
|
||||
|
||||
|
||||
|
@ -14,8 +14,8 @@ from tianshou.highlevel.params.alpha import AutoAlphaFactory
|
||||
from tianshou.highlevel.params.dist_fn import (
|
||||
DistributionFunctionFactory,
|
||||
DistributionFunctionFactoryDefault,
|
||||
TDistributionFunction,
|
||||
)
|
||||
from tianshou.policy.modelfree.pg import TDistributionFunction
|
||||
from tianshou.highlevel.params.env_param import EnvValueFactory, FloatEnvValueFactory
|
||||
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactory
|
||||
from tianshou.highlevel.params.noise import NoiseFactory
|
||||
|
@ -10,7 +10,7 @@ from tianshou.data import ReplayBuffer, to_numpy, to_torch
|
||||
from tianshou.data.types import LogpOldProtocol, RolloutBatchProtocol
|
||||
from tianshou.policy import PPOPolicy
|
||||
from tianshou.policy.base import TLearningRateScheduler
|
||||
from tianshou.policy.modelfree.pg import TDistParams
|
||||
from tianshou.policy.modelfree.pg import TDistributionFunction
|
||||
|
||||
|
||||
class GAILPolicy(PPOPolicy):
|
||||
@ -62,7 +62,7 @@ class GAILPolicy(PPOPolicy):
|
||||
actor: torch.nn.Module,
|
||||
critic: torch.nn.Module,
|
||||
optim: torch.optim.Optimizer,
|
||||
dist_fn: Callable[[TDistParams], torch.distributions.Distribution],
|
||||
dist_fn: TDistributionFunction,
|
||||
action_space: gym.Space,
|
||||
expert_buffer: ReplayBuffer,
|
||||
disc_net: torch.nn.Module,
|
||||
|
@ -11,7 +11,7 @@ from tianshou.data import ReplayBuffer, to_torch_as
|
||||
from tianshou.data.types import BatchWithAdvantagesProtocol, RolloutBatchProtocol
|
||||
from tianshou.policy import PGPolicy
|
||||
from tianshou.policy.base import TLearningRateScheduler
|
||||
from tianshou.policy.modelfree.pg import TDistParams
|
||||
from tianshou.policy.modelfree.pg import TDistributionFunction
|
||||
from tianshou.utils.net.common import ActorCritic
|
||||
|
||||
|
||||
@ -50,7 +50,7 @@ class A2CPolicy(PGPolicy):
|
||||
actor: torch.nn.Module,
|
||||
critic: torch.nn.Module,
|
||||
optim: torch.optim.Optimizer,
|
||||
dist_fn: Callable[[TDistParams], torch.distributions.Distribution],
|
||||
dist_fn: TDistributionFunction,
|
||||
action_space: gym.Space,
|
||||
vf_coef: float = 0.5,
|
||||
ent_coef: float = 0.01,
|
||||
|
@ -1,4 +1,3 @@
|
||||
from collections.abc import Callable
|
||||
from typing import Any, Literal
|
||||
|
||||
import gymnasium as gym
|
||||
@ -12,7 +11,7 @@ from tianshou.data import Batch, ReplayBuffer
|
||||
from tianshou.data.types import BatchWithAdvantagesProtocol, RolloutBatchProtocol
|
||||
from tianshou.policy import A2CPolicy
|
||||
from tianshou.policy.base import TLearningRateScheduler
|
||||
from tianshou.policy.modelfree.pg import TDistParams
|
||||
from tianshou.policy.modelfree.pg import TDistributionFunction
|
||||
|
||||
|
||||
class NPGPolicy(A2CPolicy):
|
||||
@ -47,7 +46,7 @@ class NPGPolicy(A2CPolicy):
|
||||
actor: torch.nn.Module,
|
||||
critic: torch.nn.Module,
|
||||
optim: torch.optim.Optimizer,
|
||||
dist_fn: Callable[[TDistParams], torch.distributions.Distribution],
|
||||
dist_fn: TDistributionFunction,
|
||||
action_space: gym.Space,
|
||||
optim_critic_iters: int = 5,
|
||||
actor_step_size: float = 0.5,
|
||||
|
@ -1,6 +1,5 @@
|
||||
import warnings
|
||||
from collections.abc import Callable
|
||||
from typing import Any, Literal, TypeAlias, cast
|
||||
from typing import Any, Literal, cast, TypeAlias, Callable
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
@ -17,7 +16,7 @@ from tianshou.policy import BasePolicy
|
||||
from tianshou.policy.base import TLearningRateScheduler
|
||||
from tianshou.utils import RunningMeanStd
|
||||
|
||||
TDistParams: TypeAlias = torch.Tensor | [torch.Tensor, torch.Tensor]
|
||||
TDistributionFunction: TypeAlias = Callable[[torch.Tensor, ...], torch.distributions.Distribution]
|
||||
|
||||
|
||||
class PGPolicy(BasePolicy):
|
||||
@ -56,7 +55,7 @@ class PGPolicy(BasePolicy):
|
||||
*,
|
||||
actor: torch.nn.Module,
|
||||
optim: torch.optim.Optimizer,
|
||||
dist_fn: Callable[[TDistParams], torch.distributions.Distribution],
|
||||
dist_fn: TDistributionFunction,
|
||||
action_space: gym.Space,
|
||||
discount_factor: float = 0.99,
|
||||
# TODO: rename to return_normalization?
|
||||
|
@ -1,4 +1,3 @@
|
||||
from collections.abc import Callable
|
||||
from typing import Any, Literal
|
||||
|
||||
import gymnasium as gym
|
||||
@ -10,7 +9,7 @@ from tianshou.data import ReplayBuffer, to_torch_as
|
||||
from tianshou.data.types import LogpOldProtocol, RolloutBatchProtocol
|
||||
from tianshou.policy import A2CPolicy
|
||||
from tianshou.policy.base import TLearningRateScheduler
|
||||
from tianshou.policy.modelfree.pg import TDistParams
|
||||
from tianshou.policy.modelfree.pg import TDistributionFunction
|
||||
from tianshou.utils.net.common import ActorCritic
|
||||
|
||||
|
||||
@ -58,7 +57,7 @@ class PPOPolicy(A2CPolicy):
|
||||
actor: torch.nn.Module,
|
||||
critic: torch.nn.Module,
|
||||
optim: torch.optim.Optimizer,
|
||||
dist_fn: Callable[[TDistParams], torch.distributions.Distribution],
|
||||
dist_fn: TDistributionFunction,
|
||||
action_space: gym.Space,
|
||||
eps_clip: float = 0.2,
|
||||
dual_clip: float | None = None,
|
||||
|
@ -1,5 +1,4 @@
|
||||
import warnings
|
||||
from collections.abc import Callable
|
||||
from typing import Any, Literal
|
||||
|
||||
import gymnasium as gym
|
||||
@ -10,7 +9,7 @@ from torch.distributions import kl_divergence
|
||||
from tianshou.data import Batch
|
||||
from tianshou.policy import NPGPolicy
|
||||
from tianshou.policy.base import TLearningRateScheduler
|
||||
from tianshou.policy.modelfree.pg import TDistParams
|
||||
from tianshou.policy.modelfree.pg import TDistributionFunction
|
||||
|
||||
|
||||
class TRPOPolicy(NPGPolicy):
|
||||
@ -47,7 +46,7 @@ class TRPOPolicy(NPGPolicy):
|
||||
actor: torch.nn.Module,
|
||||
critic: torch.nn.Module,
|
||||
optim: torch.optim.Optimizer,
|
||||
dist_fn: Callable[[TDistParams], torch.distributions.Distribution],
|
||||
dist_fn: TDistributionFunction,
|
||||
action_space: gym.Space,
|
||||
max_kl: float = 0.01,
|
||||
backtrack_coeff: float = 0.8,
|
||||
|
Loading…
x
Reference in New Issue
Block a user