diff --git a/tianshou/data/utils.py b/tianshou/data/utils.py index 8890c4b..6749c78 100644 --- a/tianshou/data/utils.py +++ b/tianshou/data/utils.py @@ -21,7 +21,7 @@ def to_numpy(x: Union[ def to_torch(x: Union[torch.Tensor, dict, Batch, np.ndarray], dtype: Optional[torch.dtype] = None, - device: Union[str, int] = 'cpu' + device: Union[str, int, torch.device] = 'cpu' ) -> Union[dict, Batch, torch.Tensor]: """Return an object without np.ndarray.""" if isinstance(x, np.ndarray): diff --git a/tianshou/policy/modelfree/ddpg.py b/tianshou/policy/modelfree/ddpg.py index ec56123..0944e8d 100644 --- a/tianshou/policy/modelfree/ddpg.py +++ b/tianshou/policy/modelfree/ddpg.py @@ -81,17 +81,12 @@ class DDPGPolicy(BasePolicy): """Set the exploration noise.""" self._noise = noise - def train(self) -> None: + def train(self, mode=True) -> torch.nn.Module: """Set the module in training mode, except for the target network.""" - self.training = True - self.actor.train() - self.critic.train() - - def eval(self) -> None: - """Set the module in evaluation mode, except for the target network.""" - self.training = False - self.actor.eval() - self.critic.eval() + self.training = mode + self.actor.train(mode) + self.critic.train(mode) + return self def sync_weight(self) -> None: """Soft-update the weight for the target network.""" @@ -127,8 +122,6 @@ class DDPGPolicy(BasePolicy): **kwargs) -> Batch: """Compute action over the given batch data. - :param float eps: in [0, 1], for exploration use. - :return: A :class:`~tianshou.data.Batch` which has 2 keys: * ``act`` the action. @@ -141,12 +134,12 @@ class DDPGPolicy(BasePolicy): """ model = getattr(self, model) obs = getattr(batch, input) - logits, h = model(obs, state=state, info=batch.info) - logits += self._action_bias + actions, h = model(obs, state=state, info=batch.info) + actions += self._action_bias if self.training and explorating: - logits += to_torch_as(self._noise(logits.shape), logits) - logits = logits.clamp(self._range[0], self._range[1]) - return Batch(act=logits, state=h) + actions += to_torch_as(self._noise(actions.shape), actions) + actions = actions.clamp(self._range[0], self._range[1]) + return Batch(act=actions, state=h) def learn(self, batch: Batch, **kwargs) -> Dict[str, float]: current_q = self.critic(batch.obs, batch.act) diff --git a/tianshou/policy/modelfree/dqn.py b/tianshou/policy/modelfree/dqn.py index 0df4e0d..c34ba4e 100644 --- a/tianshou/policy/modelfree/dqn.py +++ b/tianshou/policy/modelfree/dqn.py @@ -54,15 +54,11 @@ class DQNPolicy(BasePolicy): """Set the eps for epsilon-greedy exploration.""" self.eps = eps - def train(self) -> None: + def train(self, mode=True) -> torch.nn.Module: """Set the module in training mode, except for the target network.""" - self.training = True - self.model.train() - - def eval(self) -> None: - """Set the module in evaluation mode, except for the target network.""" - self.training = False - self.model.eval() + self.training = mode + self.model.train(mode) + return self def sync_weight(self) -> None: """Synchronize the weight for the target network.""" diff --git a/tianshou/policy/modelfree/sac.py b/tianshou/policy/modelfree/sac.py index 8ddb2bf..ab0a633 100644 --- a/tianshou/policy/modelfree/sac.py +++ b/tianshou/policy/modelfree/sac.py @@ -89,17 +89,12 @@ class SACPolicy(DDPGPolicy): self.__eps = np.finfo(np.float32).eps.item() - def train(self) -> None: - self.training = True - self.actor.train() - self.critic1.train() - self.critic2.train() - - def eval(self) -> None: - self.training = False - self.actor.eval() - self.critic1.eval() - self.critic2.eval() + def train(self, mode=True) -> torch.nn.Module: + self.training = mode + self.actor.train(mode) + self.critic1.train(mode) + self.critic2.train(mode) + return self def sync_weight(self) -> None: for o, n in zip( diff --git a/tianshou/policy/modelfree/td3.py b/tianshou/policy/modelfree/td3.py index 2223e37..5d9bc37 100644 --- a/tianshou/policy/modelfree/td3.py +++ b/tianshou/policy/modelfree/td3.py @@ -82,17 +82,12 @@ class TD3Policy(DDPGPolicy): self._cnt = 0 self._last = 0 - def train(self) -> None: - self.training = True - self.actor.train() - self.critic1.train() - self.critic2.train() - - def eval(self) -> None: - self.training = False - self.actor.eval() - self.critic1.eval() - self.critic2.eval() + def train(self, mode=True) -> torch.nn.Module: + self.training = mode + self.actor.train(mode) + self.critic1.train(mode) + self.critic2.train(mode) + return self def sync_weight(self) -> None: for o, n in zip(self.actor_old.parameters(), self.actor.parameters()):