bug fix when using multiple environments

This commit is contained in:
NM512 2023-04-03 08:00:16 +09:00
parent 942eae10a9
commit 8bd69bfcd4

View File

@ -312,12 +312,13 @@ def main(config):
print(f"Prefill dataset ({prefill} steps).") print(f"Prefill dataset ({prefill} steps).")
if hasattr(acts, "discrete"): if hasattr(acts, "discrete"):
random_actor = tools.OneHotDist( random_actor = tools.OneHotDist(
torch.zeros_like(torch.Tensor(acts.low))[None] torch.zeros_like(torch.Tensor(acts.low)).repeat(config.envs, 1)
) )
else: else:
random_actor = torchd.independent.Independent( random_actor = torchd.independent.Independent(
torchd.uniform.Uniform( torchd.uniform.Uniform(
torch.Tensor(acts.low)[None], torch.Tensor(acts.high)[None] torch.Tensor(acts.low).repeat(config.envs, 1),
torch.Tensor(acts.high).repeat(config.envs, 1),
), ),
1, 1,
) )