diff --git a/tianshou/highlevel/agent.py b/tianshou/highlevel/agent.py index dce3379..772c9d1 100644 --- a/tianshou/highlevel/agent.py +++ b/tianshou/highlevel/agent.py @@ -115,7 +115,10 @@ class AgentFactory(ABC, ToStringMixin): train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) test_collector = Collector(policy, envs.test_envs) if self.sampling_config.start_timesteps > 0: - train_collector.collect(n_step=self.sampling_config.start_timesteps, random=True) + train_collector.collect( + n_step=self.sampling_config.start_timesteps, + random=self.sampling_config.start_timesteps_random, + ) return train_collector, test_collector def set_policy_wrapper_factory(