use nn.Sequential in DQN (#176)
This commit is contained in:
parent
99a1d40e85
commit
32df0567bb
@ -49,10 +49,6 @@ class DQN(nn.Module):
|
||||
super(DQN, self).__init__()
|
||||
self.device = device
|
||||
|
||||
self.conv1 = nn.Conv2d(4, 32, kernel_size=8, stride=4)
|
||||
self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2)
|
||||
self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1)
|
||||
|
||||
def conv2d_size_out(size, kernel_size=5, stride=2):
|
||||
return (size - (kernel_size - 1) - 1) // stride + 1
|
||||
|
||||
@ -68,16 +64,22 @@ class DQN(nn.Module):
|
||||
convw = conv2d_layers_size_out(w)
|
||||
convh = conv2d_layers_size_out(h)
|
||||
linear_input_size = convw * convh * 64
|
||||
self.fc = nn.Linear(linear_input_size, 512)
|
||||
self.head = nn.Linear(512, action_shape)
|
||||
|
||||
self.net = nn.Sequential(
|
||||
nn.Conv2d(4, 32, kernel_size=8, stride=4),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(32, 64, kernel_size=4, stride=2),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(64, 64, kernel_size=3, stride=1),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Flatten(),
|
||||
nn.Linear(linear_input_size, 512),
|
||||
nn.Linear(512, action_shape)
|
||||
)
|
||||
|
||||
def forward(self, x, state=None, info={}):
|
||||
r"""x -> Q(x, \*)"""
|
||||
if not isinstance(x, torch.Tensor):
|
||||
x = torch.tensor(x, device=self.device, dtype=torch.float32)
|
||||
x = x.permute(0, 3, 1, 2)
|
||||
x = F.relu(self.conv1(x))
|
||||
x = F.relu(self.conv2(x))
|
||||
x = F.relu(self.conv3(x))
|
||||
x = self.fc(x.reshape(x.size(0), -1))
|
||||
return self.head(x), state
|
||||
return self.net(x), state
|
||||
|
Loading…
x
Reference in New Issue
Block a user