doc update and do not force save 'policy' in np format (#168)

This commit is contained in:
Trinkle23897 2020-07-27 16:54:14 +08:00
parent e024afab8c
commit b7a4015db7
5 changed files with 9 additions and 12 deletions

View File

@ -23,7 +23,7 @@ def get_args():
parser.add_argument('--gamma', type=float, default=0.9)
parser.add_argument('--epoch', type=int, default=100)
parser.add_argument('--step-per-epoch', type=int, default=1000)
parser.add_argument('--collect-per-step', type=int, default=100)
parser.add_argument('--collect-per-step', type=int, default=10)
parser.add_argument('--repeat-per-collect', type=int, default=1)
parser.add_argument('--batch-size', type=int, default=64)
parser.add_argument('--layer-num', type=int, default=2)

View File

@ -23,7 +23,7 @@ def get_args():
parser.add_argument('--gamma', type=float, default=0.99)
parser.add_argument('--epoch', type=int, default=100)
parser.add_argument('--step-per-epoch', type=int, default=1000)
parser.add_argument('--collect-per-step', type=int, default=100)
parser.add_argument('--collect-per-step', type=int, default=10)
parser.add_argument('--repeat-per-collect', type=int, default=2)
parser.add_argument('--batch-size', type=int, default=64)
parser.add_argument('--layer-num', type=int, default=1)

View File

@ -258,13 +258,11 @@ class Collector(object):
with torch.no_grad():
result = self.policy(self.data, last_state)
# convert None to Batch(), since None is reserved for 0-init
state = result.get('state', Batch())
# convert None to Batch(), since None is reserved for 0-init
if state is None:
state = Batch()
self.data.state = state
if hasattr(result, 'policy'):
self.data.policy = to_numpy(result.policy)
self.data.update(state=state, policy=result.get('policy', Batch()))
# save hidden state to policy._state, in order to save into buffer
self.data.policy._state = self.data.state

View File

@ -86,9 +86,8 @@ class BasePolicy(ABC, nn.Module):
# some code
return Batch(logits=..., act=..., state=None, dist=...)
After version >= 0.2.3, the keyword "policy" is reserverd and the
corresponding data will be stored into the replay buffer in numpy. For
instance,
The keyword ``policy`` is reserved and the corresponding data will be
stored into the replay buffer in numpy. For instance,
::
# some code

View File

@ -40,9 +40,9 @@ def onpolicy_trainer(
process might be finished before reaching the ``max_epoch``.
:param int step_per_epoch: the number of step for updating policy network
in one epoch.
:param int collect_per_step: the number of frames the collector would
collect before the network update. In other words, collect some frames
and do one policy network update.
:param int collect_per_step: the number of episodes the collector would
collect before the network update. In other words, collect some
episodes and do one policy network update.
:param int repeat_per_collect: the number of repeat time for policy
learning, for example, set it to 2 means the policy needs to learn each
given batch data twice.