doc update and do not force save 'policy' in np format (#168)
This commit is contained in:
parent
e024afab8c
commit
b7a4015db7
@ -23,7 +23,7 @@ def get_args():
|
|||||||
parser.add_argument('--gamma', type=float, default=0.9)
|
parser.add_argument('--gamma', type=float, default=0.9)
|
||||||
parser.add_argument('--epoch', type=int, default=100)
|
parser.add_argument('--epoch', type=int, default=100)
|
||||||
parser.add_argument('--step-per-epoch', type=int, default=1000)
|
parser.add_argument('--step-per-epoch', type=int, default=1000)
|
||||||
parser.add_argument('--collect-per-step', type=int, default=100)
|
parser.add_argument('--collect-per-step', type=int, default=10)
|
||||||
parser.add_argument('--repeat-per-collect', type=int, default=1)
|
parser.add_argument('--repeat-per-collect', type=int, default=1)
|
||||||
parser.add_argument('--batch-size', type=int, default=64)
|
parser.add_argument('--batch-size', type=int, default=64)
|
||||||
parser.add_argument('--layer-num', type=int, default=2)
|
parser.add_argument('--layer-num', type=int, default=2)
|
||||||
|
@ -23,7 +23,7 @@ def get_args():
|
|||||||
parser.add_argument('--gamma', type=float, default=0.99)
|
parser.add_argument('--gamma', type=float, default=0.99)
|
||||||
parser.add_argument('--epoch', type=int, default=100)
|
parser.add_argument('--epoch', type=int, default=100)
|
||||||
parser.add_argument('--step-per-epoch', type=int, default=1000)
|
parser.add_argument('--step-per-epoch', type=int, default=1000)
|
||||||
parser.add_argument('--collect-per-step', type=int, default=100)
|
parser.add_argument('--collect-per-step', type=int, default=10)
|
||||||
parser.add_argument('--repeat-per-collect', type=int, default=2)
|
parser.add_argument('--repeat-per-collect', type=int, default=2)
|
||||||
parser.add_argument('--batch-size', type=int, default=64)
|
parser.add_argument('--batch-size', type=int, default=64)
|
||||||
parser.add_argument('--layer-num', type=int, default=1)
|
parser.add_argument('--layer-num', type=int, default=1)
|
||||||
|
@ -258,13 +258,11 @@ class Collector(object):
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
result = self.policy(self.data, last_state)
|
result = self.policy(self.data, last_state)
|
||||||
|
|
||||||
# convert None to Batch(), since None is reserved for 0-init
|
|
||||||
state = result.get('state', Batch())
|
state = result.get('state', Batch())
|
||||||
|
# convert None to Batch(), since None is reserved for 0-init
|
||||||
if state is None:
|
if state is None:
|
||||||
state = Batch()
|
state = Batch()
|
||||||
self.data.state = state
|
self.data.update(state=state, policy=result.get('policy', Batch()))
|
||||||
if hasattr(result, 'policy'):
|
|
||||||
self.data.policy = to_numpy(result.policy)
|
|
||||||
# save hidden state to policy._state, in order to save into buffer
|
# save hidden state to policy._state, in order to save into buffer
|
||||||
self.data.policy._state = self.data.state
|
self.data.policy._state = self.data.state
|
||||||
|
|
||||||
|
@ -86,9 +86,8 @@ class BasePolicy(ABC, nn.Module):
|
|||||||
# some code
|
# some code
|
||||||
return Batch(logits=..., act=..., state=None, dist=...)
|
return Batch(logits=..., act=..., state=None, dist=...)
|
||||||
|
|
||||||
After version >= 0.2.3, the keyword "policy" is reserverd and the
|
The keyword ``policy`` is reserved and the corresponding data will be
|
||||||
corresponding data will be stored into the replay buffer in numpy. For
|
stored into the replay buffer in numpy. For instance,
|
||||||
instance,
|
|
||||||
::
|
::
|
||||||
|
|
||||||
# some code
|
# some code
|
||||||
|
@ -40,9 +40,9 @@ def onpolicy_trainer(
|
|||||||
process might be finished before reaching the ``max_epoch``.
|
process might be finished before reaching the ``max_epoch``.
|
||||||
:param int step_per_epoch: the number of step for updating policy network
|
:param int step_per_epoch: the number of step for updating policy network
|
||||||
in one epoch.
|
in one epoch.
|
||||||
:param int collect_per_step: the number of frames the collector would
|
:param int collect_per_step: the number of episodes the collector would
|
||||||
collect before the network update. In other words, collect some frames
|
collect before the network update. In other words, collect some
|
||||||
and do one policy network update.
|
episodes and do one policy network update.
|
||||||
:param int repeat_per_collect: the number of repeat time for policy
|
:param int repeat_per_collect: the number of repeat time for policy
|
||||||
learning, for example, set it to 2 means the policy needs to learn each
|
learning, for example, set it to 2 means the policy needs to learn each
|
||||||
given batch data twice.
|
given batch data twice.
|
||||||
|
Loading…
x
Reference in New Issue
Block a user