refract test code
This commit is contained in:
parent
d64d78d769
commit
8bd8246b16
0
test/base/__init__.py
Normal file
0
test/base/__init__.py
Normal file
30
test/base/env.py
Normal file
30
test/base/env.py
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
import gym
|
||||||
|
import time
|
||||||
|
|
||||||
|
|
||||||
|
class MyTestEnv(gym.Env):
|
||||||
|
def __init__(self, size, sleep=0):
|
||||||
|
self.size = size
|
||||||
|
self.sleep = sleep
|
||||||
|
self.reset()
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.done = False
|
||||||
|
self.index = 0
|
||||||
|
return self.index
|
||||||
|
|
||||||
|
def step(self, action):
|
||||||
|
if self.done:
|
||||||
|
raise ValueError('step after done !!!')
|
||||||
|
if self.sleep > 0:
|
||||||
|
time.sleep(self.sleep)
|
||||||
|
if self.index == self.size:
|
||||||
|
self.done = True
|
||||||
|
return self.index, 0, True, {}
|
||||||
|
if action == 0:
|
||||||
|
self.index = max(self.index - 1, 0)
|
||||||
|
return self.index, 0, False, {}
|
||||||
|
elif action == 1:
|
||||||
|
self.index += 1
|
||||||
|
self.done = self.index == self.size
|
||||||
|
return self.index, int(self.done), self.done, {}
|
@ -1,8 +1,8 @@
|
|||||||
from tianshou.data import ReplayBuffer
|
from tianshou.data import ReplayBuffer
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
from test_env import MyTestEnv
|
from env import MyTestEnv
|
||||||
else: # pytest
|
else: # pytest
|
||||||
from test.test_env import MyTestEnv
|
from test.base.env import MyTestEnv
|
||||||
|
|
||||||
|
|
||||||
def test_replaybuffer(size=10, bufsize=20):
|
def test_replaybuffer(size=10, bufsize=20):
|
@ -1,36 +1,12 @@
|
|||||||
import gym
|
|
||||||
import time
|
import time
|
||||||
import pytest
|
import pytest
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from tianshou.env import FrameStack, VectorEnv, SubprocVectorEnv, RayVectorEnv
|
from tianshou.env import FrameStack, VectorEnv, SubprocVectorEnv, RayVectorEnv
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
class MyTestEnv(gym.Env):
|
from env import MyTestEnv
|
||||||
def __init__(self, size, sleep=0):
|
else: # pytest
|
||||||
self.size = size
|
from test.base.env import MyTestEnv
|
||||||
self.sleep = sleep
|
|
||||||
self.reset()
|
|
||||||
|
|
||||||
def reset(self):
|
|
||||||
self.done = False
|
|
||||||
self.index = 0
|
|
||||||
return self.index
|
|
||||||
|
|
||||||
def step(self, action):
|
|
||||||
if self.done:
|
|
||||||
raise ValueError('step after done !!!')
|
|
||||||
if self.sleep > 0:
|
|
||||||
time.sleep(self.sleep)
|
|
||||||
if self.index == self.size:
|
|
||||||
self.done = True
|
|
||||||
return self.index, 0, True, {}
|
|
||||||
if action == 0:
|
|
||||||
self.index = max(self.index - 1, 0)
|
|
||||||
return self.index, 0, False, {}
|
|
||||||
elif action == 1:
|
|
||||||
self.index += 1
|
|
||||||
self.done = self.index == self.size
|
|
||||||
return self.index, int(self.done), self.done, {}
|
|
||||||
|
|
||||||
|
|
||||||
def test_framestack(k=4, size=10):
|
def test_framestack(k=4, size=10):
|
0
test/continuous/__init__.py
Normal file
0
test/continuous/__init__.py
Normal file
49
test/continuous/net.py
Normal file
49
test/continuous/net.py
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
|
||||||
|
class Actor(nn.Module):
|
||||||
|
def __init__(self, layer_num, state_shape, action_shape,
|
||||||
|
max_action, device='cpu'):
|
||||||
|
super().__init__()
|
||||||
|
self.device = device
|
||||||
|
self.model = [
|
||||||
|
nn.Linear(np.prod(state_shape), 128),
|
||||||
|
nn.ReLU(inplace=True)]
|
||||||
|
for i in range(layer_num):
|
||||||
|
self.model += [nn.Linear(128, 128), nn.ReLU(inplace=True)]
|
||||||
|
self.model += [nn.Linear(128, np.prod(action_shape))]
|
||||||
|
self.model = nn.Sequential(*self.model)
|
||||||
|
self._max = max_action
|
||||||
|
|
||||||
|
def forward(self, s, **kwargs):
|
||||||
|
s = torch.tensor(s, device=self.device, dtype=torch.float)
|
||||||
|
batch = s.shape[0]
|
||||||
|
s = s.view(batch, -1)
|
||||||
|
logits = self.model(s)
|
||||||
|
logits = self._max * torch.tanh(logits)
|
||||||
|
return logits, None
|
||||||
|
|
||||||
|
|
||||||
|
class Critic(nn.Module):
|
||||||
|
def __init__(self, layer_num, state_shape, action_shape, device='cpu'):
|
||||||
|
super().__init__()
|
||||||
|
self.device = device
|
||||||
|
self.model = [
|
||||||
|
nn.Linear(np.prod(state_shape) + np.prod(action_shape), 128),
|
||||||
|
nn.ReLU(inplace=True)]
|
||||||
|
for i in range(layer_num):
|
||||||
|
self.model += [nn.Linear(128, 128), nn.ReLU(inplace=True)]
|
||||||
|
self.model += [nn.Linear(128, 1)]
|
||||||
|
self.model = nn.Sequential(*self.model)
|
||||||
|
|
||||||
|
def forward(self, s, a):
|
||||||
|
s = torch.tensor(s, device=self.device, dtype=torch.float)
|
||||||
|
if isinstance(a, np.ndarray):
|
||||||
|
a = torch.tensor(a, device=self.device, dtype=torch.float)
|
||||||
|
batch = s.shape[0]
|
||||||
|
s = s.view(batch, -1)
|
||||||
|
a = a.view(batch, -1)
|
||||||
|
logits = self.model(torch.cat([s, a], dim=1))
|
||||||
|
return logits
|
@ -3,7 +3,6 @@ import torch
|
|||||||
import pprint
|
import pprint
|
||||||
import argparse
|
import argparse
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from torch import nn
|
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
from tianshou.policy import DDPGPolicy
|
from tianshou.policy import DDPGPolicy
|
||||||
@ -11,51 +10,10 @@ from tianshou.trainer import offpolicy_trainer
|
|||||||
from tianshou.data import Collector, ReplayBuffer
|
from tianshou.data import Collector, ReplayBuffer
|
||||||
from tianshou.env import VectorEnv, SubprocVectorEnv
|
from tianshou.env import VectorEnv, SubprocVectorEnv
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
class Actor(nn.Module):
|
from net import Actor, Critic
|
||||||
def __init__(self, layer_num, state_shape, action_shape,
|
else: # pytest
|
||||||
max_action, device='cpu'):
|
from test.continuous.net import Actor, Critic
|
||||||
super().__init__()
|
|
||||||
self.device = device
|
|
||||||
self.model = [
|
|
||||||
nn.Linear(np.prod(state_shape), 128),
|
|
||||||
nn.ReLU(inplace=True)]
|
|
||||||
for i in range(layer_num):
|
|
||||||
self.model += [nn.Linear(128, 128), nn.ReLU(inplace=True)]
|
|
||||||
self.model += [nn.Linear(128, np.prod(action_shape))]
|
|
||||||
self.model = nn.Sequential(*self.model)
|
|
||||||
self._max = max_action
|
|
||||||
|
|
||||||
def forward(self, s, **kwargs):
|
|
||||||
s = torch.tensor(s, device=self.device, dtype=torch.float)
|
|
||||||
batch = s.shape[0]
|
|
||||||
s = s.view(batch, -1)
|
|
||||||
logits = self.model(s)
|
|
||||||
logits = self._max * torch.tanh(logits)
|
|
||||||
return logits, None
|
|
||||||
|
|
||||||
|
|
||||||
class Critic(nn.Module):
|
|
||||||
def __init__(self, layer_num, state_shape, action_shape, device='cpu'):
|
|
||||||
super().__init__()
|
|
||||||
self.device = device
|
|
||||||
self.model = [
|
|
||||||
nn.Linear(np.prod(state_shape) + np.prod(action_shape), 128),
|
|
||||||
nn.ReLU(inplace=True)]
|
|
||||||
for i in range(layer_num):
|
|
||||||
self.model += [nn.Linear(128, 128), nn.ReLU(inplace=True)]
|
|
||||||
self.model += [nn.Linear(128, 1)]
|
|
||||||
self.model = nn.Sequential(*self.model)
|
|
||||||
|
|
||||||
def forward(self, s, a):
|
|
||||||
s = torch.tensor(s, device=self.device, dtype=torch.float)
|
|
||||||
if isinstance(a, np.ndarray):
|
|
||||||
a = torch.tensor(a, device=self.device, dtype=torch.float)
|
|
||||||
batch = s.shape[0]
|
|
||||||
s = s.view(batch, -1)
|
|
||||||
a = a.view(batch, -1)
|
|
||||||
logits = self.model(torch.cat([s, a], dim=1))
|
|
||||||
return logits
|
|
||||||
|
|
||||||
|
|
||||||
def get_args():
|
def get_args():
|
0
test/discrete/__init__.py
Normal file
0
test/discrete/__init__.py
Normal file
49
test/discrete/net.py
Normal file
49
test/discrete/net.py
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
from torch import nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
class Net(nn.Module):
|
||||||
|
def __init__(self, layer_num, state_shape, action_shape=0, device='cpu'):
|
||||||
|
super().__init__()
|
||||||
|
self.device = device
|
||||||
|
self.model = [
|
||||||
|
nn.Linear(np.prod(state_shape), 128),
|
||||||
|
nn.ReLU(inplace=True)]
|
||||||
|
for i in range(layer_num):
|
||||||
|
self.model += [nn.Linear(128, 128), nn.ReLU(inplace=True)]
|
||||||
|
if action_shape:
|
||||||
|
self.model += [nn.Linear(128, np.prod(action_shape))]
|
||||||
|
self.model = nn.Sequential(*self.model)
|
||||||
|
|
||||||
|
def forward(self, s, state=None, info={}):
|
||||||
|
s = torch.tensor(s, device=self.device, dtype=torch.float)
|
||||||
|
batch = s.shape[0]
|
||||||
|
s = s.view(batch, -1)
|
||||||
|
logits = self.model(s)
|
||||||
|
return logits, state
|
||||||
|
|
||||||
|
|
||||||
|
class Actor(nn.Module):
|
||||||
|
def __init__(self, preprocess_net, action_shape):
|
||||||
|
super().__init__()
|
||||||
|
self.preprocess = preprocess_net
|
||||||
|
self.last = nn.Linear(128, np.prod(action_shape))
|
||||||
|
|
||||||
|
def forward(self, s, state=None, info={}):
|
||||||
|
logits, h = self.preprocess(s, state)
|
||||||
|
logits = F.softmax(self.last(logits), dim=-1)
|
||||||
|
return logits, h
|
||||||
|
|
||||||
|
|
||||||
|
class Critic(nn.Module):
|
||||||
|
def __init__(self, preprocess_net):
|
||||||
|
super().__init__()
|
||||||
|
self.preprocess = preprocess_net
|
||||||
|
self.last = nn.Linear(128, 1)
|
||||||
|
|
||||||
|
def forward(self, s):
|
||||||
|
logits, h = self.preprocess(s, None)
|
||||||
|
logits = self.last(logits)
|
||||||
|
return logits
|
@ -3,8 +3,6 @@ import torch
|
|||||||
import pprint
|
import pprint
|
||||||
import argparse
|
import argparse
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from torch import nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
from tianshou.policy import A2CPolicy
|
from tianshou.policy import A2CPolicy
|
||||||
@ -12,50 +10,10 @@ from tianshou.env import SubprocVectorEnv
|
|||||||
from tianshou.trainer import onpolicy_trainer
|
from tianshou.trainer import onpolicy_trainer
|
||||||
from tianshou.data import Collector, ReplayBuffer
|
from tianshou.data import Collector, ReplayBuffer
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
class Net(nn.Module):
|
from net import Net, Actor, Critic
|
||||||
def __init__(self, layer_num, state_shape, device='cpu'):
|
else: # pytest
|
||||||
super().__init__()
|
from test.discrete.net import Net, Actor, Critic
|
||||||
self.device = device
|
|
||||||
self.model = [
|
|
||||||
nn.Linear(np.prod(state_shape), 128),
|
|
||||||
nn.ReLU(inplace=True)]
|
|
||||||
for i in range(layer_num):
|
|
||||||
self.model += [nn.Linear(128, 128), nn.ReLU(inplace=True)]
|
|
||||||
self.model = nn.Sequential(*self.model)
|
|
||||||
|
|
||||||
def forward(self, s):
|
|
||||||
s = torch.tensor(s, device=self.device, dtype=torch.float)
|
|
||||||
batch = s.shape[0]
|
|
||||||
s = s.view(batch, -1)
|
|
||||||
logits = self.model(s)
|
|
||||||
return logits
|
|
||||||
|
|
||||||
|
|
||||||
class Actor(nn.Module):
|
|
||||||
def __init__(self, preprocess_net, action_shape):
|
|
||||||
super().__init__()
|
|
||||||
self.model = nn.Sequential(*[
|
|
||||||
preprocess_net,
|
|
||||||
nn.Linear(128, np.prod(action_shape)),
|
|
||||||
])
|
|
||||||
|
|
||||||
def forward(self, s, **kwargs):
|
|
||||||
logits = F.softmax(self.model(s), dim=-1)
|
|
||||||
return logits, None
|
|
||||||
|
|
||||||
|
|
||||||
class Critic(nn.Module):
|
|
||||||
def __init__(self, preprocess_net):
|
|
||||||
super().__init__()
|
|
||||||
self.model = nn.Sequential(*[
|
|
||||||
preprocess_net,
|
|
||||||
nn.Linear(128, 1),
|
|
||||||
])
|
|
||||||
|
|
||||||
def forward(self, s):
|
|
||||||
logits = self.model(s)
|
|
||||||
return logits
|
|
||||||
|
|
||||||
|
|
||||||
def get_args():
|
def get_args():
|
||||||
@ -103,7 +61,7 @@ def test_a2c(args=get_args()):
|
|||||||
train_envs.seed(args.seed)
|
train_envs.seed(args.seed)
|
||||||
test_envs.seed(args.seed)
|
test_envs.seed(args.seed)
|
||||||
# model
|
# model
|
||||||
net = Net(args.layer_num, args.state_shape, args.device)
|
net = Net(args.layer_num, args.state_shape, device=args.device)
|
||||||
actor = Actor(net, args.action_shape).to(args.device)
|
actor = Actor(net, args.action_shape).to(args.device)
|
||||||
critic = Critic(net).to(args.device)
|
critic = Critic(net).to(args.device)
|
||||||
optim = torch.optim.Adam(list(
|
optim = torch.optim.Adam(list(
|
@ -3,7 +3,6 @@ import torch
|
|||||||
import pprint
|
import pprint
|
||||||
import argparse
|
import argparse
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from torch import nn
|
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
from tianshou.policy import DQNPolicy
|
from tianshou.policy import DQNPolicy
|
||||||
@ -11,24 +10,10 @@ from tianshou.env import SubprocVectorEnv
|
|||||||
from tianshou.trainer import offpolicy_trainer
|
from tianshou.trainer import offpolicy_trainer
|
||||||
from tianshou.data import Collector, ReplayBuffer
|
from tianshou.data import Collector, ReplayBuffer
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
class Net(nn.Module):
|
from net import Net
|
||||||
def __init__(self, layer_num, state_shape, action_shape, device='cpu'):
|
else: # pytest
|
||||||
super().__init__()
|
from test.discrete.net import Net
|
||||||
self.device = device
|
|
||||||
self.model = [
|
|
||||||
nn.Linear(np.prod(state_shape), 128),
|
|
||||||
nn.ReLU(inplace=True)]
|
|
||||||
for i in range(layer_num):
|
|
||||||
self.model += [nn.Linear(128, 128), nn.ReLU(inplace=True)]
|
|
||||||
self.model += [nn.Linear(128, np.prod(action_shape))]
|
|
||||||
self.model = nn.Sequential(*self.model)
|
|
||||||
|
|
||||||
def forward(self, s, **kwargs):
|
|
||||||
s = torch.tensor(s, device=self.device, dtype=torch.float)
|
|
||||||
batch = s.shape[0]
|
|
||||||
q = self.model(s.view(batch, -1))
|
|
||||||
return q, None
|
|
||||||
|
|
||||||
|
|
||||||
def get_args():
|
def get_args():
|
@ -4,7 +4,6 @@ import torch
|
|||||||
import pprint
|
import pprint
|
||||||
import argparse
|
import argparse
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from torch import nn
|
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
from tianshou.policy import PGPolicy
|
from tianshou.policy import PGPolicy
|
||||||
@ -12,6 +11,11 @@ from tianshou.env import SubprocVectorEnv
|
|||||||
from tianshou.trainer import onpolicy_trainer
|
from tianshou.trainer import onpolicy_trainer
|
||||||
from tianshou.data import Batch, Collector, ReplayBuffer
|
from tianshou.data import Batch, Collector, ReplayBuffer
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
from net import Net
|
||||||
|
else: # pytest
|
||||||
|
from test.discrete.net import Net
|
||||||
|
|
||||||
|
|
||||||
def compute_return_base(batch, aa=None, bb=None, gamma=0.1):
|
def compute_return_base(batch, aa=None, bb=None, gamma=0.1):
|
||||||
returns = np.zeros_like(batch.rew)
|
returns = np.zeros_like(batch.rew)
|
||||||
@ -66,25 +70,6 @@ def test_fn(size=2560):
|
|||||||
print(f'policy: {(time.time() - t) / cnt}')
|
print(f'policy: {(time.time() - t) / cnt}')
|
||||||
|
|
||||||
|
|
||||||
class Net(nn.Module):
|
|
||||||
def __init__(self, layer_num, state_shape, action_shape, device='cpu'):
|
|
||||||
super().__init__()
|
|
||||||
self.device = device
|
|
||||||
self.model = [
|
|
||||||
nn.Linear(np.prod(state_shape), 128),
|
|
||||||
nn.ReLU(inplace=True)]
|
|
||||||
for i in range(layer_num):
|
|
||||||
self.model += [nn.Linear(128, 128), nn.ReLU(inplace=True)]
|
|
||||||
self.model += [nn.Linear(128, np.prod(action_shape))]
|
|
||||||
self.model = nn.Sequential(*self.model)
|
|
||||||
|
|
||||||
def forward(self, s, **kwargs):
|
|
||||||
s = torch.tensor(s, device=self.device, dtype=torch.float)
|
|
||||||
batch = s.shape[0]
|
|
||||||
logits = self.model(s.view(batch, -1))
|
|
||||||
return logits, None
|
|
||||||
|
|
||||||
|
|
||||||
def get_args():
|
def get_args():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('--task', type=str, default='CartPole-v0')
|
parser.add_argument('--task', type=str, default='CartPole-v0')
|
||||||
@ -126,7 +111,9 @@ def test_pg(args=get_args()):
|
|||||||
train_envs.seed(args.seed)
|
train_envs.seed(args.seed)
|
||||||
test_envs.seed(args.seed)
|
test_envs.seed(args.seed)
|
||||||
# model
|
# model
|
||||||
net = Net(args.layer_num, args.state_shape, args.action_shape, args.device)
|
net = Net(
|
||||||
|
args.layer_num, args.state_shape, args.action_shape,
|
||||||
|
device=args.device)
|
||||||
net = net.to(args.device)
|
net = net.to(args.device)
|
||||||
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
|
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
|
||||||
dist = torch.distributions.Categorical
|
dist = torch.distributions.Categorical
|
@ -3,8 +3,6 @@ import torch
|
|||||||
import pprint
|
import pprint
|
||||||
import argparse
|
import argparse
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from torch import nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
from tianshou.policy import PPOPolicy
|
from tianshou.policy import PPOPolicy
|
||||||
@ -12,50 +10,10 @@ from tianshou.env import SubprocVectorEnv
|
|||||||
from tianshou.trainer import onpolicy_trainer
|
from tianshou.trainer import onpolicy_trainer
|
||||||
from tianshou.data import Collector, ReplayBuffer
|
from tianshou.data import Collector, ReplayBuffer
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
class Net(nn.Module):
|
from net import Net, Actor, Critic
|
||||||
def __init__(self, layer_num, state_shape, device='cpu'):
|
else: # pytest
|
||||||
super().__init__()
|
from test.discrete.net import Net, Actor, Critic
|
||||||
self.device = device
|
|
||||||
self.model = [
|
|
||||||
nn.Linear(np.prod(state_shape), 128),
|
|
||||||
nn.ReLU(inplace=True)]
|
|
||||||
for i in range(layer_num):
|
|
||||||
self.model += [nn.Linear(128, 128), nn.ReLU(inplace=True)]
|
|
||||||
self.model = nn.Sequential(*self.model)
|
|
||||||
|
|
||||||
def forward(self, s):
|
|
||||||
s = torch.tensor(s, device=self.device, dtype=torch.float)
|
|
||||||
batch = s.shape[0]
|
|
||||||
s = s.view(batch, -1)
|
|
||||||
logits = self.model(s)
|
|
||||||
return logits
|
|
||||||
|
|
||||||
|
|
||||||
class Actor(nn.Module):
|
|
||||||
def __init__(self, preprocess_net, action_shape):
|
|
||||||
super().__init__()
|
|
||||||
self.model = nn.Sequential(*[
|
|
||||||
preprocess_net,
|
|
||||||
nn.Linear(128, np.prod(action_shape)),
|
|
||||||
])
|
|
||||||
|
|
||||||
def forward(self, s, **kwargs):
|
|
||||||
logits = F.softmax(self.model(s), dim=-1)
|
|
||||||
return logits, None
|
|
||||||
|
|
||||||
|
|
||||||
class Critic(nn.Module):
|
|
||||||
def __init__(self, preprocess_net):
|
|
||||||
super().__init__()
|
|
||||||
self.model = nn.Sequential(*[
|
|
||||||
preprocess_net,
|
|
||||||
nn.Linear(128, 1),
|
|
||||||
])
|
|
||||||
|
|
||||||
def forward(self, s):
|
|
||||||
logits = self.model(s)
|
|
||||||
return logits
|
|
||||||
|
|
||||||
|
|
||||||
def get_args():
|
def get_args():
|
||||||
@ -104,7 +62,7 @@ def test_ppo(args=get_args()):
|
|||||||
train_envs.seed(args.seed)
|
train_envs.seed(args.seed)
|
||||||
test_envs.seed(args.seed)
|
test_envs.seed(args.seed)
|
||||||
# model
|
# model
|
||||||
net = Net(args.layer_num, args.state_shape, args.device)
|
net = Net(args.layer_num, args.state_shape, device=args.device)
|
||||||
actor = Actor(net, args.action_shape).to(args.device)
|
actor = Actor(net, args.action_shape).to(args.device)
|
||||||
critic = Critic(net).to(args.device)
|
critic = Critic(net).to(args.device)
|
||||||
optim = torch.optim.Adam(list(
|
optim = torch.optim.Adam(list(
|
Loading…
x
Reference in New Issue
Block a user