implement REDQ based on original contribution by @Jimenius (#623)
Co-authored-by: Minhui Li <limh@lamda.nju.edu.cn>
This commit is contained in:
		
							parent
							
								
									41afc2584a
								
							
						
					
					
						commit
						dd16818ce4
					
				@ -34,6 +34,7 @@
 | 
			
		||||
- [Deep Deterministic Policy Gradient (DDPG)](https://arxiv.org/pdf/1509.02971.pdf)
 | 
			
		||||
- [Twin Delayed DDPG (TD3)](https://arxiv.org/pdf/1802.09477.pdf)
 | 
			
		||||
- [Soft Actor-Critic (SAC)](https://arxiv.org/pdf/1812.05905.pdf)
 | 
			
		||||
- [Randomized Ensembled Double Q-Learning (REDQ)](https://arxiv.org/pdf/2101.05982.pdf)
 | 
			
		||||
- [Discrete Soft Actor-Critic (SAC-Discrete)](https://arxiv.org/pdf/1910.07207.pdf)
 | 
			
		||||
- Vanilla Imitation Learning
 | 
			
		||||
- [Batch-Constrained deep Q-Learning (BCQ)](https://arxiv.org/pdf/1812.02900.pdf)
 | 
			
		||||
 | 
			
		||||
@ -96,6 +96,11 @@ Off-policy
 | 
			
		||||
   :undoc-members:
 | 
			
		||||
   :show-inheritance:
 | 
			
		||||
 | 
			
		||||
.. autoclass:: tianshou.policy.REDQPolicy
 | 
			
		||||
   :members:
 | 
			
		||||
   :undoc-members:
 | 
			
		||||
   :show-inheritance:
 | 
			
		||||
 | 
			
		||||
.. autoclass:: tianshou.policy.DiscreteSACPolicy
 | 
			
		||||
   :members:
 | 
			
		||||
   :undoc-members:
 | 
			
		||||
 | 
			
		||||
@ -25,6 +25,7 @@ Welcome to Tianshou!
 | 
			
		||||
* :class:`~tianshou.policy.DDPGPolicy` `Deep Deterministic Policy Gradient <https://arxiv.org/pdf/1509.02971.pdf>`_
 | 
			
		||||
* :class:`~tianshou.policy.TD3Policy` `Twin Delayed DDPG <https://arxiv.org/pdf/1802.09477.pdf>`_
 | 
			
		||||
* :class:`~tianshou.policy.SACPolicy` `Soft Actor-Critic <https://arxiv.org/pdf/1812.05905.pdf>`_
 | 
			
		||||
* :class:`~tianshou.policy.REDQPolicy` `Randomized Ensembled Double Q-Learning <https://arxiv.org/pdf/2101.05982.pdf>`_
 | 
			
		||||
* :class:`~tianshou.policy.DiscreteSACPolicy` `Discrete Soft Actor-Critic <https://arxiv.org/pdf/1910.07207.pdf>`_
 | 
			
		||||
* :class:`~tianshou.policy.ImitationPolicy` Imitation Learning
 | 
			
		||||
* :class:`~tianshou.policy.BCQPolicy` `Batch-Constrained deep Q-Learning <https://arxiv.org/pdf/1812.02900.pdf>`_
 | 
			
		||||
 | 
			
		||||
@ -157,3 +157,4 @@ Nvidia
 | 
			
		||||
Enduro
 | 
			
		||||
Qbert
 | 
			
		||||
Seaquest
 | 
			
		||||
subnets
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										192
									
								
								examples/mujoco/mujoco_redq.py
									
									
									
									
									
										Executable file
									
								
							
							
						
						
									
										192
									
								
								examples/mujoco/mujoco_redq.py
									
									
									
									
									
										Executable file
									
								
							@ -0,0 +1,192 @@
 | 
			
		||||
#!/usr/bin/env python3
 | 
			
		||||
 | 
			
		||||
import argparse
 | 
			
		||||
import datetime
 | 
			
		||||
import os
 | 
			
		||||
import pprint
 | 
			
		||||
 | 
			
		||||
import gym
 | 
			
		||||
import numpy as np
 | 
			
		||||
import torch
 | 
			
		||||
from torch.utils.tensorboard import SummaryWriter
 | 
			
		||||
 | 
			
		||||
from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer
 | 
			
		||||
from tianshou.env import SubprocVectorEnv
 | 
			
		||||
from tianshou.policy import REDQPolicy
 | 
			
		||||
from tianshou.trainer import offpolicy_trainer
 | 
			
		||||
from tianshou.utils import TensorboardLogger
 | 
			
		||||
from tianshou.utils.net.common import EnsembleLinear, Net
 | 
			
		||||
from tianshou.utils.net.continuous import ActorProb, Critic
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_args():
 | 
			
		||||
    parser = argparse.ArgumentParser()
 | 
			
		||||
    parser.add_argument('--task', type=str, default='Ant-v3')
 | 
			
		||||
    parser.add_argument('--seed', type=int, default=0)
 | 
			
		||||
    parser.add_argument('--buffer-size', type=int, default=1000000)
 | 
			
		||||
    parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[256, 256])
 | 
			
		||||
    parser.add_argument('--ensemble-size', type=int, default=10)
 | 
			
		||||
    parser.add_argument('--subset-size', type=int, default=2)
 | 
			
		||||
    parser.add_argument('--actor-lr', type=float, default=1e-3)
 | 
			
		||||
    parser.add_argument('--critic-lr', type=float, default=1e-3)
 | 
			
		||||
    parser.add_argument('--gamma', type=float, default=0.99)
 | 
			
		||||
    parser.add_argument('--tau', type=float, default=0.005)
 | 
			
		||||
    parser.add_argument('--alpha', type=float, default=0.2)
 | 
			
		||||
    parser.add_argument('--auto-alpha', default=False, action='store_true')
 | 
			
		||||
    parser.add_argument('--alpha-lr', type=float, default=3e-4)
 | 
			
		||||
    parser.add_argument("--start-timesteps", type=int, default=10000)
 | 
			
		||||
    parser.add_argument('--epoch', type=int, default=200)
 | 
			
		||||
    parser.add_argument('--step-per-epoch', type=int, default=5000)
 | 
			
		||||
    parser.add_argument('--step-per-collect', type=int, default=1)
 | 
			
		||||
    parser.add_argument('--update-per-step', type=int, default=20)
 | 
			
		||||
    parser.add_argument('--n-step', type=int, default=1)
 | 
			
		||||
    parser.add_argument('--batch-size', type=int, default=256)
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        '--target-mode', type=str, choices=('min', 'mean'), default='min'
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument('--training-num', type=int, default=1)
 | 
			
		||||
    parser.add_argument('--test-num', type=int, default=10)
 | 
			
		||||
    parser.add_argument('--logdir', type=str, default='log')
 | 
			
		||||
    parser.add_argument('--render', type=float, default=0.)
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu'
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument('--resume-path', type=str, default=None)
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        '--watch',
 | 
			
		||||
        default=False,
 | 
			
		||||
        action='store_true',
 | 
			
		||||
        help='watch the play of pre-trained policy only'
 | 
			
		||||
    )
 | 
			
		||||
    return parser.parse_args()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_redq(args=get_args()):
 | 
			
		||||
    env = gym.make(args.task)
 | 
			
		||||
    args.state_shape = env.observation_space.shape or env.observation_space.n
 | 
			
		||||
    args.action_shape = env.action_space.shape or env.action_space.n
 | 
			
		||||
    args.max_action = env.action_space.high[0]
 | 
			
		||||
    print("Observations shape:", args.state_shape)
 | 
			
		||||
    print("Actions shape:", args.action_shape)
 | 
			
		||||
    print("Action range:", np.min(env.action_space.low), np.max(env.action_space.high))
 | 
			
		||||
    # train_envs = gym.make(args.task)
 | 
			
		||||
    if args.training_num > 1:
 | 
			
		||||
        train_envs = SubprocVectorEnv(
 | 
			
		||||
            [lambda: gym.make(args.task) for _ in range(args.training_num)]
 | 
			
		||||
        )
 | 
			
		||||
    else:
 | 
			
		||||
        train_envs = gym.make(args.task)
 | 
			
		||||
    # test_envs = gym.make(args.task)
 | 
			
		||||
    test_envs = SubprocVectorEnv(
 | 
			
		||||
        [lambda: gym.make(args.task) for _ in range(args.test_num)]
 | 
			
		||||
    )
 | 
			
		||||
    # seed
 | 
			
		||||
    np.random.seed(args.seed)
 | 
			
		||||
    torch.manual_seed(args.seed)
 | 
			
		||||
    train_envs.seed(args.seed)
 | 
			
		||||
    test_envs.seed(args.seed)
 | 
			
		||||
    # model
 | 
			
		||||
    net_a = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device)
 | 
			
		||||
    actor = ActorProb(
 | 
			
		||||
        net_a,
 | 
			
		||||
        args.action_shape,
 | 
			
		||||
        max_action=args.max_action,
 | 
			
		||||
        device=args.device,
 | 
			
		||||
        unbounded=True,
 | 
			
		||||
        conditioned_sigma=True
 | 
			
		||||
    ).to(args.device)
 | 
			
		||||
    actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
 | 
			
		||||
 | 
			
		||||
    def linear(x, y):
 | 
			
		||||
        return EnsembleLinear(args.ensemble_size, x, y)
 | 
			
		||||
 | 
			
		||||
    net_c = Net(
 | 
			
		||||
        args.state_shape,
 | 
			
		||||
        args.action_shape,
 | 
			
		||||
        hidden_sizes=args.hidden_sizes,
 | 
			
		||||
        concat=True,
 | 
			
		||||
        device=args.device,
 | 
			
		||||
        linear_layer=linear,
 | 
			
		||||
    )
 | 
			
		||||
    critics = Critic(
 | 
			
		||||
        net_c,
 | 
			
		||||
        device=args.device,
 | 
			
		||||
        linear_layer=linear,
 | 
			
		||||
        flatten_input=False,
 | 
			
		||||
    ).to(args.device)
 | 
			
		||||
    critics_optim = torch.optim.Adam(critics.parameters(), lr=args.critic_lr)
 | 
			
		||||
 | 
			
		||||
    if args.auto_alpha:
 | 
			
		||||
        target_entropy = -np.prod(env.action_space.shape)
 | 
			
		||||
        log_alpha = torch.zeros(1, requires_grad=True, device=args.device)
 | 
			
		||||
        alpha_optim = torch.optim.Adam([log_alpha], lr=args.alpha_lr)
 | 
			
		||||
        args.alpha = (target_entropy, log_alpha, alpha_optim)
 | 
			
		||||
 | 
			
		||||
    policy = REDQPolicy(
 | 
			
		||||
        actor,
 | 
			
		||||
        actor_optim,
 | 
			
		||||
        critics,
 | 
			
		||||
        critics_optim,
 | 
			
		||||
        args.ensemble_size,
 | 
			
		||||
        args.subset_size,
 | 
			
		||||
        tau=args.tau,
 | 
			
		||||
        gamma=args.gamma,
 | 
			
		||||
        alpha=args.alpha,
 | 
			
		||||
        estimation_step=args.n_step,
 | 
			
		||||
        actor_delay=args.update_per_step,
 | 
			
		||||
        target_mode=args.target_mode,
 | 
			
		||||
        action_space=env.action_space,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    # load a previous policy
 | 
			
		||||
    if args.resume_path:
 | 
			
		||||
        policy.load_state_dict(torch.load(args.resume_path, map_location=args.device))
 | 
			
		||||
        print("Loaded agent from: ", args.resume_path)
 | 
			
		||||
 | 
			
		||||
    # collector
 | 
			
		||||
    if args.training_num > 1:
 | 
			
		||||
        buffer = VectorReplayBuffer(args.buffer_size, len(train_envs))
 | 
			
		||||
    else:
 | 
			
		||||
        buffer = ReplayBuffer(args.buffer_size)
 | 
			
		||||
    train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
 | 
			
		||||
    test_collector = Collector(policy, test_envs)
 | 
			
		||||
    train_collector.collect(n_step=args.start_timesteps, random=True)
 | 
			
		||||
    # log
 | 
			
		||||
    t0 = datetime.datetime.now().strftime("%m%d_%H%M%S")
 | 
			
		||||
    log_file = f'seed_{args.seed}_{t0}-{args.task.replace("-", "_")}_redq'
 | 
			
		||||
    log_path = os.path.join(args.logdir, args.task, 'redq', log_file)
 | 
			
		||||
    writer = SummaryWriter(log_path)
 | 
			
		||||
    writer.add_text("args", str(args))
 | 
			
		||||
    logger = TensorboardLogger(writer)
 | 
			
		||||
 | 
			
		||||
    def save_best_fn(policy):
 | 
			
		||||
        torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
 | 
			
		||||
 | 
			
		||||
    if not args.watch:
 | 
			
		||||
        # trainer
 | 
			
		||||
        result = offpolicy_trainer(
 | 
			
		||||
            policy,
 | 
			
		||||
            train_collector,
 | 
			
		||||
            test_collector,
 | 
			
		||||
            args.epoch,
 | 
			
		||||
            args.step_per_epoch,
 | 
			
		||||
            args.step_per_collect,
 | 
			
		||||
            args.test_num,
 | 
			
		||||
            args.batch_size,
 | 
			
		||||
            save_best_fn=save_best_fn,
 | 
			
		||||
            logger=logger,
 | 
			
		||||
            update_per_step=args.update_per_step,
 | 
			
		||||
            test_in_train=False
 | 
			
		||||
        )
 | 
			
		||||
        pprint.pprint(result)
 | 
			
		||||
 | 
			
		||||
    # Let's watch its performance!
 | 
			
		||||
    policy.eval()
 | 
			
		||||
    test_envs.seed(args.seed)
 | 
			
		||||
    test_collector.reset()
 | 
			
		||||
    result = test_collector.collect(n_episode=args.test_num, render=args.render)
 | 
			
		||||
    print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}')
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == '__main__':
 | 
			
		||||
    test_redq()
 | 
			
		||||
							
								
								
									
										178
									
								
								test/continuous/test_redq.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										178
									
								
								test/continuous/test_redq.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,178 @@
 | 
			
		||||
import argparse
 | 
			
		||||
import os
 | 
			
		||||
import pprint
 | 
			
		||||
 | 
			
		||||
import gym
 | 
			
		||||
import numpy as np
 | 
			
		||||
import torch
 | 
			
		||||
from torch.utils.tensorboard import SummaryWriter
 | 
			
		||||
 | 
			
		||||
from tianshou.data import Collector, VectorReplayBuffer
 | 
			
		||||
from tianshou.env import DummyVectorEnv
 | 
			
		||||
from tianshou.policy import REDQPolicy
 | 
			
		||||
from tianshou.trainer import offpolicy_trainer
 | 
			
		||||
from tianshou.utils import TensorboardLogger
 | 
			
		||||
from tianshou.utils.net.common import EnsembleLinear, Net
 | 
			
		||||
from tianshou.utils.net.continuous import ActorProb, Critic
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_args():
 | 
			
		||||
    parser = argparse.ArgumentParser()
 | 
			
		||||
    parser.add_argument('--task', type=str, default='Pendulum-v1')
 | 
			
		||||
    parser.add_argument('--reward-threshold', type=float, default=None)
 | 
			
		||||
    parser.add_argument('--seed', type=int, default=0)
 | 
			
		||||
    parser.add_argument('--buffer-size', type=int, default=20000)
 | 
			
		||||
    parser.add_argument('--ensemble-size', type=int, default=4)
 | 
			
		||||
    parser.add_argument('--subset-size', type=int, default=2)
 | 
			
		||||
    parser.add_argument('--actor-lr', type=float, default=1e-4)
 | 
			
		||||
    parser.add_argument('--critic-lr', type=float, default=1e-3)
 | 
			
		||||
    parser.add_argument('--gamma', type=float, default=0.99)
 | 
			
		||||
    parser.add_argument('--tau', type=float, default=0.005)
 | 
			
		||||
    parser.add_argument('--alpha', type=float, default=0.2)
 | 
			
		||||
    parser.add_argument('--auto-alpha', action='store_true', default=False)
 | 
			
		||||
    parser.add_argument('--alpha-lr', type=float, default=3e-4)
 | 
			
		||||
    parser.add_argument("--start-timesteps", type=int, default=1000)
 | 
			
		||||
    parser.add_argument('--epoch', type=int, default=5)
 | 
			
		||||
    parser.add_argument('--step-per-epoch', type=int, default=5000)
 | 
			
		||||
    parser.add_argument('--step-per-collect', type=int, default=1)
 | 
			
		||||
    parser.add_argument('--update-per-step', type=int, default=3)
 | 
			
		||||
    parser.add_argument('--n-step', type=int, default=1)
 | 
			
		||||
    parser.add_argument('--batch-size', type=int, default=64)
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        '--target-mode', type=str, choices=('min', 'mean'), default='min'
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[64, 64])
 | 
			
		||||
    parser.add_argument('--training-num', type=int, default=8)
 | 
			
		||||
    parser.add_argument('--test-num', type=int, default=100)
 | 
			
		||||
    parser.add_argument('--logdir', type=str, default='log')
 | 
			
		||||
    parser.add_argument('--render', type=float, default=0.)
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu'
 | 
			
		||||
    )
 | 
			
		||||
    args = parser.parse_known_args()[0]
 | 
			
		||||
    return args
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_redq(args=get_args()):
 | 
			
		||||
    env = gym.make(args.task)
 | 
			
		||||
    args.state_shape = env.observation_space.shape or env.observation_space.n
 | 
			
		||||
    args.action_shape = env.action_space.shape or env.action_space.n
 | 
			
		||||
    args.max_action = env.action_space.high[0]
 | 
			
		||||
    if args.reward_threshold is None:
 | 
			
		||||
        default_reward_threshold = {"Pendulum-v0": -250, "Pendulum-v1": -250}
 | 
			
		||||
        args.reward_threshold = default_reward_threshold.get(
 | 
			
		||||
            args.task, env.spec.reward_threshold
 | 
			
		||||
        )
 | 
			
		||||
    # you can also use tianshou.env.SubprocVectorEnv
 | 
			
		||||
    # train_envs = gym.make(args.task)
 | 
			
		||||
    train_envs = DummyVectorEnv(
 | 
			
		||||
        [lambda: gym.make(args.task) for _ in range(args.training_num)]
 | 
			
		||||
    )
 | 
			
		||||
    # test_envs = gym.make(args.task)
 | 
			
		||||
    test_envs = DummyVectorEnv(
 | 
			
		||||
        [lambda: gym.make(args.task) for _ in range(args.test_num)]
 | 
			
		||||
    )
 | 
			
		||||
    # seed
 | 
			
		||||
    np.random.seed(args.seed)
 | 
			
		||||
    torch.manual_seed(args.seed)
 | 
			
		||||
    train_envs.seed(args.seed)
 | 
			
		||||
    test_envs.seed(args.seed)
 | 
			
		||||
    # model
 | 
			
		||||
    net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device)
 | 
			
		||||
    actor = ActorProb(
 | 
			
		||||
        net,
 | 
			
		||||
        args.action_shape,
 | 
			
		||||
        max_action=args.max_action,
 | 
			
		||||
        device=args.device,
 | 
			
		||||
        unbounded=True,
 | 
			
		||||
        conditioned_sigma=True
 | 
			
		||||
    ).to(args.device)
 | 
			
		||||
    actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
 | 
			
		||||
 | 
			
		||||
    def linear(x, y):
 | 
			
		||||
        return EnsembleLinear(args.ensemble_size, x, y)
 | 
			
		||||
 | 
			
		||||
    net_c = Net(
 | 
			
		||||
        args.state_shape,
 | 
			
		||||
        args.action_shape,
 | 
			
		||||
        hidden_sizes=args.hidden_sizes,
 | 
			
		||||
        concat=True,
 | 
			
		||||
        device=args.device,
 | 
			
		||||
        linear_layer=linear,
 | 
			
		||||
    )
 | 
			
		||||
    critic = Critic(
 | 
			
		||||
        net_c, device=args.device, linear_layer=linear, flatten_input=False
 | 
			
		||||
    ).to(args.device)
 | 
			
		||||
    critic_optim = torch.optim.Adam(critic.parameters(), lr=args.critic_lr)
 | 
			
		||||
 | 
			
		||||
    if args.auto_alpha:
 | 
			
		||||
        target_entropy = -np.prod(env.action_space.shape)
 | 
			
		||||
        log_alpha = torch.zeros(1, requires_grad=True, device=args.device)
 | 
			
		||||
        alpha_optim = torch.optim.Adam([log_alpha], lr=args.alpha_lr)
 | 
			
		||||
        args.alpha = (target_entropy, log_alpha, alpha_optim)
 | 
			
		||||
 | 
			
		||||
    policy = REDQPolicy(
 | 
			
		||||
        actor,
 | 
			
		||||
        actor_optim,
 | 
			
		||||
        critic,
 | 
			
		||||
        critic_optim,
 | 
			
		||||
        args.ensemble_size,
 | 
			
		||||
        args.subset_size,
 | 
			
		||||
        tau=args.tau,
 | 
			
		||||
        gamma=args.gamma,
 | 
			
		||||
        alpha=args.alpha,
 | 
			
		||||
        estimation_step=args.n_step,
 | 
			
		||||
        actor_delay=args.update_per_step,
 | 
			
		||||
        target_mode=args.target_mode,
 | 
			
		||||
        action_space=env.action_space,
 | 
			
		||||
    )
 | 
			
		||||
    # collector
 | 
			
		||||
    train_collector = Collector(
 | 
			
		||||
        policy,
 | 
			
		||||
        train_envs,
 | 
			
		||||
        VectorReplayBuffer(args.buffer_size, len(train_envs)),
 | 
			
		||||
        exploration_noise=True
 | 
			
		||||
    )
 | 
			
		||||
    test_collector = Collector(policy, test_envs)
 | 
			
		||||
    train_collector.collect(n_step=args.start_timesteps, random=True)
 | 
			
		||||
    # log
 | 
			
		||||
    log_path = os.path.join(args.logdir, args.task, 'redq')
 | 
			
		||||
    writer = SummaryWriter(log_path)
 | 
			
		||||
    logger = TensorboardLogger(writer)
 | 
			
		||||
 | 
			
		||||
    def save_best_fn(policy):
 | 
			
		||||
        torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
 | 
			
		||||
 | 
			
		||||
    def stop_fn(mean_rewards):
 | 
			
		||||
        return mean_rewards >= args.reward_threshold
 | 
			
		||||
 | 
			
		||||
    # trainer
 | 
			
		||||
    result = offpolicy_trainer(
 | 
			
		||||
        policy,
 | 
			
		||||
        train_collector,
 | 
			
		||||
        test_collector,
 | 
			
		||||
        args.epoch,
 | 
			
		||||
        args.step_per_epoch,
 | 
			
		||||
        args.step_per_collect,
 | 
			
		||||
        args.test_num,
 | 
			
		||||
        args.batch_size,
 | 
			
		||||
        update_per_step=args.update_per_step,
 | 
			
		||||
        stop_fn=stop_fn,
 | 
			
		||||
        save_best_fn=save_best_fn,
 | 
			
		||||
        logger=logger
 | 
			
		||||
    )
 | 
			
		||||
    assert stop_fn(result['best_reward'])
 | 
			
		||||
 | 
			
		||||
    if __name__ == '__main__':
 | 
			
		||||
        pprint.pprint(result)
 | 
			
		||||
        # Let's watch its performance!
 | 
			
		||||
        env = gym.make(args.task)
 | 
			
		||||
        policy.eval()
 | 
			
		||||
        collector = Collector(policy, env)
 | 
			
		||||
        result = collector.collect(n_episode=1, render=args.render)
 | 
			
		||||
        rews, lens = result["rews"], result["lens"]
 | 
			
		||||
        print(f"Final reward: {rews.mean()}, length: {lens.mean()}")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == '__main__':
 | 
			
		||||
    test_redq()
 | 
			
		||||
							
								
								
									
										9
									
								
								tianshou/env/pettingzoo_env.py
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										9
									
								
								tianshou/env/pettingzoo_env.py
									
									
									
									
										vendored
									
									
								
							@ -55,8 +55,8 @@ class PettingZooEnv(AECEnv, ABC):
 | 
			
		||||
 | 
			
		||||
        self.reset()
 | 
			
		||||
 | 
			
		||||
    def reset(self) -> dict:
 | 
			
		||||
        self.env.reset()
 | 
			
		||||
    def reset(self, *args: Any, **kwargs: Any) -> dict:
 | 
			
		||||
        self.env.reset(*args, **kwargs)
 | 
			
		||||
        observation = self.env.observe(self.env.agent_selection)
 | 
			
		||||
        if isinstance(observation, dict) and 'action_mask' in observation:
 | 
			
		||||
            return {
 | 
			
		||||
@ -103,7 +103,10 @@ class PettingZooEnv(AECEnv, ABC):
 | 
			
		||||
        self.env.close()
 | 
			
		||||
 | 
			
		||||
    def seed(self, seed: Any = None) -> None:
 | 
			
		||||
        self.env.seed(seed)
 | 
			
		||||
        try:
 | 
			
		||||
            self.env.seed(seed)
 | 
			
		||||
        except NotImplementedError:
 | 
			
		||||
            self.env.reset(seed=seed)
 | 
			
		||||
 | 
			
		||||
    def render(self, mode: str = "human") -> Any:
 | 
			
		||||
        return self.env.render(mode)
 | 
			
		||||
 | 
			
		||||
@ -17,6 +17,7 @@ from tianshou.policy.modelfree.ppo import PPOPolicy
 | 
			
		||||
from tianshou.policy.modelfree.trpo import TRPOPolicy
 | 
			
		||||
from tianshou.policy.modelfree.td3 import TD3Policy
 | 
			
		||||
from tianshou.policy.modelfree.sac import SACPolicy
 | 
			
		||||
from tianshou.policy.modelfree.redq import REDQPolicy
 | 
			
		||||
from tianshou.policy.modelfree.discrete_sac import DiscreteSACPolicy
 | 
			
		||||
from tianshou.policy.imitation.base import ImitationPolicy
 | 
			
		||||
from tianshou.policy.imitation.bcq import BCQPolicy
 | 
			
		||||
@ -46,6 +47,7 @@ __all__ = [
 | 
			
		||||
    "TRPOPolicy",
 | 
			
		||||
    "TD3Policy",
 | 
			
		||||
    "SACPolicy",
 | 
			
		||||
    "REDQPolicy",
 | 
			
		||||
    "DiscreteSACPolicy",
 | 
			
		||||
    "ImitationPolicy",
 | 
			
		||||
    "BCQPolicy",
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										200
									
								
								tianshou/policy/modelfree/redq.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										200
									
								
								tianshou/policy/modelfree/redq.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,200 @@
 | 
			
		||||
from copy import deepcopy
 | 
			
		||||
from typing import Any, Dict, Optional, Tuple, Union
 | 
			
		||||
 | 
			
		||||
import numpy as np
 | 
			
		||||
import torch
 | 
			
		||||
from torch.distributions import Independent, Normal
 | 
			
		||||
 | 
			
		||||
from tianshou.data import Batch, ReplayBuffer
 | 
			
		||||
from tianshou.exploration import BaseNoise
 | 
			
		||||
from tianshou.policy import DDPGPolicy
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class REDQPolicy(DDPGPolicy):
 | 
			
		||||
    """Implementation of REDQ. arXiv:2101.05982.
 | 
			
		||||
 | 
			
		||||
    :param torch.nn.Module actor: the actor network following the rules in
 | 
			
		||||
        :class:`~tianshou.policy.BasePolicy`. (s -> logits)
 | 
			
		||||
    :param torch.optim.Optimizer actor_optim: the optimizer for actor network.
 | 
			
		||||
    :param torch.nn.Module critics: critic ensemble networks.
 | 
			
		||||
    :param torch.optim.Optimizer critics_optim: the optimizer for the critic networks.
 | 
			
		||||
    :param int ensemble_size: Number of sub-networks in the critic ensemble.
 | 
			
		||||
        Default to 10.
 | 
			
		||||
    :param int subset_size: Number of networks in the subset. Default to 2.
 | 
			
		||||
    :param float tau: param for soft update of the target network. Default to 0.005.
 | 
			
		||||
    :param float gamma: discount factor, in [0, 1]. Default to 0.99.
 | 
			
		||||
    :param (float, torch.Tensor, torch.optim.Optimizer) or float alpha: entropy
 | 
			
		||||
        regularization coefficient. Default to 0.2.
 | 
			
		||||
        If a tuple (target_entropy, log_alpha, alpha_optim) is provided, then
 | 
			
		||||
        alpha is automatically tuned.
 | 
			
		||||
    :param bool reward_normalization: normalize the reward to Normal(0, 1).
 | 
			
		||||
        Default to False.
 | 
			
		||||
    :param int actor_delay: Number of critic updates before an actor update.
 | 
			
		||||
        Default to 20.
 | 
			
		||||
    :param BaseNoise exploration_noise: add a noise to action for exploration.
 | 
			
		||||
        Default to None. This is useful when solving hard-exploration problem.
 | 
			
		||||
    :param bool deterministic_eval: whether to use deterministic action (mean
 | 
			
		||||
        of Gaussian policy) instead of stochastic action sampled by the policy.
 | 
			
		||||
        Default to True.
 | 
			
		||||
    :param str target_mode: methods to integrate critic values in the subset,
 | 
			
		||||
        currently support minimum and average. Default to min.
 | 
			
		||||
    :param bool action_scaling: whether to map actions from range [-1, 1] to range
 | 
			
		||||
        [action_spaces.low, action_spaces.high]. Default to True.
 | 
			
		||||
    :param str action_bound_method: method to bound action to range [-1, 1], can be
 | 
			
		||||
        either "clip" (for simply clipping the action) or empty string for no bounding.
 | 
			
		||||
        Default to "clip".
 | 
			
		||||
    :param Optional[gym.Space] action_space: env's action space, mandatory if you want
 | 
			
		||||
        to use option "action_scaling" or "action_bound_method". Default to None.
 | 
			
		||||
 | 
			
		||||
    .. seealso::
 | 
			
		||||
 | 
			
		||||
        Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed
 | 
			
		||||
        explanation.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        actor: torch.nn.Module,
 | 
			
		||||
        actor_optim: torch.optim.Optimizer,
 | 
			
		||||
        critics: torch.nn.Module,
 | 
			
		||||
        critics_optim: torch.optim.Optimizer,
 | 
			
		||||
        ensemble_size: int = 10,
 | 
			
		||||
        subset_size: int = 2,
 | 
			
		||||
        tau: float = 0.005,
 | 
			
		||||
        gamma: float = 0.99,
 | 
			
		||||
        alpha: Union[float, Tuple[float, torch.Tensor, torch.optim.Optimizer]] = 0.2,
 | 
			
		||||
        reward_normalization: bool = False,
 | 
			
		||||
        estimation_step: int = 1,
 | 
			
		||||
        actor_delay: int = 20,
 | 
			
		||||
        exploration_noise: Optional[BaseNoise] = None,
 | 
			
		||||
        deterministic_eval: bool = True,
 | 
			
		||||
        target_mode: str = "min",
 | 
			
		||||
        **kwargs: Any,
 | 
			
		||||
    ) -> None:
 | 
			
		||||
        super().__init__(
 | 
			
		||||
            None, None, None, None, tau, gamma, exploration_noise,
 | 
			
		||||
            reward_normalization, estimation_step, **kwargs
 | 
			
		||||
        )
 | 
			
		||||
        self.actor, self.actor_optim = actor, actor_optim
 | 
			
		||||
        self.critics, self.critics_old = critics, deepcopy(critics)
 | 
			
		||||
        self.critics_old.eval()
 | 
			
		||||
        self.critics_optim = critics_optim
 | 
			
		||||
        assert 0 < subset_size <= ensemble_size, \
 | 
			
		||||
            "Invalid choice of ensemble size or subset size."
 | 
			
		||||
        self.ensemble_size = ensemble_size
 | 
			
		||||
        self.subset_size = subset_size
 | 
			
		||||
 | 
			
		||||
        self._is_auto_alpha = False
 | 
			
		||||
        self._alpha: Union[float, torch.Tensor]
 | 
			
		||||
        if isinstance(alpha, tuple):
 | 
			
		||||
            self._is_auto_alpha = True
 | 
			
		||||
            self._target_entropy, self._log_alpha, self._alpha_optim = alpha
 | 
			
		||||
            assert alpha[1].shape == torch.Size([1]) and alpha[1].requires_grad
 | 
			
		||||
            self._alpha = self._log_alpha.detach().exp()
 | 
			
		||||
        else:
 | 
			
		||||
            self._alpha = alpha
 | 
			
		||||
 | 
			
		||||
        if target_mode in ("min", "mean"):
 | 
			
		||||
            self.target_mode = target_mode
 | 
			
		||||
        else:
 | 
			
		||||
            raise ValueError("Unsupported mode of Q target computing.")
 | 
			
		||||
 | 
			
		||||
        self.critic_gradient_step = 0
 | 
			
		||||
        self.actor_delay = actor_delay
 | 
			
		||||
        self._deterministic_eval = deterministic_eval
 | 
			
		||||
        self.__eps = np.finfo(np.float32).eps.item()
 | 
			
		||||
 | 
			
		||||
    def train(self, mode: bool = True) -> "REDQPolicy":
 | 
			
		||||
        self.training = mode
 | 
			
		||||
        self.actor.train(mode)
 | 
			
		||||
        self.critics.train(mode)
 | 
			
		||||
        return self
 | 
			
		||||
 | 
			
		||||
    def sync_weight(self) -> None:
 | 
			
		||||
        for o, n in zip(self.critics_old.parameters(), self.critics.parameters()):
 | 
			
		||||
            o.data.copy_(o.data * (1.0 - self.tau) + n.data * self.tau)
 | 
			
		||||
 | 
			
		||||
    def forward(  # type: ignore
 | 
			
		||||
        self,
 | 
			
		||||
        batch: Batch,
 | 
			
		||||
        state: Optional[Union[dict, Batch, np.ndarray]] = None,
 | 
			
		||||
        input: str = "obs",
 | 
			
		||||
        **kwargs: Any,
 | 
			
		||||
    ) -> Batch:
 | 
			
		||||
        obs = batch[input]
 | 
			
		||||
        logits, h = self.actor(obs, state=state, info=batch.info)
 | 
			
		||||
        assert isinstance(logits, tuple)
 | 
			
		||||
        dist = Independent(Normal(*logits), 1)
 | 
			
		||||
        if self._deterministic_eval and not self.training:
 | 
			
		||||
            act = logits[0]
 | 
			
		||||
        else:
 | 
			
		||||
            act = dist.rsample()
 | 
			
		||||
        log_prob = dist.log_prob(act).unsqueeze(-1)
 | 
			
		||||
        # apply correction for Tanh squashing when computing logprob from Gaussian
 | 
			
		||||
        # You can check out the original SAC paper (arXiv 1801.01290): Eq 21.
 | 
			
		||||
        # in appendix C to get some understanding of this equation.
 | 
			
		||||
        squashed_action = torch.tanh(act)
 | 
			
		||||
        log_prob = log_prob - torch.log((1 - squashed_action.pow(2)) +
 | 
			
		||||
                                        self.__eps).sum(-1, keepdim=True)
 | 
			
		||||
        return Batch(
 | 
			
		||||
            logits=logits, act=squashed_action, state=h, dist=dist, log_prob=log_prob
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor:
 | 
			
		||||
        batch = buffer[indices]  # batch.obs: s_{t+n}
 | 
			
		||||
        obs_next_result = self(batch, input="obs_next")
 | 
			
		||||
        a_ = obs_next_result.act
 | 
			
		||||
        sample_ensemble_idx = np.random.choice(
 | 
			
		||||
            self.ensemble_size, self.subset_size, replace=False
 | 
			
		||||
        )
 | 
			
		||||
        qs = self.critics_old(batch.obs_next, a_)[sample_ensemble_idx, ...]
 | 
			
		||||
        if self.target_mode == "min":
 | 
			
		||||
            target_q, _ = torch.min(qs, dim=0)
 | 
			
		||||
        elif self.target_mode == "mean":
 | 
			
		||||
            target_q = torch.mean(qs, dim=0)
 | 
			
		||||
        target_q -= self._alpha * obs_next_result.log_prob
 | 
			
		||||
 | 
			
		||||
        return target_q
 | 
			
		||||
 | 
			
		||||
    def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
 | 
			
		||||
        # critic ensemble
 | 
			
		||||
        weight = getattr(batch, "weight", 1.0)
 | 
			
		||||
        current_qs = self.critics(batch.obs, batch.act).flatten(1)
 | 
			
		||||
        target_q = batch.returns.flatten()
 | 
			
		||||
        td = current_qs - target_q
 | 
			
		||||
        critic_loss = (td.pow(2) * weight).mean()
 | 
			
		||||
        self.critics_optim.zero_grad()
 | 
			
		||||
        critic_loss.backward()
 | 
			
		||||
        self.critics_optim.step()
 | 
			
		||||
        batch.weight = torch.mean(td, dim=0)  # prio-buffer
 | 
			
		||||
        self.critic_gradient_step += 1
 | 
			
		||||
 | 
			
		||||
        # actor
 | 
			
		||||
        if self.critic_gradient_step % self.actor_delay == 0:
 | 
			
		||||
            obs_result = self(batch)
 | 
			
		||||
            a = obs_result.act
 | 
			
		||||
            current_qa = self.critics(batch.obs, a).mean(dim=0).flatten()
 | 
			
		||||
            actor_loss = (self._alpha * obs_result.log_prob.flatten() -
 | 
			
		||||
                          current_qa).mean()
 | 
			
		||||
            self.actor_optim.zero_grad()
 | 
			
		||||
            actor_loss.backward()
 | 
			
		||||
            self.actor_optim.step()
 | 
			
		||||
 | 
			
		||||
            if self._is_auto_alpha:
 | 
			
		||||
                log_prob = obs_result.log_prob.detach() + self._target_entropy
 | 
			
		||||
                alpha_loss = -(self._log_alpha * log_prob).mean()
 | 
			
		||||
                self._alpha_optim.zero_grad()
 | 
			
		||||
                alpha_loss.backward()
 | 
			
		||||
                self._alpha_optim.step()
 | 
			
		||||
                self._alpha = self._log_alpha.detach().exp()
 | 
			
		||||
 | 
			
		||||
        self.sync_weight()
 | 
			
		||||
 | 
			
		||||
        result = {"loss/critics": critic_loss.item()}
 | 
			
		||||
        if self.critic_gradient_step % self.actor_delay == 0:
 | 
			
		||||
            result["loss/actor"] = actor_loss.item(),
 | 
			
		||||
            if self._is_auto_alpha:
 | 
			
		||||
                result["loss/alpha"] = alpha_loss.item()
 | 
			
		||||
                result["alpha"] = self._alpha.item()  # type: ignore
 | 
			
		||||
 | 
			
		||||
        return result
 | 
			
		||||
@ -1,4 +1,14 @@
 | 
			
		||||
from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, Union
 | 
			
		||||
from typing import (
 | 
			
		||||
    Any,
 | 
			
		||||
    Dict,
 | 
			
		||||
    List,
 | 
			
		||||
    Optional,
 | 
			
		||||
    Sequence,
 | 
			
		||||
    Tuple,
 | 
			
		||||
    Type,
 | 
			
		||||
    Union,
 | 
			
		||||
    no_type_check,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
import numpy as np
 | 
			
		||||
import torch
 | 
			
		||||
@ -46,6 +56,7 @@ class MLP(nn.Module):
 | 
			
		||||
        nn.ReLU.
 | 
			
		||||
    :param device: which device to create this model on. Default to None.
 | 
			
		||||
    :param linear_layer: use this module as linear layer. Default to nn.Linear.
 | 
			
		||||
    :param bool flatten_input: whether to flatten input data. Default to True.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(
 | 
			
		||||
@ -57,6 +68,7 @@ class MLP(nn.Module):
 | 
			
		||||
        activation: Optional[Union[ModuleType, Sequence[ModuleType]]] = nn.ReLU,
 | 
			
		||||
        device: Optional[Union[str, int, torch.device]] = None,
 | 
			
		||||
        linear_layer: Type[nn.Linear] = nn.Linear,
 | 
			
		||||
        flatten_input: bool = True,
 | 
			
		||||
    ) -> None:
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        self.device = device
 | 
			
		||||
@ -86,15 +98,15 @@ class MLP(nn.Module):
 | 
			
		||||
            model += [linear_layer(hidden_sizes[-1], output_dim)]
 | 
			
		||||
        self.output_dim = output_dim or hidden_sizes[-1]
 | 
			
		||||
        self.model = nn.Sequential(*model)
 | 
			
		||||
        self.flatten_input = flatten_input
 | 
			
		||||
 | 
			
		||||
    @no_type_check
 | 
			
		||||
    def forward(self, obs: Union[np.ndarray, torch.Tensor]) -> torch.Tensor:
 | 
			
		||||
        if self.device is not None:
 | 
			
		||||
            obs = torch.as_tensor(
 | 
			
		||||
                obs,
 | 
			
		||||
                device=self.device,  # type: ignore
 | 
			
		||||
                dtype=torch.float32,
 | 
			
		||||
            )
 | 
			
		||||
        return self.model(obs.flatten(1))  # type: ignore
 | 
			
		||||
            obs = torch.as_tensor(obs, device=self.device, dtype=torch.float32)
 | 
			
		||||
        if self.flatten_input:
 | 
			
		||||
            obs = obs.flatten(1)
 | 
			
		||||
        return self.model(obs)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Net(nn.Module):
 | 
			
		||||
@ -129,6 +141,7 @@ class Net(nn.Module):
 | 
			
		||||
        pass a tuple of two dict (first for Q and second for V) stating
 | 
			
		||||
        self-defined arguments as stated in
 | 
			
		||||
        class:`~tianshou.utils.net.common.MLP`. Default to None.
 | 
			
		||||
    :param linear_layer: use this module as linear layer. Default to nn.Linear.
 | 
			
		||||
 | 
			
		||||
    .. seealso::
 | 
			
		||||
 | 
			
		||||
@ -152,6 +165,7 @@ class Net(nn.Module):
 | 
			
		||||
        concat: bool = False,
 | 
			
		||||
        num_atoms: int = 1,
 | 
			
		||||
        dueling_param: Optional[Tuple[Dict[str, Any], Dict[str, Any]]] = None,
 | 
			
		||||
        linear_layer: Type[nn.Linear] = nn.Linear,
 | 
			
		||||
    ) -> None:
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        self.device = device
 | 
			
		||||
@ -164,7 +178,8 @@ class Net(nn.Module):
 | 
			
		||||
        self.use_dueling = dueling_param is not None
 | 
			
		||||
        output_dim = action_dim if not self.use_dueling and not concat else 0
 | 
			
		||||
        self.model = MLP(
 | 
			
		||||
            input_dim, output_dim, hidden_sizes, norm_layer, activation, device
 | 
			
		||||
            input_dim, output_dim, hidden_sizes, norm_layer, activation, device,
 | 
			
		||||
            linear_layer
 | 
			
		||||
        )
 | 
			
		||||
        self.output_dim = self.model.output_dim
 | 
			
		||||
        if self.use_dueling:  # dueling DQN
 | 
			
		||||
@ -311,3 +326,40 @@ class DataParallelNet(nn.Module):
 | 
			
		||||
        if not isinstance(obs, torch.Tensor):
 | 
			
		||||
            obs = torch.as_tensor(obs, dtype=torch.float32)
 | 
			
		||||
        return self.net(obs=obs.cuda(), *args, **kwargs)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class EnsembleLinear(nn.Module):
 | 
			
		||||
    """Linear Layer of Ensemble network.
 | 
			
		||||
 | 
			
		||||
    :param int ensemble_size: Number of subnets in the ensemble.
 | 
			
		||||
    :param int inp_feature: dimension of the input vector.
 | 
			
		||||
    :param int out_feature: dimension of the output vector.
 | 
			
		||||
    :param bool bias: whether to include an additive bias, default to be True.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        ensemble_size: int,
 | 
			
		||||
        in_feature: int,
 | 
			
		||||
        out_feature: int,
 | 
			
		||||
        bias: bool = True,
 | 
			
		||||
    ) -> None:
 | 
			
		||||
        super().__init__()
 | 
			
		||||
 | 
			
		||||
        # To be consistent with PyTorch default initializer
 | 
			
		||||
        k = np.sqrt(1. / in_feature)
 | 
			
		||||
        weight_data = torch.rand((ensemble_size, in_feature, out_feature)) * 2 * k - k
 | 
			
		||||
        self.weight = nn.Parameter(weight_data, requires_grad=True)
 | 
			
		||||
 | 
			
		||||
        self.bias: Union[nn.Parameter, None]
 | 
			
		||||
        if bias:
 | 
			
		||||
            bias_data = torch.rand((ensemble_size, 1, out_feature)) * 2 * k - k
 | 
			
		||||
            self.bias = nn.Parameter(bias_data, requires_grad=True)
 | 
			
		||||
        else:
 | 
			
		||||
            self.bias = None
 | 
			
		||||
 | 
			
		||||
    def forward(self, x: torch.Tensor) -> torch.Tensor:
 | 
			
		||||
        x = torch.matmul(x, self.weight)
 | 
			
		||||
        if self.bias is not None:
 | 
			
		||||
            x = x + self.bias
 | 
			
		||||
        return x
 | 
			
		||||
 | 
			
		||||
@ -1,4 +1,4 @@
 | 
			
		||||
from typing import Any, Dict, Optional, Sequence, Tuple, Union
 | 
			
		||||
from typing import Any, Dict, Optional, Sequence, Tuple, Type, Union
 | 
			
		||||
 | 
			
		||||
import numpy as np
 | 
			
		||||
import torch
 | 
			
		||||
@ -79,6 +79,9 @@ class Critic(nn.Module):
 | 
			
		||||
        only a single linear layer).
 | 
			
		||||
    :param int preprocess_net_output_dim: the output dimension of
 | 
			
		||||
        preprocess_net.
 | 
			
		||||
    :param linear_layer: use this module as linear layer. Default to nn.Linear.
 | 
			
		||||
    :param bool flatten_input: whether to flatten input data for the last layer.
 | 
			
		||||
        Default to True.
 | 
			
		||||
 | 
			
		||||
    For advanced usage (how to customize the network), please refer to
 | 
			
		||||
    :ref:`build_the_network`.
 | 
			
		||||
@ -95,6 +98,8 @@ class Critic(nn.Module):
 | 
			
		||||
        hidden_sizes: Sequence[int] = (),
 | 
			
		||||
        device: Union[str, int, torch.device] = "cpu",
 | 
			
		||||
        preprocess_net_output_dim: Optional[int] = None,
 | 
			
		||||
        linear_layer: Type[nn.Linear] = nn.Linear,
 | 
			
		||||
        flatten_input: bool = True,
 | 
			
		||||
    ) -> None:
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        self.device = device
 | 
			
		||||
@ -105,7 +110,9 @@ class Critic(nn.Module):
 | 
			
		||||
            input_dim,  # type: ignore
 | 
			
		||||
            1,
 | 
			
		||||
            hidden_sizes,
 | 
			
		||||
            device=self.device
 | 
			
		||||
            device=self.device,
 | 
			
		||||
            linear_layer=linear_layer,
 | 
			
		||||
            flatten_input=flatten_input,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def forward(
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user