diff --git a/dreamer.py b/dreamer.py index 9bdfcd0..591994b 100644 --- a/dreamer.py +++ b/dreamer.py @@ -312,12 +312,13 @@ def main(config): print(f"Prefill dataset ({prefill} steps).") if hasattr(acts, "discrete"): random_actor = tools.OneHotDist( - torch.zeros_like(torch.Tensor(acts.low))[None] + torch.zeros_like(torch.Tensor(acts.low)).repeat(config.envs, 1) ) else: random_actor = torchd.independent.Independent( torchd.uniform.Uniform( - torch.Tensor(acts.low)[None], torch.Tensor(acts.high)[None] + torch.Tensor(acts.low).repeat(config.envs, 1), + torch.Tensor(acts.high).repeat(config.envs, 1), ), 1, )