diff --git a/examples/pong_a2c.py b/examples/pong_a2c.py index 800cc2d..31e439d 100644 --- a/examples/pong_a2c.py +++ b/examples/pong_a2c.py @@ -23,7 +23,7 @@ def get_args(): parser.add_argument('--gamma', type=float, default=0.9) parser.add_argument('--epoch', type=int, default=100) parser.add_argument('--step-per-epoch', type=int, default=1000) - parser.add_argument('--collect-per-step', type=int, default=100) + parser.add_argument('--collect-per-step', type=int, default=10) parser.add_argument('--repeat-per-collect', type=int, default=1) parser.add_argument('--batch-size', type=int, default=64) parser.add_argument('--layer-num', type=int, default=2) diff --git a/examples/pong_ppo.py b/examples/pong_ppo.py index 991d3d4..5202a38 100644 --- a/examples/pong_ppo.py +++ b/examples/pong_ppo.py @@ -23,7 +23,7 @@ def get_args(): parser.add_argument('--gamma', type=float, default=0.99) parser.add_argument('--epoch', type=int, default=100) parser.add_argument('--step-per-epoch', type=int, default=1000) - parser.add_argument('--collect-per-step', type=int, default=100) + parser.add_argument('--collect-per-step', type=int, default=10) parser.add_argument('--repeat-per-collect', type=int, default=2) parser.add_argument('--batch-size', type=int, default=64) parser.add_argument('--layer-num', type=int, default=1) diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index fad5e12..9105374 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -258,13 +258,11 @@ class Collector(object): with torch.no_grad(): result = self.policy(self.data, last_state) - # convert None to Batch(), since None is reserved for 0-init state = result.get('state', Batch()) + # convert None to Batch(), since None is reserved for 0-init if state is None: state = Batch() - self.data.state = state - if hasattr(result, 'policy'): - self.data.policy = to_numpy(result.policy) + self.data.update(state=state, policy=result.get('policy', Batch())) # save hidden state to policy._state, in order to save into buffer self.data.policy._state = self.data.state diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index 4690749..9d7711c 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -86,9 +86,8 @@ class BasePolicy(ABC, nn.Module): # some code return Batch(logits=..., act=..., state=None, dist=...) - After version >= 0.2.3, the keyword "policy" is reserverd and the - corresponding data will be stored into the replay buffer in numpy. For - instance, + The keyword ``policy`` is reserved and the corresponding data will be + stored into the replay buffer in numpy. For instance, :: # some code diff --git a/tianshou/trainer/onpolicy.py b/tianshou/trainer/onpolicy.py index 8f77dd1..dec42b2 100644 --- a/tianshou/trainer/onpolicy.py +++ b/tianshou/trainer/onpolicy.py @@ -40,9 +40,9 @@ def onpolicy_trainer( process might be finished before reaching the ``max_epoch``. :param int step_per_epoch: the number of step for updating policy network in one epoch. - :param int collect_per_step: the number of frames the collector would - collect before the network update. In other words, collect some frames - and do one policy network update. + :param int collect_per_step: the number of episodes the collector would + collect before the network update. In other words, collect some + episodes and do one policy network update. :param int repeat_per_collect: the number of repeat time for policy learning, for example, set it to 2 means the policy needs to learn each given batch data twice.