Formalize variable names (#509)

Co-authored-by: Jiayi Weng <trinkle23897@gmail.com>
This commit is contained in:
ChenDRAG 2022-01-30 00:53:56 +08:00 committed by GitHub
parent bc53ead273
commit c25926dd8f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
42 changed files with 607 additions and 581 deletions

View File

@ -350,7 +350,7 @@ But the state stored in the buffer may be a shallow-copy. To make sure each of y
def reset():
return copy.deepcopy(self.graph)
def step(a):
def step(action):
...
return copy.deepcopy(self.graph), reward, done, {}
@ -391,13 +391,13 @@ In addition, legal actions in multi-agent RL often vary with timestep (just like
The above description gives rise to the following formulation of multi-agent RL:
::
action = policy(state, agent_id, mask)
(next_state, next_agent_id, next_mask), reward = env.step(action)
act = policy(state, agent_id, mask)
(next_state, next_agent_id, next_mask), reward = env.step(act)
By constructing a new state ``state_ = (state, agent_id, mask)``, essentially we can return to the typical formulation of RL:
::
action = policy(state_)
next_state_, reward = env.step(action)
act = policy(state_)
next_state_, reward = env.step(act)
Following this idea, we write a tiny example of playing `Tic Tac Toe <https://en.wikipedia.org/wiki/Tic-tac-toe>`_ against a random player by using a Q-learning algorithm. The tutorial is at :doc:`/tutorials/tictactoe`.

View File

@ -298,14 +298,14 @@ where :math:`\gamma` is the discount factor, :math:`\gamma \in [0, 1]`. Here is
::
# pseudocode, cannot work
s = env.reset()
obs = env.reset()
buffer = Buffer(size=10000)
agent = DQN()
for i in range(int(1e6)):
a = agent.compute_action(s)
s_, r, d, _ = env.step(a)
buffer.store(s, a, s_, r, d)
s = s_
act = agent.compute_action(obs)
obs_next, rew, done, _ = env.step(act)
buffer.store(obs, act, obs_next, rew, done)
obs = obs_next
if i % 1000 == 0:
b_s, b_a, b_s_, b_r, b_d = buffer.get(size=64)
# compute 2-step returns. How?
@ -390,14 +390,14 @@ We give a high-level explanation through the pseudocode used in section :ref:`pr
::
# pseudocode, cannot work # methods in tianshou
s = env.reset()
obs = env.reset()
buffer = Buffer(size=10000) # buffer = tianshou.data.ReplayBuffer(size=10000)
agent = DQN() # policy.__init__(...)
for i in range(int(1e6)): # done in trainer
a = agent.compute_action(s) # act = policy(batch, ...).act
s_, r, d, _ = env.step(a) # collector.collect(...)
buffer.store(s, a, s_, r, d) # collector.collect(...)
s = s_ # collector.collect(...)
act = agent.compute_action(obs) # act = policy(batch, ...).act
obs_next, rew, done, _ = env.step(act) # collector.collect(...)
buffer.store(obs, act, obs_next, rew, done) # collector.collect(...)
obs = obs_next # collector.collect(...)
if i % 1000 == 0: # done in trainer
# the following is done in policy.update(batch_size, buffer)
b_s, b_a, b_s_, b_r, b_d = buffer.get(size=64) # batch, indices = buffer.sample(batch_size)

View File

@ -42,13 +42,13 @@ class DQN(nn.Module):
def forward(
self,
x: Union[np.ndarray, torch.Tensor],
obs: Union[np.ndarray, torch.Tensor],
state: Optional[Any] = None,
info: Dict[str, Any] = {},
) -> Tuple[torch.Tensor, Any]:
r"""Mapping: x -> Q(x, \*)."""
x = torch.as_tensor(x, device=self.device, dtype=torch.float32)
return self.net(x), state
r"""Mapping: s -> Q(s, \*)."""
obs = torch.as_tensor(obs, device=self.device, dtype=torch.float32)
return self.net(obs), state
class C51(DQN):
@ -73,15 +73,15 @@ class C51(DQN):
def forward(
self,
x: Union[np.ndarray, torch.Tensor],
obs: Union[np.ndarray, torch.Tensor],
state: Optional[Any] = None,
info: Dict[str, Any] = {},
) -> Tuple[torch.Tensor, Any]:
r"""Mapping: x -> Z(x, \*)."""
x, state = super().forward(x)
x = x.view(-1, self.num_atoms).softmax(dim=-1)
x = x.view(-1, self.action_num, self.num_atoms)
return x, state
obs, state = super().forward(obs)
obs = obs.view(-1, self.num_atoms).softmax(dim=-1)
obs = obs.view(-1, self.action_num, self.num_atoms)
return obs, state
class Rainbow(DQN):
@ -127,22 +127,22 @@ class Rainbow(DQN):
def forward(
self,
x: Union[np.ndarray, torch.Tensor],
obs: Union[np.ndarray, torch.Tensor],
state: Optional[Any] = None,
info: Dict[str, Any] = {},
) -> Tuple[torch.Tensor, Any]:
r"""Mapping: x -> Z(x, \*)."""
x, state = super().forward(x)
q = self.Q(x)
obs, state = super().forward(obs)
q = self.Q(obs)
q = q.view(-1, self.action_num, self.num_atoms)
if self._is_dueling:
v = self.V(x)
v = self.V(obs)
v = v.view(-1, 1, self.num_atoms)
logits = q - q.mean(dim=1, keepdim=True) + v
else:
logits = q
y = logits.softmax(dim=2)
return y, state
probs = logits.softmax(dim=2)
return probs, state
class QRDQN(DQN):
@ -168,11 +168,11 @@ class QRDQN(DQN):
def forward(
self,
x: Union[np.ndarray, torch.Tensor],
obs: Union[np.ndarray, torch.Tensor],
state: Optional[Any] = None,
info: Dict[str, Any] = {},
) -> Tuple[torch.Tensor, Any]:
r"""Mapping: x -> Z(x, \*)."""
x, state = super().forward(x)
x = x.view(-1, self.action_num, self.num_quantiles)
return x, state
obs, state = super().forward(obs)
obs = obs.view(-1, self.action_num, self.num_quantiles)
return obs, state

View File

@ -56,16 +56,16 @@ class Wrapper(gym.Wrapper):
self.rm_done = rm_done
def step(self, action):
r = 0.0
rew_sum = 0.0
for _ in range(self.action_repeat):
obs, reward, done, info = self.env.step(action)
obs, rew, done, info = self.env.step(action)
# remove done reward penalty
if not done or not self.rm_done:
r = r + reward
rew_sum = rew_sum + rew
if done:
break
# scale reward
return obs, self.reward_scale * r, done, info
return obs, self.reward_scale * rew_sum, done, info
def test_sac_bipedal(args=get_args()):

View File

@ -86,10 +86,10 @@ class MyTestEnv(gym.Env):
def _get_reward(self):
"""Generate a non-scalar reward if ma_rew is True."""
x = int(self.done)
end_flag = int(self.done)
if self.ma_rew > 0:
return [x] * self.ma_rew
return x
return [end_flag] * self.ma_rew
return end_flag
def _get_state(self):
"""Generate state(observation) of MyTestEnv"""

View File

@ -32,10 +32,12 @@ def test_replaybuffer(size=10, bufsize=20):
assert str(buf) == buf.__class__.__name__ + '()'
obs = env.reset()
action_list = [1] * 5 + [0] * 10 + [1] * 10
for i, a in enumerate(action_list):
obs_next, rew, done, info = env.step(a)
for i, act in enumerate(action_list):
obs_next, rew, done, info = env.step(act)
buf.add(
Batch(obs=obs, act=[a], rew=rew, done=done, obs_next=obs_next, info=info)
Batch(
obs=obs, act=[act], rew=rew, done=done, obs_next=obs_next, info=info
)
)
obs = obs_next
assert len(buf) == min(bufsize, i + 1)
@ -220,11 +222,11 @@ def test_priortized_replaybuffer(size=32, bufsize=15):
buf2 = PrioritizedVectorReplayBuffer(bufsize, buffer_num=3, alpha=0.5, beta=0.5)
obs = env.reset()
action_list = [1] * 5 + [0] * 10 + [1] * 10
for i, a in enumerate(action_list):
obs_next, rew, done, info = env.step(a)
for i, act in enumerate(action_list):
obs_next, rew, done, info = env.step(act)
batch = Batch(
obs=obs,
act=a,
act=act,
rew=rew,
done=done,
obs_next=obs_next,

View File

@ -331,20 +331,20 @@ def test_collector_with_ma():
policy = MyPolicy()
c0 = Collector(policy, env, ReplayBuffer(size=100), Logger.single_preprocess_fn)
# n_step=3 will collect a full episode
r = c0.collect(n_step=3)['rews']
assert len(r) == 0
r = c0.collect(n_episode=2)['rews']
assert r.shape == (2, 4) and np.all(r == 1)
rew = c0.collect(n_step=3)['rews']
assert len(rew) == 0
rew = c0.collect(n_episode=2)['rews']
assert rew.shape == (2, 4) and np.all(rew == 1)
env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0, ma_rew=4) for i in [2, 3, 4, 5]]
envs = DummyVectorEnv(env_fns)
c1 = Collector(
policy, envs, VectorReplayBuffer(total_size=100, buffer_num=4),
Logger.single_preprocess_fn
)
r = c1.collect(n_step=12)['rews']
assert r.shape == (2, 4) and np.all(r == 1), r
r = c1.collect(n_episode=8)['rews']
assert r.shape == (8, 4) and np.all(r == 1)
rew = c1.collect(n_step=12)['rews']
assert rew.shape == (2, 4) and np.all(rew == 1), rew
rew = c1.collect(n_episode=8)['rews']
assert rew.shape == (8, 4) and np.all(rew == 1)
batch, _ = c1.buffer.sample(10)
print(batch)
c0.buffer.update(c1.buffer)
@ -446,8 +446,8 @@ def test_collector_with_ma():
policy, envs, VectorReplayBuffer(total_size=100, buffer_num=4, stack_num=4),
Logger.single_preprocess_fn
)
r = c2.collect(n_episode=10)['rews']
assert r.shape == (10, 4) and np.all(r == 1)
rew = c2.collect(n_episode=10)['rews']
assert rew.shape == (10, 4) and np.all(rew == 1)
batch, _ = c2.buffer.sample(10)

View File

@ -58,22 +58,22 @@ def test_async_env(size=10000, num=8, sleep=0.1):
# should be smaller
action_list = [1] * num + [0] * (num * 2) + [1] * (num * 4)
current_idx_start = 0
action = action_list[:num]
act = action_list[:num]
env_ids = list(range(num))
o = []
spent_time = time.time()
while current_idx_start < len(action_list):
A, B, C, D = v.step(action=action, id=env_ids)
A, B, C, D = v.step(action=act, id=env_ids)
b = Batch({'obs': A, 'rew': B, 'done': C, 'info': D})
env_ids = b.info.env_id
o.append(b)
current_idx_start += len(action)
current_idx_start += len(act)
# len of action may be smaller than len(A) in the end
action = action_list[current_idx_start:current_idx_start + len(A)]
act = action_list[current_idx_start:current_idx_start + len(A)]
# truncate env_ids with the first terms
# typically len(env_ids) == len(A) == len(action), except for the
# last batch when actions are not enough
env_ids = env_ids[:len(action)]
env_ids = env_ids[:len(act)]
spent_time = time.time() - spent_time
Batch.cat(o)
v.close()

View File

@ -142,11 +142,11 @@ def compute_nstep_return_base(nstep, gamma, buffer, indices):
returns = np.zeros_like(indices, dtype=float)
buf_len = len(buffer)
for i in range(len(indices)):
flag, r = False, 0.
flag, rew = False, 0.
real_step_n = nstep
for n in range(nstep):
idx = (indices[i] + n) % buf_len
r += buffer.rew[idx] * gamma**n
rew += buffer.rew[idx] * gamma**n
if buffer.done[idx]:
if not (
hasattr(buffer, 'info') and buffer.info['TimeLimit.truncated'][idx]
@ -156,8 +156,8 @@ def compute_nstep_return_base(nstep, gamma, buffer, indices):
break
if not flag:
idx = (indices[i] + real_step_n - 1) % buf_len
r += to_numpy(target_q_fn(buffer, idx)) * gamma**real_step_n
returns[i] = r
rew += to_numpy(target_q_fn(buffer, idx)) * gamma**real_step_n
returns[i] = rew
return returns

View File

@ -41,7 +41,7 @@ def gomoku(args=get_args()):
return TicTacToeEnv(args.board_size, args.win_size)
test_envs = DummyVectorEnv([env_func for _ in range(args.test_num)])
for r in range(args.self_play_round):
for round in range(args.self_play_round):
rews = []
agent_learn.set_eps(0.0)
# compute the reward over previous learner
@ -66,12 +66,12 @@ def gomoku(args=get_args()):
# previous learner can only be used for forward
agent.forward = opponent.forward
args.model_save_path = os.path.join(
args.logdir, 'Gomoku', 'dqn', f'policy_round_{r}_epoch_{epoch}.pth'
args.logdir, 'Gomoku', 'dqn', f'policy_round_{round}_epoch_{epoch}.pth'
)
result, agent_learn = train_agent(
args, agent_learn=agent_learn, agent_opponent=agent, optim=optim
)
print(f'round_{r}_epoch_{epoch}')
print(f'round_{round}_epoch_{epoch}')
pprint.pprint(result)
learnt_agent = deepcopy(agent_learn)
learnt_agent.set_eps(0.0)

View File

@ -11,17 +11,18 @@ import torch
IndexType = Union[slice, int, np.ndarray, List[int]]
def _is_batch_set(data: Any) -> bool:
def _is_batch_set(obj: Any) -> bool:
# Batch set is a list/tuple of dict/Batch objects,
# or 1-D np.ndarray with object type,
# where each element is a dict/Batch object
if isinstance(data, np.ndarray): # most often case
# "for e in data" will just unpack the first dimension,
# but data.tolist() will flatten ndarray of objects
# so do not use data.tolist()
return data.dtype == object and all(isinstance(e, (dict, Batch)) for e in data)
elif isinstance(data, (list, tuple)):
if len(data) > 0 and all(isinstance(e, (dict, Batch)) for e in data):
if isinstance(obj, np.ndarray): # most often case
# "for element in obj" will just unpack the first dimension,
# but obj.tolist() will flatten ndarray of objects
# so do not use obj.tolist()
return obj.dtype == object and \
all(isinstance(element, (dict, Batch)) for element in obj)
elif isinstance(obj, (list, tuple)):
if len(obj) > 0 and all(isinstance(element, (dict, Batch)) for element in obj):
return True
return False
@ -48,28 +49,29 @@ def _is_number(value: Any) -> bool:
return isinstance(value, (Number, np.number, np.bool_))
def _to_array_with_correct_type(v: Any) -> np.ndarray:
if isinstance(v, np.ndarray) and issubclass(v.dtype.type, (np.bool_, np.number)):
return v # most often case
def _to_array_with_correct_type(obj: Any) -> np.ndarray:
if isinstance(obj, np.ndarray) and \
issubclass(obj.dtype.type, (np.bool_, np.number)):
return obj # most often case
# convert the value to np.ndarray
# convert to object data type if neither bool nor number
# convert to object obj type if neither bool nor number
# raises an exception if array's elements are tensors themselves
v = np.asanyarray(v)
if not issubclass(v.dtype.type, (np.bool_, np.number)):
v = v.astype(object)
if v.dtype == object:
# scalar ndarray with object data type is very annoying
obj_array = np.asanyarray(obj)
if not issubclass(obj_array.dtype.type, (np.bool_, np.number)):
obj_array = obj_array.astype(object)
if obj_array.dtype == object:
# scalar ndarray with object obj type is very annoying
# a=np.array([np.array({}, dtype=object), np.array({}, dtype=object)])
# a is not array([{}, {}], dtype=object), and a[0]={} results in
# something very strange:
# array([{}, array({}, dtype=object)], dtype=object)
if not v.shape:
v = v.item(0)
elif all(isinstance(e, np.ndarray) for e in v.reshape(-1)):
return v # various length, np.array([[1], [2, 3], [4, 5, 6]])
elif any(isinstance(e, torch.Tensor) for e in v.reshape(-1)):
if not obj_array.shape:
obj_array = obj_array.item(0)
elif all(isinstance(arr, np.ndarray) for arr in obj_array.reshape(-1)):
return obj_array # various length, np.array([[1], [2, 3], [4, 5, 6]])
elif any(isinstance(arr, torch.Tensor) for arr in obj_array.reshape(-1)):
raise ValueError("Numpy arrays of tensors are not supported yet.")
return v
return obj_array
def _create_value(
@ -113,44 +115,45 @@ def _create_value(
def _assert_type_keys(keys: Iterable[str]) -> None:
assert all(isinstance(e, str) for e in keys), \
assert all(isinstance(key, str) for key in keys), \
f"keys should all be string, but got {keys}"
def _parse_value(v: Any) -> Optional[Union["Batch", np.ndarray, torch.Tensor]]:
if isinstance(v, Batch): # most often case
return v
elif (isinstance(v, np.ndarray) and
issubclass(v.dtype.type, (np.bool_, np.number))) or \
isinstance(v, torch.Tensor) or v is None: # third often case
return v
elif _is_number(v): # second often case, but it is more time-consuming
return np.asanyarray(v)
elif isinstance(v, dict):
return Batch(v)
def _parse_value(obj: Any) -> Optional[Union["Batch", np.ndarray, torch.Tensor]]:
if isinstance(obj, Batch): # most often case
return obj
elif (isinstance(obj, np.ndarray) and
issubclass(obj.dtype.type, (np.bool_, np.number))) or \
isinstance(obj, torch.Tensor) or obj is None: # third often case
return obj
elif _is_number(obj): # second often case, but it is more time-consuming
return np.asanyarray(obj)
elif isinstance(obj, dict):
return Batch(obj)
else:
if not isinstance(v, np.ndarray) and isinstance(v, Collection) and \
len(v) > 0 and all(isinstance(e, torch.Tensor) for e in v):
if not isinstance(obj, np.ndarray) and \
isinstance(obj, Collection) and len(obj) > 0 and \
all(isinstance(element, torch.Tensor) for element in obj):
try:
return torch.stack(v) # type: ignore
except RuntimeError as e:
return torch.stack(obj) # type: ignore
except RuntimeError as exception:
raise TypeError(
"Batch does not support non-stackable iterable"
" of torch.Tensor as unique value yet."
) from e
if _is_batch_set(v):
v = Batch(v) # list of dict / Batch
) from exception
if _is_batch_set(obj):
obj = Batch(obj) # list of dict / Batch
else:
# None, scalar, normal data list (main case)
# None, scalar, normal obj list (main case)
# or an actual list of objects
try:
v = _to_array_with_correct_type(v)
except ValueError as e:
obj = _to_array_with_correct_type(obj)
except ValueError as exception:
raise TypeError(
"Batch does not support heterogeneous list/"
"tuple of tensors as unique value yet."
) from e
return v
) from exception
return obj
def _alloc_by_keys_diff(
@ -189,8 +192,8 @@ class Batch:
if batch_dict is not None:
if isinstance(batch_dict, (dict, Batch)):
_assert_type_keys(batch_dict.keys())
for k, v in batch_dict.items():
self.__dict__[k] = _parse_value(v)
for batch_key, obj in batch_dict.items():
self.__dict__[batch_key] = _parse_value(obj)
elif _is_batch_set(batch_dict):
self.stack_(batch_dict) # type: ignore
if len(kwargs) > 0:
@ -214,10 +217,10 @@ class Batch:
Only the actual data are serialized for both efficiency and simplicity.
"""
state = {}
for k, v in self.items():
if isinstance(v, Batch):
v = v.__getstate__()
state[k] = v
for batch_key, obj in self.items():
if isinstance(obj, Batch):
obj = obj.__getstate__()
state[batch_key] = obj
return state
def __setstate__(self, state: Dict[str, Any]) -> None:
@ -234,13 +237,13 @@ class Batch:
return self.__dict__[index]
batch_items = self.items()
if len(batch_items) > 0:
b = Batch()
for k, v in batch_items:
if isinstance(v, Batch) and v.is_empty():
b.__dict__[k] = Batch()
new_batch = Batch()
for batch_key, obj in batch_items:
if isinstance(obj, Batch) and obj.is_empty():
new_batch.__dict__[batch_key] = Batch()
else:
b.__dict__[k] = v[index]
return b
new_batch.__dict__[batch_key] = obj[index]
return new_batch
else:
raise IndexError("Cannot access item from empty Batch object.")
@ -273,20 +276,20 @@ class Batch:
def __iadd__(self, other: Union["Batch", Number, np.number]) -> "Batch":
"""Algebraic addition with another Batch instance in-place."""
if isinstance(other, Batch):
for (k, r), v in zip(
for (batch_key, obj), value in zip(
self.__dict__.items(), other.__dict__.values()
): # TODO are keys consistent?
if isinstance(r, Batch) and r.is_empty():
if isinstance(obj, Batch) and obj.is_empty():
continue
else:
self.__dict__[k] += v
self.__dict__[batch_key] += value
return self
elif _is_number(other):
for k, r in self.items():
if isinstance(r, Batch) and r.is_empty():
for batch_key, obj in self.items():
if isinstance(obj, Batch) and obj.is_empty():
continue
else:
self.__dict__[k] += other
self.__dict__[batch_key] += other
return self
else:
raise TypeError("Only addition of Batch or number is supported.")
@ -295,54 +298,54 @@ class Batch:
"""Algebraic addition with another Batch instance out-of-place."""
return deepcopy(self).__iadd__(other)
def __imul__(self, val: Union[Number, np.number]) -> "Batch":
def __imul__(self, value: Union[Number, np.number]) -> "Batch":
"""Algebraic multiplication with a scalar value in-place."""
assert _is_number(val), "Only multiplication by a number is supported."
for k, r in self.__dict__.items():
if isinstance(r, Batch) and r.is_empty():
assert _is_number(value), "Only multiplication by a number is supported."
for batch_key, obj in self.__dict__.items():
if isinstance(obj, Batch) and obj.is_empty():
continue
self.__dict__[k] *= val
self.__dict__[batch_key] *= value
return self
def __mul__(self, val: Union[Number, np.number]) -> "Batch":
def __mul__(self, value: Union[Number, np.number]) -> "Batch":
"""Algebraic multiplication with a scalar value out-of-place."""
return deepcopy(self).__imul__(val)
return deepcopy(self).__imul__(value)
def __itruediv__(self, val: Union[Number, np.number]) -> "Batch":
def __itruediv__(self, value: Union[Number, np.number]) -> "Batch":
"""Algebraic division with a scalar value in-place."""
assert _is_number(val), "Only division by a number is supported."
for k, r in self.__dict__.items():
if isinstance(r, Batch) and r.is_empty():
assert _is_number(value), "Only division by a number is supported."
for batch_key, obj in self.__dict__.items():
if isinstance(obj, Batch) and obj.is_empty():
continue
self.__dict__[k] /= val
self.__dict__[batch_key] /= value
return self
def __truediv__(self, val: Union[Number, np.number]) -> "Batch":
def __truediv__(self, value: Union[Number, np.number]) -> "Batch":
"""Algebraic division with a scalar value out-of-place."""
return deepcopy(self).__itruediv__(val)
return deepcopy(self).__itruediv__(value)
def __repr__(self) -> str:
"""Return str(self)."""
s = self.__class__.__name__ + "(\n"
self_str = self.__class__.__name__ + "(\n"
flag = False
for k, v in self.__dict__.items():
rpl = "\n" + " " * (6 + len(k))
obj = pprint.pformat(v).replace("\n", rpl)
s += f" {k}: {obj},\n"
for batch_key, obj in self.__dict__.items():
rpl = "\n" + " " * (6 + len(batch_key))
obj_name = pprint.pformat(obj).replace("\n", rpl)
self_str += f" {batch_key}: {obj_name},\n"
flag = True
if flag:
s += ")"
self_str += ")"
else:
s = self.__class__.__name__ + "()"
return s
self_str = self.__class__.__name__ + "()"
return self_str
def to_numpy(self) -> None:
"""Change all torch.Tensor to numpy.ndarray in-place."""
for k, v in self.items():
if isinstance(v, torch.Tensor):
self.__dict__[k] = v.detach().cpu().numpy()
elif isinstance(v, Batch):
v.to_numpy()
for batch_key, obj in self.items():
if isinstance(obj, torch.Tensor):
self.__dict__[batch_key] = obj.detach().cpu().numpy()
elif isinstance(obj, Batch):
obj.to_numpy()
def to_torch(
self,
@ -353,24 +356,24 @@ class Batch:
if not isinstance(device, torch.device):
device = torch.device(device)
for k, v in self.items():
if isinstance(v, torch.Tensor):
if dtype is not None and v.dtype != dtype or \
v.device.type != device.type or \
device.index != v.device.index:
for batch_key, obj in self.items():
if isinstance(obj, torch.Tensor):
if dtype is not None and obj.dtype != dtype or \
obj.device.type != device.type or \
device.index != obj.device.index:
if dtype is not None:
v = v.type(dtype)
self.__dict__[k] = v.to(device)
elif isinstance(v, Batch):
v.to_torch(dtype, device)
obj = obj.type(dtype)
self.__dict__[batch_key] = obj.to(device)
elif isinstance(obj, Batch):
obj.to_torch(dtype, device)
else:
# ndarray or scalar
if not isinstance(v, np.ndarray):
v = np.asanyarray(v)
v = torch.from_numpy(v).to(device)
if not isinstance(obj, np.ndarray):
obj = np.asanyarray(obj)
obj = torch.from_numpy(obj).to(device)
if dtype is not None:
v = v.type(dtype)
self.__dict__[k] = v
obj = obj.type(dtype)
self.__dict__[batch_key] = obj
def __cat(self, batches: Sequence[Union[dict, "Batch"]], lens: List[int]) -> None:
"""Private method for Batch.cat_.
@ -395,50 +398,51 @@ class Batch:
# partial keys will be padded by zeros
# with the shape of [len, rest_shape]
sum_lens = [0]
for x in lens:
sum_lens.append(sum_lens[-1] + x)
for len_ in lens:
sum_lens.append(sum_lens[-1] + len_)
# collect non-empty keys
keys_map = [
set(
k for k, v in batch.items()
if not (isinstance(v, Batch) and v.is_empty())
batch_key for batch_key, obj in batch.items()
if not (isinstance(obj, Batch) and obj.is_empty())
) for batch in batches
]
keys_shared = set.intersection(*keys_map)
values_shared = [[e[k] for e in batches] for k in keys_shared]
for k, v in zip(keys_shared, values_shared):
if all(isinstance(e, (dict, Batch)) for e in v):
values_shared = [[batch[key] for batch in batches] for key in keys_shared]
for key, shared_value in zip(keys_shared, values_shared):
if all(isinstance(element, (dict, Batch)) for element in shared_value):
batch_holder = Batch()
batch_holder.__cat(v, lens=lens)
self.__dict__[k] = batch_holder
elif all(isinstance(e, torch.Tensor) for e in v):
self.__dict__[k] = torch.cat(v)
batch_holder.__cat(shared_value, lens=lens)
self.__dict__[key] = batch_holder
elif all(isinstance(element, torch.Tensor) for element in shared_value):
self.__dict__[key] = torch.cat(shared_value)
else:
# cat Batch(a=np.zeros((3, 4))) and Batch(a=Batch(b=Batch()))
# will fail here
v = np.concatenate(v)
self.__dict__[k] = _to_array_with_correct_type(v)
keys_total = set.union(*[set(b.keys()) for b in batches])
shared_value = np.concatenate(shared_value)
self.__dict__[key] = _to_array_with_correct_type(shared_value)
keys_total = set.union(*[set(batch.keys()) for batch in batches])
keys_reserve_or_partial = set.difference(keys_total, keys_shared)
# keys that are reserved in all batches
keys_reserve = set.difference(keys_total, set.union(*keys_map))
# keys that occur only in some batches, but not all
keys_partial = keys_reserve_or_partial.difference(keys_reserve)
for k in keys_reserve:
for key in keys_reserve:
# reserved keys
self.__dict__[k] = Batch()
for k in keys_partial:
for i, e in enumerate(batches):
if k not in e.__dict__:
self.__dict__[key] = Batch()
for key in keys_partial:
for i, batch in enumerate(batches):
if key not in batch.__dict__:
continue
val = e.get(k)
if isinstance(val, Batch) and val.is_empty():
value = batch.get(key)
if isinstance(value, Batch) and value.is_empty():
continue
try:
self.__dict__[k][sum_lens[i]:sum_lens[i + 1]] = val
self.__dict__[key][sum_lens[i]:sum_lens[i + 1]] = value
except KeyError:
self.__dict__[k] = _create_value(val, sum_lens[-1], stack=False)
self.__dict__[k][sum_lens[i]:sum_lens[i + 1]] = val
self.__dict__[key] = \
_create_value(value, sum_lens[-1], stack=False)
self.__dict__[key][sum_lens[i]:sum_lens[i + 1]] = value
def cat_(self, batches: Union["Batch", Sequence[Union[dict, "Batch"]]]) -> None:
"""Concatenate a list of (or one) Batch objects into current batch."""
@ -446,16 +450,16 @@ class Batch:
batches = [batches]
# check input format
batch_list = []
for b in batches:
if isinstance(b, dict):
if len(b) > 0:
batch_list.append(Batch(b))
elif isinstance(b, Batch):
for batch in batches:
if isinstance(batch, dict):
if len(batch) > 0:
batch_list.append(Batch(batch))
elif isinstance(batch, Batch):
# x.is_empty() means that x is Batch() and should be ignored
if not b.is_empty():
batch_list.append(b)
if not batch.is_empty():
batch_list.append(batch)
else:
raise ValueError(f"Cannot concatenate {type(b)} in Batch.cat_")
raise ValueError(f"Cannot concatenate {type(batch)} in Batch.cat_")
if len(batch_list) == 0:
return
batches = batch_list
@ -463,13 +467,15 @@ class Batch:
# x.is_empty(recurse=True) here means x is a nested empty batch
# like Batch(a=Batch), and we have to treat it as length zero and
# keep it.
lens = [0 if x.is_empty(recurse=True) else len(x) for x in batches]
except TypeError as e:
lens = [
0 if batch.is_empty(recurse=True) else len(batch) for batch in batches
]
except TypeError as exception:
raise ValueError(
"Batch.cat_ meets an exception. Maybe because there is any "
f"scalar in {batches} but Batch.cat_ does not support the "
"concatenation of scalar."
) from e
) from exception
if not self.is_empty():
batches = [self] + list(batches)
lens = [0 if self.is_empty(recurse=True) else len(self)] + lens
@ -501,16 +507,16 @@ class Batch:
"""Stack a list of Batch object into current batch."""
# check input format
batch_list = []
for b in batches:
if isinstance(b, dict):
if len(b) > 0:
batch_list.append(Batch(b))
elif isinstance(b, Batch):
for batch in batches:
if isinstance(batch, dict):
if len(batch) > 0:
batch_list.append(Batch(batch))
elif isinstance(batch, Batch):
# x.is_empty() means that x is Batch() and should be ignored
if not b.is_empty():
batch_list.append(b)
if not batch.is_empty():
batch_list.append(batch)
else:
raise ValueError(f"Cannot concatenate {type(b)} in Batch.stack_")
raise ValueError(f"Cannot concatenate {type(batch)} in Batch.stack_")
if len(batch_list) == 0:
return
batches = batch_list
@ -519,28 +525,31 @@ class Batch:
# collect non-empty keys
keys_map = [
set(
k for k, v in batch.items()
if not (isinstance(v, Batch) and v.is_empty())
batch_key for batch_key, obj in batch.items()
if not (isinstance(obj, Batch) and obj.is_empty())
) for batch in batches
]
keys_shared = set.intersection(*keys_map)
values_shared = [[e[k] for e in batches] for k in keys_shared]
for k, v in zip(keys_shared, values_shared):
if all(isinstance(e, torch.Tensor) for e in v): # second often
self.__dict__[k] = torch.stack(v, axis)
elif all(isinstance(e, (Batch, dict)) for e in v): # third often
self.__dict__[k] = Batch.stack(v, axis)
values_shared = [[batch[key] for batch in batches] for key in keys_shared]
for shared_key, value in zip(keys_shared, values_shared):
# second often
if all(isinstance(element, torch.Tensor) for element in value):
self.__dict__[shared_key] = torch.stack(value, axis)
# third often
elif all(isinstance(element, (Batch, dict)) for element in value):
self.__dict__[shared_key] = Batch.stack(value, axis)
else: # most often case is np.ndarray
try:
self.__dict__[k] = _to_array_with_correct_type(np.stack(v, axis))
self.__dict__[shared_key] = \
_to_array_with_correct_type(np.stack(value, axis))
except ValueError:
warnings.warn(
"You are using tensors with different shape,"
" fallback to dtype=object by default."
)
self.__dict__[k] = np.array(v, dtype=object)
self.__dict__[shared_key] = np.array(value, dtype=object)
# all the keys
keys_total = set.union(*[set(b.keys()) for b in batches])
keys_total = set.union(*[set(batch.keys()) for batch in batches])
# keys that are reserved in all batches
keys_reserve = set.difference(keys_total, set.union(*keys_map))
# keys that are either partial or reserved
@ -552,21 +561,21 @@ class Batch:
f"Stack of Batch with non-shared keys {keys_partial} is only "
f"supported with axis=0, but got axis={axis}!"
)
for k in keys_reserve:
for key in keys_reserve:
# reserved keys
self.__dict__[k] = Batch()
for k in keys_partial:
for i, e in enumerate(batches):
if k not in e.__dict__:
continue
val = e.get(k)
if isinstance(val, Batch) and val.is_empty():
self.__dict__[key] = Batch()
for key in keys_partial:
for i, batch in enumerate(batches):
if key not in batch.__dict__:
continue
value = batch.get(key)
if isinstance(value, Batch) and value.is_empty(): # type: ignore
continue # type: ignore
try:
self.__dict__[k][i] = val
self.__dict__[key][i] = value
except KeyError:
self.__dict__[k] = _create_value(val, len(batches))
self.__dict__[k][i] = val
self.__dict__[key] = _create_value(value, len(batches))
self.__dict__[key][i] = value
@staticmethod
def stack(batches: Sequence[Union[dict, "Batch"]], axis: int = 0) -> "Batch":
@ -620,27 +629,27 @@ class Batch:
),
)
"""
for k, v in self.items():
if isinstance(v, torch.Tensor): # most often case
self.__dict__[k][index] = 0
elif v is None:
for batch_key, obj in self.items():
if isinstance(obj, torch.Tensor): # most often case
self.__dict__[batch_key][index] = 0
elif obj is None:
continue
elif isinstance(v, np.ndarray):
if v.dtype == object:
self.__dict__[k][index] = None
elif isinstance(obj, np.ndarray):
if obj.dtype == object:
self.__dict__[batch_key][index] = None
else:
self.__dict__[k][index] = 0
elif isinstance(v, Batch):
self.__dict__[k].empty_(index=index)
self.__dict__[batch_key][index] = 0
elif isinstance(obj, Batch):
self.__dict__[batch_key].empty_(index=index)
else: # scalar value
warnings.warn(
"You are calling Batch.empty on a NumPy scalar, "
"which may cause undefined behaviors."
)
if _is_number(v):
self.__dict__[k] = v.__class__(0)
if _is_number(obj):
self.__dict__[batch_key] = obj.__class__(0)
else:
self.__dict__[k] = None
self.__dict__[batch_key] = None
return self
@staticmethod
@ -658,26 +667,26 @@ class Batch:
if batch is None:
self.update(kwargs)
return
for k, v in batch.items():
self.__dict__[k] = _parse_value(v)
for batch_key, obj in batch.items():
self.__dict__[batch_key] = _parse_value(obj)
if kwargs:
self.update(kwargs)
def __len__(self) -> int:
"""Return len(self)."""
r = []
for v in self.__dict__.values():
if isinstance(v, Batch) and v.is_empty(recurse=True):
lens = []
for obj in self.__dict__.values():
if isinstance(obj, Batch) and obj.is_empty(recurse=True):
continue
elif hasattr(v, "__len__") and (isinstance(v, Batch) or v.ndim > 0):
r.append(len(v))
elif hasattr(obj, "__len__") and (isinstance(obj, Batch) or obj.ndim > 0):
lens.append(len(obj))
else:
raise TypeError(f"Object {v} in {self} has no len()")
if len(r) == 0:
raise TypeError(f"Object {obj} in {self} has no len()")
if len(lens) == 0:
# empty batch has the shape of any, like the tensorflow '?' shape.
# So it has no length.
raise TypeError(f"Object {self} has no len()")
return min(r)
return min(lens)
def is_empty(self, recurse: bool = False) -> bool:
"""Test if a Batch is empty.
@ -710,8 +719,8 @@ class Batch:
if not recurse:
return False
return all(
False if not isinstance(x, Batch) else x.is_empty(recurse=True)
for x in self.values()
False if not isinstance(obj, Batch) else obj.is_empty(recurse=True)
for obj in self.values()
)
@property
@ -721,9 +730,9 @@ class Batch:
return []
else:
data_shape = []
for v in self.__dict__.values():
for obj in self.__dict__.values():
try:
data_shape.append(list(v.shape))
data_shape.append(list(obj.shape))
except AttributeError:
data_shape.append([])
return list(map(min, zip(*data_shape))) if len(data_shape) > 1 \

View File

@ -69,8 +69,8 @@ class ReplayBuffer:
"""Return self.key."""
try:
return self._meta[key]
except KeyError as e:
raise AttributeError from e
except KeyError as exception:
raise AttributeError from exception
def __setstate__(self, state: Dict[str, Any]) -> None:
"""Unpickling interface.
@ -198,10 +198,10 @@ class ReplayBuffer:
episode_reward is 0.
"""
# preprocess batch
b = Batch()
new_batch = Batch()
for key in set(self._reserved_keys).intersection(batch.keys()):
b.__dict__[key] = batch[key]
batch = b
new_batch.__dict__[key] = batch[key]
batch = new_batch
assert set(["obs", "act", "rew", "done"]).issubset(batch.keys())
stacked_batch = buffer_ids is not None
if stacked_batch:
@ -315,9 +315,9 @@ class ReplayBuffer:
return Batch.stack(stack, axis=indices.ndim)
else:
return np.stack(stack, axis=indices.ndim)
except IndexError as e:
except IndexError as exception:
if not (isinstance(val, Batch) and val.is_empty()):
raise e # val != Batch()
raise exception # val != Batch()
return Batch()
def __getitem__(self, index: Union[slice, int, List[int], np.ndarray]) -> Batch:

View File

@ -114,10 +114,10 @@ class ReplayBufferManager(ReplayBuffer):
episode_reward is 0.
"""
# preprocess batch
b = Batch()
new_batch = Batch()
for key in set(self._reserved_keys).intersection(batch.keys()):
b.__dict__[key] = batch[key]
batch = b
new_batch.__dict__[key] = batch[key]
batch = new_batch
assert set(["obs", "act", "rew", "done"]).issubset(batch.keys())
if self._save_only_last_obs:
batch.obs = batch.obs[:, -1]

View File

@ -111,11 +111,11 @@ def to_hdf5(x: Hdf5ConvertibleType, y: h5py.Group) -> None:
# and possibly in other cases like structured arrays.
try:
to_hdf5_via_pickle(v, y, k)
except Exception as e:
except Exception as exception:
raise RuntimeError(
f"Attempted to pickle {v.__class__.__name__} due to "
"data type not supported by HDF5 and failed."
) from e
) from exception
y[k].attrs["__data_type__"] = "pickled_ndarray"
elif isinstance(v, (int, float)):
# ints and floats are stored as attributes of groups
@ -123,11 +123,11 @@ def to_hdf5(x: Hdf5ConvertibleType, y: h5py.Group) -> None:
else: # resort to pickle for any other type of object
try:
to_hdf5_via_pickle(v, y, k)
except Exception as e:
except Exception as exception:
raise NotImplementedError(
f"No conversion to HDF5 for object of type '{type(v)}' "
"implemented and fallback to pickle failed."
) from e
) from exception
y[k].attrs["__data_type__"] = v.__class__.__name__

View File

@ -15,8 +15,8 @@ class MultiAgentEnv(ABC, gym.Env):
env = MultiAgentEnv(...)
# obs is a dict containing obs, agent_id, and mask
obs = env.reset()
action = policy(obs)
obs, rew, done, info = env.step(action)
act = policy(obs)
obs, rew, done, info = env.step(act)
env.close()
The available action's mask is set to 1, otherwise it is set to 0. Further

View File

@ -412,10 +412,10 @@ class RayVectorEnv(BaseVectorEnv):
def __init__(self, env_fns: List[Callable[[], gym.Env]], **kwargs: Any) -> None:
try:
import ray
except ImportError as e:
except ImportError as exception:
raise ImportError(
"Please install ray to support RayVectorEnv: pip install ray"
) from e
) from exception
if not ray.is_initialized():
ray.init()
super().__init__(env_fns, RayEnvWorker, **kwargs)

View File

@ -98,6 +98,12 @@ class BasePolicy(ABC, nn.Module):
"""
return act
def soft_update(self, tgt: nn.Module, src: nn.Module, tau: float) -> None:
"""Softly update the parameters of target module towards the parameters \
of source module."""
for tgt_param, src_param in zip(tgt.parameters(), src.parameters()):
tgt_param.data.copy_(tau * src_param.data + (1 - tau) * tgt_param.data)
@abstractmethod
def forward(
self,
@ -387,10 +393,10 @@ def _gae_return(
) -> np.ndarray:
returns = np.zeros(rew.shape)
delta = rew + v_s_ * gamma - v_s
m = (1.0 - end_flag) * (gamma * gae_lambda)
discount = (1.0 - end_flag) * (gamma * gae_lambda)
gae = 0.0
for i in range(len(rew) - 1, -1, -1):
gae = delta[i] + m[i] * gae
gae = delta[i] + discount[i] * gae
returns[i] = gae
return returns

View File

@ -40,23 +40,23 @@ class ImitationPolicy(BasePolicy):
state: Optional[Union[dict, Batch, np.ndarray]] = None,
**kwargs: Any,
) -> Batch:
logits, h = self.model(batch.obs, state=state, info=batch.info)
logits, hidden = self.model(batch.obs, state=state, info=batch.info)
if self.action_type == "discrete":
a = logits.max(dim=1)[1]
act = logits.max(dim=1)[1]
else:
a = logits
return Batch(logits=logits, act=a, state=h)
act = logits
return Batch(logits=logits, act=act, state=hidden)
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
self.optim.zero_grad()
if self.action_type == "continuous": # regression
a = self(batch).act
a_ = to_torch(batch.act, dtype=torch.float32, device=a.device)
loss = F.mse_loss(a, a_) # type: ignore
act = self(batch).act
act_target = to_torch(batch.act, dtype=torch.float32, device=act.device)
loss = F.mse_loss(act, act_target) # type: ignore
elif self.action_type == "discrete": # classification
a = F.log_softmax(self(batch).logits, dim=-1)
a_ = to_torch(batch.act, dtype=torch.long, device=a.device)
loss = F.nll_loss(a, a_) # type: ignore
act = F.log_softmax(self(batch).logits, dim=-1)
act_target = to_torch(batch.act, dtype=torch.long, device=act.device)
loss = F.nll_loss(act, act_target) # type: ignore
loss.backward()
self.optim.step()
return {"loss": loss.item()}

View File

@ -105,32 +105,27 @@ class BCQPolicy(BasePolicy):
obs_group: torch.Tensor = to_torch( # type: ignore
batch.obs, device=self.device
)
act = []
act_group = []
for obs in obs_group:
# now obs is (state_dim)
obs = (obs.reshape(1, -1)).repeat(self.forward_sampled_times, 1)
# now obs is (forward_sampled_times, state_dim)
# decode(obs) generates action and actor perturbs it
action = self.actor(obs, self.vae.decode(obs))
act = self.actor(obs, self.vae.decode(obs))
# now action is (forward_sampled_times, action_dim)
q1 = self.critic1(obs, action)
q1 = self.critic1(obs, act)
# q1 is (forward_sampled_times, 1)
ind = q1.argmax(0)
act.append(action[ind].cpu().data.numpy().flatten())
act = np.array(act)
return Batch(act=act)
max_indice = q1.argmax(0)
act_group.append(act[max_indice].cpu().data.numpy().flatten())
act_group = np.array(act_group)
return Batch(act=act_group)
def sync_weight(self) -> None:
"""Soft-update the weight for the target network."""
for net, net_target in [
[self.critic1, self.critic1_target], [self.critic2, self.critic2_target],
[self.actor, self.actor_target]
]:
for param, target_param in zip(net.parameters(), net_target.parameters()):
target_param.data.copy_(
self.tau * param.data + (1 - self.tau) * target_param.data
)
self.soft_update(self.critic1_target, self.critic1, self.tau)
self.soft_update(self.critic2_target, self.critic2, self.tau)
self.soft_update(self.actor_target, self.actor, self.tau)
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
# batch: obs, act, rew, done, obs_next. (numpy array)

View File

@ -113,13 +113,8 @@ class CQLPolicy(SACPolicy):
def sync_weight(self) -> None:
"""Soft-update the weight for the target network."""
for net, net_old in [
[self.critic1, self.critic1_old], [self.critic2, self.critic2_old]
]:
for param, target_param in zip(net.parameters(), net_old.parameters()):
target_param.data.copy_(
self._tau * param.data + (1 - self._tau) * target_param.data
)
self.soft_update(self.critic1_old, self.critic1, self.tau)
self.soft_update(self.critic2_old, self.critic2, self.tau)
def actor_pred(self, obs: torch.Tensor) -> \
Tuple[torch.Tensor, torch.Tensor]:

View File

@ -94,13 +94,10 @@ class DiscreteBCQPolicy(DQNPolicy):
# mask actions for argmax
ratio = imitation_logits - imitation_logits.max(dim=-1, keepdim=True).values
mask = (ratio < self._log_tau).float()
action = (q_value - np.inf * mask).argmax(dim=-1)
act = (q_value - np.inf * mask).argmax(dim=-1)
return Batch(
act=action,
state=state,
q_value=q_value,
imitation_logits=imitation_logits
act=act, state=state, q_value=q_value, imitation_logits=imitation_logits
)
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:

View File

@ -57,15 +57,15 @@ class DiscreteCQLPolicy(QRDQNPolicy):
curr_dist = all_dist[np.arange(len(act)), act, :].unsqueeze(2)
target_dist = batch.returns.unsqueeze(1)
# calculate each element's difference between curr_dist and target_dist
u = F.smooth_l1_loss(target_dist, curr_dist, reduction="none")
dist_diff = F.smooth_l1_loss(target_dist, curr_dist, reduction="none")
huber_loss = (
u * (self.tau_hat -
(target_dist - curr_dist).detach().le(0.).float()).abs()
dist_diff *
(self.tau_hat - (target_dist - curr_dist).detach().le(0.).float()).abs()
).sum(-1).mean(1)
qr_loss = (huber_loss * weight).mean()
# ref: https://github.com/ku2482/fqf-iqn-qrdqn.pytorch/
# blob/master/fqf_iqn_qrdqn/agent/qrdqn_agent.py L130
batch.weight = u.detach().abs().sum(-1).mean(1) # prio-buffer
batch.weight = dist_diff.detach().abs().sum(-1).mean(1) # prio-buffer
# add CQL loss
q = self.compute_q_value(all_dist, None)
dataset_expec = q.gather(1, act.unsqueeze(1)).mean()

View File

@ -97,9 +97,9 @@ class DiscreteCRRPolicy(PGPolicy):
target = rew.unsqueeze(1) + self._gamma * expected_target_q
critic_loss = 0.5 * F.mse_loss(qa_t, target)
# Actor loss
a_t, _ = self.actor(batch.obs)
m = Categorical(logits=a_t)
expected_policy_q = (q_t * m.probs).sum(-1, keepdim=True)
act_target, _ = self.actor(batch.obs)
dist = Categorical(logits=act_target)
expected_policy_q = (q_t * dist.probs).sum(-1, keepdim=True)
advantage = qa_t - expected_policy_q
if self._policy_improvement_mode == "binary":
actor_loss_coef = (advantage > 0).float()
@ -109,7 +109,7 @@ class DiscreteCRRPolicy(PGPolicy):
)
else:
actor_loss_coef = 1.0 # effectively behavior cloning
actor_loss = (-m.log_prob(act) * actor_loss_coef).mean()
actor_loss = (-dist.log_prob(act) * actor_loss_coef).mean()
# CQL loss/regularizer
min_q_loss = (q_t.logsumexp(1) - qa_t).mean()
loss = actor_loss + critic_loss + self._min_q_weight * min_q_loss

View File

@ -202,13 +202,13 @@ class PSRLPolicy(BasePolicy):
rew_sum = np.zeros((n_s, n_a))
rew_square_sum = np.zeros((n_s, n_a))
rew_count = np.zeros((n_s, n_a))
for b in batch.split(size=1):
obs, act, obs_next = b.obs, b.act, b.obs_next
for minibatch in batch.split(size=1):
obs, act, obs_next = minibatch.obs, minibatch.act, minibatch.obs_next
trans_count[obs, act, obs_next] += 1
rew_sum[obs, act] += b.rew
rew_square_sum[obs, act] += b.rew**2
rew_sum[obs, act] += minibatch.rew
rew_square_sum[obs, act] += minibatch.rew**2
rew_count[obs, act] += 1
if self._add_done_loop and b.done:
if self._add_done_loop and minibatch.done:
# special operation for terminal states: add a self-loop
trans_count[obs_next, :, obs_next] += 1
rew_count[obs_next, :] += 1

View File

@ -85,9 +85,9 @@ class A2CPolicy(PGPolicy):
) -> Batch:
v_s, v_s_ = [], []
with torch.no_grad():
for b in batch.split(self._batch, shuffle=False, merge_last=True):
v_s.append(self.critic(b.obs))
v_s_.append(self.critic(b.obs_next))
for minibatch in batch.split(self._batch, shuffle=False, merge_last=True):
v_s.append(self.critic(minibatch.obs))
v_s_.append(self.critic(minibatch.obs_next))
batch.v_s = torch.cat(v_s, dim=0).flatten() # old value
v_s = batch.v_s.cpu().numpy()
v_s_ = torch.cat(v_s_, dim=0).flatten().cpu().numpy()
@ -122,14 +122,15 @@ class A2CPolicy(PGPolicy):
) -> Dict[str, List[float]]:
losses, actor_losses, vf_losses, ent_losses = [], [], [], []
for _ in range(repeat):
for b in batch.split(batch_size, merge_last=True):
for minibatch in batch.split(batch_size, merge_last=True):
# calculate loss for actor
dist = self(b).dist
log_prob = dist.log_prob(b.act).reshape(len(b.adv), -1).transpose(0, 1)
actor_loss = -(log_prob * b.adv).mean()
dist = self(minibatch).dist
log_prob = dist.log_prob(minibatch.act)
log_prob = log_prob.reshape(len(minibatch.adv), -1).transpose(0, 1)
actor_loss = -(log_prob * minibatch.adv).mean()
# calculate loss for critic
value = self.critic(b.obs).flatten()
vf_loss = F.mse_loss(b.returns, value)
value = self.critic(minibatch.obs).flatten()
vf_loss = F.mse_loss(minibatch.returns, value)
# calculate regularization and overall loss
ent_loss = dist.entropy().mean()
loss = actor_loss + self._weight_vf * vf_loss \

View File

@ -70,13 +70,13 @@ class C51Policy(DQNPolicy):
def _target_dist(self, batch: Batch) -> torch.Tensor:
if self._target:
a = self(batch, input="obs_next").act
act = self(batch, input="obs_next").act
next_dist = self(batch, model="model_old", input="obs_next").logits
else:
next_b = self(batch, input="obs_next")
a = next_b.act
next_dist = next_b.logits
next_dist = next_dist[np.arange(len(a)), a, :]
next_batch = self(batch, input="obs_next")
act = next_batch.act
next_dist = next_batch.logits
next_dist = next_dist[np.arange(len(act)), act, :]
target_support = batch.returns.clamp(self._v_min, self._v_max)
# An amazing trick for calculating the projection gracefully.
# ref: https://github.com/ShangtongZhang/DeepRL

View File

@ -73,7 +73,7 @@ class DDPGPolicy(BasePolicy):
self.critic_old.eval()
self.critic_optim: torch.optim.Optimizer = critic_optim
assert 0.0 <= tau <= 1.0, "tau should be in [0, 1]"
self._tau = tau
self.tau = tau
assert 0.0 <= gamma <= 1.0, "gamma should be in [0, 1]"
self._gamma = gamma
self._noise = exploration_noise
@ -95,10 +95,8 @@ class DDPGPolicy(BasePolicy):
def sync_weight(self) -> None:
"""Soft-update the weight for the target network."""
for o, n in zip(self.actor_old.parameters(), self.actor.parameters()):
o.data.copy_(o.data * (1.0 - self._tau) + n.data * self._tau)
for o, n in zip(self.critic_old.parameters(), self.critic.parameters()):
o.data.copy_(o.data * (1.0 - self._tau) + n.data * self._tau)
self.soft_update(self.actor_old, self.actor, self.tau)
self.soft_update(self.critic_old, self.critic, self.tau)
def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor:
batch = buffer[indices] # batch.obs_next: s_{t+n}
@ -139,8 +137,8 @@ class DDPGPolicy(BasePolicy):
"""
model = getattr(self, model)
obs = batch[input]
actions, h = model(obs, state=state, info=batch.info)
return Batch(act=actions, state=h)
actions, hidden = model(obs, state=state, info=batch.info)
return Batch(act=actions, state=hidden)
@staticmethod
def _mse_optimizer(
@ -163,8 +161,7 @@ class DDPGPolicy(BasePolicy):
td, critic_loss = self._mse_optimizer(batch, self.critic, self.critic_optim)
batch.weight = td # prio-buffer
# actor
action = self(batch).act
actor_loss = -self.critic(batch.obs, action).mean()
actor_loss = -self.critic(batch.obs, self(batch).act).mean()
self.actor_optim.zero_grad()
actor_loss.backward()
self.actor_optim.step()

View File

@ -76,10 +76,10 @@ class DiscreteSACPolicy(SACPolicy):
**kwargs: Any,
) -> Batch:
obs = batch[input]
logits, h = self.actor(obs, state=state, info=batch.info)
logits, hidden = self.actor(obs, state=state, info=batch.info)
dist = Categorical(logits=logits)
act = dist.sample()
return Batch(logits=logits, act=act, state=h, dist=dist)
return Batch(logits=logits, act=act, state=hidden, dist=dist)
def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor:
batch = buffer[indices] # batch.obs: s_{t+n}

View File

@ -151,13 +151,13 @@ class DQNPolicy(BasePolicy):
"""
model = getattr(self, model)
obs = batch[input]
obs_ = obs.obs if hasattr(obs, "obs") else obs
logits, h = model(obs_, state=state, info=batch.info)
obs_next = obs.obs if hasattr(obs, "obs") else obs
logits, hidden = model(obs_next, state=state, info=batch.info)
q = self.compute_q_value(logits, getattr(obs, "mask", None))
if not hasattr(self, "max_action_num"):
self.max_action_num = q.shape[1]
act = to_numpy(q.max(dim=1)[1])
return Batch(logits=logits, act=act, state=h)
return Batch(logits=logits, act=act, state=hidden)
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
if self._target and self._iter % self._freq == 0:
@ -166,10 +166,10 @@ class DQNPolicy(BasePolicy):
weight = batch.pop("weight", 1.0)
q = self(batch).logits
q = q[np.arange(len(q)), batch.act]
r = to_torch_as(batch.returns.flatten(), q)
td = r - q
loss = (td.pow(2) * weight).mean()
batch.weight = td # prio-buffer
returns = to_torch_as(batch.returns.flatten(), q)
td_error = returns - q
loss = (td_error.pow(2) * weight).mean()
batch.weight = td_error # prio-buffer
loss.backward()
self.optim.step()
self._iter += 1

View File

@ -60,15 +60,15 @@ class FQFPolicy(QRDQNPolicy):
batch = buffer[indices] # batch.obs_next: s_{t+n}
if self._target:
result = self(batch, input="obs_next")
a, fractions = result.act, result.fractions
act, fractions = result.act, result.fractions
next_dist = self(
batch, model="model_old", input="obs_next", fractions=fractions
).logits
else:
next_b = self(batch, input="obs_next")
a = next_b.act
next_dist = next_b.logits
next_dist = next_dist[np.arange(len(a)), a, :]
next_batch = self(batch, input="obs_next")
act = next_batch.act
next_dist = next_batch.logits
next_dist = next_dist[np.arange(len(act)), act, :]
return next_dist # shape: [bsz, num_quantiles]
def forward(
@ -82,14 +82,17 @@ class FQFPolicy(QRDQNPolicy):
) -> Batch:
model = getattr(self, model)
obs = batch[input]
obs_ = obs.obs if hasattr(obs, "obs") else obs
obs_next = obs.obs if hasattr(obs, "obs") else obs
if fractions is None:
(logits, fractions, quantiles_tau), h = model(
obs_, propose_model=self.propose_model, state=state, info=batch.info
(logits, fractions, quantiles_tau), hidden = model(
obs_next,
propose_model=self.propose_model,
state=state,
info=batch.info
)
else:
(logits, _, quantiles_tau), h = model(
obs_,
(logits, _, quantiles_tau), hidden = model(
obs_next,
propose_model=self.propose_model,
fractions=fractions,
state=state,
@ -106,7 +109,7 @@ class FQFPolicy(QRDQNPolicy):
return Batch(
logits=logits,
act=act,
state=h,
state=hidden,
fractions=fractions,
quantiles_tau=quantiles_tau
)
@ -122,9 +125,9 @@ class FQFPolicy(QRDQNPolicy):
curr_dist = curr_dist_orig[np.arange(len(act)), act, :].unsqueeze(2)
target_dist = batch.returns.unsqueeze(1)
# calculate each element's difference between curr_dist and target_dist
u = F.smooth_l1_loss(target_dist, curr_dist, reduction="none")
dist_diff = F.smooth_l1_loss(target_dist, curr_dist, reduction="none")
huber_loss = (
u * (
dist_diff * (
tau_hats.unsqueeze(2) -
(target_dist - curr_dist).detach().le(0.).float()
).abs()
@ -132,7 +135,7 @@ class FQFPolicy(QRDQNPolicy):
quantile_loss = (huber_loss * weight).mean()
# ref: https://github.com/ku2482/fqf-iqn-qrdqn.pytorch/
# blob/master/fqf_iqn_qrdqn/agent/qrdqn_agent.py L130
batch.weight = u.detach().abs().sum(-1).mean(1) # prio-buffer
batch.weight = dist_diff.detach().abs().sum(-1).mean(1) # prio-buffer
# calculate fraction loss
with torch.no_grad():
sa_quantile_hats = curr_dist_orig[np.arange(len(act)), act, :]

View File

@ -73,36 +73,37 @@ class IQNPolicy(QRDQNPolicy):
sample_size = self._sample_size
model = getattr(self, model)
obs = batch[input]
obs_ = obs.obs if hasattr(obs, "obs") else obs
(logits,
taus), h = model(obs_, sample_size=sample_size, state=state, info=batch.info)
obs_next = obs.obs if hasattr(obs, "obs") else obs
(logits, taus), hidden = model(
obs_next, sample_size=sample_size, state=state, info=batch.info
)
q = self.compute_q_value(logits, getattr(obs, "mask", None))
if not hasattr(self, "max_action_num"):
self.max_action_num = q.shape[1]
act = to_numpy(q.max(dim=1)[1])
return Batch(logits=logits, act=act, state=h, taus=taus)
return Batch(logits=logits, act=act, state=hidden, taus=taus)
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
if self._target and self._iter % self._freq == 0:
self.sync_weight()
self.optim.zero_grad()
weight = batch.pop("weight", 1.0)
out = self(batch)
curr_dist, taus = out.logits, out.taus
action_batch = self(batch)
curr_dist, taus = action_batch.logits, action_batch.taus
act = batch.act
curr_dist = curr_dist[np.arange(len(act)), act, :].unsqueeze(2)
target_dist = batch.returns.unsqueeze(1)
# calculate each element's difference between curr_dist and target_dist
u = F.smooth_l1_loss(target_dist, curr_dist, reduction="none")
dist_diff = F.smooth_l1_loss(target_dist, curr_dist, reduction="none")
huber_loss = (
u *
dist_diff *
(taus.unsqueeze(2) -
(target_dist - curr_dist).detach().le(0.).float()).abs()
).sum(-1).mean(1)
loss = (huber_loss * weight).mean()
# ref: https://github.com/ku2482/fqf-iqn-qrdqn.pytorch/
# blob/master/fqf_iqn_qrdqn/agent/qrdqn_agent.py L130
batch.weight = u.detach().abs().sum(-1).mean(1) # prio-buffer
batch.weight = dist_diff.detach().abs().sum(-1).mean(1) # prio-buffer
loss.backward()
self.optim.step()
self._iter += 1

View File

@ -71,8 +71,8 @@ class NPGPolicy(A2CPolicy):
batch = super().process_fn(batch, buffer, indices)
old_log_prob = []
with torch.no_grad():
for b in batch.split(self._batch, shuffle=False, merge_last=True):
old_log_prob.append(self(b).dist.log_prob(b.act))
for minibatch in batch.split(self._batch, shuffle=False, merge_last=True):
old_log_prob.append(self(minibatch).dist.log_prob(minibatch.act))
batch.logp_old = torch.cat(old_log_prob, dim=0)
if self._norm_adv:
batch.adv = (batch.adv - batch.adv.mean()) / batch.adv.std()
@ -83,20 +83,20 @@ class NPGPolicy(A2CPolicy):
) -> Dict[str, List[float]]:
actor_losses, vf_losses, kls = [], [], []
for _ in range(repeat):
for b in batch.split(batch_size, merge_last=True):
for minibatch in batch.split(batch_size, merge_last=True):
# optimize actor
# direction: calculate villia gradient
dist = self(b).dist
log_prob = dist.log_prob(b.act)
dist = self(minibatch).dist
log_prob = dist.log_prob(minibatch.act)
log_prob = log_prob.reshape(log_prob.size(0), -1).transpose(0, 1)
actor_loss = -(log_prob * b.adv).mean()
actor_loss = -(log_prob * minibatch.adv).mean()
flat_grads = self._get_flat_grad(
actor_loss, self.actor, retain_graph=True
).detach()
# direction: calculate natural gradient
with torch.no_grad():
old_dist = self(b).dist
old_dist = self(minibatch).dist
kl = kl_divergence(old_dist, dist).mean()
# calculate first order gradient of kl with respect to theta
@ -112,13 +112,13 @@ class NPGPolicy(A2CPolicy):
)
new_flat_params = flat_params + self._step_size * search_direction
self._set_from_flat_params(self.actor, new_flat_params)
new_dist = self(b).dist
new_dist = self(minibatch).dist
kl = kl_divergence(old_dist, new_dist).mean()
# optimize citirc
for _ in range(self._optim_critic_iters):
value = self.critic(b.obs).flatten()
vf_loss = F.mse_loss(b.returns, value)
value = self.critic(minibatch.obs).flatten()
vf_loss = F.mse_loss(minibatch.returns, value)
self.optim.zero_grad()
vf_loss.backward()
self.optim.step()
@ -147,14 +147,14 @@ class NPGPolicy(A2CPolicy):
def _conjugate_gradients(
self,
b: torch.Tensor,
minibatch: torch.Tensor,
flat_kl_grad: torch.Tensor,
nsteps: int = 10,
residual_tol: float = 1e-10
) -> torch.Tensor:
x = torch.zeros_like(b)
r, p = b.clone(), b.clone()
# Note: should be 'r, p = b - MVP(x)', but for x=0, MVP(x)=0.
x = torch.zeros_like(minibatch)
r, p = minibatch.clone(), minibatch.clone()
# Note: should be 'r, p = minibatch - MVP(x)', but for x=0, MVP(x)=0.
# Change if doing warm start.
rdotr = r.dot(r)
for _ in range(nsteps):

View File

@ -107,7 +107,7 @@ class PGPolicy(BasePolicy):
Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for
more detailed explanation.
"""
logits, h = self.actor(batch.obs, state=state)
logits, hidden = self.actor(batch.obs, state=state)
if isinstance(logits, tuple):
dist = self.dist_fn(*logits)
else:
@ -119,20 +119,20 @@ class PGPolicy(BasePolicy):
act = logits[0]
else:
act = dist.sample()
return Batch(logits=logits, act=act, state=h, dist=dist)
return Batch(logits=logits, act=act, state=hidden, dist=dist)
def learn( # type: ignore
self, batch: Batch, batch_size: int, repeat: int, **kwargs: Any
) -> Dict[str, List[float]]:
losses = []
for _ in range(repeat):
for b in batch.split(batch_size, merge_last=True):
for minibatch in batch.split(batch_size, merge_last=True):
self.optim.zero_grad()
result = self(b)
result = self(minibatch)
dist = result.dist
a = to_torch_as(b.act, result.act)
ret = to_torch_as(b.returns, result.act)
log_prob = dist.log_prob(a).reshape(len(ret), -1).transpose(0, 1)
act = to_torch_as(minibatch.act, result.act)
ret = to_torch_as(minibatch.returns, result.act)
log_prob = dist.log_prob(act).reshape(len(ret), -1).transpose(0, 1)
loss = -(log_prob * ret).mean()
loss.backward()
self.optim.step()

View File

@ -96,8 +96,8 @@ class PPOPolicy(A2CPolicy):
batch.act = to_torch_as(batch.act, batch.v_s)
old_log_prob = []
with torch.no_grad():
for b in batch.split(self._batch, shuffle=False, merge_last=True):
old_log_prob.append(self(b).dist.log_prob(b.act))
for minibatch in batch.split(self._batch, shuffle=False, merge_last=True):
old_log_prob.append(self(minibatch).dist.log_prob(minibatch.act))
batch.logp_old = torch.cat(old_log_prob, dim=0)
return batch
@ -108,32 +108,35 @@ class PPOPolicy(A2CPolicy):
for step in range(repeat):
if self._recompute_adv and step > 0:
batch = self._compute_returns(batch, self._buffer, self._indices)
for b in batch.split(batch_size, merge_last=True):
for minibatch in batch.split(batch_size, merge_last=True):
# calculate loss for actor
dist = self(b).dist
dist = self(minibatch).dist
if self._norm_adv:
mean, std = b.adv.mean(), b.adv.std()
b.adv = (b.adv - mean) / std # per-batch norm
ratio = (dist.log_prob(b.act) - b.logp_old).exp().float()
mean, std = minibatch.adv.mean(), minibatch.adv.std()
minibatch.adv = (minibatch.adv - mean) / std # per-batch norm
ratio = (dist.log_prob(minibatch.act) -
minibatch.logp_old).exp().float()
ratio = ratio.reshape(ratio.size(0), -1).transpose(0, 1)
surr1 = ratio * b.adv
surr2 = ratio.clamp(1.0 - self._eps_clip, 1.0 + self._eps_clip) * b.adv
surr1 = ratio * minibatch.adv
surr2 = ratio.clamp(
1.0 - self._eps_clip, 1.0 + self._eps_clip
) * minibatch.adv
if self._dual_clip:
clip1 = torch.min(surr1, surr2)
clip2 = torch.max(clip1, self._dual_clip * b.adv)
clip_loss = -torch.where(b.adv < 0, clip2, clip1).mean()
clip2 = torch.max(clip1, self._dual_clip * minibatch.adv)
clip_loss = -torch.where(minibatch.adv < 0, clip2, clip1).mean()
else:
clip_loss = -torch.min(surr1, surr2).mean()
# calculate loss for critic
value = self.critic(b.obs).flatten()
value = self.critic(minibatch.obs).flatten()
if self._value_clip:
v_clip = b.v_s + (value -
b.v_s).clamp(-self._eps_clip, self._eps_clip)
vf1 = (b.returns - value).pow(2)
vf2 = (b.returns - v_clip).pow(2)
v_clip = minibatch.v_s + \
(value - minibatch.v_s).clamp(-self._eps_clip, self._eps_clip)
vf1 = (minibatch.returns - value).pow(2)
vf2 = (minibatch.returns - v_clip).pow(2)
vf_loss = torch.max(vf1, vf2).mean()
else:
vf_loss = (b.returns - value).pow(2).mean()
vf_loss = (minibatch.returns - value).pow(2).mean()
# calculate regularization and overall loss
ent_loss = dist.entropy().mean()
loss = clip_loss + self._weight_vf * vf_loss \

View File

@ -56,13 +56,13 @@ class QRDQNPolicy(DQNPolicy):
def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor:
batch = buffer[indices] # batch.obs_next: s_{t+n}
if self._target:
a = self(batch, input="obs_next").act
act = self(batch, input="obs_next").act
next_dist = self(batch, model="model_old", input="obs_next").logits
else:
next_b = self(batch, input="obs_next")
a = next_b.act
next_dist = next_b.logits
next_dist = next_dist[np.arange(len(a)), a, :]
next_batch = self(batch, input="obs_next")
act = next_batch.act
next_dist = next_batch.logits
next_dist = next_dist[np.arange(len(act)), act, :]
return next_dist # shape: [bsz, num_quantiles]
def compute_q_value(
@ -80,15 +80,15 @@ class QRDQNPolicy(DQNPolicy):
curr_dist = curr_dist[np.arange(len(act)), act, :].unsqueeze(2)
target_dist = batch.returns.unsqueeze(1)
# calculate each element's difference between curr_dist and target_dist
u = F.smooth_l1_loss(target_dist, curr_dist, reduction="none")
dist_diff = F.smooth_l1_loss(target_dist, curr_dist, reduction="none")
huber_loss = (
u * (self.tau_hat -
(target_dist - curr_dist).detach().le(0.).float()).abs()
dist_diff *
(self.tau_hat - (target_dist - curr_dist).detach().le(0.).float()).abs()
).sum(-1).mean(1)
loss = (huber_loss * weight).mean()
# ref: https://github.com/ku2482/fqf-iqn-qrdqn.pytorch/
# blob/master/fqf_iqn_qrdqn/agent/qrdqn_agent.py L130
batch.weight = u.detach().abs().sum(-1).mean(1) # prio-buffer
batch.weight = dist_diff.detach().abs().sum(-1).mean(1) # prio-buffer
loss.backward()
self.optim.step()
self._iter += 1

View File

@ -99,10 +99,8 @@ class SACPolicy(DDPGPolicy):
return self
def sync_weight(self) -> None:
for o, n in zip(self.critic1_old.parameters(), self.critic1.parameters()):
o.data.copy_(o.data * (1.0 - self._tau) + n.data * self._tau)
for o, n in zip(self.critic2_old.parameters(), self.critic2.parameters()):
o.data.copy_(o.data * (1.0 - self._tau) + n.data * self._tau)
self.soft_update(self.critic1_old, self.critic1, self.tau)
self.soft_update(self.critic2_old, self.critic2, self.tau)
def forward( # type: ignore
self,
@ -112,7 +110,7 @@ class SACPolicy(DDPGPolicy):
**kwargs: Any,
) -> Batch:
obs = batch[input]
logits, h = self.actor(obs, state=state, info=batch.info)
logits, hidden = self.actor(obs, state=state, info=batch.info)
assert isinstance(logits, tuple)
dist = Independent(Normal(*logits), 1)
if self._deterministic_eval and not self.training:
@ -134,16 +132,20 @@ class SACPolicy(DDPGPolicy):
action_scale * (1 - squashed_action.pow(2)) + self.__eps
).sum(-1, keepdim=True)
return Batch(
logits=logits, act=squashed_action, state=h, dist=dist, log_prob=log_prob
logits=logits,
act=squashed_action,
state=hidden,
dist=dist,
log_prob=log_prob
)
def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor:
batch = buffer[indices] # batch.obs: s_{t+n}
obs_next_result = self(batch, input='obs_next')
a_ = obs_next_result.act
obs_next_result = self(batch, input="obs_next")
act_ = obs_next_result.act
target_q = torch.min(
self.critic1_old(batch.obs_next, a_),
self.critic2_old(batch.obs_next, a_),
self.critic1_old(batch.obs_next, act_),
self.critic2_old(batch.obs_next, act_),
) - self._alpha * obs_next_result.log_prob
return target_q
@ -159,9 +161,9 @@ class SACPolicy(DDPGPolicy):
# actor
obs_result = self(batch)
a = obs_result.act
current_q1a = self.critic1(batch.obs, a).flatten()
current_q2a = self.critic2(batch.obs, a).flatten()
act = obs_result.act
current_q1a = self.critic1(batch.obs, act).flatten()
current_q2a = self.critic2(batch.obs, act).flatten()
actor_loss = (
self._alpha * obs_result.log_prob.flatten() -
torch.min(current_q1a, current_q2a)

View File

@ -89,23 +89,20 @@ class TD3Policy(DDPGPolicy):
return self
def sync_weight(self) -> None:
for o, n in zip(self.actor_old.parameters(), self.actor.parameters()):
o.data.copy_(o.data * (1.0 - self._tau) + n.data * self._tau)
for o, n in zip(self.critic1_old.parameters(), self.critic1.parameters()):
o.data.copy_(o.data * (1.0 - self._tau) + n.data * self._tau)
for o, n in zip(self.critic2_old.parameters(), self.critic2.parameters()):
o.data.copy_(o.data * (1.0 - self._tau) + n.data * self._tau)
self.soft_update(self.critic1_old, self.critic1, self.tau)
self.soft_update(self.critic2_old, self.critic2, self.tau)
self.soft_update(self.actor_old, self.actor, self.tau)
def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor:
batch = buffer[indices] # batch.obs: s_{t+n}
a_ = self(batch, model="actor_old", input="obs_next").act
dev = a_.device
noise = torch.randn(size=a_.shape, device=dev) * self._policy_noise
act_ = self(batch, model="actor_old", input="obs_next").act
noise = torch.randn(size=act_.shape, device=act_.device) * self._policy_noise
if self._noise_clip > 0.0:
noise = noise.clamp(-self._noise_clip, self._noise_clip)
a_ += noise
act_ += noise
target_q = torch.min(
self.critic1_old(batch.obs_next, a_), self.critic2_old(batch.obs_next, a_)
self.critic1_old(batch.obs_next, act_),
self.critic2_old(batch.obs_next, act_),
)
return target_q

View File

@ -71,20 +71,21 @@ class TRPOPolicy(NPGPolicy):
) -> Dict[str, List[float]]:
actor_losses, vf_losses, step_sizes, kls = [], [], [], []
for _ in range(repeat):
for b in batch.split(batch_size, merge_last=True):
for minibatch in batch.split(batch_size, merge_last=True):
# optimize actor
# direction: calculate villia gradient
dist = self(b).dist # TODO could come from batch
ratio = (dist.log_prob(b.act) - b.logp_old).exp().float()
dist = self(minibatch).dist # TODO could come from batch
ratio = (dist.log_prob(minibatch.act) -
minibatch.logp_old).exp().float()
ratio = ratio.reshape(ratio.size(0), -1).transpose(0, 1)
actor_loss = -(ratio * b.adv).mean()
actor_loss = -(ratio * minibatch.adv).mean()
flat_grads = self._get_flat_grad(
actor_loss, self.actor, retain_graph=True
).detach()
# direction: calculate natural gradient
with torch.no_grad():
old_dist = self(b).dist
old_dist = self(minibatch).dist
kl = kl_divergence(old_dist, dist).mean()
# calculate first order gradient of kl with respect to theta
@ -109,12 +110,13 @@ class TRPOPolicy(NPGPolicy):
new_flat_params = flat_params + step_size * search_direction
self._set_from_flat_params(self.actor, new_flat_params)
# calculate kl and if in bound, loss actually down
new_dist = self(b).dist
new_dratio = (new_dist.log_prob(b.act) -
b.logp_old).exp().float()
new_dist = self(minibatch).dist
new_dratio = (
new_dist.log_prob(minibatch.act) - minibatch.logp_old
).exp().float()
new_dratio = new_dratio.reshape(new_dratio.size(0),
-1).transpose(0, 1)
new_actor_loss = -(new_dratio * b.adv).mean()
new_actor_loss = -(new_dratio * minibatch.adv).mean()
kl = kl_divergence(old_dist, new_dist).mean()
if kl < self._delta and new_actor_loss < actor_loss:
@ -133,8 +135,8 @@ class TRPOPolicy(NPGPolicy):
# optimize citirc
for _ in range(self._optim_critic_iters):
value = self.critic(b.obs).flatten()
vf_loss = F.mse_loss(b.returns, value)
value = self.critic(minibatch.obs).flatten()
vf_loss = F.mse_loss(minibatch.returns, value)
self.optim.zero_grad()
vf_loss.backward()
self.optim.step()

View File

@ -87,14 +87,14 @@ class MLP(nn.Module):
self.output_dim = output_dim or hidden_sizes[-1]
self.model = nn.Sequential(*model)
def forward(self, s: Union[np.ndarray, torch.Tensor]) -> torch.Tensor:
def forward(self, obs: Union[np.ndarray, torch.Tensor]) -> torch.Tensor:
if self.device is not None:
s = torch.as_tensor(
s,
obs = torch.as_tensor(
obs,
device=self.device, # type: ignore
dtype=torch.float32,
)
return self.model(s.flatten(1)) # type: ignore
return self.model(obs.flatten(1)) # type: ignore
class Net(nn.Module):
@ -187,12 +187,12 @@ class Net(nn.Module):
def forward(
self,
s: Union[np.ndarray, torch.Tensor],
obs: Union[np.ndarray, torch.Tensor],
state: Any = None,
info: Dict[str, Any] = {},
) -> Tuple[torch.Tensor, Any]:
"""Mapping: s -> flatten (inside MLP)-> logits."""
logits = self.model(s)
"""Mapping: obs -> flatten (inside MLP)-> logits."""
logits = self.model(obs)
bsz = logits.shape[0]
if self.use_dueling: # Dueling DQN
q, v = self.Q(logits), self.V(logits)
@ -235,38 +235,45 @@ class Recurrent(nn.Module):
def forward(
self,
s: Union[np.ndarray, torch.Tensor],
obs: Union[np.ndarray, torch.Tensor],
state: Optional[Dict[str, torch.Tensor]] = None,
info: Dict[str, Any] = {},
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
"""Mapping: s -> flatten -> logits.
"""Mapping: obs -> flatten -> logits.
In the evaluation mode, s should be with shape ``[bsz, dim]``; in the
training mode, s should be with shape ``[bsz, len, dim]``. See the code
In the evaluation mode, `obs` should be with shape ``[bsz, dim]``; in the
training mode, `obs` should be with shape ``[bsz, len, dim]``. See the code
and comment for more detail.
"""
s = torch.as_tensor(s, device=self.device, dtype=torch.float32) # type: ignore
# s [bsz, len, dim] (training) or [bsz, dim] (evaluation)
obs = torch.as_tensor(
obs,
device=self.device, # type: ignore
dtype=torch.float32,
)
# obs [bsz, len, dim] (training) or [bsz, dim] (evaluation)
# In short, the tensor's shape in training phase is longer than which
# in evaluation phase.
if len(s.shape) == 2:
s = s.unsqueeze(-2)
s = self.fc1(s)
if len(obs.shape) == 2:
obs = obs.unsqueeze(-2)
obs = self.fc1(obs)
self.nn.flatten_parameters()
if state is None:
s, (h, c) = self.nn(s)
obs, (hidden, cell) = self.nn(obs)
else:
# we store the stack data in [bsz, len, ...] format
# but pytorch rnn needs [len, bsz, ...]
s, (h, c) = self.nn(
s, (
state["h"].transpose(0, 1).contiguous(),
state["c"].transpose(0, 1).contiguous()
obs, (hidden, cell) = self.nn(
obs, (
state["hidden"].transpose(0, 1).contiguous(),
state["cell"].transpose(0, 1).contiguous()
)
)
s = self.fc2(s[:, -1])
obs = self.fc2(obs[:, -1])
# please ensure the first dim is batch size: [bsz, len, ...]
return s, {"h": h.transpose(0, 1).detach(), "c": c.transpose(0, 1).detach()}
return obs, {
"hidden": hidden.transpose(0, 1).detach(),
"cell": cell.transpose(0, 1).detach()
}
class ActorCritic(nn.Module):
@ -299,8 +306,8 @@ class DataParallelNet(nn.Module):
super().__init__()
self.net = nn.DataParallel(net)
def forward(self, s: Union[np.ndarray, torch.Tensor], *args: Any,
def forward(self, obs: Union[np.ndarray, torch.Tensor], *args: Any,
**kwargs: Any) -> Tuple[Any, Any]:
if not isinstance(s, torch.Tensor):
s = torch.as_tensor(s, dtype=torch.float32)
return self.net(s=s.cuda(), *args, **kwargs)
if not isinstance(obs, torch.Tensor):
obs = torch.as_tensor(obs, dtype=torch.float32)
return self.net(obs=obs.cuda(), *args, **kwargs)

View File

@ -58,14 +58,14 @@ class Actor(nn.Module):
def forward(
self,
s: Union[np.ndarray, torch.Tensor],
obs: Union[np.ndarray, torch.Tensor],
state: Any = None,
info: Dict[str, Any] = {},
) -> Tuple[torch.Tensor, Any]:
"""Mapping: s -> logits -> action."""
logits, h = self.preprocess(s, state)
"""Mapping: obs -> logits -> action."""
logits, hidden = self.preprocess(obs, state)
logits = self._max * torch.tanh(self.last(logits))
return logits, h
return logits, hidden
class Critic(nn.Module):
@ -110,24 +110,24 @@ class Critic(nn.Module):
def forward(
self,
s: Union[np.ndarray, torch.Tensor],
a: Optional[Union[np.ndarray, torch.Tensor]] = None,
obs: Union[np.ndarray, torch.Tensor],
act: Optional[Union[np.ndarray, torch.Tensor]] = None,
info: Dict[str, Any] = {},
) -> torch.Tensor:
"""Mapping: (s, a) -> logits -> Q(s, a)."""
s = torch.as_tensor(
s,
obs = torch.as_tensor(
obs,
device=self.device, # type: ignore
dtype=torch.float32,
).flatten(1)
if a is not None:
a = torch.as_tensor(
a,
if act is not None:
act = torch.as_tensor(
act,
device=self.device, # type: ignore
dtype=torch.float32,
).flatten(1)
s = torch.cat([s, a], dim=1)
logits, h = self.preprocess(s)
obs = torch.cat([obs, act], dim=1)
logits, hidden = self.preprocess(obs)
logits = self.last(logits)
return logits
@ -196,12 +196,12 @@ class ActorProb(nn.Module):
def forward(
self,
s: Union[np.ndarray, torch.Tensor],
obs: Union[np.ndarray, torch.Tensor],
state: Any = None,
info: Dict[str, Any] = {},
) -> Tuple[Tuple[torch.Tensor, torch.Tensor], Any]:
"""Mapping: s -> logits -> (mu, sigma)."""
logits, h = self.preprocess(s, state)
"""Mapping: obs -> logits -> (mu, sigma)."""
logits, hidden = self.preprocess(obs, state)
mu = self.mu(logits)
if not self._unbounded:
mu = self._max * torch.tanh(mu)
@ -252,30 +252,34 @@ class RecurrentActorProb(nn.Module):
def forward(
self,
s: Union[np.ndarray, torch.Tensor],
obs: Union[np.ndarray, torch.Tensor],
state: Optional[Dict[str, torch.Tensor]] = None,
info: Dict[str, Any] = {},
) -> Tuple[Tuple[torch.Tensor, torch.Tensor], Dict[str, torch.Tensor]]:
"""Almost the same as :class:`~tianshou.utils.net.common.Recurrent`."""
s = torch.as_tensor(s, device=self.device, dtype=torch.float32) # type: ignore
# s [bsz, len, dim] (training) or [bsz, dim] (evaluation)
obs = torch.as_tensor(
obs,
device=self.device, # type: ignore
dtype=torch.float32,
)
# obs [bsz, len, dim] (training) or [bsz, dim] (evaluation)
# In short, the tensor's shape in training phase is longer than which
# in evaluation phase.
if len(s.shape) == 2:
s = s.unsqueeze(-2)
if len(obs.shape) == 2:
obs = obs.unsqueeze(-2)
self.nn.flatten_parameters()
if state is None:
s, (h, c) = self.nn(s)
obs, (hidden, cell) = self.nn(obs)
else:
# we store the stack data in [bsz, len, ...] format
# but pytorch rnn needs [len, bsz, ...]
s, (h, c) = self.nn(
s, (
state["h"].transpose(0, 1).contiguous(),
state["c"].transpose(0, 1).contiguous()
obs, (hidden, cell) = self.nn(
obs, (
state["hidden"].transpose(0, 1).contiguous(),
state["cell"].transpose(0, 1).contiguous()
)
)
logits = s[:, -1]
logits = obs[:, -1]
mu = self.mu(logits)
if not self._unbounded:
mu = self._max * torch.tanh(mu)
@ -287,8 +291,8 @@ class RecurrentActorProb(nn.Module):
sigma = (self.sigma_param.view(shape) + torch.zeros_like(mu)).exp()
# please ensure the first dim is batch size: [bsz, len, ...]
return (mu, sigma), {
"h": h.transpose(0, 1).detach(),
"c": c.transpose(0, 1).detach()
"hidden": hidden.transpose(0, 1).detach(),
"cell": cell.transpose(0, 1).detach()
}
@ -321,28 +325,32 @@ class RecurrentCritic(nn.Module):
def forward(
self,
s: Union[np.ndarray, torch.Tensor],
a: Optional[Union[np.ndarray, torch.Tensor]] = None,
obs: Union[np.ndarray, torch.Tensor],
act: Optional[Union[np.ndarray, torch.Tensor]] = None,
info: Dict[str, Any] = {},
) -> torch.Tensor:
"""Almost the same as :class:`~tianshou.utils.net.common.Recurrent`."""
s = torch.as_tensor(s, device=self.device, dtype=torch.float32) # type: ignore
# s [bsz, len, dim] (training) or [bsz, dim] (evaluation)
obs = torch.as_tensor(
obs,
device=self.device, # type: ignore
dtype=torch.float32,
)
# obs [bsz, len, dim] (training) or [bsz, dim] (evaluation)
# In short, the tensor's shape in training phase is longer than which
# in evaluation phase.
assert len(s.shape) == 3
assert len(obs.shape) == 3
self.nn.flatten_parameters()
s, (h, c) = self.nn(s)
s = s[:, -1]
if a is not None:
a = torch.as_tensor(
a,
obs, (hidden, cell) = self.nn(obs)
obs = obs[:, -1]
if act is not None:
act = torch.as_tensor(
act,
device=self.device, # type: ignore
dtype=torch.float32,
)
s = torch.cat([s, a], dim=1)
s = self.fc2(s)
return s
obs = torch.cat([obs, act], dim=1)
obs = self.fc2(obs)
return obs
class Perturbation(nn.Module):
@ -381,9 +389,9 @@ class Perturbation(nn.Module):
def forward(self, state: torch.Tensor, action: torch.Tensor) -> torch.Tensor:
# preprocess_net
logits = self.preprocess_net(torch.cat([state, action], -1))[0]
a = self.phi * self.max_action * torch.tanh(logits)
noise = self.phi * self.max_action * torch.tanh(logits)
# clip to [-max_action, max_action]
return (a + action).clamp(-self.max_action, self.max_action)
return (noise + action).clamp(-self.max_action, self.max_action)
class VAE(nn.Module):
@ -434,31 +442,32 @@ class VAE(nn.Module):
self, state: torch.Tensor, action: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# [state, action] -> z , [state, z] -> action
z = self.encoder(torch.cat([state, action], -1))
latent_z = self.encoder(torch.cat([state, action], -1))
# shape of z: (state.shape[:-1], hidden_dim)
mean = self.mean(z)
mean = self.mean(latent_z)
# Clamped for numerical stability
log_std = self.log_std(z).clamp(-4, 15)
log_std = self.log_std(latent_z).clamp(-4, 15)
std = torch.exp(log_std)
# shape of mean, std: (state.shape[:-1], latent_dim)
z = mean + std * torch.randn_like(std) # (state.shape[:-1], latent_dim)
latent_z = mean + std * torch.randn_like(std) # (state.shape[:-1], latent_dim)
u = self.decode(state, z) # (state.shape[:-1], action_dim)
return u, mean, std
reconstruction = self.decode(state, latent_z) # (state.shape[:-1], action_dim)
return reconstruction, mean, std
def decode(
self,
state: torch.Tensor,
z: Union[torch.Tensor, None] = None
latent_z: Union[torch.Tensor, None] = None
) -> torch.Tensor:
# decode(state) -> action
if z is None:
if latent_z is None:
# state.shape[0] may be batch_size
# latent vector clipped to [-0.5, 0.5]
z = torch.randn(state.shape[:-1] + (self.latent_dim, )) \
latent_z = torch.randn(state.shape[:-1] + (self.latent_dim, )) \
.to(self.device).clamp(-0.5, 0.5)
# decode z with state!
return self.max_action * torch.tanh(self.decoder(torch.cat([state, z], -1)))
return self.max_action * \
torch.tanh(self.decoder(torch.cat([state, latent_z], -1)))

View File

@ -59,16 +59,16 @@ class Actor(nn.Module):
def forward(
self,
s: Union[np.ndarray, torch.Tensor],
obs: Union[np.ndarray, torch.Tensor],
state: Any = None,
info: Dict[str, Any] = {},
) -> Tuple[torch.Tensor, Any]:
r"""Mapping: s -> Q(s, \*)."""
logits, h = self.preprocess(s, state)
logits, hidden = self.preprocess(obs, state)
logits = self.last(logits)
if self.softmax_output:
logits = F.softmax(logits, dim=-1)
return logits, h
return logits, hidden
class Critic(nn.Module):
@ -114,10 +114,10 @@ class Critic(nn.Module):
)
def forward(
self, s: Union[np.ndarray, torch.Tensor], **kwargs: Any
self, obs: Union[np.ndarray, torch.Tensor], **kwargs: Any
) -> torch.Tensor:
"""Mapping: s -> V(s)."""
logits, _ = self.preprocess(s, state=kwargs.get("state", None))
logits, _ = self.preprocess(obs, state=kwargs.get("state", None))
return self.last(logits)
@ -199,10 +199,10 @@ class ImplicitQuantileNetwork(Critic):
).to(device)
def forward( # type: ignore
self, s: Union[np.ndarray, torch.Tensor], sample_size: int, **kwargs: Any
self, obs: Union[np.ndarray, torch.Tensor], sample_size: int, **kwargs: Any
) -> Tuple[Any, torch.Tensor]:
r"""Mapping: s -> Q(s, \*)."""
logits, h = self.preprocess(s, state=kwargs.get("state", None))
logits, hidden = self.preprocess(obs, state=kwargs.get("state", None))
# Sample fractions.
batch_size = logits.size(0)
taus = torch.rand(
@ -211,7 +211,7 @@ class ImplicitQuantileNetwork(Critic):
embedding = (logits.unsqueeze(1) *
self.embed_model(taus)).view(batch_size * sample_size, -1)
out = self.last(embedding).view(batch_size, sample_size, -1).transpose(1, 2)
return (out, taus), h
return (out, taus), hidden
class FractionProposalNetwork(nn.Module):
@ -235,17 +235,17 @@ class FractionProposalNetwork(nn.Module):
self.embedding_dim = embedding_dim
def forward(
self, state_embeddings: torch.Tensor
self, obs_embeddings: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# Calculate (log of) probabilities q_i in the paper.
m = torch.distributions.Categorical(logits=self.net(state_embeddings))
taus_1_N = torch.cumsum(m.probs, dim=1)
dist = torch.distributions.Categorical(logits=self.net(obs_embeddings))
taus_1_N = torch.cumsum(dist.probs, dim=1)
# Calculate \tau_i (i=0,...,N).
taus = F.pad(taus_1_N, (1, 0))
# Calculate \hat \tau_i (i=0,...,N-1).
tau_hats = (taus[:, :-1] + taus[:, 1:]).detach() / 2.0
# Calculate entropies of value distributions.
entropies = m.entropy()
entropies = dist.entropy()
return taus, tau_hats, entropies
@ -294,13 +294,13 @@ class FullQuantileFunction(ImplicitQuantileNetwork):
return quantiles
def forward( # type: ignore
self, s: Union[np.ndarray, torch.Tensor],
self, obs: Union[np.ndarray, torch.Tensor],
propose_model: FractionProposalNetwork,
fractions: Optional[Batch] = None,
**kwargs: Any
) -> Tuple[Any, torch.Tensor]:
r"""Mapping: s -> Q(s, \*)."""
logits, h = self.preprocess(s, state=kwargs.get("state", None))
logits, hidden = self.preprocess(obs, state=kwargs.get("state", None))
# Propose fractions
if fractions is None:
taus, tau_hats, entropies = propose_model(logits.detach())
@ -313,7 +313,7 @@ class FullQuantileFunction(ImplicitQuantileNetwork):
if self.training:
with torch.no_grad():
quantiles_tau = self._compute_quantiles(logits, taus[:, 1:-1])
return (quantiles, fractions, quantiles_tau), h
return (quantiles, fractions, quantiles_tau), hidden
class NoisyLinear(nn.Module):

View File

@ -31,20 +31,20 @@ class MovAvg(object):
self.banned = [np.inf, np.nan, -np.inf]
def add(
self, x: Union[Number, np.number, list, np.ndarray, torch.Tensor]
self, data_array: Union[Number, np.number, list, np.ndarray, torch.Tensor]
) -> float:
"""Add a scalar into :class:`MovAvg`.
You can add ``torch.Tensor`` with only one element, a python scalar, or
a list of python scalar.
"""
if isinstance(x, torch.Tensor):
x = x.flatten().cpu().numpy()
if np.isscalar(x):
x = [x]
for i in x: # type: ignore
if i not in self.banned:
self.cache.append(i)
if isinstance(data_array, torch.Tensor):
data_array = data_array.flatten().cpu().numpy()
if np.isscalar(data_array):
data_array = [data_array]
for number in data_array: # type: ignore
if number not in self.banned:
self.cache.append(number)
if self.size > 0 and len(self.cache) > self.size:
self.cache = self.cache[-self.size:]
return self.get()
@ -80,10 +80,10 @@ class RunningMeanStd(object):
self.mean, self.var = mean, std
self.count = 0
def update(self, x: np.ndarray) -> None:
def update(self, data_array: np.ndarray) -> None:
"""Add a batch of item into RMS with the same shape, modify mean/var/count."""
batch_mean, batch_var = np.mean(x, axis=0), np.var(x, axis=0)
batch_count = len(x)
batch_mean, batch_var = np.mean(data_array, axis=0), np.var(data_array, axis=0)
batch_count = len(data_array)
delta = batch_mean - self.mean
total_count = self.count + batch_count