bug fix when using multiple environments

This commit is contained in:
NM512 2023-04-03 08:00:16 +09:00
parent 942eae10a9
commit 8bd69bfcd4

View File

@ -312,12 +312,13 @@ def main(config):
print(f"Prefill dataset ({prefill} steps).")
if hasattr(acts, "discrete"):
random_actor = tools.OneHotDist(
torch.zeros_like(torch.Tensor(acts.low))[None]
torch.zeros_like(torch.Tensor(acts.low)).repeat(config.envs, 1)
)
else:
random_actor = torchd.independent.Independent(
torchd.uniform.Uniform(
torch.Tensor(acts.low)[None], torch.Tensor(acts.high)[None]
torch.Tensor(acts.low).repeat(config.envs, 1),
torch.Tensor(acts.high).repeat(config.envs, 1),
),
1,
)