doc fix; policy train/eval signiture fix (#109)
* doc fix; policy train/eval signiture fix * change train/eval behavior according to pytorch * change train/eval behavior according to pytorch
This commit is contained in:
parent
db0e2e5cd2
commit
5b1373924e
@ -21,7 +21,7 @@ def to_numpy(x: Union[
|
|||||||
|
|
||||||
def to_torch(x: Union[torch.Tensor, dict, Batch, np.ndarray],
|
def to_torch(x: Union[torch.Tensor, dict, Batch, np.ndarray],
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
device: Union[str, int] = 'cpu'
|
device: Union[str, int, torch.device] = 'cpu'
|
||||||
) -> Union[dict, Batch, torch.Tensor]:
|
) -> Union[dict, Batch, torch.Tensor]:
|
||||||
"""Return an object without np.ndarray."""
|
"""Return an object without np.ndarray."""
|
||||||
if isinstance(x, np.ndarray):
|
if isinstance(x, np.ndarray):
|
||||||
|
@ -81,17 +81,12 @@ class DDPGPolicy(BasePolicy):
|
|||||||
"""Set the exploration noise."""
|
"""Set the exploration noise."""
|
||||||
self._noise = noise
|
self._noise = noise
|
||||||
|
|
||||||
def train(self) -> None:
|
def train(self, mode=True) -> torch.nn.Module:
|
||||||
"""Set the module in training mode, except for the target network."""
|
"""Set the module in training mode, except for the target network."""
|
||||||
self.training = True
|
self.training = mode
|
||||||
self.actor.train()
|
self.actor.train(mode)
|
||||||
self.critic.train()
|
self.critic.train(mode)
|
||||||
|
return self
|
||||||
def eval(self) -> None:
|
|
||||||
"""Set the module in evaluation mode, except for the target network."""
|
|
||||||
self.training = False
|
|
||||||
self.actor.eval()
|
|
||||||
self.critic.eval()
|
|
||||||
|
|
||||||
def sync_weight(self) -> None:
|
def sync_weight(self) -> None:
|
||||||
"""Soft-update the weight for the target network."""
|
"""Soft-update the weight for the target network."""
|
||||||
@ -127,8 +122,6 @@ class DDPGPolicy(BasePolicy):
|
|||||||
**kwargs) -> Batch:
|
**kwargs) -> Batch:
|
||||||
"""Compute action over the given batch data.
|
"""Compute action over the given batch data.
|
||||||
|
|
||||||
:param float eps: in [0, 1], for exploration use.
|
|
||||||
|
|
||||||
:return: A :class:`~tianshou.data.Batch` which has 2 keys:
|
:return: A :class:`~tianshou.data.Batch` which has 2 keys:
|
||||||
|
|
||||||
* ``act`` the action.
|
* ``act`` the action.
|
||||||
@ -141,12 +134,12 @@ class DDPGPolicy(BasePolicy):
|
|||||||
"""
|
"""
|
||||||
model = getattr(self, model)
|
model = getattr(self, model)
|
||||||
obs = getattr(batch, input)
|
obs = getattr(batch, input)
|
||||||
logits, h = model(obs, state=state, info=batch.info)
|
actions, h = model(obs, state=state, info=batch.info)
|
||||||
logits += self._action_bias
|
actions += self._action_bias
|
||||||
if self.training and explorating:
|
if self.training and explorating:
|
||||||
logits += to_torch_as(self._noise(logits.shape), logits)
|
actions += to_torch_as(self._noise(actions.shape), actions)
|
||||||
logits = logits.clamp(self._range[0], self._range[1])
|
actions = actions.clamp(self._range[0], self._range[1])
|
||||||
return Batch(act=logits, state=h)
|
return Batch(act=actions, state=h)
|
||||||
|
|
||||||
def learn(self, batch: Batch, **kwargs) -> Dict[str, float]:
|
def learn(self, batch: Batch, **kwargs) -> Dict[str, float]:
|
||||||
current_q = self.critic(batch.obs, batch.act)
|
current_q = self.critic(batch.obs, batch.act)
|
||||||
|
@ -54,15 +54,11 @@ class DQNPolicy(BasePolicy):
|
|||||||
"""Set the eps for epsilon-greedy exploration."""
|
"""Set the eps for epsilon-greedy exploration."""
|
||||||
self.eps = eps
|
self.eps = eps
|
||||||
|
|
||||||
def train(self) -> None:
|
def train(self, mode=True) -> torch.nn.Module:
|
||||||
"""Set the module in training mode, except for the target network."""
|
"""Set the module in training mode, except for the target network."""
|
||||||
self.training = True
|
self.training = mode
|
||||||
self.model.train()
|
self.model.train(mode)
|
||||||
|
return self
|
||||||
def eval(self) -> None:
|
|
||||||
"""Set the module in evaluation mode, except for the target network."""
|
|
||||||
self.training = False
|
|
||||||
self.model.eval()
|
|
||||||
|
|
||||||
def sync_weight(self) -> None:
|
def sync_weight(self) -> None:
|
||||||
"""Synchronize the weight for the target network."""
|
"""Synchronize the weight for the target network."""
|
||||||
|
@ -89,17 +89,12 @@ class SACPolicy(DDPGPolicy):
|
|||||||
|
|
||||||
self.__eps = np.finfo(np.float32).eps.item()
|
self.__eps = np.finfo(np.float32).eps.item()
|
||||||
|
|
||||||
def train(self) -> None:
|
def train(self, mode=True) -> torch.nn.Module:
|
||||||
self.training = True
|
self.training = mode
|
||||||
self.actor.train()
|
self.actor.train(mode)
|
||||||
self.critic1.train()
|
self.critic1.train(mode)
|
||||||
self.critic2.train()
|
self.critic2.train(mode)
|
||||||
|
return self
|
||||||
def eval(self) -> None:
|
|
||||||
self.training = False
|
|
||||||
self.actor.eval()
|
|
||||||
self.critic1.eval()
|
|
||||||
self.critic2.eval()
|
|
||||||
|
|
||||||
def sync_weight(self) -> None:
|
def sync_weight(self) -> None:
|
||||||
for o, n in zip(
|
for o, n in zip(
|
||||||
|
@ -82,17 +82,12 @@ class TD3Policy(DDPGPolicy):
|
|||||||
self._cnt = 0
|
self._cnt = 0
|
||||||
self._last = 0
|
self._last = 0
|
||||||
|
|
||||||
def train(self) -> None:
|
def train(self, mode=True) -> torch.nn.Module:
|
||||||
self.training = True
|
self.training = mode
|
||||||
self.actor.train()
|
self.actor.train(mode)
|
||||||
self.critic1.train()
|
self.critic1.train(mode)
|
||||||
self.critic2.train()
|
self.critic2.train(mode)
|
||||||
|
return self
|
||||||
def eval(self) -> None:
|
|
||||||
self.training = False
|
|
||||||
self.actor.eval()
|
|
||||||
self.critic1.eval()
|
|
||||||
self.critic2.eval()
|
|
||||||
|
|
||||||
def sync_weight(self) -> None:
|
def sync_weight(self) -> None:
|
||||||
for o, n in zip(self.actor_old.parameters(), self.actor.parameters()):
|
for o, n in zip(self.actor_old.parameters(), self.actor.parameters()):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user