From c25926dd8f5b6179f7f76486ee228982f48b4469 Mon Sep 17 00:00:00 2001 From: ChenDRAG <40993476+ChenDRAG@users.noreply.github.com> Date: Sun, 30 Jan 2022 00:53:56 +0800 Subject: [PATCH] Formalize variable names (#509) Co-authored-by: Jiayi Weng --- docs/tutorials/cheatsheet.rst | 10 +- docs/tutorials/concepts.rst | 20 +- examples/atari/atari_network.py | 38 +- examples/box2d/bipedal_hardcore_sac.py | 8 +- test/base/env.py | 6 +- test/base/test_buffer.py | 14 +- test/base/test_collector.py | 20 +- test/base/test_env.py | 10 +- test/base/test_returns.py | 8 +- test/multiagent/Gomoku.py | 6 +- tianshou/data/batch.py | 405 +++++++++++----------- tianshou/data/buffer/base.py | 14 +- tianshou/data/buffer/manager.py | 6 +- tianshou/data/utils/converter.py | 8 +- tianshou/env/maenv.py | 4 +- tianshou/env/venvs.py | 4 +- tianshou/policy/base.py | 10 +- tianshou/policy/imitation/base.py | 20 +- tianshou/policy/imitation/bcq.py | 25 +- tianshou/policy/imitation/cql.py | 9 +- tianshou/policy/imitation/discrete_bcq.py | 7 +- tianshou/policy/imitation/discrete_cql.py | 8 +- tianshou/policy/imitation/discrete_crr.py | 8 +- tianshou/policy/modelbased/psrl.py | 10 +- tianshou/policy/modelfree/a2c.py | 19 +- tianshou/policy/modelfree/c51.py | 10 +- tianshou/policy/modelfree/ddpg.py | 15 +- tianshou/policy/modelfree/discrete_sac.py | 4 +- tianshou/policy/modelfree/dqn.py | 14 +- tianshou/policy/modelfree/fqf.py | 31 +- tianshou/policy/modelfree/iqn.py | 19 +- tianshou/policy/modelfree/npg.py | 28 +- tianshou/policy/modelfree/pg.py | 14 +- tianshou/policy/modelfree/ppo.py | 37 +- tianshou/policy/modelfree/qrdqn.py | 18 +- tianshou/policy/modelfree/sac.py | 28 +- tianshou/policy/modelfree/td3.py | 19 +- tianshou/policy/modelfree/trpo.py | 24 +- tianshou/utils/net/common.py | 61 ++-- tianshou/utils/net/continuous.py | 117 ++++--- tianshou/utils/net/discrete.py | 30 +- tianshou/utils/statistics.py | 22 +- 42 files changed, 607 insertions(+), 581 deletions(-) diff --git a/docs/tutorials/cheatsheet.rst b/docs/tutorials/cheatsheet.rst index bb04ab5..a1feb29 100644 --- a/docs/tutorials/cheatsheet.rst +++ b/docs/tutorials/cheatsheet.rst @@ -350,7 +350,7 @@ But the state stored in the buffer may be a shallow-copy. To make sure each of y def reset(): return copy.deepcopy(self.graph) - def step(a): + def step(action): ... return copy.deepcopy(self.graph), reward, done, {} @@ -391,13 +391,13 @@ In addition, legal actions in multi-agent RL often vary with timestep (just like The above description gives rise to the following formulation of multi-agent RL: :: - action = policy(state, agent_id, mask) - (next_state, next_agent_id, next_mask), reward = env.step(action) + act = policy(state, agent_id, mask) + (next_state, next_agent_id, next_mask), reward = env.step(act) By constructing a new state ``state_ = (state, agent_id, mask)``, essentially we can return to the typical formulation of RL: :: - action = policy(state_) - next_state_, reward = env.step(action) + act = policy(state_) + next_state_, reward = env.step(act) Following this idea, we write a tiny example of playing `Tic Tac Toe `_ against a random player by using a Q-learning algorithm. The tutorial is at :doc:`/tutorials/tictactoe`. diff --git a/docs/tutorials/concepts.rst b/docs/tutorials/concepts.rst index b1f76de..d500787 100644 --- a/docs/tutorials/concepts.rst +++ b/docs/tutorials/concepts.rst @@ -298,14 +298,14 @@ where :math:`\gamma` is the discount factor, :math:`\gamma \in [0, 1]`. Here is :: # pseudocode, cannot work - s = env.reset() + obs = env.reset() buffer = Buffer(size=10000) agent = DQN() for i in range(int(1e6)): - a = agent.compute_action(s) - s_, r, d, _ = env.step(a) - buffer.store(s, a, s_, r, d) - s = s_ + act = agent.compute_action(obs) + obs_next, rew, done, _ = env.step(act) + buffer.store(obs, act, obs_next, rew, done) + obs = obs_next if i % 1000 == 0: b_s, b_a, b_s_, b_r, b_d = buffer.get(size=64) # compute 2-step returns. How? @@ -390,14 +390,14 @@ We give a high-level explanation through the pseudocode used in section :ref:`pr :: # pseudocode, cannot work # methods in tianshou - s = env.reset() + obs = env.reset() buffer = Buffer(size=10000) # buffer = tianshou.data.ReplayBuffer(size=10000) agent = DQN() # policy.__init__(...) for i in range(int(1e6)): # done in trainer - a = agent.compute_action(s) # act = policy(batch, ...).act - s_, r, d, _ = env.step(a) # collector.collect(...) - buffer.store(s, a, s_, r, d) # collector.collect(...) - s = s_ # collector.collect(...) + act = agent.compute_action(obs) # act = policy(batch, ...).act + obs_next, rew, done, _ = env.step(act) # collector.collect(...) + buffer.store(obs, act, obs_next, rew, done) # collector.collect(...) + obs = obs_next # collector.collect(...) if i % 1000 == 0: # done in trainer # the following is done in policy.update(batch_size, buffer) b_s, b_a, b_s_, b_r, b_d = buffer.get(size=64) # batch, indices = buffer.sample(batch_size) diff --git a/examples/atari/atari_network.py b/examples/atari/atari_network.py index 4598fce..e12dee0 100644 --- a/examples/atari/atari_network.py +++ b/examples/atari/atari_network.py @@ -42,13 +42,13 @@ class DQN(nn.Module): def forward( self, - x: Union[np.ndarray, torch.Tensor], + obs: Union[np.ndarray, torch.Tensor], state: Optional[Any] = None, info: Dict[str, Any] = {}, ) -> Tuple[torch.Tensor, Any]: - r"""Mapping: x -> Q(x, \*).""" - x = torch.as_tensor(x, device=self.device, dtype=torch.float32) - return self.net(x), state + r"""Mapping: s -> Q(s, \*).""" + obs = torch.as_tensor(obs, device=self.device, dtype=torch.float32) + return self.net(obs), state class C51(DQN): @@ -73,15 +73,15 @@ class C51(DQN): def forward( self, - x: Union[np.ndarray, torch.Tensor], + obs: Union[np.ndarray, torch.Tensor], state: Optional[Any] = None, info: Dict[str, Any] = {}, ) -> Tuple[torch.Tensor, Any]: r"""Mapping: x -> Z(x, \*).""" - x, state = super().forward(x) - x = x.view(-1, self.num_atoms).softmax(dim=-1) - x = x.view(-1, self.action_num, self.num_atoms) - return x, state + obs, state = super().forward(obs) + obs = obs.view(-1, self.num_atoms).softmax(dim=-1) + obs = obs.view(-1, self.action_num, self.num_atoms) + return obs, state class Rainbow(DQN): @@ -127,22 +127,22 @@ class Rainbow(DQN): def forward( self, - x: Union[np.ndarray, torch.Tensor], + obs: Union[np.ndarray, torch.Tensor], state: Optional[Any] = None, info: Dict[str, Any] = {}, ) -> Tuple[torch.Tensor, Any]: r"""Mapping: x -> Z(x, \*).""" - x, state = super().forward(x) - q = self.Q(x) + obs, state = super().forward(obs) + q = self.Q(obs) q = q.view(-1, self.action_num, self.num_atoms) if self._is_dueling: - v = self.V(x) + v = self.V(obs) v = v.view(-1, 1, self.num_atoms) logits = q - q.mean(dim=1, keepdim=True) + v else: logits = q - y = logits.softmax(dim=2) - return y, state + probs = logits.softmax(dim=2) + return probs, state class QRDQN(DQN): @@ -168,11 +168,11 @@ class QRDQN(DQN): def forward( self, - x: Union[np.ndarray, torch.Tensor], + obs: Union[np.ndarray, torch.Tensor], state: Optional[Any] = None, info: Dict[str, Any] = {}, ) -> Tuple[torch.Tensor, Any]: r"""Mapping: x -> Z(x, \*).""" - x, state = super().forward(x) - x = x.view(-1, self.action_num, self.num_quantiles) - return x, state + obs, state = super().forward(obs) + obs = obs.view(-1, self.action_num, self.num_quantiles) + return obs, state diff --git a/examples/box2d/bipedal_hardcore_sac.py b/examples/box2d/bipedal_hardcore_sac.py index 598622d..1d2d7f1 100644 --- a/examples/box2d/bipedal_hardcore_sac.py +++ b/examples/box2d/bipedal_hardcore_sac.py @@ -56,16 +56,16 @@ class Wrapper(gym.Wrapper): self.rm_done = rm_done def step(self, action): - r = 0.0 + rew_sum = 0.0 for _ in range(self.action_repeat): - obs, reward, done, info = self.env.step(action) + obs, rew, done, info = self.env.step(action) # remove done reward penalty if not done or not self.rm_done: - r = r + reward + rew_sum = rew_sum + rew if done: break # scale reward - return obs, self.reward_scale * r, done, info + return obs, self.reward_scale * rew_sum, done, info def test_sac_bipedal(args=get_args()): diff --git a/test/base/env.py b/test/base/env.py index cdcb51e..3ae031c 100644 --- a/test/base/env.py +++ b/test/base/env.py @@ -86,10 +86,10 @@ class MyTestEnv(gym.Env): def _get_reward(self): """Generate a non-scalar reward if ma_rew is True.""" - x = int(self.done) + end_flag = int(self.done) if self.ma_rew > 0: - return [x] * self.ma_rew - return x + return [end_flag] * self.ma_rew + return end_flag def _get_state(self): """Generate state(observation) of MyTestEnv""" diff --git a/test/base/test_buffer.py b/test/base/test_buffer.py index c1568c7..83ef975 100644 --- a/test/base/test_buffer.py +++ b/test/base/test_buffer.py @@ -32,10 +32,12 @@ def test_replaybuffer(size=10, bufsize=20): assert str(buf) == buf.__class__.__name__ + '()' obs = env.reset() action_list = [1] * 5 + [0] * 10 + [1] * 10 - for i, a in enumerate(action_list): - obs_next, rew, done, info = env.step(a) + for i, act in enumerate(action_list): + obs_next, rew, done, info = env.step(act) buf.add( - Batch(obs=obs, act=[a], rew=rew, done=done, obs_next=obs_next, info=info) + Batch( + obs=obs, act=[act], rew=rew, done=done, obs_next=obs_next, info=info + ) ) obs = obs_next assert len(buf) == min(bufsize, i + 1) @@ -220,11 +222,11 @@ def test_priortized_replaybuffer(size=32, bufsize=15): buf2 = PrioritizedVectorReplayBuffer(bufsize, buffer_num=3, alpha=0.5, beta=0.5) obs = env.reset() action_list = [1] * 5 + [0] * 10 + [1] * 10 - for i, a in enumerate(action_list): - obs_next, rew, done, info = env.step(a) + for i, act in enumerate(action_list): + obs_next, rew, done, info = env.step(act) batch = Batch( obs=obs, - act=a, + act=act, rew=rew, done=done, obs_next=obs_next, diff --git a/test/base/test_collector.py b/test/base/test_collector.py index 61bd5a6..9a8d749 100644 --- a/test/base/test_collector.py +++ b/test/base/test_collector.py @@ -331,20 +331,20 @@ def test_collector_with_ma(): policy = MyPolicy() c0 = Collector(policy, env, ReplayBuffer(size=100), Logger.single_preprocess_fn) # n_step=3 will collect a full episode - r = c0.collect(n_step=3)['rews'] - assert len(r) == 0 - r = c0.collect(n_episode=2)['rews'] - assert r.shape == (2, 4) and np.all(r == 1) + rew = c0.collect(n_step=3)['rews'] + assert len(rew) == 0 + rew = c0.collect(n_episode=2)['rews'] + assert rew.shape == (2, 4) and np.all(rew == 1) env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0, ma_rew=4) for i in [2, 3, 4, 5]] envs = DummyVectorEnv(env_fns) c1 = Collector( policy, envs, VectorReplayBuffer(total_size=100, buffer_num=4), Logger.single_preprocess_fn ) - r = c1.collect(n_step=12)['rews'] - assert r.shape == (2, 4) and np.all(r == 1), r - r = c1.collect(n_episode=8)['rews'] - assert r.shape == (8, 4) and np.all(r == 1) + rew = c1.collect(n_step=12)['rews'] + assert rew.shape == (2, 4) and np.all(rew == 1), rew + rew = c1.collect(n_episode=8)['rews'] + assert rew.shape == (8, 4) and np.all(rew == 1) batch, _ = c1.buffer.sample(10) print(batch) c0.buffer.update(c1.buffer) @@ -446,8 +446,8 @@ def test_collector_with_ma(): policy, envs, VectorReplayBuffer(total_size=100, buffer_num=4, stack_num=4), Logger.single_preprocess_fn ) - r = c2.collect(n_episode=10)['rews'] - assert r.shape == (10, 4) and np.all(r == 1) + rew = c2.collect(n_episode=10)['rews'] + assert rew.shape == (10, 4) and np.all(rew == 1) batch, _ = c2.buffer.sample(10) diff --git a/test/base/test_env.py b/test/base/test_env.py index dbd651d..f0471e7 100644 --- a/test/base/test_env.py +++ b/test/base/test_env.py @@ -58,22 +58,22 @@ def test_async_env(size=10000, num=8, sleep=0.1): # should be smaller action_list = [1] * num + [0] * (num * 2) + [1] * (num * 4) current_idx_start = 0 - action = action_list[:num] + act = action_list[:num] env_ids = list(range(num)) o = [] spent_time = time.time() while current_idx_start < len(action_list): - A, B, C, D = v.step(action=action, id=env_ids) + A, B, C, D = v.step(action=act, id=env_ids) b = Batch({'obs': A, 'rew': B, 'done': C, 'info': D}) env_ids = b.info.env_id o.append(b) - current_idx_start += len(action) + current_idx_start += len(act) # len of action may be smaller than len(A) in the end - action = action_list[current_idx_start:current_idx_start + len(A)] + act = action_list[current_idx_start:current_idx_start + len(A)] # truncate env_ids with the first terms # typically len(env_ids) == len(A) == len(action), except for the # last batch when actions are not enough - env_ids = env_ids[:len(action)] + env_ids = env_ids[:len(act)] spent_time = time.time() - spent_time Batch.cat(o) v.close() diff --git a/test/base/test_returns.py b/test/base/test_returns.py index 3adcdaf..797c03c 100644 --- a/test/base/test_returns.py +++ b/test/base/test_returns.py @@ -142,11 +142,11 @@ def compute_nstep_return_base(nstep, gamma, buffer, indices): returns = np.zeros_like(indices, dtype=float) buf_len = len(buffer) for i in range(len(indices)): - flag, r = False, 0. + flag, rew = False, 0. real_step_n = nstep for n in range(nstep): idx = (indices[i] + n) % buf_len - r += buffer.rew[idx] * gamma**n + rew += buffer.rew[idx] * gamma**n if buffer.done[idx]: if not ( hasattr(buffer, 'info') and buffer.info['TimeLimit.truncated'][idx] @@ -156,8 +156,8 @@ def compute_nstep_return_base(nstep, gamma, buffer, indices): break if not flag: idx = (indices[i] + real_step_n - 1) % buf_len - r += to_numpy(target_q_fn(buffer, idx)) * gamma**real_step_n - returns[i] = r + rew += to_numpy(target_q_fn(buffer, idx)) * gamma**real_step_n + returns[i] = rew return returns diff --git a/test/multiagent/Gomoku.py b/test/multiagent/Gomoku.py index 7d17542..91d3cbc 100644 --- a/test/multiagent/Gomoku.py +++ b/test/multiagent/Gomoku.py @@ -41,7 +41,7 @@ def gomoku(args=get_args()): return TicTacToeEnv(args.board_size, args.win_size) test_envs = DummyVectorEnv([env_func for _ in range(args.test_num)]) - for r in range(args.self_play_round): + for round in range(args.self_play_round): rews = [] agent_learn.set_eps(0.0) # compute the reward over previous learner @@ -66,12 +66,12 @@ def gomoku(args=get_args()): # previous learner can only be used for forward agent.forward = opponent.forward args.model_save_path = os.path.join( - args.logdir, 'Gomoku', 'dqn', f'policy_round_{r}_epoch_{epoch}.pth' + args.logdir, 'Gomoku', 'dqn', f'policy_round_{round}_epoch_{epoch}.pth' ) result, agent_learn = train_agent( args, agent_learn=agent_learn, agent_opponent=agent, optim=optim ) - print(f'round_{r}_epoch_{epoch}') + print(f'round_{round}_epoch_{epoch}') pprint.pprint(result) learnt_agent = deepcopy(agent_learn) learnt_agent.set_eps(0.0) diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index 98adf68..7fddda2 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -11,17 +11,18 @@ import torch IndexType = Union[slice, int, np.ndarray, List[int]] -def _is_batch_set(data: Any) -> bool: +def _is_batch_set(obj: Any) -> bool: # Batch set is a list/tuple of dict/Batch objects, # or 1-D np.ndarray with object type, # where each element is a dict/Batch object - if isinstance(data, np.ndarray): # most often case - # "for e in data" will just unpack the first dimension, - # but data.tolist() will flatten ndarray of objects - # so do not use data.tolist() - return data.dtype == object and all(isinstance(e, (dict, Batch)) for e in data) - elif isinstance(data, (list, tuple)): - if len(data) > 0 and all(isinstance(e, (dict, Batch)) for e in data): + if isinstance(obj, np.ndarray): # most often case + # "for element in obj" will just unpack the first dimension, + # but obj.tolist() will flatten ndarray of objects + # so do not use obj.tolist() + return obj.dtype == object and \ + all(isinstance(element, (dict, Batch)) for element in obj) + elif isinstance(obj, (list, tuple)): + if len(obj) > 0 and all(isinstance(element, (dict, Batch)) for element in obj): return True return False @@ -48,28 +49,29 @@ def _is_number(value: Any) -> bool: return isinstance(value, (Number, np.number, np.bool_)) -def _to_array_with_correct_type(v: Any) -> np.ndarray: - if isinstance(v, np.ndarray) and issubclass(v.dtype.type, (np.bool_, np.number)): - return v # most often case +def _to_array_with_correct_type(obj: Any) -> np.ndarray: + if isinstance(obj, np.ndarray) and \ + issubclass(obj.dtype.type, (np.bool_, np.number)): + return obj # most often case # convert the value to np.ndarray - # convert to object data type if neither bool nor number + # convert to object obj type if neither bool nor number # raises an exception if array's elements are tensors themselves - v = np.asanyarray(v) - if not issubclass(v.dtype.type, (np.bool_, np.number)): - v = v.astype(object) - if v.dtype == object: - # scalar ndarray with object data type is very annoying + obj_array = np.asanyarray(obj) + if not issubclass(obj_array.dtype.type, (np.bool_, np.number)): + obj_array = obj_array.astype(object) + if obj_array.dtype == object: + # scalar ndarray with object obj type is very annoying # a=np.array([np.array({}, dtype=object), np.array({}, dtype=object)]) # a is not array([{}, {}], dtype=object), and a[0]={} results in # something very strange: # array([{}, array({}, dtype=object)], dtype=object) - if not v.shape: - v = v.item(0) - elif all(isinstance(e, np.ndarray) for e in v.reshape(-1)): - return v # various length, np.array([[1], [2, 3], [4, 5, 6]]) - elif any(isinstance(e, torch.Tensor) for e in v.reshape(-1)): + if not obj_array.shape: + obj_array = obj_array.item(0) + elif all(isinstance(arr, np.ndarray) for arr in obj_array.reshape(-1)): + return obj_array # various length, np.array([[1], [2, 3], [4, 5, 6]]) + elif any(isinstance(arr, torch.Tensor) for arr in obj_array.reshape(-1)): raise ValueError("Numpy arrays of tensors are not supported yet.") - return v + return obj_array def _create_value( @@ -113,44 +115,45 @@ def _create_value( def _assert_type_keys(keys: Iterable[str]) -> None: - assert all(isinstance(e, str) for e in keys), \ + assert all(isinstance(key, str) for key in keys), \ f"keys should all be string, but got {keys}" -def _parse_value(v: Any) -> Optional[Union["Batch", np.ndarray, torch.Tensor]]: - if isinstance(v, Batch): # most often case - return v - elif (isinstance(v, np.ndarray) and - issubclass(v.dtype.type, (np.bool_, np.number))) or \ - isinstance(v, torch.Tensor) or v is None: # third often case - return v - elif _is_number(v): # second often case, but it is more time-consuming - return np.asanyarray(v) - elif isinstance(v, dict): - return Batch(v) +def _parse_value(obj: Any) -> Optional[Union["Batch", np.ndarray, torch.Tensor]]: + if isinstance(obj, Batch): # most often case + return obj + elif (isinstance(obj, np.ndarray) and + issubclass(obj.dtype.type, (np.bool_, np.number))) or \ + isinstance(obj, torch.Tensor) or obj is None: # third often case + return obj + elif _is_number(obj): # second often case, but it is more time-consuming + return np.asanyarray(obj) + elif isinstance(obj, dict): + return Batch(obj) else: - if not isinstance(v, np.ndarray) and isinstance(v, Collection) and \ - len(v) > 0 and all(isinstance(e, torch.Tensor) for e in v): + if not isinstance(obj, np.ndarray) and \ + isinstance(obj, Collection) and len(obj) > 0 and \ + all(isinstance(element, torch.Tensor) for element in obj): try: - return torch.stack(v) # type: ignore - except RuntimeError as e: + return torch.stack(obj) # type: ignore + except RuntimeError as exception: raise TypeError( "Batch does not support non-stackable iterable" " of torch.Tensor as unique value yet." - ) from e - if _is_batch_set(v): - v = Batch(v) # list of dict / Batch + ) from exception + if _is_batch_set(obj): + obj = Batch(obj) # list of dict / Batch else: - # None, scalar, normal data list (main case) + # None, scalar, normal obj list (main case) # or an actual list of objects try: - v = _to_array_with_correct_type(v) - except ValueError as e: + obj = _to_array_with_correct_type(obj) + except ValueError as exception: raise TypeError( "Batch does not support heterogeneous list/" "tuple of tensors as unique value yet." - ) from e - return v + ) from exception + return obj def _alloc_by_keys_diff( @@ -189,8 +192,8 @@ class Batch: if batch_dict is not None: if isinstance(batch_dict, (dict, Batch)): _assert_type_keys(batch_dict.keys()) - for k, v in batch_dict.items(): - self.__dict__[k] = _parse_value(v) + for batch_key, obj in batch_dict.items(): + self.__dict__[batch_key] = _parse_value(obj) elif _is_batch_set(batch_dict): self.stack_(batch_dict) # type: ignore if len(kwargs) > 0: @@ -214,10 +217,10 @@ class Batch: Only the actual data are serialized for both efficiency and simplicity. """ state = {} - for k, v in self.items(): - if isinstance(v, Batch): - v = v.__getstate__() - state[k] = v + for batch_key, obj in self.items(): + if isinstance(obj, Batch): + obj = obj.__getstate__() + state[batch_key] = obj return state def __setstate__(self, state: Dict[str, Any]) -> None: @@ -234,13 +237,13 @@ class Batch: return self.__dict__[index] batch_items = self.items() if len(batch_items) > 0: - b = Batch() - for k, v in batch_items: - if isinstance(v, Batch) and v.is_empty(): - b.__dict__[k] = Batch() + new_batch = Batch() + for batch_key, obj in batch_items: + if isinstance(obj, Batch) and obj.is_empty(): + new_batch.__dict__[batch_key] = Batch() else: - b.__dict__[k] = v[index] - return b + new_batch.__dict__[batch_key] = obj[index] + return new_batch else: raise IndexError("Cannot access item from empty Batch object.") @@ -273,20 +276,20 @@ class Batch: def __iadd__(self, other: Union["Batch", Number, np.number]) -> "Batch": """Algebraic addition with another Batch instance in-place.""" if isinstance(other, Batch): - for (k, r), v in zip( + for (batch_key, obj), value in zip( self.__dict__.items(), other.__dict__.values() ): # TODO are keys consistent? - if isinstance(r, Batch) and r.is_empty(): + if isinstance(obj, Batch) and obj.is_empty(): continue else: - self.__dict__[k] += v + self.__dict__[batch_key] += value return self elif _is_number(other): - for k, r in self.items(): - if isinstance(r, Batch) and r.is_empty(): + for batch_key, obj in self.items(): + if isinstance(obj, Batch) and obj.is_empty(): continue else: - self.__dict__[k] += other + self.__dict__[batch_key] += other return self else: raise TypeError("Only addition of Batch or number is supported.") @@ -295,54 +298,54 @@ class Batch: """Algebraic addition with another Batch instance out-of-place.""" return deepcopy(self).__iadd__(other) - def __imul__(self, val: Union[Number, np.number]) -> "Batch": + def __imul__(self, value: Union[Number, np.number]) -> "Batch": """Algebraic multiplication with a scalar value in-place.""" - assert _is_number(val), "Only multiplication by a number is supported." - for k, r in self.__dict__.items(): - if isinstance(r, Batch) and r.is_empty(): + assert _is_number(value), "Only multiplication by a number is supported." + for batch_key, obj in self.__dict__.items(): + if isinstance(obj, Batch) and obj.is_empty(): continue - self.__dict__[k] *= val + self.__dict__[batch_key] *= value return self - def __mul__(self, val: Union[Number, np.number]) -> "Batch": + def __mul__(self, value: Union[Number, np.number]) -> "Batch": """Algebraic multiplication with a scalar value out-of-place.""" - return deepcopy(self).__imul__(val) + return deepcopy(self).__imul__(value) - def __itruediv__(self, val: Union[Number, np.number]) -> "Batch": + def __itruediv__(self, value: Union[Number, np.number]) -> "Batch": """Algebraic division with a scalar value in-place.""" - assert _is_number(val), "Only division by a number is supported." - for k, r in self.__dict__.items(): - if isinstance(r, Batch) and r.is_empty(): + assert _is_number(value), "Only division by a number is supported." + for batch_key, obj in self.__dict__.items(): + if isinstance(obj, Batch) and obj.is_empty(): continue - self.__dict__[k] /= val + self.__dict__[batch_key] /= value return self - def __truediv__(self, val: Union[Number, np.number]) -> "Batch": + def __truediv__(self, value: Union[Number, np.number]) -> "Batch": """Algebraic division with a scalar value out-of-place.""" - return deepcopy(self).__itruediv__(val) + return deepcopy(self).__itruediv__(value) def __repr__(self) -> str: """Return str(self).""" - s = self.__class__.__name__ + "(\n" + self_str = self.__class__.__name__ + "(\n" flag = False - for k, v in self.__dict__.items(): - rpl = "\n" + " " * (6 + len(k)) - obj = pprint.pformat(v).replace("\n", rpl) - s += f" {k}: {obj},\n" + for batch_key, obj in self.__dict__.items(): + rpl = "\n" + " " * (6 + len(batch_key)) + obj_name = pprint.pformat(obj).replace("\n", rpl) + self_str += f" {batch_key}: {obj_name},\n" flag = True if flag: - s += ")" + self_str += ")" else: - s = self.__class__.__name__ + "()" - return s + self_str = self.__class__.__name__ + "()" + return self_str def to_numpy(self) -> None: """Change all torch.Tensor to numpy.ndarray in-place.""" - for k, v in self.items(): - if isinstance(v, torch.Tensor): - self.__dict__[k] = v.detach().cpu().numpy() - elif isinstance(v, Batch): - v.to_numpy() + for batch_key, obj in self.items(): + if isinstance(obj, torch.Tensor): + self.__dict__[batch_key] = obj.detach().cpu().numpy() + elif isinstance(obj, Batch): + obj.to_numpy() def to_torch( self, @@ -353,24 +356,24 @@ class Batch: if not isinstance(device, torch.device): device = torch.device(device) - for k, v in self.items(): - if isinstance(v, torch.Tensor): - if dtype is not None and v.dtype != dtype or \ - v.device.type != device.type or \ - device.index != v.device.index: + for batch_key, obj in self.items(): + if isinstance(obj, torch.Tensor): + if dtype is not None and obj.dtype != dtype or \ + obj.device.type != device.type or \ + device.index != obj.device.index: if dtype is not None: - v = v.type(dtype) - self.__dict__[k] = v.to(device) - elif isinstance(v, Batch): - v.to_torch(dtype, device) + obj = obj.type(dtype) + self.__dict__[batch_key] = obj.to(device) + elif isinstance(obj, Batch): + obj.to_torch(dtype, device) else: # ndarray or scalar - if not isinstance(v, np.ndarray): - v = np.asanyarray(v) - v = torch.from_numpy(v).to(device) + if not isinstance(obj, np.ndarray): + obj = np.asanyarray(obj) + obj = torch.from_numpy(obj).to(device) if dtype is not None: - v = v.type(dtype) - self.__dict__[k] = v + obj = obj.type(dtype) + self.__dict__[batch_key] = obj def __cat(self, batches: Sequence[Union[dict, "Batch"]], lens: List[int]) -> None: """Private method for Batch.cat_. @@ -395,50 +398,51 @@ class Batch: # partial keys will be padded by zeros # with the shape of [len, rest_shape] sum_lens = [0] - for x in lens: - sum_lens.append(sum_lens[-1] + x) + for len_ in lens: + sum_lens.append(sum_lens[-1] + len_) # collect non-empty keys keys_map = [ set( - k for k, v in batch.items() - if not (isinstance(v, Batch) and v.is_empty()) + batch_key for batch_key, obj in batch.items() + if not (isinstance(obj, Batch) and obj.is_empty()) ) for batch in batches ] keys_shared = set.intersection(*keys_map) - values_shared = [[e[k] for e in batches] for k in keys_shared] - for k, v in zip(keys_shared, values_shared): - if all(isinstance(e, (dict, Batch)) for e in v): + values_shared = [[batch[key] for batch in batches] for key in keys_shared] + for key, shared_value in zip(keys_shared, values_shared): + if all(isinstance(element, (dict, Batch)) for element in shared_value): batch_holder = Batch() - batch_holder.__cat(v, lens=lens) - self.__dict__[k] = batch_holder - elif all(isinstance(e, torch.Tensor) for e in v): - self.__dict__[k] = torch.cat(v) + batch_holder.__cat(shared_value, lens=lens) + self.__dict__[key] = batch_holder + elif all(isinstance(element, torch.Tensor) for element in shared_value): + self.__dict__[key] = torch.cat(shared_value) else: # cat Batch(a=np.zeros((3, 4))) and Batch(a=Batch(b=Batch())) # will fail here - v = np.concatenate(v) - self.__dict__[k] = _to_array_with_correct_type(v) - keys_total = set.union(*[set(b.keys()) for b in batches]) + shared_value = np.concatenate(shared_value) + self.__dict__[key] = _to_array_with_correct_type(shared_value) + keys_total = set.union(*[set(batch.keys()) for batch in batches]) keys_reserve_or_partial = set.difference(keys_total, keys_shared) # keys that are reserved in all batches keys_reserve = set.difference(keys_total, set.union(*keys_map)) # keys that occur only in some batches, but not all keys_partial = keys_reserve_or_partial.difference(keys_reserve) - for k in keys_reserve: + for key in keys_reserve: # reserved keys - self.__dict__[k] = Batch() - for k in keys_partial: - for i, e in enumerate(batches): - if k not in e.__dict__: + self.__dict__[key] = Batch() + for key in keys_partial: + for i, batch in enumerate(batches): + if key not in batch.__dict__: continue - val = e.get(k) - if isinstance(val, Batch) and val.is_empty(): + value = batch.get(key) + if isinstance(value, Batch) and value.is_empty(): continue try: - self.__dict__[k][sum_lens[i]:sum_lens[i + 1]] = val + self.__dict__[key][sum_lens[i]:sum_lens[i + 1]] = value except KeyError: - self.__dict__[k] = _create_value(val, sum_lens[-1], stack=False) - self.__dict__[k][sum_lens[i]:sum_lens[i + 1]] = val + self.__dict__[key] = \ + _create_value(value, sum_lens[-1], stack=False) + self.__dict__[key][sum_lens[i]:sum_lens[i + 1]] = value def cat_(self, batches: Union["Batch", Sequence[Union[dict, "Batch"]]]) -> None: """Concatenate a list of (or one) Batch objects into current batch.""" @@ -446,16 +450,16 @@ class Batch: batches = [batches] # check input format batch_list = [] - for b in batches: - if isinstance(b, dict): - if len(b) > 0: - batch_list.append(Batch(b)) - elif isinstance(b, Batch): + for batch in batches: + if isinstance(batch, dict): + if len(batch) > 0: + batch_list.append(Batch(batch)) + elif isinstance(batch, Batch): # x.is_empty() means that x is Batch() and should be ignored - if not b.is_empty(): - batch_list.append(b) + if not batch.is_empty(): + batch_list.append(batch) else: - raise ValueError(f"Cannot concatenate {type(b)} in Batch.cat_") + raise ValueError(f"Cannot concatenate {type(batch)} in Batch.cat_") if len(batch_list) == 0: return batches = batch_list @@ -463,13 +467,15 @@ class Batch: # x.is_empty(recurse=True) here means x is a nested empty batch # like Batch(a=Batch), and we have to treat it as length zero and # keep it. - lens = [0 if x.is_empty(recurse=True) else len(x) for x in batches] - except TypeError as e: + lens = [ + 0 if batch.is_empty(recurse=True) else len(batch) for batch in batches + ] + except TypeError as exception: raise ValueError( "Batch.cat_ meets an exception. Maybe because there is any " f"scalar in {batches} but Batch.cat_ does not support the " "concatenation of scalar." - ) from e + ) from exception if not self.is_empty(): batches = [self] + list(batches) lens = [0 if self.is_empty(recurse=True) else len(self)] + lens @@ -501,16 +507,16 @@ class Batch: """Stack a list of Batch object into current batch.""" # check input format batch_list = [] - for b in batches: - if isinstance(b, dict): - if len(b) > 0: - batch_list.append(Batch(b)) - elif isinstance(b, Batch): + for batch in batches: + if isinstance(batch, dict): + if len(batch) > 0: + batch_list.append(Batch(batch)) + elif isinstance(batch, Batch): # x.is_empty() means that x is Batch() and should be ignored - if not b.is_empty(): - batch_list.append(b) + if not batch.is_empty(): + batch_list.append(batch) else: - raise ValueError(f"Cannot concatenate {type(b)} in Batch.stack_") + raise ValueError(f"Cannot concatenate {type(batch)} in Batch.stack_") if len(batch_list) == 0: return batches = batch_list @@ -519,28 +525,31 @@ class Batch: # collect non-empty keys keys_map = [ set( - k for k, v in batch.items() - if not (isinstance(v, Batch) and v.is_empty()) + batch_key for batch_key, obj in batch.items() + if not (isinstance(obj, Batch) and obj.is_empty()) ) for batch in batches ] keys_shared = set.intersection(*keys_map) - values_shared = [[e[k] for e in batches] for k in keys_shared] - for k, v in zip(keys_shared, values_shared): - if all(isinstance(e, torch.Tensor) for e in v): # second often - self.__dict__[k] = torch.stack(v, axis) - elif all(isinstance(e, (Batch, dict)) for e in v): # third often - self.__dict__[k] = Batch.stack(v, axis) + values_shared = [[batch[key] for batch in batches] for key in keys_shared] + for shared_key, value in zip(keys_shared, values_shared): + # second often + if all(isinstance(element, torch.Tensor) for element in value): + self.__dict__[shared_key] = torch.stack(value, axis) + # third often + elif all(isinstance(element, (Batch, dict)) for element in value): + self.__dict__[shared_key] = Batch.stack(value, axis) else: # most often case is np.ndarray try: - self.__dict__[k] = _to_array_with_correct_type(np.stack(v, axis)) + self.__dict__[shared_key] = \ + _to_array_with_correct_type(np.stack(value, axis)) except ValueError: warnings.warn( "You are using tensors with different shape," " fallback to dtype=object by default." ) - self.__dict__[k] = np.array(v, dtype=object) + self.__dict__[shared_key] = np.array(value, dtype=object) # all the keys - keys_total = set.union(*[set(b.keys()) for b in batches]) + keys_total = set.union(*[set(batch.keys()) for batch in batches]) # keys that are reserved in all batches keys_reserve = set.difference(keys_total, set.union(*keys_map)) # keys that are either partial or reserved @@ -552,21 +561,21 @@ class Batch: f"Stack of Batch with non-shared keys {keys_partial} is only " f"supported with axis=0, but got axis={axis}!" ) - for k in keys_reserve: + for key in keys_reserve: # reserved keys - self.__dict__[k] = Batch() - for k in keys_partial: - for i, e in enumerate(batches): - if k not in e.__dict__: - continue - val = e.get(k) - if isinstance(val, Batch) and val.is_empty(): + self.__dict__[key] = Batch() + for key in keys_partial: + for i, batch in enumerate(batches): + if key not in batch.__dict__: continue + value = batch.get(key) + if isinstance(value, Batch) and value.is_empty(): # type: ignore + continue # type: ignore try: - self.__dict__[k][i] = val + self.__dict__[key][i] = value except KeyError: - self.__dict__[k] = _create_value(val, len(batches)) - self.__dict__[k][i] = val + self.__dict__[key] = _create_value(value, len(batches)) + self.__dict__[key][i] = value @staticmethod def stack(batches: Sequence[Union[dict, "Batch"]], axis: int = 0) -> "Batch": @@ -620,27 +629,27 @@ class Batch: ), ) """ - for k, v in self.items(): - if isinstance(v, torch.Tensor): # most often case - self.__dict__[k][index] = 0 - elif v is None: + for batch_key, obj in self.items(): + if isinstance(obj, torch.Tensor): # most often case + self.__dict__[batch_key][index] = 0 + elif obj is None: continue - elif isinstance(v, np.ndarray): - if v.dtype == object: - self.__dict__[k][index] = None + elif isinstance(obj, np.ndarray): + if obj.dtype == object: + self.__dict__[batch_key][index] = None else: - self.__dict__[k][index] = 0 - elif isinstance(v, Batch): - self.__dict__[k].empty_(index=index) + self.__dict__[batch_key][index] = 0 + elif isinstance(obj, Batch): + self.__dict__[batch_key].empty_(index=index) else: # scalar value warnings.warn( "You are calling Batch.empty on a NumPy scalar, " "which may cause undefined behaviors." ) - if _is_number(v): - self.__dict__[k] = v.__class__(0) + if _is_number(obj): + self.__dict__[batch_key] = obj.__class__(0) else: - self.__dict__[k] = None + self.__dict__[batch_key] = None return self @staticmethod @@ -658,26 +667,26 @@ class Batch: if batch is None: self.update(kwargs) return - for k, v in batch.items(): - self.__dict__[k] = _parse_value(v) + for batch_key, obj in batch.items(): + self.__dict__[batch_key] = _parse_value(obj) if kwargs: self.update(kwargs) def __len__(self) -> int: """Return len(self).""" - r = [] - for v in self.__dict__.values(): - if isinstance(v, Batch) and v.is_empty(recurse=True): + lens = [] + for obj in self.__dict__.values(): + if isinstance(obj, Batch) and obj.is_empty(recurse=True): continue - elif hasattr(v, "__len__") and (isinstance(v, Batch) or v.ndim > 0): - r.append(len(v)) + elif hasattr(obj, "__len__") and (isinstance(obj, Batch) or obj.ndim > 0): + lens.append(len(obj)) else: - raise TypeError(f"Object {v} in {self} has no len()") - if len(r) == 0: + raise TypeError(f"Object {obj} in {self} has no len()") + if len(lens) == 0: # empty batch has the shape of any, like the tensorflow '?' shape. # So it has no length. raise TypeError(f"Object {self} has no len()") - return min(r) + return min(lens) def is_empty(self, recurse: bool = False) -> bool: """Test if a Batch is empty. @@ -710,8 +719,8 @@ class Batch: if not recurse: return False return all( - False if not isinstance(x, Batch) else x.is_empty(recurse=True) - for x in self.values() + False if not isinstance(obj, Batch) else obj.is_empty(recurse=True) + for obj in self.values() ) @property @@ -721,9 +730,9 @@ class Batch: return [] else: data_shape = [] - for v in self.__dict__.values(): + for obj in self.__dict__.values(): try: - data_shape.append(list(v.shape)) + data_shape.append(list(obj.shape)) except AttributeError: data_shape.append([]) return list(map(min, zip(*data_shape))) if len(data_shape) > 1 \ diff --git a/tianshou/data/buffer/base.py b/tianshou/data/buffer/base.py index 05595af..b384e1f 100644 --- a/tianshou/data/buffer/base.py +++ b/tianshou/data/buffer/base.py @@ -69,8 +69,8 @@ class ReplayBuffer: """Return self.key.""" try: return self._meta[key] - except KeyError as e: - raise AttributeError from e + except KeyError as exception: + raise AttributeError from exception def __setstate__(self, state: Dict[str, Any]) -> None: """Unpickling interface. @@ -198,10 +198,10 @@ class ReplayBuffer: episode_reward is 0. """ # preprocess batch - b = Batch() + new_batch = Batch() for key in set(self._reserved_keys).intersection(batch.keys()): - b.__dict__[key] = batch[key] - batch = b + new_batch.__dict__[key] = batch[key] + batch = new_batch assert set(["obs", "act", "rew", "done"]).issubset(batch.keys()) stacked_batch = buffer_ids is not None if stacked_batch: @@ -315,9 +315,9 @@ class ReplayBuffer: return Batch.stack(stack, axis=indices.ndim) else: return np.stack(stack, axis=indices.ndim) - except IndexError as e: + except IndexError as exception: if not (isinstance(val, Batch) and val.is_empty()): - raise e # val != Batch() + raise exception # val != Batch() return Batch() def __getitem__(self, index: Union[slice, int, List[int], np.ndarray]) -> Batch: diff --git a/tianshou/data/buffer/manager.py b/tianshou/data/buffer/manager.py index 70ebcab..2b32911 100644 --- a/tianshou/data/buffer/manager.py +++ b/tianshou/data/buffer/manager.py @@ -114,10 +114,10 @@ class ReplayBufferManager(ReplayBuffer): episode_reward is 0. """ # preprocess batch - b = Batch() + new_batch = Batch() for key in set(self._reserved_keys).intersection(batch.keys()): - b.__dict__[key] = batch[key] - batch = b + new_batch.__dict__[key] = batch[key] + batch = new_batch assert set(["obs", "act", "rew", "done"]).issubset(batch.keys()) if self._save_only_last_obs: batch.obs = batch.obs[:, -1] diff --git a/tianshou/data/utils/converter.py b/tianshou/data/utils/converter.py index bd1dd53..12fa724 100644 --- a/tianshou/data/utils/converter.py +++ b/tianshou/data/utils/converter.py @@ -111,11 +111,11 @@ def to_hdf5(x: Hdf5ConvertibleType, y: h5py.Group) -> None: # and possibly in other cases like structured arrays. try: to_hdf5_via_pickle(v, y, k) - except Exception as e: + except Exception as exception: raise RuntimeError( f"Attempted to pickle {v.__class__.__name__} due to " "data type not supported by HDF5 and failed." - ) from e + ) from exception y[k].attrs["__data_type__"] = "pickled_ndarray" elif isinstance(v, (int, float)): # ints and floats are stored as attributes of groups @@ -123,11 +123,11 @@ def to_hdf5(x: Hdf5ConvertibleType, y: h5py.Group) -> None: else: # resort to pickle for any other type of object try: to_hdf5_via_pickle(v, y, k) - except Exception as e: + except Exception as exception: raise NotImplementedError( f"No conversion to HDF5 for object of type '{type(v)}' " "implemented and fallback to pickle failed." - ) from e + ) from exception y[k].attrs["__data_type__"] = v.__class__.__name__ diff --git a/tianshou/env/maenv.py b/tianshou/env/maenv.py index 456bbca..6fa1ad2 100644 --- a/tianshou/env/maenv.py +++ b/tianshou/env/maenv.py @@ -15,8 +15,8 @@ class MultiAgentEnv(ABC, gym.Env): env = MultiAgentEnv(...) # obs is a dict containing obs, agent_id, and mask obs = env.reset() - action = policy(obs) - obs, rew, done, info = env.step(action) + act = policy(obs) + obs, rew, done, info = env.step(act) env.close() The available action's mask is set to 1, otherwise it is set to 0. Further diff --git a/tianshou/env/venvs.py b/tianshou/env/venvs.py index b46593c..3942146 100644 --- a/tianshou/env/venvs.py +++ b/tianshou/env/venvs.py @@ -412,10 +412,10 @@ class RayVectorEnv(BaseVectorEnv): def __init__(self, env_fns: List[Callable[[], gym.Env]], **kwargs: Any) -> None: try: import ray - except ImportError as e: + except ImportError as exception: raise ImportError( "Please install ray to support RayVectorEnv: pip install ray" - ) from e + ) from exception if not ray.is_initialized(): ray.init() super().__init__(env_fns, RayEnvWorker, **kwargs) diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index feb6479..3572e83 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -98,6 +98,12 @@ class BasePolicy(ABC, nn.Module): """ return act + def soft_update(self, tgt: nn.Module, src: nn.Module, tau: float) -> None: + """Softly update the parameters of target module towards the parameters \ + of source module.""" + for tgt_param, src_param in zip(tgt.parameters(), src.parameters()): + tgt_param.data.copy_(tau * src_param.data + (1 - tau) * tgt_param.data) + @abstractmethod def forward( self, @@ -387,10 +393,10 @@ def _gae_return( ) -> np.ndarray: returns = np.zeros(rew.shape) delta = rew + v_s_ * gamma - v_s - m = (1.0 - end_flag) * (gamma * gae_lambda) + discount = (1.0 - end_flag) * (gamma * gae_lambda) gae = 0.0 for i in range(len(rew) - 1, -1, -1): - gae = delta[i] + m[i] * gae + gae = delta[i] + discount[i] * gae returns[i] = gae return returns diff --git a/tianshou/policy/imitation/base.py b/tianshou/policy/imitation/base.py index a5321ac..405b8c7 100644 --- a/tianshou/policy/imitation/base.py +++ b/tianshou/policy/imitation/base.py @@ -40,23 +40,23 @@ class ImitationPolicy(BasePolicy): state: Optional[Union[dict, Batch, np.ndarray]] = None, **kwargs: Any, ) -> Batch: - logits, h = self.model(batch.obs, state=state, info=batch.info) + logits, hidden = self.model(batch.obs, state=state, info=batch.info) if self.action_type == "discrete": - a = logits.max(dim=1)[1] + act = logits.max(dim=1)[1] else: - a = logits - return Batch(logits=logits, act=a, state=h) + act = logits + return Batch(logits=logits, act=act, state=hidden) def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: self.optim.zero_grad() if self.action_type == "continuous": # regression - a = self(batch).act - a_ = to_torch(batch.act, dtype=torch.float32, device=a.device) - loss = F.mse_loss(a, a_) # type: ignore + act = self(batch).act + act_target = to_torch(batch.act, dtype=torch.float32, device=act.device) + loss = F.mse_loss(act, act_target) # type: ignore elif self.action_type == "discrete": # classification - a = F.log_softmax(self(batch).logits, dim=-1) - a_ = to_torch(batch.act, dtype=torch.long, device=a.device) - loss = F.nll_loss(a, a_) # type: ignore + act = F.log_softmax(self(batch).logits, dim=-1) + act_target = to_torch(batch.act, dtype=torch.long, device=act.device) + loss = F.nll_loss(act, act_target) # type: ignore loss.backward() self.optim.step() return {"loss": loss.item()} diff --git a/tianshou/policy/imitation/bcq.py b/tianshou/policy/imitation/bcq.py index 2aeeb32..afd9be9 100644 --- a/tianshou/policy/imitation/bcq.py +++ b/tianshou/policy/imitation/bcq.py @@ -105,32 +105,27 @@ class BCQPolicy(BasePolicy): obs_group: torch.Tensor = to_torch( # type: ignore batch.obs, device=self.device ) - act = [] + act_group = [] for obs in obs_group: # now obs is (state_dim) obs = (obs.reshape(1, -1)).repeat(self.forward_sampled_times, 1) # now obs is (forward_sampled_times, state_dim) # decode(obs) generates action and actor perturbs it - action = self.actor(obs, self.vae.decode(obs)) + act = self.actor(obs, self.vae.decode(obs)) # now action is (forward_sampled_times, action_dim) - q1 = self.critic1(obs, action) + q1 = self.critic1(obs, act) # q1 is (forward_sampled_times, 1) - ind = q1.argmax(0) - act.append(action[ind].cpu().data.numpy().flatten()) - act = np.array(act) - return Batch(act=act) + max_indice = q1.argmax(0) + act_group.append(act[max_indice].cpu().data.numpy().flatten()) + act_group = np.array(act_group) + return Batch(act=act_group) def sync_weight(self) -> None: """Soft-update the weight for the target network.""" - for net, net_target in [ - [self.critic1, self.critic1_target], [self.critic2, self.critic2_target], - [self.actor, self.actor_target] - ]: - for param, target_param in zip(net.parameters(), net_target.parameters()): - target_param.data.copy_( - self.tau * param.data + (1 - self.tau) * target_param.data - ) + self.soft_update(self.critic1_target, self.critic1, self.tau) + self.soft_update(self.critic2_target, self.critic2, self.tau) + self.soft_update(self.actor_target, self.actor, self.tau) def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: # batch: obs, act, rew, done, obs_next. (numpy array) diff --git a/tianshou/policy/imitation/cql.py b/tianshou/policy/imitation/cql.py index c2df77c..102ede0 100644 --- a/tianshou/policy/imitation/cql.py +++ b/tianshou/policy/imitation/cql.py @@ -113,13 +113,8 @@ class CQLPolicy(SACPolicy): def sync_weight(self) -> None: """Soft-update the weight for the target network.""" - for net, net_old in [ - [self.critic1, self.critic1_old], [self.critic2, self.critic2_old] - ]: - for param, target_param in zip(net.parameters(), net_old.parameters()): - target_param.data.copy_( - self._tau * param.data + (1 - self._tau) * target_param.data - ) + self.soft_update(self.critic1_old, self.critic1, self.tau) + self.soft_update(self.critic2_old, self.critic2, self.tau) def actor_pred(self, obs: torch.Tensor) -> \ Tuple[torch.Tensor, torch.Tensor]: diff --git a/tianshou/policy/imitation/discrete_bcq.py b/tianshou/policy/imitation/discrete_bcq.py index d9cac65..bca9b09 100644 --- a/tianshou/policy/imitation/discrete_bcq.py +++ b/tianshou/policy/imitation/discrete_bcq.py @@ -94,13 +94,10 @@ class DiscreteBCQPolicy(DQNPolicy): # mask actions for argmax ratio = imitation_logits - imitation_logits.max(dim=-1, keepdim=True).values mask = (ratio < self._log_tau).float() - action = (q_value - np.inf * mask).argmax(dim=-1) + act = (q_value - np.inf * mask).argmax(dim=-1) return Batch( - act=action, - state=state, - q_value=q_value, - imitation_logits=imitation_logits + act=act, state=state, q_value=q_value, imitation_logits=imitation_logits ) def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: diff --git a/tianshou/policy/imitation/discrete_cql.py b/tianshou/policy/imitation/discrete_cql.py index ad4ed19..1adbb26 100644 --- a/tianshou/policy/imitation/discrete_cql.py +++ b/tianshou/policy/imitation/discrete_cql.py @@ -57,15 +57,15 @@ class DiscreteCQLPolicy(QRDQNPolicy): curr_dist = all_dist[np.arange(len(act)), act, :].unsqueeze(2) target_dist = batch.returns.unsqueeze(1) # calculate each element's difference between curr_dist and target_dist - u = F.smooth_l1_loss(target_dist, curr_dist, reduction="none") + dist_diff = F.smooth_l1_loss(target_dist, curr_dist, reduction="none") huber_loss = ( - u * (self.tau_hat - - (target_dist - curr_dist).detach().le(0.).float()).abs() + dist_diff * + (self.tau_hat - (target_dist - curr_dist).detach().le(0.).float()).abs() ).sum(-1).mean(1) qr_loss = (huber_loss * weight).mean() # ref: https://github.com/ku2482/fqf-iqn-qrdqn.pytorch/ # blob/master/fqf_iqn_qrdqn/agent/qrdqn_agent.py L130 - batch.weight = u.detach().abs().sum(-1).mean(1) # prio-buffer + batch.weight = dist_diff.detach().abs().sum(-1).mean(1) # prio-buffer # add CQL loss q = self.compute_q_value(all_dist, None) dataset_expec = q.gather(1, act.unsqueeze(1)).mean() diff --git a/tianshou/policy/imitation/discrete_crr.py b/tianshou/policy/imitation/discrete_crr.py index dd4efe7..b182ead 100644 --- a/tianshou/policy/imitation/discrete_crr.py +++ b/tianshou/policy/imitation/discrete_crr.py @@ -97,9 +97,9 @@ class DiscreteCRRPolicy(PGPolicy): target = rew.unsqueeze(1) + self._gamma * expected_target_q critic_loss = 0.5 * F.mse_loss(qa_t, target) # Actor loss - a_t, _ = self.actor(batch.obs) - m = Categorical(logits=a_t) - expected_policy_q = (q_t * m.probs).sum(-1, keepdim=True) + act_target, _ = self.actor(batch.obs) + dist = Categorical(logits=act_target) + expected_policy_q = (q_t * dist.probs).sum(-1, keepdim=True) advantage = qa_t - expected_policy_q if self._policy_improvement_mode == "binary": actor_loss_coef = (advantage > 0).float() @@ -109,7 +109,7 @@ class DiscreteCRRPolicy(PGPolicy): ) else: actor_loss_coef = 1.0 # effectively behavior cloning - actor_loss = (-m.log_prob(act) * actor_loss_coef).mean() + actor_loss = (-dist.log_prob(act) * actor_loss_coef).mean() # CQL loss/regularizer min_q_loss = (q_t.logsumexp(1) - qa_t).mean() loss = actor_loss + critic_loss + self._min_q_weight * min_q_loss diff --git a/tianshou/policy/modelbased/psrl.py b/tianshou/policy/modelbased/psrl.py index e5ea5c9..3caab0d 100644 --- a/tianshou/policy/modelbased/psrl.py +++ b/tianshou/policy/modelbased/psrl.py @@ -202,13 +202,13 @@ class PSRLPolicy(BasePolicy): rew_sum = np.zeros((n_s, n_a)) rew_square_sum = np.zeros((n_s, n_a)) rew_count = np.zeros((n_s, n_a)) - for b in batch.split(size=1): - obs, act, obs_next = b.obs, b.act, b.obs_next + for minibatch in batch.split(size=1): + obs, act, obs_next = minibatch.obs, minibatch.act, minibatch.obs_next trans_count[obs, act, obs_next] += 1 - rew_sum[obs, act] += b.rew - rew_square_sum[obs, act] += b.rew**2 + rew_sum[obs, act] += minibatch.rew + rew_square_sum[obs, act] += minibatch.rew**2 rew_count[obs, act] += 1 - if self._add_done_loop and b.done: + if self._add_done_loop and minibatch.done: # special operation for terminal states: add a self-loop trans_count[obs_next, :, obs_next] += 1 rew_count[obs_next, :] += 1 diff --git a/tianshou/policy/modelfree/a2c.py b/tianshou/policy/modelfree/a2c.py index f67d25d..2ad5bb3 100644 --- a/tianshou/policy/modelfree/a2c.py +++ b/tianshou/policy/modelfree/a2c.py @@ -85,9 +85,9 @@ class A2CPolicy(PGPolicy): ) -> Batch: v_s, v_s_ = [], [] with torch.no_grad(): - for b in batch.split(self._batch, shuffle=False, merge_last=True): - v_s.append(self.critic(b.obs)) - v_s_.append(self.critic(b.obs_next)) + for minibatch in batch.split(self._batch, shuffle=False, merge_last=True): + v_s.append(self.critic(minibatch.obs)) + v_s_.append(self.critic(minibatch.obs_next)) batch.v_s = torch.cat(v_s, dim=0).flatten() # old value v_s = batch.v_s.cpu().numpy() v_s_ = torch.cat(v_s_, dim=0).flatten().cpu().numpy() @@ -122,14 +122,15 @@ class A2CPolicy(PGPolicy): ) -> Dict[str, List[float]]: losses, actor_losses, vf_losses, ent_losses = [], [], [], [] for _ in range(repeat): - for b in batch.split(batch_size, merge_last=True): + for minibatch in batch.split(batch_size, merge_last=True): # calculate loss for actor - dist = self(b).dist - log_prob = dist.log_prob(b.act).reshape(len(b.adv), -1).transpose(0, 1) - actor_loss = -(log_prob * b.adv).mean() + dist = self(minibatch).dist + log_prob = dist.log_prob(minibatch.act) + log_prob = log_prob.reshape(len(minibatch.adv), -1).transpose(0, 1) + actor_loss = -(log_prob * minibatch.adv).mean() # calculate loss for critic - value = self.critic(b.obs).flatten() - vf_loss = F.mse_loss(b.returns, value) + value = self.critic(minibatch.obs).flatten() + vf_loss = F.mse_loss(minibatch.returns, value) # calculate regularization and overall loss ent_loss = dist.entropy().mean() loss = actor_loss + self._weight_vf * vf_loss \ diff --git a/tianshou/policy/modelfree/c51.py b/tianshou/policy/modelfree/c51.py index 4e79eb3..e49b2d6 100644 --- a/tianshou/policy/modelfree/c51.py +++ b/tianshou/policy/modelfree/c51.py @@ -70,13 +70,13 @@ class C51Policy(DQNPolicy): def _target_dist(self, batch: Batch) -> torch.Tensor: if self._target: - a = self(batch, input="obs_next").act + act = self(batch, input="obs_next").act next_dist = self(batch, model="model_old", input="obs_next").logits else: - next_b = self(batch, input="obs_next") - a = next_b.act - next_dist = next_b.logits - next_dist = next_dist[np.arange(len(a)), a, :] + next_batch = self(batch, input="obs_next") + act = next_batch.act + next_dist = next_batch.logits + next_dist = next_dist[np.arange(len(act)), act, :] target_support = batch.returns.clamp(self._v_min, self._v_max) # An amazing trick for calculating the projection gracefully. # ref: https://github.com/ShangtongZhang/DeepRL diff --git a/tianshou/policy/modelfree/ddpg.py b/tianshou/policy/modelfree/ddpg.py index 18bb81b..0a779c6 100644 --- a/tianshou/policy/modelfree/ddpg.py +++ b/tianshou/policy/modelfree/ddpg.py @@ -73,7 +73,7 @@ class DDPGPolicy(BasePolicy): self.critic_old.eval() self.critic_optim: torch.optim.Optimizer = critic_optim assert 0.0 <= tau <= 1.0, "tau should be in [0, 1]" - self._tau = tau + self.tau = tau assert 0.0 <= gamma <= 1.0, "gamma should be in [0, 1]" self._gamma = gamma self._noise = exploration_noise @@ -95,10 +95,8 @@ class DDPGPolicy(BasePolicy): def sync_weight(self) -> None: """Soft-update the weight for the target network.""" - for o, n in zip(self.actor_old.parameters(), self.actor.parameters()): - o.data.copy_(o.data * (1.0 - self._tau) + n.data * self._tau) - for o, n in zip(self.critic_old.parameters(), self.critic.parameters()): - o.data.copy_(o.data * (1.0 - self._tau) + n.data * self._tau) + self.soft_update(self.actor_old, self.actor, self.tau) + self.soft_update(self.critic_old, self.critic, self.tau) def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: batch = buffer[indices] # batch.obs_next: s_{t+n} @@ -139,8 +137,8 @@ class DDPGPolicy(BasePolicy): """ model = getattr(self, model) obs = batch[input] - actions, h = model(obs, state=state, info=batch.info) - return Batch(act=actions, state=h) + actions, hidden = model(obs, state=state, info=batch.info) + return Batch(act=actions, state=hidden) @staticmethod def _mse_optimizer( @@ -163,8 +161,7 @@ class DDPGPolicy(BasePolicy): td, critic_loss = self._mse_optimizer(batch, self.critic, self.critic_optim) batch.weight = td # prio-buffer # actor - action = self(batch).act - actor_loss = -self.critic(batch.obs, action).mean() + actor_loss = -self.critic(batch.obs, self(batch).act).mean() self.actor_optim.zero_grad() actor_loss.backward() self.actor_optim.step() diff --git a/tianshou/policy/modelfree/discrete_sac.py b/tianshou/policy/modelfree/discrete_sac.py index 7c580f3..33e06da 100644 --- a/tianshou/policy/modelfree/discrete_sac.py +++ b/tianshou/policy/modelfree/discrete_sac.py @@ -76,10 +76,10 @@ class DiscreteSACPolicy(SACPolicy): **kwargs: Any, ) -> Batch: obs = batch[input] - logits, h = self.actor(obs, state=state, info=batch.info) + logits, hidden = self.actor(obs, state=state, info=batch.info) dist = Categorical(logits=logits) act = dist.sample() - return Batch(logits=logits, act=act, state=h, dist=dist) + return Batch(logits=logits, act=act, state=hidden, dist=dist) def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: batch = buffer[indices] # batch.obs: s_{t+n} diff --git a/tianshou/policy/modelfree/dqn.py b/tianshou/policy/modelfree/dqn.py index 8504909..d03c3e3 100644 --- a/tianshou/policy/modelfree/dqn.py +++ b/tianshou/policy/modelfree/dqn.py @@ -151,13 +151,13 @@ class DQNPolicy(BasePolicy): """ model = getattr(self, model) obs = batch[input] - obs_ = obs.obs if hasattr(obs, "obs") else obs - logits, h = model(obs_, state=state, info=batch.info) + obs_next = obs.obs if hasattr(obs, "obs") else obs + logits, hidden = model(obs_next, state=state, info=batch.info) q = self.compute_q_value(logits, getattr(obs, "mask", None)) if not hasattr(self, "max_action_num"): self.max_action_num = q.shape[1] act = to_numpy(q.max(dim=1)[1]) - return Batch(logits=logits, act=act, state=h) + return Batch(logits=logits, act=act, state=hidden) def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: if self._target and self._iter % self._freq == 0: @@ -166,10 +166,10 @@ class DQNPolicy(BasePolicy): weight = batch.pop("weight", 1.0) q = self(batch).logits q = q[np.arange(len(q)), batch.act] - r = to_torch_as(batch.returns.flatten(), q) - td = r - q - loss = (td.pow(2) * weight).mean() - batch.weight = td # prio-buffer + returns = to_torch_as(batch.returns.flatten(), q) + td_error = returns - q + loss = (td_error.pow(2) * weight).mean() + batch.weight = td_error # prio-buffer loss.backward() self.optim.step() self._iter += 1 diff --git a/tianshou/policy/modelfree/fqf.py b/tianshou/policy/modelfree/fqf.py index 3c015b3..054781a 100644 --- a/tianshou/policy/modelfree/fqf.py +++ b/tianshou/policy/modelfree/fqf.py @@ -60,15 +60,15 @@ class FQFPolicy(QRDQNPolicy): batch = buffer[indices] # batch.obs_next: s_{t+n} if self._target: result = self(batch, input="obs_next") - a, fractions = result.act, result.fractions + act, fractions = result.act, result.fractions next_dist = self( batch, model="model_old", input="obs_next", fractions=fractions ).logits else: - next_b = self(batch, input="obs_next") - a = next_b.act - next_dist = next_b.logits - next_dist = next_dist[np.arange(len(a)), a, :] + next_batch = self(batch, input="obs_next") + act = next_batch.act + next_dist = next_batch.logits + next_dist = next_dist[np.arange(len(act)), act, :] return next_dist # shape: [bsz, num_quantiles] def forward( @@ -82,14 +82,17 @@ class FQFPolicy(QRDQNPolicy): ) -> Batch: model = getattr(self, model) obs = batch[input] - obs_ = obs.obs if hasattr(obs, "obs") else obs + obs_next = obs.obs if hasattr(obs, "obs") else obs if fractions is None: - (logits, fractions, quantiles_tau), h = model( - obs_, propose_model=self.propose_model, state=state, info=batch.info + (logits, fractions, quantiles_tau), hidden = model( + obs_next, + propose_model=self.propose_model, + state=state, + info=batch.info ) else: - (logits, _, quantiles_tau), h = model( - obs_, + (logits, _, quantiles_tau), hidden = model( + obs_next, propose_model=self.propose_model, fractions=fractions, state=state, @@ -106,7 +109,7 @@ class FQFPolicy(QRDQNPolicy): return Batch( logits=logits, act=act, - state=h, + state=hidden, fractions=fractions, quantiles_tau=quantiles_tau ) @@ -122,9 +125,9 @@ class FQFPolicy(QRDQNPolicy): curr_dist = curr_dist_orig[np.arange(len(act)), act, :].unsqueeze(2) target_dist = batch.returns.unsqueeze(1) # calculate each element's difference between curr_dist and target_dist - u = F.smooth_l1_loss(target_dist, curr_dist, reduction="none") + dist_diff = F.smooth_l1_loss(target_dist, curr_dist, reduction="none") huber_loss = ( - u * ( + dist_diff * ( tau_hats.unsqueeze(2) - (target_dist - curr_dist).detach().le(0.).float() ).abs() @@ -132,7 +135,7 @@ class FQFPolicy(QRDQNPolicy): quantile_loss = (huber_loss * weight).mean() # ref: https://github.com/ku2482/fqf-iqn-qrdqn.pytorch/ # blob/master/fqf_iqn_qrdqn/agent/qrdqn_agent.py L130 - batch.weight = u.detach().abs().sum(-1).mean(1) # prio-buffer + batch.weight = dist_diff.detach().abs().sum(-1).mean(1) # prio-buffer # calculate fraction loss with torch.no_grad(): sa_quantile_hats = curr_dist_orig[np.arange(len(act)), act, :] diff --git a/tianshou/policy/modelfree/iqn.py b/tianshou/policy/modelfree/iqn.py index 9d9777b..502dd69 100644 --- a/tianshou/policy/modelfree/iqn.py +++ b/tianshou/policy/modelfree/iqn.py @@ -73,36 +73,37 @@ class IQNPolicy(QRDQNPolicy): sample_size = self._sample_size model = getattr(self, model) obs = batch[input] - obs_ = obs.obs if hasattr(obs, "obs") else obs - (logits, - taus), h = model(obs_, sample_size=sample_size, state=state, info=batch.info) + obs_next = obs.obs if hasattr(obs, "obs") else obs + (logits, taus), hidden = model( + obs_next, sample_size=sample_size, state=state, info=batch.info + ) q = self.compute_q_value(logits, getattr(obs, "mask", None)) if not hasattr(self, "max_action_num"): self.max_action_num = q.shape[1] act = to_numpy(q.max(dim=1)[1]) - return Batch(logits=logits, act=act, state=h, taus=taus) + return Batch(logits=logits, act=act, state=hidden, taus=taus) def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: if self._target and self._iter % self._freq == 0: self.sync_weight() self.optim.zero_grad() weight = batch.pop("weight", 1.0) - out = self(batch) - curr_dist, taus = out.logits, out.taus + action_batch = self(batch) + curr_dist, taus = action_batch.logits, action_batch.taus act = batch.act curr_dist = curr_dist[np.arange(len(act)), act, :].unsqueeze(2) target_dist = batch.returns.unsqueeze(1) # calculate each element's difference between curr_dist and target_dist - u = F.smooth_l1_loss(target_dist, curr_dist, reduction="none") + dist_diff = F.smooth_l1_loss(target_dist, curr_dist, reduction="none") huber_loss = ( - u * + dist_diff * (taus.unsqueeze(2) - (target_dist - curr_dist).detach().le(0.).float()).abs() ).sum(-1).mean(1) loss = (huber_loss * weight).mean() # ref: https://github.com/ku2482/fqf-iqn-qrdqn.pytorch/ # blob/master/fqf_iqn_qrdqn/agent/qrdqn_agent.py L130 - batch.weight = u.detach().abs().sum(-1).mean(1) # prio-buffer + batch.weight = dist_diff.detach().abs().sum(-1).mean(1) # prio-buffer loss.backward() self.optim.step() self._iter += 1 diff --git a/tianshou/policy/modelfree/npg.py b/tianshou/policy/modelfree/npg.py index 758093d..ce91fdb 100644 --- a/tianshou/policy/modelfree/npg.py +++ b/tianshou/policy/modelfree/npg.py @@ -71,8 +71,8 @@ class NPGPolicy(A2CPolicy): batch = super().process_fn(batch, buffer, indices) old_log_prob = [] with torch.no_grad(): - for b in batch.split(self._batch, shuffle=False, merge_last=True): - old_log_prob.append(self(b).dist.log_prob(b.act)) + for minibatch in batch.split(self._batch, shuffle=False, merge_last=True): + old_log_prob.append(self(minibatch).dist.log_prob(minibatch.act)) batch.logp_old = torch.cat(old_log_prob, dim=0) if self._norm_adv: batch.adv = (batch.adv - batch.adv.mean()) / batch.adv.std() @@ -83,20 +83,20 @@ class NPGPolicy(A2CPolicy): ) -> Dict[str, List[float]]: actor_losses, vf_losses, kls = [], [], [] for _ in range(repeat): - for b in batch.split(batch_size, merge_last=True): + for minibatch in batch.split(batch_size, merge_last=True): # optimize actor # direction: calculate villia gradient - dist = self(b).dist - log_prob = dist.log_prob(b.act) + dist = self(minibatch).dist + log_prob = dist.log_prob(minibatch.act) log_prob = log_prob.reshape(log_prob.size(0), -1).transpose(0, 1) - actor_loss = -(log_prob * b.adv).mean() + actor_loss = -(log_prob * minibatch.adv).mean() flat_grads = self._get_flat_grad( actor_loss, self.actor, retain_graph=True ).detach() # direction: calculate natural gradient with torch.no_grad(): - old_dist = self(b).dist + old_dist = self(minibatch).dist kl = kl_divergence(old_dist, dist).mean() # calculate first order gradient of kl with respect to theta @@ -112,13 +112,13 @@ class NPGPolicy(A2CPolicy): ) new_flat_params = flat_params + self._step_size * search_direction self._set_from_flat_params(self.actor, new_flat_params) - new_dist = self(b).dist + new_dist = self(minibatch).dist kl = kl_divergence(old_dist, new_dist).mean() # optimize citirc for _ in range(self._optim_critic_iters): - value = self.critic(b.obs).flatten() - vf_loss = F.mse_loss(b.returns, value) + value = self.critic(minibatch.obs).flatten() + vf_loss = F.mse_loss(minibatch.returns, value) self.optim.zero_grad() vf_loss.backward() self.optim.step() @@ -147,14 +147,14 @@ class NPGPolicy(A2CPolicy): def _conjugate_gradients( self, - b: torch.Tensor, + minibatch: torch.Tensor, flat_kl_grad: torch.Tensor, nsteps: int = 10, residual_tol: float = 1e-10 ) -> torch.Tensor: - x = torch.zeros_like(b) - r, p = b.clone(), b.clone() - # Note: should be 'r, p = b - MVP(x)', but for x=0, MVP(x)=0. + x = torch.zeros_like(minibatch) + r, p = minibatch.clone(), minibatch.clone() + # Note: should be 'r, p = minibatch - MVP(x)', but for x=0, MVP(x)=0. # Change if doing warm start. rdotr = r.dot(r) for _ in range(nsteps): diff --git a/tianshou/policy/modelfree/pg.py b/tianshou/policy/modelfree/pg.py index a648288..a525746 100644 --- a/tianshou/policy/modelfree/pg.py +++ b/tianshou/policy/modelfree/pg.py @@ -107,7 +107,7 @@ class PGPolicy(BasePolicy): Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for more detailed explanation. """ - logits, h = self.actor(batch.obs, state=state) + logits, hidden = self.actor(batch.obs, state=state) if isinstance(logits, tuple): dist = self.dist_fn(*logits) else: @@ -119,20 +119,20 @@ class PGPolicy(BasePolicy): act = logits[0] else: act = dist.sample() - return Batch(logits=logits, act=act, state=h, dist=dist) + return Batch(logits=logits, act=act, state=hidden, dist=dist) def learn( # type: ignore self, batch: Batch, batch_size: int, repeat: int, **kwargs: Any ) -> Dict[str, List[float]]: losses = [] for _ in range(repeat): - for b in batch.split(batch_size, merge_last=True): + for minibatch in batch.split(batch_size, merge_last=True): self.optim.zero_grad() - result = self(b) + result = self(minibatch) dist = result.dist - a = to_torch_as(b.act, result.act) - ret = to_torch_as(b.returns, result.act) - log_prob = dist.log_prob(a).reshape(len(ret), -1).transpose(0, 1) + act = to_torch_as(minibatch.act, result.act) + ret = to_torch_as(minibatch.returns, result.act) + log_prob = dist.log_prob(act).reshape(len(ret), -1).transpose(0, 1) loss = -(log_prob * ret).mean() loss.backward() self.optim.step() diff --git a/tianshou/policy/modelfree/ppo.py b/tianshou/policy/modelfree/ppo.py index 45439ed..fe5aa2f 100644 --- a/tianshou/policy/modelfree/ppo.py +++ b/tianshou/policy/modelfree/ppo.py @@ -96,8 +96,8 @@ class PPOPolicy(A2CPolicy): batch.act = to_torch_as(batch.act, batch.v_s) old_log_prob = [] with torch.no_grad(): - for b in batch.split(self._batch, shuffle=False, merge_last=True): - old_log_prob.append(self(b).dist.log_prob(b.act)) + for minibatch in batch.split(self._batch, shuffle=False, merge_last=True): + old_log_prob.append(self(minibatch).dist.log_prob(minibatch.act)) batch.logp_old = torch.cat(old_log_prob, dim=0) return batch @@ -108,32 +108,35 @@ class PPOPolicy(A2CPolicy): for step in range(repeat): if self._recompute_adv and step > 0: batch = self._compute_returns(batch, self._buffer, self._indices) - for b in batch.split(batch_size, merge_last=True): + for minibatch in batch.split(batch_size, merge_last=True): # calculate loss for actor - dist = self(b).dist + dist = self(minibatch).dist if self._norm_adv: - mean, std = b.adv.mean(), b.adv.std() - b.adv = (b.adv - mean) / std # per-batch norm - ratio = (dist.log_prob(b.act) - b.logp_old).exp().float() + mean, std = minibatch.adv.mean(), minibatch.adv.std() + minibatch.adv = (minibatch.adv - mean) / std # per-batch norm + ratio = (dist.log_prob(minibatch.act) - + minibatch.logp_old).exp().float() ratio = ratio.reshape(ratio.size(0), -1).transpose(0, 1) - surr1 = ratio * b.adv - surr2 = ratio.clamp(1.0 - self._eps_clip, 1.0 + self._eps_clip) * b.adv + surr1 = ratio * minibatch.adv + surr2 = ratio.clamp( + 1.0 - self._eps_clip, 1.0 + self._eps_clip + ) * minibatch.adv if self._dual_clip: clip1 = torch.min(surr1, surr2) - clip2 = torch.max(clip1, self._dual_clip * b.adv) - clip_loss = -torch.where(b.adv < 0, clip2, clip1).mean() + clip2 = torch.max(clip1, self._dual_clip * minibatch.adv) + clip_loss = -torch.where(minibatch.adv < 0, clip2, clip1).mean() else: clip_loss = -torch.min(surr1, surr2).mean() # calculate loss for critic - value = self.critic(b.obs).flatten() + value = self.critic(minibatch.obs).flatten() if self._value_clip: - v_clip = b.v_s + (value - - b.v_s).clamp(-self._eps_clip, self._eps_clip) - vf1 = (b.returns - value).pow(2) - vf2 = (b.returns - v_clip).pow(2) + v_clip = minibatch.v_s + \ + (value - minibatch.v_s).clamp(-self._eps_clip, self._eps_clip) + vf1 = (minibatch.returns - value).pow(2) + vf2 = (minibatch.returns - v_clip).pow(2) vf_loss = torch.max(vf1, vf2).mean() else: - vf_loss = (b.returns - value).pow(2).mean() + vf_loss = (minibatch.returns - value).pow(2).mean() # calculate regularization and overall loss ent_loss = dist.entropy().mean() loss = clip_loss + self._weight_vf * vf_loss \ diff --git a/tianshou/policy/modelfree/qrdqn.py b/tianshou/policy/modelfree/qrdqn.py index fe3e101..ea4913f 100644 --- a/tianshou/policy/modelfree/qrdqn.py +++ b/tianshou/policy/modelfree/qrdqn.py @@ -56,13 +56,13 @@ class QRDQNPolicy(DQNPolicy): def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: batch = buffer[indices] # batch.obs_next: s_{t+n} if self._target: - a = self(batch, input="obs_next").act + act = self(batch, input="obs_next").act next_dist = self(batch, model="model_old", input="obs_next").logits else: - next_b = self(batch, input="obs_next") - a = next_b.act - next_dist = next_b.logits - next_dist = next_dist[np.arange(len(a)), a, :] + next_batch = self(batch, input="obs_next") + act = next_batch.act + next_dist = next_batch.logits + next_dist = next_dist[np.arange(len(act)), act, :] return next_dist # shape: [bsz, num_quantiles] def compute_q_value( @@ -80,15 +80,15 @@ class QRDQNPolicy(DQNPolicy): curr_dist = curr_dist[np.arange(len(act)), act, :].unsqueeze(2) target_dist = batch.returns.unsqueeze(1) # calculate each element's difference between curr_dist and target_dist - u = F.smooth_l1_loss(target_dist, curr_dist, reduction="none") + dist_diff = F.smooth_l1_loss(target_dist, curr_dist, reduction="none") huber_loss = ( - u * (self.tau_hat - - (target_dist - curr_dist).detach().le(0.).float()).abs() + dist_diff * + (self.tau_hat - (target_dist - curr_dist).detach().le(0.).float()).abs() ).sum(-1).mean(1) loss = (huber_loss * weight).mean() # ref: https://github.com/ku2482/fqf-iqn-qrdqn.pytorch/ # blob/master/fqf_iqn_qrdqn/agent/qrdqn_agent.py L130 - batch.weight = u.detach().abs().sum(-1).mean(1) # prio-buffer + batch.weight = dist_diff.detach().abs().sum(-1).mean(1) # prio-buffer loss.backward() self.optim.step() self._iter += 1 diff --git a/tianshou/policy/modelfree/sac.py b/tianshou/policy/modelfree/sac.py index 2657a1e..fc89cf3 100644 --- a/tianshou/policy/modelfree/sac.py +++ b/tianshou/policy/modelfree/sac.py @@ -99,10 +99,8 @@ class SACPolicy(DDPGPolicy): return self def sync_weight(self) -> None: - for o, n in zip(self.critic1_old.parameters(), self.critic1.parameters()): - o.data.copy_(o.data * (1.0 - self._tau) + n.data * self._tau) - for o, n in zip(self.critic2_old.parameters(), self.critic2.parameters()): - o.data.copy_(o.data * (1.0 - self._tau) + n.data * self._tau) + self.soft_update(self.critic1_old, self.critic1, self.tau) + self.soft_update(self.critic2_old, self.critic2, self.tau) def forward( # type: ignore self, @@ -112,7 +110,7 @@ class SACPolicy(DDPGPolicy): **kwargs: Any, ) -> Batch: obs = batch[input] - logits, h = self.actor(obs, state=state, info=batch.info) + logits, hidden = self.actor(obs, state=state, info=batch.info) assert isinstance(logits, tuple) dist = Independent(Normal(*logits), 1) if self._deterministic_eval and not self.training: @@ -134,16 +132,20 @@ class SACPolicy(DDPGPolicy): action_scale * (1 - squashed_action.pow(2)) + self.__eps ).sum(-1, keepdim=True) return Batch( - logits=logits, act=squashed_action, state=h, dist=dist, log_prob=log_prob + logits=logits, + act=squashed_action, + state=hidden, + dist=dist, + log_prob=log_prob ) def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: batch = buffer[indices] # batch.obs: s_{t+n} - obs_next_result = self(batch, input='obs_next') - a_ = obs_next_result.act + obs_next_result = self(batch, input="obs_next") + act_ = obs_next_result.act target_q = torch.min( - self.critic1_old(batch.obs_next, a_), - self.critic2_old(batch.obs_next, a_), + self.critic1_old(batch.obs_next, act_), + self.critic2_old(batch.obs_next, act_), ) - self._alpha * obs_next_result.log_prob return target_q @@ -159,9 +161,9 @@ class SACPolicy(DDPGPolicy): # actor obs_result = self(batch) - a = obs_result.act - current_q1a = self.critic1(batch.obs, a).flatten() - current_q2a = self.critic2(batch.obs, a).flatten() + act = obs_result.act + current_q1a = self.critic1(batch.obs, act).flatten() + current_q2a = self.critic2(batch.obs, act).flatten() actor_loss = ( self._alpha * obs_result.log_prob.flatten() - torch.min(current_q1a, current_q2a) diff --git a/tianshou/policy/modelfree/td3.py b/tianshou/policy/modelfree/td3.py index a033237..8ad31db 100644 --- a/tianshou/policy/modelfree/td3.py +++ b/tianshou/policy/modelfree/td3.py @@ -89,23 +89,20 @@ class TD3Policy(DDPGPolicy): return self def sync_weight(self) -> None: - for o, n in zip(self.actor_old.parameters(), self.actor.parameters()): - o.data.copy_(o.data * (1.0 - self._tau) + n.data * self._tau) - for o, n in zip(self.critic1_old.parameters(), self.critic1.parameters()): - o.data.copy_(o.data * (1.0 - self._tau) + n.data * self._tau) - for o, n in zip(self.critic2_old.parameters(), self.critic2.parameters()): - o.data.copy_(o.data * (1.0 - self._tau) + n.data * self._tau) + self.soft_update(self.critic1_old, self.critic1, self.tau) + self.soft_update(self.critic2_old, self.critic2, self.tau) + self.soft_update(self.actor_old, self.actor, self.tau) def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: batch = buffer[indices] # batch.obs: s_{t+n} - a_ = self(batch, model="actor_old", input="obs_next").act - dev = a_.device - noise = torch.randn(size=a_.shape, device=dev) * self._policy_noise + act_ = self(batch, model="actor_old", input="obs_next").act + noise = torch.randn(size=act_.shape, device=act_.device) * self._policy_noise if self._noise_clip > 0.0: noise = noise.clamp(-self._noise_clip, self._noise_clip) - a_ += noise + act_ += noise target_q = torch.min( - self.critic1_old(batch.obs_next, a_), self.critic2_old(batch.obs_next, a_) + self.critic1_old(batch.obs_next, act_), + self.critic2_old(batch.obs_next, act_), ) return target_q diff --git a/tianshou/policy/modelfree/trpo.py b/tianshou/policy/modelfree/trpo.py index 75956d9..2803a2d 100644 --- a/tianshou/policy/modelfree/trpo.py +++ b/tianshou/policy/modelfree/trpo.py @@ -71,20 +71,21 @@ class TRPOPolicy(NPGPolicy): ) -> Dict[str, List[float]]: actor_losses, vf_losses, step_sizes, kls = [], [], [], [] for _ in range(repeat): - for b in batch.split(batch_size, merge_last=True): + for minibatch in batch.split(batch_size, merge_last=True): # optimize actor # direction: calculate villia gradient - dist = self(b).dist # TODO could come from batch - ratio = (dist.log_prob(b.act) - b.logp_old).exp().float() + dist = self(minibatch).dist # TODO could come from batch + ratio = (dist.log_prob(minibatch.act) - + minibatch.logp_old).exp().float() ratio = ratio.reshape(ratio.size(0), -1).transpose(0, 1) - actor_loss = -(ratio * b.adv).mean() + actor_loss = -(ratio * minibatch.adv).mean() flat_grads = self._get_flat_grad( actor_loss, self.actor, retain_graph=True ).detach() # direction: calculate natural gradient with torch.no_grad(): - old_dist = self(b).dist + old_dist = self(minibatch).dist kl = kl_divergence(old_dist, dist).mean() # calculate first order gradient of kl with respect to theta @@ -109,12 +110,13 @@ class TRPOPolicy(NPGPolicy): new_flat_params = flat_params + step_size * search_direction self._set_from_flat_params(self.actor, new_flat_params) # calculate kl and if in bound, loss actually down - new_dist = self(b).dist - new_dratio = (new_dist.log_prob(b.act) - - b.logp_old).exp().float() + new_dist = self(minibatch).dist + new_dratio = ( + new_dist.log_prob(minibatch.act) - minibatch.logp_old + ).exp().float() new_dratio = new_dratio.reshape(new_dratio.size(0), -1).transpose(0, 1) - new_actor_loss = -(new_dratio * b.adv).mean() + new_actor_loss = -(new_dratio * minibatch.adv).mean() kl = kl_divergence(old_dist, new_dist).mean() if kl < self._delta and new_actor_loss < actor_loss: @@ -133,8 +135,8 @@ class TRPOPolicy(NPGPolicy): # optimize citirc for _ in range(self._optim_critic_iters): - value = self.critic(b.obs).flatten() - vf_loss = F.mse_loss(b.returns, value) + value = self.critic(minibatch.obs).flatten() + vf_loss = F.mse_loss(minibatch.returns, value) self.optim.zero_grad() vf_loss.backward() self.optim.step() diff --git a/tianshou/utils/net/common.py b/tianshou/utils/net/common.py index 5052c59..9d03b61 100644 --- a/tianshou/utils/net/common.py +++ b/tianshou/utils/net/common.py @@ -87,14 +87,14 @@ class MLP(nn.Module): self.output_dim = output_dim or hidden_sizes[-1] self.model = nn.Sequential(*model) - def forward(self, s: Union[np.ndarray, torch.Tensor]) -> torch.Tensor: + def forward(self, obs: Union[np.ndarray, torch.Tensor]) -> torch.Tensor: if self.device is not None: - s = torch.as_tensor( - s, + obs = torch.as_tensor( + obs, device=self.device, # type: ignore dtype=torch.float32, ) - return self.model(s.flatten(1)) # type: ignore + return self.model(obs.flatten(1)) # type: ignore class Net(nn.Module): @@ -187,12 +187,12 @@ class Net(nn.Module): def forward( self, - s: Union[np.ndarray, torch.Tensor], + obs: Union[np.ndarray, torch.Tensor], state: Any = None, info: Dict[str, Any] = {}, ) -> Tuple[torch.Tensor, Any]: - """Mapping: s -> flatten (inside MLP)-> logits.""" - logits = self.model(s) + """Mapping: obs -> flatten (inside MLP)-> logits.""" + logits = self.model(obs) bsz = logits.shape[0] if self.use_dueling: # Dueling DQN q, v = self.Q(logits), self.V(logits) @@ -235,38 +235,45 @@ class Recurrent(nn.Module): def forward( self, - s: Union[np.ndarray, torch.Tensor], + obs: Union[np.ndarray, torch.Tensor], state: Optional[Dict[str, torch.Tensor]] = None, info: Dict[str, Any] = {}, ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: - """Mapping: s -> flatten -> logits. + """Mapping: obs -> flatten -> logits. - In the evaluation mode, s should be with shape ``[bsz, dim]``; in the - training mode, s should be with shape ``[bsz, len, dim]``. See the code + In the evaluation mode, `obs` should be with shape ``[bsz, dim]``; in the + training mode, `obs` should be with shape ``[bsz, len, dim]``. See the code and comment for more detail. """ - s = torch.as_tensor(s, device=self.device, dtype=torch.float32) # type: ignore - # s [bsz, len, dim] (training) or [bsz, dim] (evaluation) + obs = torch.as_tensor( + obs, + device=self.device, # type: ignore + dtype=torch.float32, + ) + # obs [bsz, len, dim] (training) or [bsz, dim] (evaluation) # In short, the tensor's shape in training phase is longer than which # in evaluation phase. - if len(s.shape) == 2: - s = s.unsqueeze(-2) - s = self.fc1(s) + if len(obs.shape) == 2: + obs = obs.unsqueeze(-2) + obs = self.fc1(obs) self.nn.flatten_parameters() if state is None: - s, (h, c) = self.nn(s) + obs, (hidden, cell) = self.nn(obs) else: # we store the stack data in [bsz, len, ...] format # but pytorch rnn needs [len, bsz, ...] - s, (h, c) = self.nn( - s, ( - state["h"].transpose(0, 1).contiguous(), - state["c"].transpose(0, 1).contiguous() + obs, (hidden, cell) = self.nn( + obs, ( + state["hidden"].transpose(0, 1).contiguous(), + state["cell"].transpose(0, 1).contiguous() ) ) - s = self.fc2(s[:, -1]) + obs = self.fc2(obs[:, -1]) # please ensure the first dim is batch size: [bsz, len, ...] - return s, {"h": h.transpose(0, 1).detach(), "c": c.transpose(0, 1).detach()} + return obs, { + "hidden": hidden.transpose(0, 1).detach(), + "cell": cell.transpose(0, 1).detach() + } class ActorCritic(nn.Module): @@ -299,8 +306,8 @@ class DataParallelNet(nn.Module): super().__init__() self.net = nn.DataParallel(net) - def forward(self, s: Union[np.ndarray, torch.Tensor], *args: Any, + def forward(self, obs: Union[np.ndarray, torch.Tensor], *args: Any, **kwargs: Any) -> Tuple[Any, Any]: - if not isinstance(s, torch.Tensor): - s = torch.as_tensor(s, dtype=torch.float32) - return self.net(s=s.cuda(), *args, **kwargs) + if not isinstance(obs, torch.Tensor): + obs = torch.as_tensor(obs, dtype=torch.float32) + return self.net(obs=obs.cuda(), *args, **kwargs) diff --git a/tianshou/utils/net/continuous.py b/tianshou/utils/net/continuous.py index 38effe3..d68f385 100644 --- a/tianshou/utils/net/continuous.py +++ b/tianshou/utils/net/continuous.py @@ -58,14 +58,14 @@ class Actor(nn.Module): def forward( self, - s: Union[np.ndarray, torch.Tensor], + obs: Union[np.ndarray, torch.Tensor], state: Any = None, info: Dict[str, Any] = {}, ) -> Tuple[torch.Tensor, Any]: - """Mapping: s -> logits -> action.""" - logits, h = self.preprocess(s, state) + """Mapping: obs -> logits -> action.""" + logits, hidden = self.preprocess(obs, state) logits = self._max * torch.tanh(self.last(logits)) - return logits, h + return logits, hidden class Critic(nn.Module): @@ -110,24 +110,24 @@ class Critic(nn.Module): def forward( self, - s: Union[np.ndarray, torch.Tensor], - a: Optional[Union[np.ndarray, torch.Tensor]] = None, + obs: Union[np.ndarray, torch.Tensor], + act: Optional[Union[np.ndarray, torch.Tensor]] = None, info: Dict[str, Any] = {}, ) -> torch.Tensor: """Mapping: (s, a) -> logits -> Q(s, a).""" - s = torch.as_tensor( - s, + obs = torch.as_tensor( + obs, device=self.device, # type: ignore dtype=torch.float32, ).flatten(1) - if a is not None: - a = torch.as_tensor( - a, + if act is not None: + act = torch.as_tensor( + act, device=self.device, # type: ignore dtype=torch.float32, ).flatten(1) - s = torch.cat([s, a], dim=1) - logits, h = self.preprocess(s) + obs = torch.cat([obs, act], dim=1) + logits, hidden = self.preprocess(obs) logits = self.last(logits) return logits @@ -196,12 +196,12 @@ class ActorProb(nn.Module): def forward( self, - s: Union[np.ndarray, torch.Tensor], + obs: Union[np.ndarray, torch.Tensor], state: Any = None, info: Dict[str, Any] = {}, ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], Any]: - """Mapping: s -> logits -> (mu, sigma).""" - logits, h = self.preprocess(s, state) + """Mapping: obs -> logits -> (mu, sigma).""" + logits, hidden = self.preprocess(obs, state) mu = self.mu(logits) if not self._unbounded: mu = self._max * torch.tanh(mu) @@ -252,30 +252,34 @@ class RecurrentActorProb(nn.Module): def forward( self, - s: Union[np.ndarray, torch.Tensor], + obs: Union[np.ndarray, torch.Tensor], state: Optional[Dict[str, torch.Tensor]] = None, info: Dict[str, Any] = {}, ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], Dict[str, torch.Tensor]]: """Almost the same as :class:`~tianshou.utils.net.common.Recurrent`.""" - s = torch.as_tensor(s, device=self.device, dtype=torch.float32) # type: ignore - # s [bsz, len, dim] (training) or [bsz, dim] (evaluation) + obs = torch.as_tensor( + obs, + device=self.device, # type: ignore + dtype=torch.float32, + ) + # obs [bsz, len, dim] (training) or [bsz, dim] (evaluation) # In short, the tensor's shape in training phase is longer than which # in evaluation phase. - if len(s.shape) == 2: - s = s.unsqueeze(-2) + if len(obs.shape) == 2: + obs = obs.unsqueeze(-2) self.nn.flatten_parameters() if state is None: - s, (h, c) = self.nn(s) + obs, (hidden, cell) = self.nn(obs) else: # we store the stack data in [bsz, len, ...] format # but pytorch rnn needs [len, bsz, ...] - s, (h, c) = self.nn( - s, ( - state["h"].transpose(0, 1).contiguous(), - state["c"].transpose(0, 1).contiguous() + obs, (hidden, cell) = self.nn( + obs, ( + state["hidden"].transpose(0, 1).contiguous(), + state["cell"].transpose(0, 1).contiguous() ) ) - logits = s[:, -1] + logits = obs[:, -1] mu = self.mu(logits) if not self._unbounded: mu = self._max * torch.tanh(mu) @@ -287,8 +291,8 @@ class RecurrentActorProb(nn.Module): sigma = (self.sigma_param.view(shape) + torch.zeros_like(mu)).exp() # please ensure the first dim is batch size: [bsz, len, ...] return (mu, sigma), { - "h": h.transpose(0, 1).detach(), - "c": c.transpose(0, 1).detach() + "hidden": hidden.transpose(0, 1).detach(), + "cell": cell.transpose(0, 1).detach() } @@ -321,28 +325,32 @@ class RecurrentCritic(nn.Module): def forward( self, - s: Union[np.ndarray, torch.Tensor], - a: Optional[Union[np.ndarray, torch.Tensor]] = None, + obs: Union[np.ndarray, torch.Tensor], + act: Optional[Union[np.ndarray, torch.Tensor]] = None, info: Dict[str, Any] = {}, ) -> torch.Tensor: """Almost the same as :class:`~tianshou.utils.net.common.Recurrent`.""" - s = torch.as_tensor(s, device=self.device, dtype=torch.float32) # type: ignore - # s [bsz, len, dim] (training) or [bsz, dim] (evaluation) + obs = torch.as_tensor( + obs, + device=self.device, # type: ignore + dtype=torch.float32, + ) + # obs [bsz, len, dim] (training) or [bsz, dim] (evaluation) # In short, the tensor's shape in training phase is longer than which # in evaluation phase. - assert len(s.shape) == 3 + assert len(obs.shape) == 3 self.nn.flatten_parameters() - s, (h, c) = self.nn(s) - s = s[:, -1] - if a is not None: - a = torch.as_tensor( - a, + obs, (hidden, cell) = self.nn(obs) + obs = obs[:, -1] + if act is not None: + act = torch.as_tensor( + act, device=self.device, # type: ignore dtype=torch.float32, ) - s = torch.cat([s, a], dim=1) - s = self.fc2(s) - return s + obs = torch.cat([obs, act], dim=1) + obs = self.fc2(obs) + return obs class Perturbation(nn.Module): @@ -381,9 +389,9 @@ class Perturbation(nn.Module): def forward(self, state: torch.Tensor, action: torch.Tensor) -> torch.Tensor: # preprocess_net logits = self.preprocess_net(torch.cat([state, action], -1))[0] - a = self.phi * self.max_action * torch.tanh(logits) + noise = self.phi * self.max_action * torch.tanh(logits) # clip to [-max_action, max_action] - return (a + action).clamp(-self.max_action, self.max_action) + return (noise + action).clamp(-self.max_action, self.max_action) class VAE(nn.Module): @@ -434,31 +442,32 @@ class VAE(nn.Module): self, state: torch.Tensor, action: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: # [state, action] -> z , [state, z] -> action - z = self.encoder(torch.cat([state, action], -1)) + latent_z = self.encoder(torch.cat([state, action], -1)) # shape of z: (state.shape[:-1], hidden_dim) - mean = self.mean(z) + mean = self.mean(latent_z) # Clamped for numerical stability - log_std = self.log_std(z).clamp(-4, 15) + log_std = self.log_std(latent_z).clamp(-4, 15) std = torch.exp(log_std) # shape of mean, std: (state.shape[:-1], latent_dim) - z = mean + std * torch.randn_like(std) # (state.shape[:-1], latent_dim) + latent_z = mean + std * torch.randn_like(std) # (state.shape[:-1], latent_dim) - u = self.decode(state, z) # (state.shape[:-1], action_dim) - return u, mean, std + reconstruction = self.decode(state, latent_z) # (state.shape[:-1], action_dim) + return reconstruction, mean, std def decode( self, state: torch.Tensor, - z: Union[torch.Tensor, None] = None + latent_z: Union[torch.Tensor, None] = None ) -> torch.Tensor: # decode(state) -> action - if z is None: + if latent_z is None: # state.shape[0] may be batch_size # latent vector clipped to [-0.5, 0.5] - z = torch.randn(state.shape[:-1] + (self.latent_dim, )) \ + latent_z = torch.randn(state.shape[:-1] + (self.latent_dim, )) \ .to(self.device).clamp(-0.5, 0.5) # decode z with state! - return self.max_action * torch.tanh(self.decoder(torch.cat([state, z], -1))) + return self.max_action * \ + torch.tanh(self.decoder(torch.cat([state, latent_z], -1))) diff --git a/tianshou/utils/net/discrete.py b/tianshou/utils/net/discrete.py index f79a9d1..a81e1c6 100644 --- a/tianshou/utils/net/discrete.py +++ b/tianshou/utils/net/discrete.py @@ -59,16 +59,16 @@ class Actor(nn.Module): def forward( self, - s: Union[np.ndarray, torch.Tensor], + obs: Union[np.ndarray, torch.Tensor], state: Any = None, info: Dict[str, Any] = {}, ) -> Tuple[torch.Tensor, Any]: r"""Mapping: s -> Q(s, \*).""" - logits, h = self.preprocess(s, state) + logits, hidden = self.preprocess(obs, state) logits = self.last(logits) if self.softmax_output: logits = F.softmax(logits, dim=-1) - return logits, h + return logits, hidden class Critic(nn.Module): @@ -114,10 +114,10 @@ class Critic(nn.Module): ) def forward( - self, s: Union[np.ndarray, torch.Tensor], **kwargs: Any + self, obs: Union[np.ndarray, torch.Tensor], **kwargs: Any ) -> torch.Tensor: """Mapping: s -> V(s).""" - logits, _ = self.preprocess(s, state=kwargs.get("state", None)) + logits, _ = self.preprocess(obs, state=kwargs.get("state", None)) return self.last(logits) @@ -199,10 +199,10 @@ class ImplicitQuantileNetwork(Critic): ).to(device) def forward( # type: ignore - self, s: Union[np.ndarray, torch.Tensor], sample_size: int, **kwargs: Any + self, obs: Union[np.ndarray, torch.Tensor], sample_size: int, **kwargs: Any ) -> Tuple[Any, torch.Tensor]: r"""Mapping: s -> Q(s, \*).""" - logits, h = self.preprocess(s, state=kwargs.get("state", None)) + logits, hidden = self.preprocess(obs, state=kwargs.get("state", None)) # Sample fractions. batch_size = logits.size(0) taus = torch.rand( @@ -211,7 +211,7 @@ class ImplicitQuantileNetwork(Critic): embedding = (logits.unsqueeze(1) * self.embed_model(taus)).view(batch_size * sample_size, -1) out = self.last(embedding).view(batch_size, sample_size, -1).transpose(1, 2) - return (out, taus), h + return (out, taus), hidden class FractionProposalNetwork(nn.Module): @@ -235,17 +235,17 @@ class FractionProposalNetwork(nn.Module): self.embedding_dim = embedding_dim def forward( - self, state_embeddings: torch.Tensor + self, obs_embeddings: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: # Calculate (log of) probabilities q_i in the paper. - m = torch.distributions.Categorical(logits=self.net(state_embeddings)) - taus_1_N = torch.cumsum(m.probs, dim=1) + dist = torch.distributions.Categorical(logits=self.net(obs_embeddings)) + taus_1_N = torch.cumsum(dist.probs, dim=1) # Calculate \tau_i (i=0,...,N). taus = F.pad(taus_1_N, (1, 0)) # Calculate \hat \tau_i (i=0,...,N-1). tau_hats = (taus[:, :-1] + taus[:, 1:]).detach() / 2.0 # Calculate entropies of value distributions. - entropies = m.entropy() + entropies = dist.entropy() return taus, tau_hats, entropies @@ -294,13 +294,13 @@ class FullQuantileFunction(ImplicitQuantileNetwork): return quantiles def forward( # type: ignore - self, s: Union[np.ndarray, torch.Tensor], + self, obs: Union[np.ndarray, torch.Tensor], propose_model: FractionProposalNetwork, fractions: Optional[Batch] = None, **kwargs: Any ) -> Tuple[Any, torch.Tensor]: r"""Mapping: s -> Q(s, \*).""" - logits, h = self.preprocess(s, state=kwargs.get("state", None)) + logits, hidden = self.preprocess(obs, state=kwargs.get("state", None)) # Propose fractions if fractions is None: taus, tau_hats, entropies = propose_model(logits.detach()) @@ -313,7 +313,7 @@ class FullQuantileFunction(ImplicitQuantileNetwork): if self.training: with torch.no_grad(): quantiles_tau = self._compute_quantiles(logits, taus[:, 1:-1]) - return (quantiles, fractions, quantiles_tau), h + return (quantiles, fractions, quantiles_tau), hidden class NoisyLinear(nn.Module): diff --git a/tianshou/utils/statistics.py b/tianshou/utils/statistics.py index a81af60..aadbbbc 100644 --- a/tianshou/utils/statistics.py +++ b/tianshou/utils/statistics.py @@ -31,20 +31,20 @@ class MovAvg(object): self.banned = [np.inf, np.nan, -np.inf] def add( - self, x: Union[Number, np.number, list, np.ndarray, torch.Tensor] + self, data_array: Union[Number, np.number, list, np.ndarray, torch.Tensor] ) -> float: """Add a scalar into :class:`MovAvg`. You can add ``torch.Tensor`` with only one element, a python scalar, or a list of python scalar. """ - if isinstance(x, torch.Tensor): - x = x.flatten().cpu().numpy() - if np.isscalar(x): - x = [x] - for i in x: # type: ignore - if i not in self.banned: - self.cache.append(i) + if isinstance(data_array, torch.Tensor): + data_array = data_array.flatten().cpu().numpy() + if np.isscalar(data_array): + data_array = [data_array] + for number in data_array: # type: ignore + if number not in self.banned: + self.cache.append(number) if self.size > 0 and len(self.cache) > self.size: self.cache = self.cache[-self.size:] return self.get() @@ -80,10 +80,10 @@ class RunningMeanStd(object): self.mean, self.var = mean, std self.count = 0 - def update(self, x: np.ndarray) -> None: + def update(self, data_array: np.ndarray) -> None: """Add a batch of item into RMS with the same shape, modify mean/var/count.""" - batch_mean, batch_var = np.mean(x, axis=0), np.var(x, axis=0) - batch_count = len(x) + batch_mean, batch_var = np.mean(data_array, axis=0), np.var(data_array, axis=0) + batch_count = len(data_array) delta = batch_mean - self.mean total_count = self.count + batch_count