fix docs build failure and a bug in a2c/ppo optimizer (#428)
* fix rtfd build * list + list -> set.union * change seed of test_qrdqn * add py39 test
This commit is contained in:
parent
291be08d43
commit
e4f4f0e144
2
.github/workflows/pytest.yml
vendored
2
.github/workflows/pytest.yml
vendored
@ -8,7 +8,7 @@ jobs:
|
|||||||
if: "!contains(github.event.head_commit.message, 'ci skip')"
|
if: "!contains(github.event.head_commit.message, 'ci skip')"
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
python-version: [3.6, 3.7, 3.8]
|
python-version: [3.6, 3.7, 3.8, 3.9]
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v2
|
- uses: actions/checkout@v2
|
||||||
- name: Set up Python ${{ matrix.python-version }}
|
- name: Set up Python ${{ matrix.python-version }}
|
||||||
|
@ -1,7 +1,8 @@
|
|||||||
gym
|
gym
|
||||||
tqdm
|
|
||||||
torch
|
|
||||||
numba
|
numba
|
||||||
tensorboard
|
numpy>=1.20
|
||||||
sphinx<4
|
sphinx<4
|
||||||
sphinxcontrib-bibtex
|
sphinxcontrib-bibtex
|
||||||
|
tensorboard
|
||||||
|
torch
|
||||||
|
tqdm
|
||||||
|
@ -81,12 +81,12 @@ def test_ppo(args=get_args()):
|
|||||||
args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device
|
args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device
|
||||||
), device=args.device).to(args.device)
|
), device=args.device).to(args.device)
|
||||||
# orthogonal initialization
|
# orthogonal initialization
|
||||||
for m in list(actor.modules()) + list(critic.modules()):
|
for m in set(actor.modules()).union(critic.modules()):
|
||||||
if isinstance(m, torch.nn.Linear):
|
if isinstance(m, torch.nn.Linear):
|
||||||
torch.nn.init.orthogonal_(m.weight)
|
torch.nn.init.orthogonal_(m.weight)
|
||||||
torch.nn.init.zeros_(m.bias)
|
torch.nn.init.zeros_(m.bias)
|
||||||
optim = torch.optim.Adam(
|
optim = torch.optim.Adam(
|
||||||
list(actor.parameters()) + list(critic.parameters()), lr=args.lr)
|
set(actor.parameters()).union(critic.parameters()), lr=args.lr)
|
||||||
|
|
||||||
# replace DiagGuassian with Independent(Normal) which is equivalent
|
# replace DiagGuassian with Independent(Normal) which is equivalent
|
||||||
# pass *logits to be consistent with policy.forward
|
# pass *logits to be consistent with policy.forward
|
||||||
|
@ -75,7 +75,7 @@ def test_a2c_with_il(args=get_args()):
|
|||||||
actor = Actor(net, args.action_shape, device=args.device).to(args.device)
|
actor = Actor(net, args.action_shape, device=args.device).to(args.device)
|
||||||
critic = Critic(net, device=args.device).to(args.device)
|
critic = Critic(net, device=args.device).to(args.device)
|
||||||
optim = torch.optim.Adam(
|
optim = torch.optim.Adam(
|
||||||
list(actor.parameters()) + list(critic.parameters()), lr=args.lr)
|
set(actor.parameters()).union(critic.parameters()), lr=args.lr)
|
||||||
dist = torch.distributions.Categorical
|
dist = torch.distributions.Categorical
|
||||||
policy = A2CPolicy(
|
policy = A2CPolicy(
|
||||||
actor, critic, optim, dist,
|
actor, critic, optim, dist,
|
||||||
|
@ -72,12 +72,12 @@ def test_ppo(args=get_args()):
|
|||||||
actor = Actor(net, args.action_shape, device=args.device).to(args.device)
|
actor = Actor(net, args.action_shape, device=args.device).to(args.device)
|
||||||
critic = Critic(net, device=args.device).to(args.device)
|
critic = Critic(net, device=args.device).to(args.device)
|
||||||
# orthogonal initialization
|
# orthogonal initialization
|
||||||
for m in list(actor.modules()) + list(critic.modules()):
|
for m in set(actor.modules()).union(critic.modules()):
|
||||||
if isinstance(m, torch.nn.Linear):
|
if isinstance(m, torch.nn.Linear):
|
||||||
torch.nn.init.orthogonal_(m.weight)
|
torch.nn.init.orthogonal_(m.weight)
|
||||||
torch.nn.init.zeros_(m.bias)
|
torch.nn.init.zeros_(m.bias)
|
||||||
optim = torch.optim.Adam(
|
optim = torch.optim.Adam(
|
||||||
list(actor.parameters()) + list(critic.parameters()), lr=args.lr)
|
set(actor.parameters()).union(critic.parameters()), lr=args.lr)
|
||||||
dist = torch.distributions.Categorical
|
dist = torch.distributions.Categorical
|
||||||
policy = PPOPolicy(
|
policy = PPOPolicy(
|
||||||
actor, critic, optim, dist,
|
actor, critic, optim, dist,
|
||||||
|
@ -18,7 +18,7 @@ from tianshou.data import Collector, VectorReplayBuffer, PrioritizedVectorReplay
|
|||||||
def get_args():
|
def get_args():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('--task', type=str, default='CartPole-v0')
|
parser.add_argument('--task', type=str, default='CartPole-v0')
|
||||||
parser.add_argument('--seed', type=int, default=0)
|
parser.add_argument('--seed', type=int, default=1)
|
||||||
parser.add_argument('--eps-test', type=float, default=0.05)
|
parser.add_argument('--eps-test', type=float, default=0.05)
|
||||||
parser.add_argument('--eps-train', type=float, default=0.1)
|
parser.add_argument('--eps-train', type=float, default=0.1)
|
||||||
parser.add_argument('--buffer-size', type=int, default=20000)
|
parser.add_argument('--buffer-size', type=int, default=20000)
|
||||||
|
@ -129,7 +129,7 @@ class A2CPolicy(PGPolicy):
|
|||||||
loss.backward()
|
loss.backward()
|
||||||
if self._grad_norm: # clip large gradient
|
if self._grad_norm: # clip large gradient
|
||||||
nn.utils.clip_grad_norm_(
|
nn.utils.clip_grad_norm_(
|
||||||
list(self.actor.parameters()) + list(self.critic.parameters()),
|
set(self.actor.parameters()).union(self.critic.parameters()),
|
||||||
max_norm=self._grad_norm)
|
max_norm=self._grad_norm)
|
||||||
self.optim.step()
|
self.optim.step()
|
||||||
actor_losses.append(actor_loss.item())
|
actor_losses.append(actor_loss.item())
|
||||||
|
@ -139,7 +139,7 @@ class PPOPolicy(A2CPolicy):
|
|||||||
loss.backward()
|
loss.backward()
|
||||||
if self._grad_norm: # clip large gradient
|
if self._grad_norm: # clip large gradient
|
||||||
nn.utils.clip_grad_norm_(
|
nn.utils.clip_grad_norm_(
|
||||||
list(self.actor.parameters()) + list(self.critic.parameters()),
|
set(self.actor.parameters()).union(self.critic.parameters()),
|
||||||
max_norm=self._grad_norm)
|
max_norm=self._grad_norm)
|
||||||
self.optim.step()
|
self.optim.step()
|
||||||
clip_losses.append(clip_loss.item())
|
clip_losses.append(clip_loss.item())
|
||||||
|
Loading…
x
Reference in New Issue
Block a user