This commit is contained in:
lucidrains 2025-10-27 10:47:24 -07:00
parent 995b1f64e5
commit 41ab83f691

View File

@ -73,8 +73,10 @@ class MockEnv(Module):
out = (state, reward)
if self.can_terminate:
shape = (self.num_envs,) if self.vectorized else (1,)
terminate = (
(torch.rand((self.num_envs)) < self.rand_terminate_prob) &
(torch.rand(shape) < self.rand_terminate_prob) &
(self._step > self.terminate_after_step)
)