diff --git a/dreamer4/mocks.py b/dreamer4/mocks.py index 2fb1c50..09be2d4 100644 --- a/dreamer4/mocks.py +++ b/dreamer4/mocks.py @@ -73,8 +73,10 @@ class MockEnv(Module): out = (state, reward) if self.can_terminate: + shape = (self.num_envs,) if self.vectorized else (1,) + terminate = ( - (torch.rand((self.num_envs)) < self.rand_terminate_prob) & + (torch.rand(shape) < self.rand_terminate_prob) & (self._step > self.terminate_after_step) )