hotfix:fix test failure in cuda environment (#289)
This commit is contained in:
parent
e3ee415b1a
commit
f528131da1
@ -65,8 +65,8 @@ def test_a2c(args=get_args()):
|
||||
# model
|
||||
net = Net(args.state_shape, hidden_sizes=args.hidden_sizes,
|
||||
device=args.device)
|
||||
actor = Actor(net, args.action_shape).to(args.device)
|
||||
critic = Critic(net).to(args.device)
|
||||
actor = Actor(net, args.action_shape, device=args.device).to(args.device)
|
||||
critic = Critic(net, device=args.device).to(args.device)
|
||||
optim = torch.optim.Adam(set(
|
||||
actor.parameters()).union(critic.parameters()), lr=args.lr)
|
||||
dist = torch.distributions.Categorical
|
||||
|
@ -65,8 +65,8 @@ def test_ppo(args=get_args()):
|
||||
# model
|
||||
net = Net(args.state_shape, hidden_sizes=args.hidden_sizes,
|
||||
device=args.device)
|
||||
actor = Actor(net, args.action_shape).to(args.device)
|
||||
critic = Critic(net).to(args.device)
|
||||
actor = Actor(net, args.action_shape, device=args.device).to(args.device)
|
||||
critic = Critic(net, device=args.device).to(args.device)
|
||||
optim = torch.optim.Adam(set(
|
||||
actor.parameters()).union(critic.parameters()), lr=args.lr)
|
||||
dist = torch.distributions.Categorical
|
||||
|
@ -68,8 +68,8 @@ def test_a2c_with_il(args=get_args()):
|
||||
# model
|
||||
net = Net(args.state_shape, hidden_sizes=args.hidden_sizes,
|
||||
device=args.device)
|
||||
actor = Actor(net, args.action_shape).to(args.device)
|
||||
critic = Critic(net).to(args.device)
|
||||
actor = Actor(net, args.action_shape, device=args.device).to(args.device)
|
||||
critic = Critic(net, device=args.device).to(args.device)
|
||||
optim = torch.optim.Adam(set(
|
||||
actor.parameters()).union(critic.parameters()), lr=args.lr)
|
||||
dist = torch.distributions.Categorical
|
||||
@ -113,7 +113,7 @@ def test_a2c_with_il(args=get_args()):
|
||||
env.spec.reward_threshold = 190 # lower the goal
|
||||
net = Net(args.state_shape, hidden_sizes=args.hidden_sizes,
|
||||
device=args.device)
|
||||
net = Actor(net, args.action_shape).to(args.device)
|
||||
net = Actor(net, args.action_shape, device=args.device).to(args.device)
|
||||
optim = torch.optim.Adam(net.parameters(), lr=args.il_lr)
|
||||
il_policy = ImitationPolicy(net, optim, mode='discrete')
|
||||
il_test_collector = Collector(
|
||||
|
@ -68,8 +68,8 @@ def test_ppo(args=get_args()):
|
||||
# model
|
||||
net = Net(args.state_shape, hidden_sizes=args.hidden_sizes,
|
||||
device=args.device)
|
||||
actor = Actor(net, args.action_shape).to(args.device)
|
||||
critic = Critic(net).to(args.device)
|
||||
actor = Actor(net, args.action_shape, device=args.device).to(args.device)
|
||||
critic = Critic(net, device=args.device).to(args.device)
|
||||
# orthogonal initialization
|
||||
for m in list(actor.modules()) + list(critic.modules()):
|
||||
if isinstance(m, torch.nn.Linear):
|
||||
|
@ -62,15 +62,18 @@ def test_discrete_sac(args=get_args()):
|
||||
# model
|
||||
net = Net(args.state_shape, hidden_sizes=args.hidden_sizes,
|
||||
device=args.device)
|
||||
actor = Actor(net, args.action_shape, softmax_output=False).to(args.device)
|
||||
actor = Actor(net, args.action_shape,
|
||||
softmax_output=False, device=args.device).to(args.device)
|
||||
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
|
||||
net_c1 = Net(args.state_shape, hidden_sizes=args.hidden_sizes,
|
||||
device=args.device)
|
||||
critic1 = Critic(net_c1, last_size=args.action_shape).to(args.device)
|
||||
critic1 = Critic(net_c1, last_size=args.action_shape,
|
||||
device=args.device).to(args.device)
|
||||
critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr)
|
||||
net_c2 = Net(args.state_shape, hidden_sizes=args.hidden_sizes,
|
||||
device=args.device)
|
||||
critic2 = Critic(net_c2, last_size=args.action_shape).to(args.device)
|
||||
critic2 = Critic(net_c2, last_size=args.action_shape,
|
||||
device=args.device).to(args.device)
|
||||
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)
|
||||
|
||||
# better not to use auto alpha in CartPole
|
||||
|
@ -49,7 +49,8 @@ class Actor(nn.Module):
|
||||
self.output_dim = np.prod(action_shape)
|
||||
input_dim = getattr(preprocess_net, "output_dim",
|
||||
preprocess_net_output_dim)
|
||||
self.last = MLP(input_dim, self.output_dim, hidden_sizes)
|
||||
self.last = MLP(input_dim, self.output_dim,
|
||||
hidden_sizes, device=self.device)
|
||||
self._max = max_action
|
||||
|
||||
def forward(
|
||||
@ -98,7 +99,7 @@ class Critic(nn.Module):
|
||||
self.output_dim = 1
|
||||
input_dim = getattr(preprocess_net, "output_dim",
|
||||
preprocess_net_output_dim)
|
||||
self.last = MLP(input_dim, 1, hidden_sizes)
|
||||
self.last = MLP(input_dim, 1, hidden_sizes, device=self.device)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -164,10 +165,12 @@ class ActorProb(nn.Module):
|
||||
self.output_dim = np.prod(action_shape)
|
||||
input_dim = getattr(preprocess_net, "output_dim",
|
||||
preprocess_net_output_dim)
|
||||
self.mu = MLP(input_dim, self.output_dim, hidden_sizes)
|
||||
self.mu = MLP(input_dim, self.output_dim,
|
||||
hidden_sizes, device=self.device)
|
||||
self._c_sigma = conditioned_sigma
|
||||
if conditioned_sigma:
|
||||
self.sigma = MLP(input_dim, self.output_dim, hidden_sizes)
|
||||
self.sigma = MLP(input_dim, self.output_dim,
|
||||
hidden_sizes, device=self.device)
|
||||
else:
|
||||
self.sigma_param = nn.Parameter(torch.zeros(self.output_dim, 1))
|
||||
self._max = max_action
|
||||
|
@ -40,13 +40,16 @@ class Actor(nn.Module):
|
||||
hidden_sizes: Sequence[int] = (),
|
||||
softmax_output: bool = True,
|
||||
preprocess_net_output_dim: Optional[int] = None,
|
||||
device: Union[str, int, torch.device] = "cpu",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.device = device
|
||||
self.preprocess = preprocess_net
|
||||
self.output_dim = np.prod(action_shape)
|
||||
input_dim = getattr(preprocess_net, "output_dim",
|
||||
preprocess_net_output_dim)
|
||||
self.last = MLP(input_dim, self.output_dim, hidden_sizes)
|
||||
self.last = MLP(input_dim, self.output_dim,
|
||||
hidden_sizes, device=self.device)
|
||||
self.softmax_output = softmax_output
|
||||
|
||||
def forward(
|
||||
@ -91,13 +94,16 @@ class Critic(nn.Module):
|
||||
hidden_sizes: Sequence[int] = (),
|
||||
last_size: int = 1,
|
||||
preprocess_net_output_dim: Optional[int] = None,
|
||||
device: Union[str, int, torch.device] = "cpu",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.device = device
|
||||
self.preprocess = preprocess_net
|
||||
self.output_dim = last_size
|
||||
input_dim = getattr(preprocess_net, "output_dim",
|
||||
preprocess_net_output_dim)
|
||||
self.last = MLP(input_dim, last_size, hidden_sizes)
|
||||
self.last = MLP(input_dim, last_size,
|
||||
hidden_sizes, device=self.device)
|
||||
|
||||
def forward(
|
||||
self, s: Union[np.ndarray, torch.Tensor], **kwargs: Any
|
||||
|
Loading…
x
Reference in New Issue
Block a user