doc update and do not force save 'policy' in np format (#168)
This commit is contained in:
parent
e024afab8c
commit
b7a4015db7
@ -23,7 +23,7 @@ def get_args():
|
||||
parser.add_argument('--gamma', type=float, default=0.9)
|
||||
parser.add_argument('--epoch', type=int, default=100)
|
||||
parser.add_argument('--step-per-epoch', type=int, default=1000)
|
||||
parser.add_argument('--collect-per-step', type=int, default=100)
|
||||
parser.add_argument('--collect-per-step', type=int, default=10)
|
||||
parser.add_argument('--repeat-per-collect', type=int, default=1)
|
||||
parser.add_argument('--batch-size', type=int, default=64)
|
||||
parser.add_argument('--layer-num', type=int, default=2)
|
||||
|
@ -23,7 +23,7 @@ def get_args():
|
||||
parser.add_argument('--gamma', type=float, default=0.99)
|
||||
parser.add_argument('--epoch', type=int, default=100)
|
||||
parser.add_argument('--step-per-epoch', type=int, default=1000)
|
||||
parser.add_argument('--collect-per-step', type=int, default=100)
|
||||
parser.add_argument('--collect-per-step', type=int, default=10)
|
||||
parser.add_argument('--repeat-per-collect', type=int, default=2)
|
||||
parser.add_argument('--batch-size', type=int, default=64)
|
||||
parser.add_argument('--layer-num', type=int, default=1)
|
||||
|
@ -258,13 +258,11 @@ class Collector(object):
|
||||
with torch.no_grad():
|
||||
result = self.policy(self.data, last_state)
|
||||
|
||||
# convert None to Batch(), since None is reserved for 0-init
|
||||
state = result.get('state', Batch())
|
||||
# convert None to Batch(), since None is reserved for 0-init
|
||||
if state is None:
|
||||
state = Batch()
|
||||
self.data.state = state
|
||||
if hasattr(result, 'policy'):
|
||||
self.data.policy = to_numpy(result.policy)
|
||||
self.data.update(state=state, policy=result.get('policy', Batch()))
|
||||
# save hidden state to policy._state, in order to save into buffer
|
||||
self.data.policy._state = self.data.state
|
||||
|
||||
|
@ -86,9 +86,8 @@ class BasePolicy(ABC, nn.Module):
|
||||
# some code
|
||||
return Batch(logits=..., act=..., state=None, dist=...)
|
||||
|
||||
After version >= 0.2.3, the keyword "policy" is reserverd and the
|
||||
corresponding data will be stored into the replay buffer in numpy. For
|
||||
instance,
|
||||
The keyword ``policy`` is reserved and the corresponding data will be
|
||||
stored into the replay buffer in numpy. For instance,
|
||||
::
|
||||
|
||||
# some code
|
||||
|
@ -40,9 +40,9 @@ def onpolicy_trainer(
|
||||
process might be finished before reaching the ``max_epoch``.
|
||||
:param int step_per_epoch: the number of step for updating policy network
|
||||
in one epoch.
|
||||
:param int collect_per_step: the number of frames the collector would
|
||||
collect before the network update. In other words, collect some frames
|
||||
and do one policy network update.
|
||||
:param int collect_per_step: the number of episodes the collector would
|
||||
collect before the network update. In other words, collect some
|
||||
episodes and do one policy network update.
|
||||
:param int repeat_per_collect: the number of repeat time for policy
|
||||
learning, for example, set it to 2 means the policy needs to learn each
|
||||
given batch data twice.
|
||||
|
Loading…
x
Reference in New Issue
Block a user