parent
ad395b5235
commit
99a1d40e85
@ -20,7 +20,8 @@
|
|||||||
|
|
||||||
- [Policy Gradient (PG)](https://papers.nips.cc/paper/1713-policy-gradient-methods-for-reinforcement-learning-with-function-approximation.pdf)
|
- [Policy Gradient (PG)](https://papers.nips.cc/paper/1713-policy-gradient-methods-for-reinforcement-learning-with-function-approximation.pdf)
|
||||||
- [Deep Q-Network (DQN)](https://storage.googleapis.com/deepmind-media/dqn/DQNNaturePaper.pdf)
|
- [Deep Q-Network (DQN)](https://storage.googleapis.com/deepmind-media/dqn/DQNNaturePaper.pdf)
|
||||||
- [Double DQN (DDQN)](https://arxiv.org/pdf/1509.06461.pdf)
|
- [Double DQN](https://arxiv.org/pdf/1509.06461.pdf)
|
||||||
|
- [Dueling DQN](https://arxiv.org/pdf/1511.06581.pdf)
|
||||||
- [Advantage Actor-Critic (A2C)](https://openai.com/blog/baselines-acktr-a2c/)
|
- [Advantage Actor-Critic (A2C)](https://openai.com/blog/baselines-acktr-a2c/)
|
||||||
- [Deep Deterministic Policy Gradient (DDPG)](https://arxiv.org/pdf/1509.02971.pdf)
|
- [Deep Deterministic Policy Gradient (DDPG)](https://arxiv.org/pdf/1509.02971.pdf)
|
||||||
- [Proximal Policy Optimization (PPO)](https://arxiv.org/pdf/1707.06347.pdf)
|
- [Proximal Policy Optimization (PPO)](https://arxiv.org/pdf/1707.06347.pdf)
|
||||||
|
@ -11,6 +11,7 @@ Welcome to Tianshou!
|
|||||||
* :class:`~tianshou.policy.PGPolicy` `Policy Gradient <https://papers.nips.cc/paper/1713-policy-gradient-methods-for-reinforcement-learning-with-function-approximation.pdf>`_
|
* :class:`~tianshou.policy.PGPolicy` `Policy Gradient <https://papers.nips.cc/paper/1713-policy-gradient-methods-for-reinforcement-learning-with-function-approximation.pdf>`_
|
||||||
* :class:`~tianshou.policy.DQNPolicy` `Deep Q-Network <https://storage.googleapis.com/deepmind-media/dqn/DQNNaturePaper.pdf>`_
|
* :class:`~tianshou.policy.DQNPolicy` `Deep Q-Network <https://storage.googleapis.com/deepmind-media/dqn/DQNNaturePaper.pdf>`_
|
||||||
* :class:`~tianshou.policy.DQNPolicy` `Double DQN <https://arxiv.org/pdf/1509.06461.pdf>`_
|
* :class:`~tianshou.policy.DQNPolicy` `Double DQN <https://arxiv.org/pdf/1509.06461.pdf>`_
|
||||||
|
* :class:`~tianshou.policy.DQNPolicy` `Dueling DQN <https://arxiv.org/pdf/1511.06581.pdf>`_
|
||||||
* :class:`~tianshou.policy.A2CPolicy` `Advantage Actor-Critic <https://openai.com/blog/baselines-acktr-a2c/>`_
|
* :class:`~tianshou.policy.A2CPolicy` `Advantage Actor-Critic <https://openai.com/blog/baselines-acktr-a2c/>`_
|
||||||
* :class:`~tianshou.policy.DDPGPolicy` `Deep Deterministic Policy Gradient <https://arxiv.org/pdf/1509.02971.pdf>`_
|
* :class:`~tianshou.policy.DDPGPolicy` `Deep Deterministic Policy Gradient <https://arxiv.org/pdf/1509.02971.pdf>`_
|
||||||
* :class:`~tianshou.policy.PPOPolicy` `Proximal Policy Optimization <https://arxiv.org/pdf/1707.06347.pdf>`_
|
* :class:`~tianshou.policy.PPOPolicy` `Proximal Policy Optimization <https://arxiv.org/pdf/1707.06347.pdf>`_
|
||||||
|
@ -193,7 +193,7 @@ The explanation of each Tianshou class/function will be deferred to their first
|
|||||||
parser.add_argument('--eps-train', type=float, default=0.1)
|
parser.add_argument('--eps-train', type=float, default=0.1)
|
||||||
parser.add_argument('--buffer-size', type=int, default=20000)
|
parser.add_argument('--buffer-size', type=int, default=20000)
|
||||||
parser.add_argument('--lr', type=float, default=1e-3)
|
parser.add_argument('--lr', type=float, default=1e-3)
|
||||||
parser.add_argument('--gamma', type=float, default=0.1,
|
parser.add_argument('--gamma', type=float, default=0.9,
|
||||||
help='a smaller gamma favors earlier win')
|
help='a smaller gamma favors earlier win')
|
||||||
parser.add_argument('--n-step', type=int, default=3)
|
parser.add_argument('--n-step', type=int, default=3)
|
||||||
parser.add_argument('--target-update-freq', type=int, default=320)
|
parser.add_argument('--target-update-freq', type=int, default=320)
|
||||||
|
116
examples/acrobot_dualdqn.py
Normal file
116
examples/acrobot_dualdqn.py
Normal file
@ -0,0 +1,116 @@
|
|||||||
|
import os
|
||||||
|
import gym
|
||||||
|
import torch
|
||||||
|
import pprint
|
||||||
|
import argparse
|
||||||
|
import numpy as np
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
|
from tianshou.env import VectorEnv
|
||||||
|
from tianshou.policy import DQNPolicy
|
||||||
|
from tianshou.trainer import offpolicy_trainer
|
||||||
|
from tianshou.data import Collector, ReplayBuffer
|
||||||
|
from tianshou.utils.net.common import Net
|
||||||
|
|
||||||
|
|
||||||
|
def get_args():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--task', type=str, default='Acrobot-v1')
|
||||||
|
parser.add_argument('--seed', type=int, default=0)
|
||||||
|
parser.add_argument('--eps-test', type=float, default=0.05)
|
||||||
|
parser.add_argument('--eps-train', type=float, default=0.5)
|
||||||
|
parser.add_argument('--buffer-size', type=int, default=20000)
|
||||||
|
parser.add_argument('--lr', type=float, default=1e-3)
|
||||||
|
parser.add_argument('--gamma', type=float, default=0.95)
|
||||||
|
parser.add_argument('--n-step', type=int, default=3)
|
||||||
|
parser.add_argument('--target-update-freq', type=int, default=320)
|
||||||
|
parser.add_argument('--epoch', type=int, default=10)
|
||||||
|
parser.add_argument('--step-per-epoch', type=int, default=1000)
|
||||||
|
parser.add_argument('--collect-per-step', type=int, default=100)
|
||||||
|
parser.add_argument('--batch-size', type=int, default=64)
|
||||||
|
parser.add_argument('--layer-num', type=int, default=0)
|
||||||
|
parser.add_argument('--training-num', type=int, default=8)
|
||||||
|
parser.add_argument('--test-num', type=int, default=100)
|
||||||
|
parser.add_argument('--logdir', type=str, default='log')
|
||||||
|
parser.add_argument('--render', type=float, default=0.)
|
||||||
|
parser.add_argument(
|
||||||
|
'--device', type=str,
|
||||||
|
default='cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
|
args = parser.parse_known_args()[0]
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
def test_dqn(args=get_args()):
|
||||||
|
env = gym.make(args.task)
|
||||||
|
args.state_shape = env.observation_space.shape or env.observation_space.n
|
||||||
|
args.action_shape = env.action_space.shape or env.action_space.n
|
||||||
|
# train_envs = gym.make(args.task)
|
||||||
|
# you can also use tianshou.env.SubprocVectorEnv
|
||||||
|
train_envs = VectorEnv(
|
||||||
|
[lambda: gym.make(args.task) for _ in range(args.training_num)])
|
||||||
|
# test_envs = gym.make(args.task)
|
||||||
|
test_envs = VectorEnv(
|
||||||
|
[lambda: gym.make(args.task) for _ in range(args.test_num)])
|
||||||
|
# seed
|
||||||
|
np.random.seed(args.seed)
|
||||||
|
torch.manual_seed(args.seed)
|
||||||
|
train_envs.seed(args.seed)
|
||||||
|
test_envs.seed(args.seed)
|
||||||
|
# model
|
||||||
|
net = Net(args.layer_num, args.state_shape,
|
||||||
|
args.action_shape, args.device, dueling=(2, 2)).to(args.device)
|
||||||
|
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
|
||||||
|
policy = DQNPolicy(
|
||||||
|
net, optim, args.gamma, args.n_step,
|
||||||
|
target_update_freq=args.target_update_freq)
|
||||||
|
# collector
|
||||||
|
train_collector = Collector(
|
||||||
|
policy, train_envs, ReplayBuffer(args.buffer_size))
|
||||||
|
test_collector = Collector(policy, test_envs)
|
||||||
|
# policy.set_eps(1)
|
||||||
|
train_collector.collect(n_step=args.batch_size)
|
||||||
|
# log
|
||||||
|
log_path = os.path.join(args.logdir, args.task, 'dqn')
|
||||||
|
writer = SummaryWriter(log_path)
|
||||||
|
|
||||||
|
def save_fn(policy):
|
||||||
|
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||||
|
|
||||||
|
def stop_fn(x):
|
||||||
|
return x >= env.spec.reward_threshold
|
||||||
|
|
||||||
|
def train_fn(x):
|
||||||
|
if x <= int(0.1 * args.epoch):
|
||||||
|
policy.set_eps(args.eps_train)
|
||||||
|
elif x <= int(0.5 * args.epoch):
|
||||||
|
eps = args.eps_train - (x - 0.1 * args.epoch) / \
|
||||||
|
(0.4 * args.epoch) * (0.5 * args.eps_train)
|
||||||
|
policy.set_eps(eps)
|
||||||
|
else:
|
||||||
|
policy.set_eps(0.5 * args.eps_train)
|
||||||
|
|
||||||
|
def test_fn(x):
|
||||||
|
policy.set_eps(args.eps_test)
|
||||||
|
|
||||||
|
# trainer
|
||||||
|
result = offpolicy_trainer(
|
||||||
|
policy, train_collector, test_collector, args.epoch,
|
||||||
|
args.step_per_epoch, args.collect_per_step, args.test_num,
|
||||||
|
args.batch_size, train_fn=train_fn, test_fn=test_fn,
|
||||||
|
stop_fn=stop_fn, save_fn=save_fn, writer=writer)
|
||||||
|
|
||||||
|
assert stop_fn(result['best_reward'])
|
||||||
|
train_collector.close()
|
||||||
|
test_collector.close()
|
||||||
|
if __name__ == '__main__':
|
||||||
|
pprint.pprint(result)
|
||||||
|
# Let's watch its performance!
|
||||||
|
env = gym.make(args.task)
|
||||||
|
collector = Collector(policy, env)
|
||||||
|
result = collector.collect(n_episode=1, render=args.render)
|
||||||
|
print(f'Final reward: {result["rew"]}, length: {result["len"]}')
|
||||||
|
collector.close()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_dqn(get_args())
|
@ -58,7 +58,8 @@ def test_dqn(args=get_args()):
|
|||||||
test_envs.seed(args.seed)
|
test_envs.seed(args.seed)
|
||||||
# model
|
# model
|
||||||
net = Net(args.layer_num, args.state_shape,
|
net = Net(args.layer_num, args.state_shape,
|
||||||
args.action_shape, args.device).to(args.device)
|
args.action_shape, args.device,
|
||||||
|
dueling=(2, 2)).to(args.device)
|
||||||
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
|
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
|
||||||
policy = DQNPolicy(
|
policy = DQNPolicy(
|
||||||
net, optim, args.gamma, args.n_step,
|
net, optim, args.gamma, args.n_step,
|
||||||
@ -80,7 +81,15 @@ def test_dqn(args=get_args()):
|
|||||||
return x >= env.spec.reward_threshold
|
return x >= env.spec.reward_threshold
|
||||||
|
|
||||||
def train_fn(x):
|
def train_fn(x):
|
||||||
policy.set_eps(args.eps_train)
|
# eps annnealing, just a demo
|
||||||
|
if x <= int(0.1 * args.epoch):
|
||||||
|
policy.set_eps(args.eps_train)
|
||||||
|
elif x <= int(0.5 * args.epoch):
|
||||||
|
eps = args.eps_train - (x - 0.1 * args.epoch) / \
|
||||||
|
(0.4 * args.epoch) * (0.9 * args.eps_train)
|
||||||
|
policy.set_eps(eps)
|
||||||
|
else:
|
||||||
|
policy.set_eps(0.1 * args.eps_train)
|
||||||
|
|
||||||
def test_fn(x):
|
def test_fn(x):
|
||||||
policy.set_eps(args.eps_test)
|
policy.set_eps(args.eps_test)
|
||||||
|
@ -23,7 +23,7 @@ def get_parser() -> argparse.ArgumentParser:
|
|||||||
parser.add_argument('--eps-train', type=float, default=0.1)
|
parser.add_argument('--eps-train', type=float, default=0.1)
|
||||||
parser.add_argument('--buffer-size', type=int, default=20000)
|
parser.add_argument('--buffer-size', type=int, default=20000)
|
||||||
parser.add_argument('--lr', type=float, default=1e-3)
|
parser.add_argument('--lr', type=float, default=1e-3)
|
||||||
parser.add_argument('--gamma', type=float, default=0.1,
|
parser.add_argument('--gamma', type=float, default=0.9,
|
||||||
help='a smaller gamma favors earlier win')
|
help='a smaller gamma favors earlier win')
|
||||||
parser.add_argument('--n-step', type=int, default=3)
|
parser.add_argument('--n-step', type=int, default=3)
|
||||||
parser.add_argument('--target-update-freq', type=int, default=320)
|
parser.add_argument('--target-update-freq', type=int, default=320)
|
||||||
@ -38,7 +38,7 @@ def get_parser() -> argparse.ArgumentParser:
|
|||||||
parser.add_argument('--render', type=float, default=0.1)
|
parser.add_argument('--render', type=float, default=0.1)
|
||||||
parser.add_argument('--board_size', type=int, default=6)
|
parser.add_argument('--board_size', type=int, default=6)
|
||||||
parser.add_argument('--win_size', type=int, default=4)
|
parser.add_argument('--win_size', type=int, default=4)
|
||||||
parser.add_argument('--win-rate', type=float, default=0.8,
|
parser.add_argument('--win_rate', type=float, default=0.9,
|
||||||
help='the expected winning rate')
|
help='the expected winning rate')
|
||||||
parser.add_argument('--watch', default=False, action='store_true',
|
parser.add_argument('--watch', default=False, action='store_true',
|
||||||
help='no training, '
|
help='no training, '
|
||||||
|
@ -11,8 +11,12 @@ from tianshou.data import Batch, ReplayBuffer, PrioritizedReplayBuffer, \
|
|||||||
|
|
||||||
class DQNPolicy(BasePolicy):
|
class DQNPolicy(BasePolicy):
|
||||||
"""Implementation of Deep Q Network. arXiv:1312.5602
|
"""Implementation of Deep Q Network. arXiv:1312.5602
|
||||||
|
|
||||||
Implementation of Double Q-Learning. arXiv:1509.06461
|
Implementation of Double Q-Learning. arXiv:1509.06461
|
||||||
|
|
||||||
|
Implementation of Dueling DQN. arXiv:1511.06581 (the dueling DQN is
|
||||||
|
implemented in the network side, not here)
|
||||||
|
|
||||||
:param torch.nn.Module model: a model following the rules in
|
:param torch.nn.Module model: a model following the rules in
|
||||||
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
|
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
|
||||||
:param torch.optim.Optimizer optim: a torch.optim for optimizing the model.
|
:param torch.optim.Optimizer optim: a torch.optim for optimizing the model.
|
||||||
|
@ -1,36 +1,77 @@
|
|||||||
import numpy as np
|
|
||||||
import torch
|
import torch
|
||||||
|
import numpy as np
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
from typing import Tuple, Union, Optional
|
||||||
|
|
||||||
from tianshou.data import to_torch
|
from tianshou.data import to_torch
|
||||||
|
|
||||||
|
|
||||||
|
def miniblock(inp: int, oup: int, norm_layer: nn.modules.Module):
|
||||||
|
ret = [nn.Linear(inp, oup)]
|
||||||
|
if norm_layer is not None:
|
||||||
|
ret += [norm_layer(oup)]
|
||||||
|
ret += [nn.ReLU(inplace=True)]
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
class Net(nn.Module):
|
class Net(nn.Module):
|
||||||
"""Simple MLP backbone. For advanced usage (how to customize the network),
|
"""Simple MLP backbone. For advanced usage (how to customize the network),
|
||||||
please refer to :ref:`build_the_network`.
|
please refer to :ref:`build_the_network`.
|
||||||
|
|
||||||
:param concat: whether the input shape is concatenated by state_shape
|
:param bool concat: whether the input shape is concatenated by state_shape
|
||||||
and action_shape. If it is True, ``action_shape`` is not the output
|
and action_shape. If it is True, ``action_shape`` is not the output
|
||||||
shape, but affects the input shape.
|
shape, but affects the input shape.
|
||||||
|
:param bool dueling: whether to use dueling network to calculate Q values
|
||||||
|
(for Dueling DQN), defaults to False.
|
||||||
|
:param nn.modules.Module norm_layer: use which normalization before ReLU,
|
||||||
|
e.g., ``nn.LayerNorm`` and ``nn.BatchNorm1d``, defaults to None.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, layer_num, state_shape, action_shape=0, device='cpu',
|
def __init__(self, layer_num: int, state_shape: tuple,
|
||||||
softmax=False, concat=False, hidden_layer_size=128):
|
action_shape: Optional[tuple] = 0,
|
||||||
|
device: Union[str, torch.device] = 'cpu',
|
||||||
|
softmax: bool = False,
|
||||||
|
concat: bool = False,
|
||||||
|
hidden_layer_size: int = 128,
|
||||||
|
dueling: Optional[Tuple[int, int]] = None,
|
||||||
|
norm_layer: Optional[nn.modules.Module] = None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.device = device
|
self.device = device
|
||||||
|
self.dueling = dueling
|
||||||
|
self.softmax = softmax
|
||||||
input_size = np.prod(state_shape)
|
input_size = np.prod(state_shape)
|
||||||
if concat:
|
if concat:
|
||||||
input_size += np.prod(action_shape)
|
input_size += np.prod(action_shape)
|
||||||
self.model = [
|
|
||||||
nn.Linear(input_size, hidden_layer_size),
|
self.model = miniblock(input_size, hidden_layer_size, norm_layer)
|
||||||
nn.ReLU(inplace=True)]
|
|
||||||
for i in range(layer_num):
|
for i in range(layer_num):
|
||||||
self.model += [nn.Linear(hidden_layer_size, hidden_layer_size),
|
self.model += miniblock(hidden_layer_size,
|
||||||
nn.ReLU(inplace=True)]
|
hidden_layer_size, norm_layer)
|
||||||
if action_shape and not concat:
|
|
||||||
self.model += [nn.Linear(hidden_layer_size, np.prod(action_shape))]
|
if self.dueling is None:
|
||||||
if softmax:
|
if action_shape and not concat:
|
||||||
self.model += [nn.Softmax(dim=-1)]
|
self.model += [nn.Linear(hidden_layer_size,
|
||||||
|
np.prod(action_shape))]
|
||||||
|
else: # dueling DQN
|
||||||
|
assert isinstance(self.dueling, tuple) and len(self.dueling) == 2
|
||||||
|
|
||||||
|
q_layer_num, v_layer_num = self.dueling
|
||||||
|
self.Q, self.V = [], []
|
||||||
|
|
||||||
|
for i in range(q_layer_num):
|
||||||
|
self.Q += miniblock(hidden_layer_size,
|
||||||
|
hidden_layer_size, norm_layer)
|
||||||
|
for i in range(v_layer_num):
|
||||||
|
self.V += miniblock(hidden_layer_size,
|
||||||
|
hidden_layer_size, norm_layer)
|
||||||
|
|
||||||
|
if action_shape and not concat:
|
||||||
|
self.Q += [nn.Linear(hidden_layer_size, np.prod(action_shape))]
|
||||||
|
self.V += [nn.Linear(hidden_layer_size, 1)]
|
||||||
|
|
||||||
|
self.Q = nn.Sequential(*self.Q)
|
||||||
|
self.V = nn.Sequential(*self.V)
|
||||||
self.model = nn.Sequential(*self.model)
|
self.model = nn.Sequential(*self.model)
|
||||||
|
|
||||||
def forward(self, s, state=None, info={}):
|
def forward(self, s, state=None, info={}):
|
||||||
@ -38,6 +79,11 @@ class Net(nn.Module):
|
|||||||
s = to_torch(s, device=self.device, dtype=torch.float32)
|
s = to_torch(s, device=self.device, dtype=torch.float32)
|
||||||
s = s.reshape(s.size(0), -1)
|
s = s.reshape(s.size(0), -1)
|
||||||
logits = self.model(s)
|
logits = self.model(s)
|
||||||
|
if self.dueling is not None: # Dueling DQN
|
||||||
|
q, v = self.Q(logits), self.V(logits)
|
||||||
|
logits = q - q.mean(dim=1, keepdim=True) + v
|
||||||
|
if self.softmax:
|
||||||
|
logits = torch.softmax(logits, dim=-1)
|
||||||
return logits, state
|
return logits, state
|
||||||
|
|
||||||
|
|
||||||
|
@ -41,25 +41,33 @@ class Critic(nn.Module):
|
|||||||
class DQN(nn.Module):
|
class DQN(nn.Module):
|
||||||
"""For advanced usage (how to customize the network), please refer to
|
"""For advanced usage (how to customize the network), please refer to
|
||||||
:ref:`build_the_network`.
|
:ref:`build_the_network`.
|
||||||
|
|
||||||
|
Reference paper: "Human-level control through deep reinforcement learning".
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, h, w, action_shape, device='cpu'):
|
def __init__(self, h, w, action_shape, device='cpu'):
|
||||||
super(DQN, self).__init__()
|
super(DQN, self).__init__()
|
||||||
self.device = device
|
self.device = device
|
||||||
|
|
||||||
self.conv1 = nn.Conv2d(4, 16, kernel_size=5, stride=2)
|
self.conv1 = nn.Conv2d(4, 32, kernel_size=8, stride=4)
|
||||||
self.bn1 = nn.BatchNorm2d(16)
|
self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2)
|
||||||
self.conv2 = nn.Conv2d(16, 32, kernel_size=5, stride=2)
|
self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1)
|
||||||
self.bn2 = nn.BatchNorm2d(32)
|
|
||||||
self.conv3 = nn.Conv2d(32, 32, kernel_size=5, stride=2)
|
|
||||||
self.bn3 = nn.BatchNorm2d(32)
|
|
||||||
|
|
||||||
def conv2d_size_out(size, kernel_size=5, stride=2):
|
def conv2d_size_out(size, kernel_size=5, stride=2):
|
||||||
return (size - (kernel_size - 1) - 1) // stride + 1
|
return (size - (kernel_size - 1) - 1) // stride + 1
|
||||||
|
|
||||||
convw = conv2d_size_out(conv2d_size_out(conv2d_size_out(w)))
|
def conv2d_layers_size_out(size,
|
||||||
convh = conv2d_size_out(conv2d_size_out(conv2d_size_out(h)))
|
kernel_size_1=8, stride_1=4,
|
||||||
linear_input_size = convw * convh * 32
|
kernel_size_2=4, stride_2=2,
|
||||||
|
kernel_size_3=3, stride_3=1):
|
||||||
|
size = conv2d_size_out(size, kernel_size_1, stride_1)
|
||||||
|
size = conv2d_size_out(size, kernel_size_2, stride_2)
|
||||||
|
size = conv2d_size_out(size, kernel_size_3, stride_3)
|
||||||
|
return size
|
||||||
|
|
||||||
|
convw = conv2d_layers_size_out(w)
|
||||||
|
convh = conv2d_layers_size_out(h)
|
||||||
|
linear_input_size = convw * convh * 64
|
||||||
self.fc = nn.Linear(linear_input_size, 512)
|
self.fc = nn.Linear(linear_input_size, 512)
|
||||||
self.head = nn.Linear(512, action_shape)
|
self.head = nn.Linear(512, action_shape)
|
||||||
|
|
||||||
@ -68,8 +76,8 @@ class DQN(nn.Module):
|
|||||||
if not isinstance(x, torch.Tensor):
|
if not isinstance(x, torch.Tensor):
|
||||||
x = torch.tensor(x, device=self.device, dtype=torch.float32)
|
x = torch.tensor(x, device=self.device, dtype=torch.float32)
|
||||||
x = x.permute(0, 3, 1, 2)
|
x = x.permute(0, 3, 1, 2)
|
||||||
x = F.relu(self.bn1(self.conv1(x)))
|
x = F.relu(self.conv1(x))
|
||||||
x = F.relu(self.bn2(self.conv2(x)))
|
x = F.relu(self.conv2(x))
|
||||||
x = F.relu(self.bn3(self.conv3(x)))
|
x = F.relu(self.conv3(x))
|
||||||
x = self.fc(x.reshape(x.size(0), -1))
|
x = self.fc(x.reshape(x.size(0), -1))
|
||||||
return self.head(x), state
|
return self.head(x), state
|
||||||
|
Loading…
x
Reference in New Issue
Block a user