hotfix:fix test failure in cuda environment (#289)
This commit is contained in:
parent
e3ee415b1a
commit
f528131da1
@ -65,8 +65,8 @@ def test_a2c(args=get_args()):
|
|||||||
# model
|
# model
|
||||||
net = Net(args.state_shape, hidden_sizes=args.hidden_sizes,
|
net = Net(args.state_shape, hidden_sizes=args.hidden_sizes,
|
||||||
device=args.device)
|
device=args.device)
|
||||||
actor = Actor(net, args.action_shape).to(args.device)
|
actor = Actor(net, args.action_shape, device=args.device).to(args.device)
|
||||||
critic = Critic(net).to(args.device)
|
critic = Critic(net, device=args.device).to(args.device)
|
||||||
optim = torch.optim.Adam(set(
|
optim = torch.optim.Adam(set(
|
||||||
actor.parameters()).union(critic.parameters()), lr=args.lr)
|
actor.parameters()).union(critic.parameters()), lr=args.lr)
|
||||||
dist = torch.distributions.Categorical
|
dist = torch.distributions.Categorical
|
||||||
|
@ -65,8 +65,8 @@ def test_ppo(args=get_args()):
|
|||||||
# model
|
# model
|
||||||
net = Net(args.state_shape, hidden_sizes=args.hidden_sizes,
|
net = Net(args.state_shape, hidden_sizes=args.hidden_sizes,
|
||||||
device=args.device)
|
device=args.device)
|
||||||
actor = Actor(net, args.action_shape).to(args.device)
|
actor = Actor(net, args.action_shape, device=args.device).to(args.device)
|
||||||
critic = Critic(net).to(args.device)
|
critic = Critic(net, device=args.device).to(args.device)
|
||||||
optim = torch.optim.Adam(set(
|
optim = torch.optim.Adam(set(
|
||||||
actor.parameters()).union(critic.parameters()), lr=args.lr)
|
actor.parameters()).union(critic.parameters()), lr=args.lr)
|
||||||
dist = torch.distributions.Categorical
|
dist = torch.distributions.Categorical
|
||||||
|
@ -68,8 +68,8 @@ def test_a2c_with_il(args=get_args()):
|
|||||||
# model
|
# model
|
||||||
net = Net(args.state_shape, hidden_sizes=args.hidden_sizes,
|
net = Net(args.state_shape, hidden_sizes=args.hidden_sizes,
|
||||||
device=args.device)
|
device=args.device)
|
||||||
actor = Actor(net, args.action_shape).to(args.device)
|
actor = Actor(net, args.action_shape, device=args.device).to(args.device)
|
||||||
critic = Critic(net).to(args.device)
|
critic = Critic(net, device=args.device).to(args.device)
|
||||||
optim = torch.optim.Adam(set(
|
optim = torch.optim.Adam(set(
|
||||||
actor.parameters()).union(critic.parameters()), lr=args.lr)
|
actor.parameters()).union(critic.parameters()), lr=args.lr)
|
||||||
dist = torch.distributions.Categorical
|
dist = torch.distributions.Categorical
|
||||||
@ -113,7 +113,7 @@ def test_a2c_with_il(args=get_args()):
|
|||||||
env.spec.reward_threshold = 190 # lower the goal
|
env.spec.reward_threshold = 190 # lower the goal
|
||||||
net = Net(args.state_shape, hidden_sizes=args.hidden_sizes,
|
net = Net(args.state_shape, hidden_sizes=args.hidden_sizes,
|
||||||
device=args.device)
|
device=args.device)
|
||||||
net = Actor(net, args.action_shape).to(args.device)
|
net = Actor(net, args.action_shape, device=args.device).to(args.device)
|
||||||
optim = torch.optim.Adam(net.parameters(), lr=args.il_lr)
|
optim = torch.optim.Adam(net.parameters(), lr=args.il_lr)
|
||||||
il_policy = ImitationPolicy(net, optim, mode='discrete')
|
il_policy = ImitationPolicy(net, optim, mode='discrete')
|
||||||
il_test_collector = Collector(
|
il_test_collector = Collector(
|
||||||
|
@ -68,8 +68,8 @@ def test_ppo(args=get_args()):
|
|||||||
# model
|
# model
|
||||||
net = Net(args.state_shape, hidden_sizes=args.hidden_sizes,
|
net = Net(args.state_shape, hidden_sizes=args.hidden_sizes,
|
||||||
device=args.device)
|
device=args.device)
|
||||||
actor = Actor(net, args.action_shape).to(args.device)
|
actor = Actor(net, args.action_shape, device=args.device).to(args.device)
|
||||||
critic = Critic(net).to(args.device)
|
critic = Critic(net, device=args.device).to(args.device)
|
||||||
# orthogonal initialization
|
# orthogonal initialization
|
||||||
for m in list(actor.modules()) + list(critic.modules()):
|
for m in list(actor.modules()) + list(critic.modules()):
|
||||||
if isinstance(m, torch.nn.Linear):
|
if isinstance(m, torch.nn.Linear):
|
||||||
|
@ -62,15 +62,18 @@ def test_discrete_sac(args=get_args()):
|
|||||||
# model
|
# model
|
||||||
net = Net(args.state_shape, hidden_sizes=args.hidden_sizes,
|
net = Net(args.state_shape, hidden_sizes=args.hidden_sizes,
|
||||||
device=args.device)
|
device=args.device)
|
||||||
actor = Actor(net, args.action_shape, softmax_output=False).to(args.device)
|
actor = Actor(net, args.action_shape,
|
||||||
|
softmax_output=False, device=args.device).to(args.device)
|
||||||
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
|
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
|
||||||
net_c1 = Net(args.state_shape, hidden_sizes=args.hidden_sizes,
|
net_c1 = Net(args.state_shape, hidden_sizes=args.hidden_sizes,
|
||||||
device=args.device)
|
device=args.device)
|
||||||
critic1 = Critic(net_c1, last_size=args.action_shape).to(args.device)
|
critic1 = Critic(net_c1, last_size=args.action_shape,
|
||||||
|
device=args.device).to(args.device)
|
||||||
critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr)
|
critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr)
|
||||||
net_c2 = Net(args.state_shape, hidden_sizes=args.hidden_sizes,
|
net_c2 = Net(args.state_shape, hidden_sizes=args.hidden_sizes,
|
||||||
device=args.device)
|
device=args.device)
|
||||||
critic2 = Critic(net_c2, last_size=args.action_shape).to(args.device)
|
critic2 = Critic(net_c2, last_size=args.action_shape,
|
||||||
|
device=args.device).to(args.device)
|
||||||
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)
|
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)
|
||||||
|
|
||||||
# better not to use auto alpha in CartPole
|
# better not to use auto alpha in CartPole
|
||||||
|
@ -49,7 +49,8 @@ class Actor(nn.Module):
|
|||||||
self.output_dim = np.prod(action_shape)
|
self.output_dim = np.prod(action_shape)
|
||||||
input_dim = getattr(preprocess_net, "output_dim",
|
input_dim = getattr(preprocess_net, "output_dim",
|
||||||
preprocess_net_output_dim)
|
preprocess_net_output_dim)
|
||||||
self.last = MLP(input_dim, self.output_dim, hidden_sizes)
|
self.last = MLP(input_dim, self.output_dim,
|
||||||
|
hidden_sizes, device=self.device)
|
||||||
self._max = max_action
|
self._max = max_action
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@ -98,7 +99,7 @@ class Critic(nn.Module):
|
|||||||
self.output_dim = 1
|
self.output_dim = 1
|
||||||
input_dim = getattr(preprocess_net, "output_dim",
|
input_dim = getattr(preprocess_net, "output_dim",
|
||||||
preprocess_net_output_dim)
|
preprocess_net_output_dim)
|
||||||
self.last = MLP(input_dim, 1, hidden_sizes)
|
self.last = MLP(input_dim, 1, hidden_sizes, device=self.device)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -164,10 +165,12 @@ class ActorProb(nn.Module):
|
|||||||
self.output_dim = np.prod(action_shape)
|
self.output_dim = np.prod(action_shape)
|
||||||
input_dim = getattr(preprocess_net, "output_dim",
|
input_dim = getattr(preprocess_net, "output_dim",
|
||||||
preprocess_net_output_dim)
|
preprocess_net_output_dim)
|
||||||
self.mu = MLP(input_dim, self.output_dim, hidden_sizes)
|
self.mu = MLP(input_dim, self.output_dim,
|
||||||
|
hidden_sizes, device=self.device)
|
||||||
self._c_sigma = conditioned_sigma
|
self._c_sigma = conditioned_sigma
|
||||||
if conditioned_sigma:
|
if conditioned_sigma:
|
||||||
self.sigma = MLP(input_dim, self.output_dim, hidden_sizes)
|
self.sigma = MLP(input_dim, self.output_dim,
|
||||||
|
hidden_sizes, device=self.device)
|
||||||
else:
|
else:
|
||||||
self.sigma_param = nn.Parameter(torch.zeros(self.output_dim, 1))
|
self.sigma_param = nn.Parameter(torch.zeros(self.output_dim, 1))
|
||||||
self._max = max_action
|
self._max = max_action
|
||||||
|
@ -40,13 +40,16 @@ class Actor(nn.Module):
|
|||||||
hidden_sizes: Sequence[int] = (),
|
hidden_sizes: Sequence[int] = (),
|
||||||
softmax_output: bool = True,
|
softmax_output: bool = True,
|
||||||
preprocess_net_output_dim: Optional[int] = None,
|
preprocess_net_output_dim: Optional[int] = None,
|
||||||
|
device: Union[str, int, torch.device] = "cpu",
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.device = device
|
||||||
self.preprocess = preprocess_net
|
self.preprocess = preprocess_net
|
||||||
self.output_dim = np.prod(action_shape)
|
self.output_dim = np.prod(action_shape)
|
||||||
input_dim = getattr(preprocess_net, "output_dim",
|
input_dim = getattr(preprocess_net, "output_dim",
|
||||||
preprocess_net_output_dim)
|
preprocess_net_output_dim)
|
||||||
self.last = MLP(input_dim, self.output_dim, hidden_sizes)
|
self.last = MLP(input_dim, self.output_dim,
|
||||||
|
hidden_sizes, device=self.device)
|
||||||
self.softmax_output = softmax_output
|
self.softmax_output = softmax_output
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@ -91,13 +94,16 @@ class Critic(nn.Module):
|
|||||||
hidden_sizes: Sequence[int] = (),
|
hidden_sizes: Sequence[int] = (),
|
||||||
last_size: int = 1,
|
last_size: int = 1,
|
||||||
preprocess_net_output_dim: Optional[int] = None,
|
preprocess_net_output_dim: Optional[int] = None,
|
||||||
|
device: Union[str, int, torch.device] = "cpu",
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.device = device
|
||||||
self.preprocess = preprocess_net
|
self.preprocess = preprocess_net
|
||||||
self.output_dim = last_size
|
self.output_dim = last_size
|
||||||
input_dim = getattr(preprocess_net, "output_dim",
|
input_dim = getattr(preprocess_net, "output_dim",
|
||||||
preprocess_net_output_dim)
|
preprocess_net_output_dim)
|
||||||
self.last = MLP(input_dim, last_size, hidden_sizes)
|
self.last = MLP(input_dim, last_size,
|
||||||
|
hidden_sizes, device=self.device)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, s: Union[np.ndarray, torch.Tensor], **kwargs: Any
|
self, s: Union[np.ndarray, torch.Tensor], **kwargs: Any
|
||||||
|
Loading…
x
Reference in New Issue
Block a user