doc update and do not force save 'policy' in np format (#168)

This commit is contained in:
Trinkle23897 2020-07-27 16:54:14 +08:00
parent e024afab8c
commit b7a4015db7
5 changed files with 9 additions and 12 deletions

View File

@ -23,7 +23,7 @@ def get_args():
parser.add_argument('--gamma', type=float, default=0.9) parser.add_argument('--gamma', type=float, default=0.9)
parser.add_argument('--epoch', type=int, default=100) parser.add_argument('--epoch', type=int, default=100)
parser.add_argument('--step-per-epoch', type=int, default=1000) parser.add_argument('--step-per-epoch', type=int, default=1000)
parser.add_argument('--collect-per-step', type=int, default=100) parser.add_argument('--collect-per-step', type=int, default=10)
parser.add_argument('--repeat-per-collect', type=int, default=1) parser.add_argument('--repeat-per-collect', type=int, default=1)
parser.add_argument('--batch-size', type=int, default=64) parser.add_argument('--batch-size', type=int, default=64)
parser.add_argument('--layer-num', type=int, default=2) parser.add_argument('--layer-num', type=int, default=2)

View File

@ -23,7 +23,7 @@ def get_args():
parser.add_argument('--gamma', type=float, default=0.99) parser.add_argument('--gamma', type=float, default=0.99)
parser.add_argument('--epoch', type=int, default=100) parser.add_argument('--epoch', type=int, default=100)
parser.add_argument('--step-per-epoch', type=int, default=1000) parser.add_argument('--step-per-epoch', type=int, default=1000)
parser.add_argument('--collect-per-step', type=int, default=100) parser.add_argument('--collect-per-step', type=int, default=10)
parser.add_argument('--repeat-per-collect', type=int, default=2) parser.add_argument('--repeat-per-collect', type=int, default=2)
parser.add_argument('--batch-size', type=int, default=64) parser.add_argument('--batch-size', type=int, default=64)
parser.add_argument('--layer-num', type=int, default=1) parser.add_argument('--layer-num', type=int, default=1)

View File

@ -258,13 +258,11 @@ class Collector(object):
with torch.no_grad(): with torch.no_grad():
result = self.policy(self.data, last_state) result = self.policy(self.data, last_state)
# convert None to Batch(), since None is reserved for 0-init
state = result.get('state', Batch()) state = result.get('state', Batch())
# convert None to Batch(), since None is reserved for 0-init
if state is None: if state is None:
state = Batch() state = Batch()
self.data.state = state self.data.update(state=state, policy=result.get('policy', Batch()))
if hasattr(result, 'policy'):
self.data.policy = to_numpy(result.policy)
# save hidden state to policy._state, in order to save into buffer # save hidden state to policy._state, in order to save into buffer
self.data.policy._state = self.data.state self.data.policy._state = self.data.state

View File

@ -86,9 +86,8 @@ class BasePolicy(ABC, nn.Module):
# some code # some code
return Batch(logits=..., act=..., state=None, dist=...) return Batch(logits=..., act=..., state=None, dist=...)
After version >= 0.2.3, the keyword "policy" is reserverd and the The keyword ``policy`` is reserved and the corresponding data will be
corresponding data will be stored into the replay buffer in numpy. For stored into the replay buffer in numpy. For instance,
instance,
:: ::
# some code # some code

View File

@ -40,9 +40,9 @@ def onpolicy_trainer(
process might be finished before reaching the ``max_epoch``. process might be finished before reaching the ``max_epoch``.
:param int step_per_epoch: the number of step for updating policy network :param int step_per_epoch: the number of step for updating policy network
in one epoch. in one epoch.
:param int collect_per_step: the number of frames the collector would :param int collect_per_step: the number of episodes the collector would
collect before the network update. In other words, collect some frames collect before the network update. In other words, collect some
and do one policy network update. episodes and do one policy network update.
:param int repeat_per_collect: the number of repeat time for policy :param int repeat_per_collect: the number of repeat time for policy
learning, for example, set it to 2 means the policy needs to learn each learning, for example, set it to 2 means the policy needs to learn each
given batch data twice. given batch data twice.