diff --git a/tianshou/utils/net/discrete.py b/tianshou/utils/net/discrete.py index ae6a1ef..afed6df 100644 --- a/tianshou/utils/net/discrete.py +++ b/tianshou/utils/net/discrete.py @@ -49,10 +49,6 @@ class DQN(nn.Module): super(DQN, self).__init__() self.device = device - self.conv1 = nn.Conv2d(4, 32, kernel_size=8, stride=4) - self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2) - self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1) - def conv2d_size_out(size, kernel_size=5, stride=2): return (size - (kernel_size - 1) - 1) // stride + 1 @@ -68,16 +64,22 @@ class DQN(nn.Module): convw = conv2d_layers_size_out(w) convh = conv2d_layers_size_out(h) linear_input_size = convw * convh * 64 - self.fc = nn.Linear(linear_input_size, 512) - self.head = nn.Linear(512, action_shape) + + self.net = nn.Sequential( + nn.Conv2d(4, 32, kernel_size=8, stride=4), + nn.ReLU(inplace=True), + nn.Conv2d(32, 64, kernel_size=4, stride=2), + nn.ReLU(inplace=True), + nn.Conv2d(64, 64, kernel_size=3, stride=1), + nn.ReLU(inplace=True), + nn.Flatten(), + nn.Linear(linear_input_size, 512), + nn.Linear(512, action_shape) + ) def forward(self, x, state=None, info={}): r"""x -> Q(x, \*)""" if not isinstance(x, torch.Tensor): x = torch.tensor(x, device=self.device, dtype=torch.float32) x = x.permute(0, 3, 1, 2) - x = F.relu(self.conv1(x)) - x = F.relu(self.conv2(x)) - x = F.relu(self.conv3(x)) - x = self.fc(x.reshape(x.size(0), -1)) - return self.head(x), state + return self.net(x), state