From 41ab83f69115406aea3a3033c3432c4aa1fe126f Mon Sep 17 00:00:00 2001 From: lucidrains Date: Mon, 27 Oct 2025 10:47:24 -0700 Subject: [PATCH] fix mock --- dreamer4/mocks.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/dreamer4/mocks.py b/dreamer4/mocks.py index 2fb1c50..09be2d4 100644 --- a/dreamer4/mocks.py +++ b/dreamer4/mocks.py @@ -73,8 +73,10 @@ class MockEnv(Module): out = (state, reward) if self.can_terminate: + shape = (self.num_envs,) if self.vectorized else (1,) + terminate = ( - (torch.rand((self.num_envs)) < self.rand_terminate_prob) & + (torch.rand(shape) < self.rand_terminate_prob) & (self._step > self.terminate_after_step) )