diff --git a/dreamer4/mocks.py b/dreamer4/mocks.py index bae22ed..55de215 100644 --- a/dreamer4/mocks.py +++ b/dreamer4/mocks.py @@ -14,14 +14,14 @@ class MockEnv(Module): self, image_shape, reward_range = (-100., 100.), - batch_size = 1, + num_envs = 1, vectorized = False ): super().__init__() self.image_shape = image_shape self.reward_range = reward_range - self.batch_size = batch_size + self.num_envs = num_envs self.vectorized = vectorized self.register_buffer('_step', tensor(0)) @@ -42,8 +42,12 @@ class MockEnv(Module): state = self.get_random_state() reward = randint(*self.reward_range, ()).float() - if self.vectorized: - state = repeat(state, '... -> b ...', b = self.batch_size) - reward = repeat(rewardstate, ' -> b', b = self.batch_size) + if not self.vectorized: + return state, reward - return state, reward + assert actions.shape[0] == self.num_envs + + state = repeat(state, '... -> b ...', b = self.num_envs) + reward = repeat(reward, ' -> b', b = self.num_envs) + + return state, rewards