doc fix; policy train/eval signiture fix (#109)

* doc fix; policy train/eval signiture fix

* change train/eval behavior according to pytorch

* change train/eval behavior according to pytorch
This commit is contained in:
youkaichao 2020-07-06 10:44:34 +08:00 committed by GitHub
parent db0e2e5cd2
commit 5b1373924e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 27 additions and 48 deletions

View File

@ -21,7 +21,7 @@ def to_numpy(x: Union[
def to_torch(x: Union[torch.Tensor, dict, Batch, np.ndarray],
dtype: Optional[torch.dtype] = None,
device: Union[str, int] = 'cpu'
device: Union[str, int, torch.device] = 'cpu'
) -> Union[dict, Batch, torch.Tensor]:
"""Return an object without np.ndarray."""
if isinstance(x, np.ndarray):

View File

@ -81,17 +81,12 @@ class DDPGPolicy(BasePolicy):
"""Set the exploration noise."""
self._noise = noise
def train(self) -> None:
def train(self, mode=True) -> torch.nn.Module:
"""Set the module in training mode, except for the target network."""
self.training = True
self.actor.train()
self.critic.train()
def eval(self) -> None:
"""Set the module in evaluation mode, except for the target network."""
self.training = False
self.actor.eval()
self.critic.eval()
self.training = mode
self.actor.train(mode)
self.critic.train(mode)
return self
def sync_weight(self) -> None:
"""Soft-update the weight for the target network."""
@ -127,8 +122,6 @@ class DDPGPolicy(BasePolicy):
**kwargs) -> Batch:
"""Compute action over the given batch data.
:param float eps: in [0, 1], for exploration use.
:return: A :class:`~tianshou.data.Batch` which has 2 keys:
* ``act`` the action.
@ -141,12 +134,12 @@ class DDPGPolicy(BasePolicy):
"""
model = getattr(self, model)
obs = getattr(batch, input)
logits, h = model(obs, state=state, info=batch.info)
logits += self._action_bias
actions, h = model(obs, state=state, info=batch.info)
actions += self._action_bias
if self.training and explorating:
logits += to_torch_as(self._noise(logits.shape), logits)
logits = logits.clamp(self._range[0], self._range[1])
return Batch(act=logits, state=h)
actions += to_torch_as(self._noise(actions.shape), actions)
actions = actions.clamp(self._range[0], self._range[1])
return Batch(act=actions, state=h)
def learn(self, batch: Batch, **kwargs) -> Dict[str, float]:
current_q = self.critic(batch.obs, batch.act)

View File

@ -54,15 +54,11 @@ class DQNPolicy(BasePolicy):
"""Set the eps for epsilon-greedy exploration."""
self.eps = eps
def train(self) -> None:
def train(self, mode=True) -> torch.nn.Module:
"""Set the module in training mode, except for the target network."""
self.training = True
self.model.train()
def eval(self) -> None:
"""Set the module in evaluation mode, except for the target network."""
self.training = False
self.model.eval()
self.training = mode
self.model.train(mode)
return self
def sync_weight(self) -> None:
"""Synchronize the weight for the target network."""

View File

@ -89,17 +89,12 @@ class SACPolicy(DDPGPolicy):
self.__eps = np.finfo(np.float32).eps.item()
def train(self) -> None:
self.training = True
self.actor.train()
self.critic1.train()
self.critic2.train()
def eval(self) -> None:
self.training = False
self.actor.eval()
self.critic1.eval()
self.critic2.eval()
def train(self, mode=True) -> torch.nn.Module:
self.training = mode
self.actor.train(mode)
self.critic1.train(mode)
self.critic2.train(mode)
return self
def sync_weight(self) -> None:
for o, n in zip(

View File

@ -82,17 +82,12 @@ class TD3Policy(DDPGPolicy):
self._cnt = 0
self._last = 0
def train(self) -> None:
self.training = True
self.actor.train()
self.critic1.train()
self.critic2.train()
def eval(self) -> None:
self.training = False
self.actor.eval()
self.critic1.eval()
self.critic2.eval()
def train(self, mode=True) -> torch.nn.Module:
self.training = mode
self.actor.train(mode)
self.critic1.train(mode)
self.critic2.train(mode)
return self
def sync_weight(self) -> None:
for o, n in zip(self.actor_old.parameters(), self.actor.parameters()):