fix docs build failure and a bug in a2c/ppo optimizer (#428)
* fix rtfd build * list + list -> set.union * change seed of test_qrdqn * add py39 test
This commit is contained in:
parent
291be08d43
commit
e4f4f0e144
2
.github/workflows/pytest.yml
vendored
2
.github/workflows/pytest.yml
vendored
@ -8,7 +8,7 @@ jobs:
|
||||
if: "!contains(github.event.head_commit.message, 'ci skip')"
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: [3.6, 3.7, 3.8]
|
||||
python-version: [3.6, 3.7, 3.8, 3.9]
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
|
@ -1,7 +1,8 @@
|
||||
gym
|
||||
tqdm
|
||||
torch
|
||||
numba
|
||||
tensorboard
|
||||
numpy>=1.20
|
||||
sphinx<4
|
||||
sphinxcontrib-bibtex
|
||||
tensorboard
|
||||
torch
|
||||
tqdm
|
||||
|
@ -81,12 +81,12 @@ def test_ppo(args=get_args()):
|
||||
args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device
|
||||
), device=args.device).to(args.device)
|
||||
# orthogonal initialization
|
||||
for m in list(actor.modules()) + list(critic.modules()):
|
||||
for m in set(actor.modules()).union(critic.modules()):
|
||||
if isinstance(m, torch.nn.Linear):
|
||||
torch.nn.init.orthogonal_(m.weight)
|
||||
torch.nn.init.zeros_(m.bias)
|
||||
optim = torch.optim.Adam(
|
||||
list(actor.parameters()) + list(critic.parameters()), lr=args.lr)
|
||||
set(actor.parameters()).union(critic.parameters()), lr=args.lr)
|
||||
|
||||
# replace DiagGuassian with Independent(Normal) which is equivalent
|
||||
# pass *logits to be consistent with policy.forward
|
||||
|
@ -75,7 +75,7 @@ def test_a2c_with_il(args=get_args()):
|
||||
actor = Actor(net, args.action_shape, device=args.device).to(args.device)
|
||||
critic = Critic(net, device=args.device).to(args.device)
|
||||
optim = torch.optim.Adam(
|
||||
list(actor.parameters()) + list(critic.parameters()), lr=args.lr)
|
||||
set(actor.parameters()).union(critic.parameters()), lr=args.lr)
|
||||
dist = torch.distributions.Categorical
|
||||
policy = A2CPolicy(
|
||||
actor, critic, optim, dist,
|
||||
|
@ -72,12 +72,12 @@ def test_ppo(args=get_args()):
|
||||
actor = Actor(net, args.action_shape, device=args.device).to(args.device)
|
||||
critic = Critic(net, device=args.device).to(args.device)
|
||||
# orthogonal initialization
|
||||
for m in list(actor.modules()) + list(critic.modules()):
|
||||
for m in set(actor.modules()).union(critic.modules()):
|
||||
if isinstance(m, torch.nn.Linear):
|
||||
torch.nn.init.orthogonal_(m.weight)
|
||||
torch.nn.init.zeros_(m.bias)
|
||||
optim = torch.optim.Adam(
|
||||
list(actor.parameters()) + list(critic.parameters()), lr=args.lr)
|
||||
set(actor.parameters()).union(critic.parameters()), lr=args.lr)
|
||||
dist = torch.distributions.Categorical
|
||||
policy = PPOPolicy(
|
||||
actor, critic, optim, dist,
|
||||
|
@ -18,7 +18,7 @@ from tianshou.data import Collector, VectorReplayBuffer, PrioritizedVectorReplay
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--task', type=str, default='CartPole-v0')
|
||||
parser.add_argument('--seed', type=int, default=0)
|
||||
parser.add_argument('--seed', type=int, default=1)
|
||||
parser.add_argument('--eps-test', type=float, default=0.05)
|
||||
parser.add_argument('--eps-train', type=float, default=0.1)
|
||||
parser.add_argument('--buffer-size', type=int, default=20000)
|
||||
|
@ -129,7 +129,7 @@ class A2CPolicy(PGPolicy):
|
||||
loss.backward()
|
||||
if self._grad_norm: # clip large gradient
|
||||
nn.utils.clip_grad_norm_(
|
||||
list(self.actor.parameters()) + list(self.critic.parameters()),
|
||||
set(self.actor.parameters()).union(self.critic.parameters()),
|
||||
max_norm=self._grad_norm)
|
||||
self.optim.step()
|
||||
actor_losses.append(actor_loss.item())
|
||||
|
@ -139,7 +139,7 @@ class PPOPolicy(A2CPolicy):
|
||||
loss.backward()
|
||||
if self._grad_norm: # clip large gradient
|
||||
nn.utils.clip_grad_norm_(
|
||||
list(self.actor.parameters()) + list(self.critic.parameters()),
|
||||
set(self.actor.parameters()).union(self.critic.parameters()),
|
||||
max_norm=self._grad_norm)
|
||||
self.optim.step()
|
||||
clip_losses.append(clip_loss.item())
|
||||
|
Loading…
x
Reference in New Issue
Block a user