diff --git a/test/base/test_env.py b/test/base/test_env.py index 432fdd1..6697b47 100644 --- a/test/base/test_env.py +++ b/test/base/test_env.py @@ -40,7 +40,12 @@ def test_async_env(num=8, sleep=0.1): env_ids = b.info.env_id o.append(b) current_index_start += len(action) + # len of action may be smaller than len(A) in the end action = action_list[current_index_start: current_index_start + len(A)] + # truncate env_ids with the first terms + # typically len(env_ids) == len(A) == len(action), except for the + # last batch when actions are not enough + env_ids = env_ids[: len(action)] spent_time = time.time() - spent_time data = Batch.cat(o) # assure 1/7 improvement