hotfix:fix test failure in cuda environment (#289)

This commit is contained in:
ChenDRAG 2021-02-09 17:13:40 +08:00 committed by GitHub
parent e3ee415b1a
commit f528131da1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 30 additions and 18 deletions

View File

@ -65,8 +65,8 @@ def test_a2c(args=get_args()):
# model
net = Net(args.state_shape, hidden_sizes=args.hidden_sizes,
device=args.device)
actor = Actor(net, args.action_shape).to(args.device)
critic = Critic(net).to(args.device)
actor = Actor(net, args.action_shape, device=args.device).to(args.device)
critic = Critic(net, device=args.device).to(args.device)
optim = torch.optim.Adam(set(
actor.parameters()).union(critic.parameters()), lr=args.lr)
dist = torch.distributions.Categorical

View File

@ -65,8 +65,8 @@ def test_ppo(args=get_args()):
# model
net = Net(args.state_shape, hidden_sizes=args.hidden_sizes,
device=args.device)
actor = Actor(net, args.action_shape).to(args.device)
critic = Critic(net).to(args.device)
actor = Actor(net, args.action_shape, device=args.device).to(args.device)
critic = Critic(net, device=args.device).to(args.device)
optim = torch.optim.Adam(set(
actor.parameters()).union(critic.parameters()), lr=args.lr)
dist = torch.distributions.Categorical

View File

@ -68,8 +68,8 @@ def test_a2c_with_il(args=get_args()):
# model
net = Net(args.state_shape, hidden_sizes=args.hidden_sizes,
device=args.device)
actor = Actor(net, args.action_shape).to(args.device)
critic = Critic(net).to(args.device)
actor = Actor(net, args.action_shape, device=args.device).to(args.device)
critic = Critic(net, device=args.device).to(args.device)
optim = torch.optim.Adam(set(
actor.parameters()).union(critic.parameters()), lr=args.lr)
dist = torch.distributions.Categorical
@ -113,7 +113,7 @@ def test_a2c_with_il(args=get_args()):
env.spec.reward_threshold = 190 # lower the goal
net = Net(args.state_shape, hidden_sizes=args.hidden_sizes,
device=args.device)
net = Actor(net, args.action_shape).to(args.device)
net = Actor(net, args.action_shape, device=args.device).to(args.device)
optim = torch.optim.Adam(net.parameters(), lr=args.il_lr)
il_policy = ImitationPolicy(net, optim, mode='discrete')
il_test_collector = Collector(

View File

@ -68,8 +68,8 @@ def test_ppo(args=get_args()):
# model
net = Net(args.state_shape, hidden_sizes=args.hidden_sizes,
device=args.device)
actor = Actor(net, args.action_shape).to(args.device)
critic = Critic(net).to(args.device)
actor = Actor(net, args.action_shape, device=args.device).to(args.device)
critic = Critic(net, device=args.device).to(args.device)
# orthogonal initialization
for m in list(actor.modules()) + list(critic.modules()):
if isinstance(m, torch.nn.Linear):

View File

@ -62,15 +62,18 @@ def test_discrete_sac(args=get_args()):
# model
net = Net(args.state_shape, hidden_sizes=args.hidden_sizes,
device=args.device)
actor = Actor(net, args.action_shape, softmax_output=False).to(args.device)
actor = Actor(net, args.action_shape,
softmax_output=False, device=args.device).to(args.device)
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
net_c1 = Net(args.state_shape, hidden_sizes=args.hidden_sizes,
device=args.device)
critic1 = Critic(net_c1, last_size=args.action_shape).to(args.device)
critic1 = Critic(net_c1, last_size=args.action_shape,
device=args.device).to(args.device)
critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr)
net_c2 = Net(args.state_shape, hidden_sizes=args.hidden_sizes,
device=args.device)
critic2 = Critic(net_c2, last_size=args.action_shape).to(args.device)
critic2 = Critic(net_c2, last_size=args.action_shape,
device=args.device).to(args.device)
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)
# better not to use auto alpha in CartPole

View File

@ -49,7 +49,8 @@ class Actor(nn.Module):
self.output_dim = np.prod(action_shape)
input_dim = getattr(preprocess_net, "output_dim",
preprocess_net_output_dim)
self.last = MLP(input_dim, self.output_dim, hidden_sizes)
self.last = MLP(input_dim, self.output_dim,
hidden_sizes, device=self.device)
self._max = max_action
def forward(
@ -98,7 +99,7 @@ class Critic(nn.Module):
self.output_dim = 1
input_dim = getattr(preprocess_net, "output_dim",
preprocess_net_output_dim)
self.last = MLP(input_dim, 1, hidden_sizes)
self.last = MLP(input_dim, 1, hidden_sizes, device=self.device)
def forward(
self,
@ -164,10 +165,12 @@ class ActorProb(nn.Module):
self.output_dim = np.prod(action_shape)
input_dim = getattr(preprocess_net, "output_dim",
preprocess_net_output_dim)
self.mu = MLP(input_dim, self.output_dim, hidden_sizes)
self.mu = MLP(input_dim, self.output_dim,
hidden_sizes, device=self.device)
self._c_sigma = conditioned_sigma
if conditioned_sigma:
self.sigma = MLP(input_dim, self.output_dim, hidden_sizes)
self.sigma = MLP(input_dim, self.output_dim,
hidden_sizes, device=self.device)
else:
self.sigma_param = nn.Parameter(torch.zeros(self.output_dim, 1))
self._max = max_action

View File

@ -40,13 +40,16 @@ class Actor(nn.Module):
hidden_sizes: Sequence[int] = (),
softmax_output: bool = True,
preprocess_net_output_dim: Optional[int] = None,
device: Union[str, int, torch.device] = "cpu",
) -> None:
super().__init__()
self.device = device
self.preprocess = preprocess_net
self.output_dim = np.prod(action_shape)
input_dim = getattr(preprocess_net, "output_dim",
preprocess_net_output_dim)
self.last = MLP(input_dim, self.output_dim, hidden_sizes)
self.last = MLP(input_dim, self.output_dim,
hidden_sizes, device=self.device)
self.softmax_output = softmax_output
def forward(
@ -91,13 +94,16 @@ class Critic(nn.Module):
hidden_sizes: Sequence[int] = (),
last_size: int = 1,
preprocess_net_output_dim: Optional[int] = None,
device: Union[str, int, torch.device] = "cpu",
) -> None:
super().__init__()
self.device = device
self.preprocess = preprocess_net
self.output_dim = last_size
input_dim = getattr(preprocess_net, "output_dim",
preprocess_net_output_dim)
self.last = MLP(input_dim, last_size, hidden_sizes)
self.last = MLP(input_dim, last_size,
hidden_sizes, device=self.device)
def forward(
self, s: Union[np.ndarray, torch.Tensor], **kwargs: Any