From e3ee415b1a21338e3ac9630261d883ab78f9a560 Mon Sep 17 00:00:00 2001 From: Trinkle23897 Date: Mon, 8 Feb 2021 12:59:37 +0800 Subject: [PATCH 1/4] temporary fix numpy<1.20.0 (#281) --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 4b22472..f8736fa 100644 --- a/setup.py +++ b/setup.py @@ -47,7 +47,7 @@ setup( install_requires=[ "gym>=0.15.4", "tqdm", - "numpy!=1.16.0", # https://github.com/numpy/numpy/issues/12793 + "numpy!=1.16.0,<1.20.0", # https://github.com/numpy/numpy/issues/12793 "tensorboard", "torch>=1.4.0", "numba>=0.51.0", From f528131da1de604f8158a1fe132e1fada71950e0 Mon Sep 17 00:00:00 2001 From: ChenDRAG <40993476+ChenDRAG@users.noreply.github.com> Date: Tue, 9 Feb 2021 17:13:40 +0800 Subject: [PATCH 2/4] =?UTF-8?q?hotfix=EF=BC=9Afix=20test=20failure=20in=20?= =?UTF-8?q?cuda=20environment=20(#289)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- examples/atari/runnable/pong_a2c.py | 4 ++-- examples/atari/runnable/pong_ppo.py | 4 ++-- test/discrete/test_a2c_with_il.py | 6 +++--- test/discrete/test_ppo.py | 4 ++-- test/discrete/test_sac.py | 9 ++++++--- tianshou/utils/net/continuous.py | 11 +++++++---- tianshou/utils/net/discrete.py | 10 ++++++++-- 7 files changed, 30 insertions(+), 18 deletions(-) diff --git a/examples/atari/runnable/pong_a2c.py b/examples/atari/runnable/pong_a2c.py index 100ae24..ffed169 100644 --- a/examples/atari/runnable/pong_a2c.py +++ b/examples/atari/runnable/pong_a2c.py @@ -65,8 +65,8 @@ def test_a2c(args=get_args()): # model net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) - actor = Actor(net, args.action_shape).to(args.device) - critic = Critic(net).to(args.device) + actor = Actor(net, args.action_shape, device=args.device).to(args.device) + critic = Critic(net, device=args.device).to(args.device) optim = torch.optim.Adam(set( actor.parameters()).union(critic.parameters()), lr=args.lr) dist = torch.distributions.Categorical diff --git a/examples/atari/runnable/pong_ppo.py b/examples/atari/runnable/pong_ppo.py index 55219da..35ed0e7 100644 --- a/examples/atari/runnable/pong_ppo.py +++ b/examples/atari/runnable/pong_ppo.py @@ -65,8 +65,8 @@ def test_ppo(args=get_args()): # model net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) - actor = Actor(net, args.action_shape).to(args.device) - critic = Critic(net).to(args.device) + actor = Actor(net, args.action_shape, device=args.device).to(args.device) + critic = Critic(net, device=args.device).to(args.device) optim = torch.optim.Adam(set( actor.parameters()).union(critic.parameters()), lr=args.lr) dist = torch.distributions.Categorical diff --git a/test/discrete/test_a2c_with_il.py b/test/discrete/test_a2c_with_il.py index 90f6681..08759f9 100644 --- a/test/discrete/test_a2c_with_il.py +++ b/test/discrete/test_a2c_with_il.py @@ -68,8 +68,8 @@ def test_a2c_with_il(args=get_args()): # model net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) - actor = Actor(net, args.action_shape).to(args.device) - critic = Critic(net).to(args.device) + actor = Actor(net, args.action_shape, device=args.device).to(args.device) + critic = Critic(net, device=args.device).to(args.device) optim = torch.optim.Adam(set( actor.parameters()).union(critic.parameters()), lr=args.lr) dist = torch.distributions.Categorical @@ -113,7 +113,7 @@ def test_a2c_with_il(args=get_args()): env.spec.reward_threshold = 190 # lower the goal net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) - net = Actor(net, args.action_shape).to(args.device) + net = Actor(net, args.action_shape, device=args.device).to(args.device) optim = torch.optim.Adam(net.parameters(), lr=args.il_lr) il_policy = ImitationPolicy(net, optim, mode='discrete') il_test_collector = Collector( diff --git a/test/discrete/test_ppo.py b/test/discrete/test_ppo.py index 231ad50..e2e671c 100644 --- a/test/discrete/test_ppo.py +++ b/test/discrete/test_ppo.py @@ -68,8 +68,8 @@ def test_ppo(args=get_args()): # model net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) - actor = Actor(net, args.action_shape).to(args.device) - critic = Critic(net).to(args.device) + actor = Actor(net, args.action_shape, device=args.device).to(args.device) + critic = Critic(net, device=args.device).to(args.device) # orthogonal initialization for m in list(actor.modules()) + list(critic.modules()): if isinstance(m, torch.nn.Linear): diff --git a/test/discrete/test_sac.py b/test/discrete/test_sac.py index 67da3fb..3d3df6f 100644 --- a/test/discrete/test_sac.py +++ b/test/discrete/test_sac.py @@ -62,15 +62,18 @@ def test_discrete_sac(args=get_args()): # model net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) - actor = Actor(net, args.action_shape, softmax_output=False).to(args.device) + actor = Actor(net, args.action_shape, + softmax_output=False, device=args.device).to(args.device) actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) net_c1 = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) - critic1 = Critic(net_c1, last_size=args.action_shape).to(args.device) + critic1 = Critic(net_c1, last_size=args.action_shape, + device=args.device).to(args.device) critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr) net_c2 = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) - critic2 = Critic(net_c2, last_size=args.action_shape).to(args.device) + critic2 = Critic(net_c2, last_size=args.action_shape, + device=args.device).to(args.device) critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) # better not to use auto alpha in CartPole diff --git a/tianshou/utils/net/continuous.py b/tianshou/utils/net/continuous.py index 333c2ab..a8f6675 100644 --- a/tianshou/utils/net/continuous.py +++ b/tianshou/utils/net/continuous.py @@ -49,7 +49,8 @@ class Actor(nn.Module): self.output_dim = np.prod(action_shape) input_dim = getattr(preprocess_net, "output_dim", preprocess_net_output_dim) - self.last = MLP(input_dim, self.output_dim, hidden_sizes) + self.last = MLP(input_dim, self.output_dim, + hidden_sizes, device=self.device) self._max = max_action def forward( @@ -98,7 +99,7 @@ class Critic(nn.Module): self.output_dim = 1 input_dim = getattr(preprocess_net, "output_dim", preprocess_net_output_dim) - self.last = MLP(input_dim, 1, hidden_sizes) + self.last = MLP(input_dim, 1, hidden_sizes, device=self.device) def forward( self, @@ -164,10 +165,12 @@ class ActorProb(nn.Module): self.output_dim = np.prod(action_shape) input_dim = getattr(preprocess_net, "output_dim", preprocess_net_output_dim) - self.mu = MLP(input_dim, self.output_dim, hidden_sizes) + self.mu = MLP(input_dim, self.output_dim, + hidden_sizes, device=self.device) self._c_sigma = conditioned_sigma if conditioned_sigma: - self.sigma = MLP(input_dim, self.output_dim, hidden_sizes) + self.sigma = MLP(input_dim, self.output_dim, + hidden_sizes, device=self.device) else: self.sigma_param = nn.Parameter(torch.zeros(self.output_dim, 1)) self._max = max_action diff --git a/tianshou/utils/net/discrete.py b/tianshou/utils/net/discrete.py index 05c0236..fc7c9b0 100644 --- a/tianshou/utils/net/discrete.py +++ b/tianshou/utils/net/discrete.py @@ -40,13 +40,16 @@ class Actor(nn.Module): hidden_sizes: Sequence[int] = (), softmax_output: bool = True, preprocess_net_output_dim: Optional[int] = None, + device: Union[str, int, torch.device] = "cpu", ) -> None: super().__init__() + self.device = device self.preprocess = preprocess_net self.output_dim = np.prod(action_shape) input_dim = getattr(preprocess_net, "output_dim", preprocess_net_output_dim) - self.last = MLP(input_dim, self.output_dim, hidden_sizes) + self.last = MLP(input_dim, self.output_dim, + hidden_sizes, device=self.device) self.softmax_output = softmax_output def forward( @@ -91,13 +94,16 @@ class Critic(nn.Module): hidden_sizes: Sequence[int] = (), last_size: int = 1, preprocess_net_output_dim: Optional[int] = None, + device: Union[str, int, torch.device] = "cpu", ) -> None: super().__init__() + self.device = device self.preprocess = preprocess_net self.output_dim = last_size input_dim = getattr(preprocess_net, "output_dim", preprocess_net_output_dim) - self.last = MLP(input_dim, last_size, hidden_sizes) + self.last = MLP(input_dim, last_size, + hidden_sizes, device=self.device) def forward( self, s: Union[np.ndarray, torch.Tensor], **kwargs: Any From d003c8e566267ec8a497d568e0f01de16d04b787 Mon Sep 17 00:00:00 2001 From: n+e Date: Tue, 16 Feb 2021 09:01:54 +0800 Subject: [PATCH 3/4] fix 2 bugs of batch (#284) 1. `_create_value(Batch(a={}, b=[1, 2, 3]), 10, False)` before: ```python TypeError: cannot concatenate with Batch() which is scalar ``` after: ```python Batch( a: Batch(), b: array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), ) ``` 2. creating keys in a batch's subkey, e.g. ```python a = Batch(info={"key1": [0, 1], "key2": [2, 3]}) a[0] = Batch(info={"key1": 2, "key3": 4}) print(a) ``` before: ```python Batch( info: Batch( key1: array([0, 1]), key2: array([0, 3]), ), ) ``` after: ```python ValueError: Creating keys is not supported by item assignment. ``` 3. small optimization for `Batch.stack_` and `Batch.cat_` --- test/base/test_batch.py | 12 ++++++++++- tianshou/data/batch.py | 47 +++++++++++++++++++++++++++++------------ tianshou/data/buffer.py | 2 +- 3 files changed, 45 insertions(+), 16 deletions(-) diff --git a/test/base/test_batch.py b/test/base/test_batch.py index 650d560..4553edf 100644 --- a/test/base/test_batch.py +++ b/test/base/test_batch.py @@ -19,6 +19,8 @@ def test_batch(): assert not Batch(a=np.float64(1.0)).is_empty() assert len(Batch(a=[1, 2, 3], b={'c': {}})) == 3 assert not Batch(a=[1, 2, 3]).is_empty() + b = Batch({'a': [4, 4], 'b': [5, 5]}, c=[None, None]) + assert b.c.dtype == np.object b = Batch() b.update() assert b.is_empty() @@ -143,8 +145,10 @@ def test_batch(): assert batch3.a.d.e[0] == 4.0 batch3.a.d[0] = Batch(f=5.0) assert batch3.a.d.f[0] == 5.0 - with pytest.raises(KeyError): + with pytest.raises(ValueError): batch3.a.d[0] = Batch(f=5.0, g=0.0) + with pytest.raises(ValueError): + batch3[0] = Batch(a={"c": 2, "e": 1}) # auto convert batch4 = Batch(a=np.array(['a', 'b'])) assert batch4.a.dtype == np.object # auto convert to np.object @@ -333,6 +337,12 @@ def test_batch_cat_and_stack(): assert torch.allclose(test.b, ans.b) assert np.allclose(test.common.c, ans.common.c) + # test with illegal input format + with pytest.raises(ValueError): + Batch.cat([[Batch(a=1)], [Batch(a=1)]]) + with pytest.raises(ValueError): + Batch.stack([[Batch(a=1)], [Batch(a=1)]]) + # exceptions assert Batch.cat([]).is_empty() assert Batch.stack([]).is_empty() diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index fe35160..4f15622 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -8,12 +8,6 @@ from collections.abc import Collection from typing import Any, List, Dict, Union, Iterator, Optional, Iterable, \ Sequence -# Disable pickle warning related to torch, since it has been removed -# on torch master branch. See Pull Request #39003 for details: -# https://github.com/pytorch/pytorch/pull/39003 -warnings.filterwarnings( - "ignore", message="pickle support for Storage will be removed in 1.5.") - def _is_batch_set(data: Any) -> bool: # Batch set is a list/tuple of dict/Batch objects, @@ -91,6 +85,9 @@ def _create_value( has_shape = isinstance(inst, (np.ndarray, torch.Tensor)) is_scalar = _is_scalar(inst) if not stack and is_scalar: + # _create_value(Batch(a={}, b=[1, 2, 3]), 10, False) will fail here + if isinstance(inst, Batch) and inst.is_empty(recurse=True): + return inst # should never hit since it has already checked in Batch.cat_ # here we do not consider scalar types, following the behavior of numpy # which does not support concatenation of zero-dimensional arrays @@ -257,7 +254,7 @@ class Batch: raise ValueError("Batch does not supported tensor assignment. " "Use a compatible Batch or dict instead.") if not set(value.keys()).issubset(self.__dict__.keys()): - raise KeyError( + raise ValueError( "Creating keys is not supported by item assignment.") for key, val in self.items(): try: @@ -449,12 +446,21 @@ class Batch: """Concatenate a list of (or one) Batch objects into current batch.""" if isinstance(batches, Batch): batches = [batches] - if len(batches) == 0: + # check input format + batch_list = [] + for b in batches: + if isinstance(b, dict): + if len(b) > 0: + batch_list.append(Batch(b)) + elif isinstance(b, Batch): + # x.is_empty() means that x is Batch() and should be ignored + if not b.is_empty(): + batch_list.append(b) + else: + raise ValueError(f"Cannot concatenate {type(b)} in Batch.cat_") + if len(batch_list) == 0: return - batches = [x if isinstance(x, Batch) else Batch(x) for x in batches] - - # x.is_empty() means that x is Batch() and should be ignored - batches = [x for x in batches if not x.is_empty()] + batches = batch_list try: # x.is_empty(recurse=True) here means x is a nested empty batch # like Batch(a=Batch), and we have to treat it as length zero and @@ -496,9 +502,22 @@ class Batch: self, batches: Sequence[Union[dict, "Batch"]], axis: int = 0 ) -> None: """Stack a list of Batch object into current batch.""" - if len(batches) == 0: + # check input format + batch_list = [] + for b in batches: + if isinstance(b, dict): + if len(b) > 0: + batch_list.append(Batch(b)) + elif isinstance(b, Batch): + # x.is_empty() means that x is Batch() and should be ignored + if not b.is_empty(): + batch_list.append(b) + else: + raise ValueError( + f"Cannot concatenate {type(b)} in Batch.stack_") + if len(batch_list) == 0: return - batches = [x if isinstance(x, Batch) else Batch(x) for x in batches] + batches = batch_list if not self.is_empty(): batches = [self] + batches # collect non-empty keys diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index 74299df..11d1ca1 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -203,7 +203,7 @@ class ReplayBuffer: ) try: value[self._index] = inst - except KeyError: + except ValueError: for key in set(inst.keys()).difference(value.__dict__.keys()): value.__dict__[key] = _create_value(inst[key], self._maxsize) value[self._index] = inst From cb65b56b13949353f6f8728d05721bd7419fd81c Mon Sep 17 00:00:00 2001 From: n+e Date: Tue, 16 Feb 2021 09:31:46 +0800 Subject: [PATCH 4/4] v0.3.2 (#292) Throw a warning in ListReplayBuffer. This version update is needed because of #289, the previous v0.3.1 cannot work well under torch<=1.6.0 with cuda environment. --- tianshou/__init__.py | 2 +- tianshou/data/buffer.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/tianshou/__init__.py b/tianshou/__init__.py index 0f7a59c..380167b 100644 --- a/tianshou/__init__.py +++ b/tianshou/__init__.py @@ -1,7 +1,7 @@ from tianshou import data, env, utils, policy, trainer, exploration -__version__ = "0.3.1" +__version__ = "0.3.2" __all__ = [ "env", diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index 11d1ca1..1d36ce1 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -1,5 +1,6 @@ import h5py import torch +import warnings import numpy as np from numbers import Number from typing import Any, Dict, List, Tuple, Union, Optional @@ -412,6 +413,7 @@ class ListReplayBuffer(ReplayBuffer): def __init__(self, **kwargs: Any) -> None: super().__init__(size=0, ignore_obs_next=False, **kwargs) + warnings.warn("ListReplayBuffer will be removed in version 0.4.0.") def sample(self, batch_size: int) -> Tuple[Batch, np.ndarray]: raise NotImplementedError("ListReplayBuffer cannot be sampled!")