bug fix when using multiple environments
This commit is contained in:
parent
942eae10a9
commit
8bd69bfcd4
@ -312,12 +312,13 @@ def main(config):
|
|||||||
print(f"Prefill dataset ({prefill} steps).")
|
print(f"Prefill dataset ({prefill} steps).")
|
||||||
if hasattr(acts, "discrete"):
|
if hasattr(acts, "discrete"):
|
||||||
random_actor = tools.OneHotDist(
|
random_actor = tools.OneHotDist(
|
||||||
torch.zeros_like(torch.Tensor(acts.low))[None]
|
torch.zeros_like(torch.Tensor(acts.low)).repeat(config.envs, 1)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
random_actor = torchd.independent.Independent(
|
random_actor = torchd.independent.Independent(
|
||||||
torchd.uniform.Uniform(
|
torchd.uniform.Uniform(
|
||||||
torch.Tensor(acts.low)[None], torch.Tensor(acts.high)[None]
|
torch.Tensor(acts.low).repeat(config.envs, 1),
|
||||||
|
torch.Tensor(acts.high).repeat(config.envs, 1),
|
||||||
),
|
),
|
||||||
1,
|
1,
|
||||||
)
|
)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user