Formalize variable names (#509)

Co-authored-by: Jiayi Weng <trinkle23897@gmail.com>
This commit is contained in:
ChenDRAG 2022-01-30 00:53:56 +08:00 committed by GitHub
parent bc53ead273
commit c25926dd8f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
42 changed files with 607 additions and 581 deletions

View File

@ -350,7 +350,7 @@ But the state stored in the buffer may be a shallow-copy. To make sure each of y
def reset(): def reset():
return copy.deepcopy(self.graph) return copy.deepcopy(self.graph)
def step(a): def step(action):
... ...
return copy.deepcopy(self.graph), reward, done, {} return copy.deepcopy(self.graph), reward, done, {}
@ -391,13 +391,13 @@ In addition, legal actions in multi-agent RL often vary with timestep (just like
The above description gives rise to the following formulation of multi-agent RL: The above description gives rise to the following formulation of multi-agent RL:
:: ::
action = policy(state, agent_id, mask) act = policy(state, agent_id, mask)
(next_state, next_agent_id, next_mask), reward = env.step(action) (next_state, next_agent_id, next_mask), reward = env.step(act)
By constructing a new state ``state_ = (state, agent_id, mask)``, essentially we can return to the typical formulation of RL: By constructing a new state ``state_ = (state, agent_id, mask)``, essentially we can return to the typical formulation of RL:
:: ::
action = policy(state_) act = policy(state_)
next_state_, reward = env.step(action) next_state_, reward = env.step(act)
Following this idea, we write a tiny example of playing `Tic Tac Toe <https://en.wikipedia.org/wiki/Tic-tac-toe>`_ against a random player by using a Q-learning algorithm. The tutorial is at :doc:`/tutorials/tictactoe`. Following this idea, we write a tiny example of playing `Tic Tac Toe <https://en.wikipedia.org/wiki/Tic-tac-toe>`_ against a random player by using a Q-learning algorithm. The tutorial is at :doc:`/tutorials/tictactoe`.

View File

@ -298,14 +298,14 @@ where :math:`\gamma` is the discount factor, :math:`\gamma \in [0, 1]`. Here is
:: ::
# pseudocode, cannot work # pseudocode, cannot work
s = env.reset() obs = env.reset()
buffer = Buffer(size=10000) buffer = Buffer(size=10000)
agent = DQN() agent = DQN()
for i in range(int(1e6)): for i in range(int(1e6)):
a = agent.compute_action(s) act = agent.compute_action(obs)
s_, r, d, _ = env.step(a) obs_next, rew, done, _ = env.step(act)
buffer.store(s, a, s_, r, d) buffer.store(obs, act, obs_next, rew, done)
s = s_ obs = obs_next
if i % 1000 == 0: if i % 1000 == 0:
b_s, b_a, b_s_, b_r, b_d = buffer.get(size=64) b_s, b_a, b_s_, b_r, b_d = buffer.get(size=64)
# compute 2-step returns. How? # compute 2-step returns. How?
@ -390,14 +390,14 @@ We give a high-level explanation through the pseudocode used in section :ref:`pr
:: ::
# pseudocode, cannot work # methods in tianshou # pseudocode, cannot work # methods in tianshou
s = env.reset() obs = env.reset()
buffer = Buffer(size=10000) # buffer = tianshou.data.ReplayBuffer(size=10000) buffer = Buffer(size=10000) # buffer = tianshou.data.ReplayBuffer(size=10000)
agent = DQN() # policy.__init__(...) agent = DQN() # policy.__init__(...)
for i in range(int(1e6)): # done in trainer for i in range(int(1e6)): # done in trainer
a = agent.compute_action(s) # act = policy(batch, ...).act act = agent.compute_action(obs) # act = policy(batch, ...).act
s_, r, d, _ = env.step(a) # collector.collect(...) obs_next, rew, done, _ = env.step(act) # collector.collect(...)
buffer.store(s, a, s_, r, d) # collector.collect(...) buffer.store(obs, act, obs_next, rew, done) # collector.collect(...)
s = s_ # collector.collect(...) obs = obs_next # collector.collect(...)
if i % 1000 == 0: # done in trainer if i % 1000 == 0: # done in trainer
# the following is done in policy.update(batch_size, buffer) # the following is done in policy.update(batch_size, buffer)
b_s, b_a, b_s_, b_r, b_d = buffer.get(size=64) # batch, indices = buffer.sample(batch_size) b_s, b_a, b_s_, b_r, b_d = buffer.get(size=64) # batch, indices = buffer.sample(batch_size)

View File

@ -42,13 +42,13 @@ class DQN(nn.Module):
def forward( def forward(
self, self,
x: Union[np.ndarray, torch.Tensor], obs: Union[np.ndarray, torch.Tensor],
state: Optional[Any] = None, state: Optional[Any] = None,
info: Dict[str, Any] = {}, info: Dict[str, Any] = {},
) -> Tuple[torch.Tensor, Any]: ) -> Tuple[torch.Tensor, Any]:
r"""Mapping: x -> Q(x, \*).""" r"""Mapping: s -> Q(s, \*)."""
x = torch.as_tensor(x, device=self.device, dtype=torch.float32) obs = torch.as_tensor(obs, device=self.device, dtype=torch.float32)
return self.net(x), state return self.net(obs), state
class C51(DQN): class C51(DQN):
@ -73,15 +73,15 @@ class C51(DQN):
def forward( def forward(
self, self,
x: Union[np.ndarray, torch.Tensor], obs: Union[np.ndarray, torch.Tensor],
state: Optional[Any] = None, state: Optional[Any] = None,
info: Dict[str, Any] = {}, info: Dict[str, Any] = {},
) -> Tuple[torch.Tensor, Any]: ) -> Tuple[torch.Tensor, Any]:
r"""Mapping: x -> Z(x, \*).""" r"""Mapping: x -> Z(x, \*)."""
x, state = super().forward(x) obs, state = super().forward(obs)
x = x.view(-1, self.num_atoms).softmax(dim=-1) obs = obs.view(-1, self.num_atoms).softmax(dim=-1)
x = x.view(-1, self.action_num, self.num_atoms) obs = obs.view(-1, self.action_num, self.num_atoms)
return x, state return obs, state
class Rainbow(DQN): class Rainbow(DQN):
@ -127,22 +127,22 @@ class Rainbow(DQN):
def forward( def forward(
self, self,
x: Union[np.ndarray, torch.Tensor], obs: Union[np.ndarray, torch.Tensor],
state: Optional[Any] = None, state: Optional[Any] = None,
info: Dict[str, Any] = {}, info: Dict[str, Any] = {},
) -> Tuple[torch.Tensor, Any]: ) -> Tuple[torch.Tensor, Any]:
r"""Mapping: x -> Z(x, \*).""" r"""Mapping: x -> Z(x, \*)."""
x, state = super().forward(x) obs, state = super().forward(obs)
q = self.Q(x) q = self.Q(obs)
q = q.view(-1, self.action_num, self.num_atoms) q = q.view(-1, self.action_num, self.num_atoms)
if self._is_dueling: if self._is_dueling:
v = self.V(x) v = self.V(obs)
v = v.view(-1, 1, self.num_atoms) v = v.view(-1, 1, self.num_atoms)
logits = q - q.mean(dim=1, keepdim=True) + v logits = q - q.mean(dim=1, keepdim=True) + v
else: else:
logits = q logits = q
y = logits.softmax(dim=2) probs = logits.softmax(dim=2)
return y, state return probs, state
class QRDQN(DQN): class QRDQN(DQN):
@ -168,11 +168,11 @@ class QRDQN(DQN):
def forward( def forward(
self, self,
x: Union[np.ndarray, torch.Tensor], obs: Union[np.ndarray, torch.Tensor],
state: Optional[Any] = None, state: Optional[Any] = None,
info: Dict[str, Any] = {}, info: Dict[str, Any] = {},
) -> Tuple[torch.Tensor, Any]: ) -> Tuple[torch.Tensor, Any]:
r"""Mapping: x -> Z(x, \*).""" r"""Mapping: x -> Z(x, \*)."""
x, state = super().forward(x) obs, state = super().forward(obs)
x = x.view(-1, self.action_num, self.num_quantiles) obs = obs.view(-1, self.action_num, self.num_quantiles)
return x, state return obs, state

View File

@ -56,16 +56,16 @@ class Wrapper(gym.Wrapper):
self.rm_done = rm_done self.rm_done = rm_done
def step(self, action): def step(self, action):
r = 0.0 rew_sum = 0.0
for _ in range(self.action_repeat): for _ in range(self.action_repeat):
obs, reward, done, info = self.env.step(action) obs, rew, done, info = self.env.step(action)
# remove done reward penalty # remove done reward penalty
if not done or not self.rm_done: if not done or not self.rm_done:
r = r + reward rew_sum = rew_sum + rew
if done: if done:
break break
# scale reward # scale reward
return obs, self.reward_scale * r, done, info return obs, self.reward_scale * rew_sum, done, info
def test_sac_bipedal(args=get_args()): def test_sac_bipedal(args=get_args()):

View File

@ -86,10 +86,10 @@ class MyTestEnv(gym.Env):
def _get_reward(self): def _get_reward(self):
"""Generate a non-scalar reward if ma_rew is True.""" """Generate a non-scalar reward if ma_rew is True."""
x = int(self.done) end_flag = int(self.done)
if self.ma_rew > 0: if self.ma_rew > 0:
return [x] * self.ma_rew return [end_flag] * self.ma_rew
return x return end_flag
def _get_state(self): def _get_state(self):
"""Generate state(observation) of MyTestEnv""" """Generate state(observation) of MyTestEnv"""

View File

@ -32,10 +32,12 @@ def test_replaybuffer(size=10, bufsize=20):
assert str(buf) == buf.__class__.__name__ + '()' assert str(buf) == buf.__class__.__name__ + '()'
obs = env.reset() obs = env.reset()
action_list = [1] * 5 + [0] * 10 + [1] * 10 action_list = [1] * 5 + [0] * 10 + [1] * 10
for i, a in enumerate(action_list): for i, act in enumerate(action_list):
obs_next, rew, done, info = env.step(a) obs_next, rew, done, info = env.step(act)
buf.add( buf.add(
Batch(obs=obs, act=[a], rew=rew, done=done, obs_next=obs_next, info=info) Batch(
obs=obs, act=[act], rew=rew, done=done, obs_next=obs_next, info=info
)
) )
obs = obs_next obs = obs_next
assert len(buf) == min(bufsize, i + 1) assert len(buf) == min(bufsize, i + 1)
@ -220,11 +222,11 @@ def test_priortized_replaybuffer(size=32, bufsize=15):
buf2 = PrioritizedVectorReplayBuffer(bufsize, buffer_num=3, alpha=0.5, beta=0.5) buf2 = PrioritizedVectorReplayBuffer(bufsize, buffer_num=3, alpha=0.5, beta=0.5)
obs = env.reset() obs = env.reset()
action_list = [1] * 5 + [0] * 10 + [1] * 10 action_list = [1] * 5 + [0] * 10 + [1] * 10
for i, a in enumerate(action_list): for i, act in enumerate(action_list):
obs_next, rew, done, info = env.step(a) obs_next, rew, done, info = env.step(act)
batch = Batch( batch = Batch(
obs=obs, obs=obs,
act=a, act=act,
rew=rew, rew=rew,
done=done, done=done,
obs_next=obs_next, obs_next=obs_next,

View File

@ -331,20 +331,20 @@ def test_collector_with_ma():
policy = MyPolicy() policy = MyPolicy()
c0 = Collector(policy, env, ReplayBuffer(size=100), Logger.single_preprocess_fn) c0 = Collector(policy, env, ReplayBuffer(size=100), Logger.single_preprocess_fn)
# n_step=3 will collect a full episode # n_step=3 will collect a full episode
r = c0.collect(n_step=3)['rews'] rew = c0.collect(n_step=3)['rews']
assert len(r) == 0 assert len(rew) == 0
r = c0.collect(n_episode=2)['rews'] rew = c0.collect(n_episode=2)['rews']
assert r.shape == (2, 4) and np.all(r == 1) assert rew.shape == (2, 4) and np.all(rew == 1)
env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0, ma_rew=4) for i in [2, 3, 4, 5]] env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0, ma_rew=4) for i in [2, 3, 4, 5]]
envs = DummyVectorEnv(env_fns) envs = DummyVectorEnv(env_fns)
c1 = Collector( c1 = Collector(
policy, envs, VectorReplayBuffer(total_size=100, buffer_num=4), policy, envs, VectorReplayBuffer(total_size=100, buffer_num=4),
Logger.single_preprocess_fn Logger.single_preprocess_fn
) )
r = c1.collect(n_step=12)['rews'] rew = c1.collect(n_step=12)['rews']
assert r.shape == (2, 4) and np.all(r == 1), r assert rew.shape == (2, 4) and np.all(rew == 1), rew
r = c1.collect(n_episode=8)['rews'] rew = c1.collect(n_episode=8)['rews']
assert r.shape == (8, 4) and np.all(r == 1) assert rew.shape == (8, 4) and np.all(rew == 1)
batch, _ = c1.buffer.sample(10) batch, _ = c1.buffer.sample(10)
print(batch) print(batch)
c0.buffer.update(c1.buffer) c0.buffer.update(c1.buffer)
@ -446,8 +446,8 @@ def test_collector_with_ma():
policy, envs, VectorReplayBuffer(total_size=100, buffer_num=4, stack_num=4), policy, envs, VectorReplayBuffer(total_size=100, buffer_num=4, stack_num=4),
Logger.single_preprocess_fn Logger.single_preprocess_fn
) )
r = c2.collect(n_episode=10)['rews'] rew = c2.collect(n_episode=10)['rews']
assert r.shape == (10, 4) and np.all(r == 1) assert rew.shape == (10, 4) and np.all(rew == 1)
batch, _ = c2.buffer.sample(10) batch, _ = c2.buffer.sample(10)

View File

@ -58,22 +58,22 @@ def test_async_env(size=10000, num=8, sleep=0.1):
# should be smaller # should be smaller
action_list = [1] * num + [0] * (num * 2) + [1] * (num * 4) action_list = [1] * num + [0] * (num * 2) + [1] * (num * 4)
current_idx_start = 0 current_idx_start = 0
action = action_list[:num] act = action_list[:num]
env_ids = list(range(num)) env_ids = list(range(num))
o = [] o = []
spent_time = time.time() spent_time = time.time()
while current_idx_start < len(action_list): while current_idx_start < len(action_list):
A, B, C, D = v.step(action=action, id=env_ids) A, B, C, D = v.step(action=act, id=env_ids)
b = Batch({'obs': A, 'rew': B, 'done': C, 'info': D}) b = Batch({'obs': A, 'rew': B, 'done': C, 'info': D})
env_ids = b.info.env_id env_ids = b.info.env_id
o.append(b) o.append(b)
current_idx_start += len(action) current_idx_start += len(act)
# len of action may be smaller than len(A) in the end # len of action may be smaller than len(A) in the end
action = action_list[current_idx_start:current_idx_start + len(A)] act = action_list[current_idx_start:current_idx_start + len(A)]
# truncate env_ids with the first terms # truncate env_ids with the first terms
# typically len(env_ids) == len(A) == len(action), except for the # typically len(env_ids) == len(A) == len(action), except for the
# last batch when actions are not enough # last batch when actions are not enough
env_ids = env_ids[:len(action)] env_ids = env_ids[:len(act)]
spent_time = time.time() - spent_time spent_time = time.time() - spent_time
Batch.cat(o) Batch.cat(o)
v.close() v.close()

View File

@ -142,11 +142,11 @@ def compute_nstep_return_base(nstep, gamma, buffer, indices):
returns = np.zeros_like(indices, dtype=float) returns = np.zeros_like(indices, dtype=float)
buf_len = len(buffer) buf_len = len(buffer)
for i in range(len(indices)): for i in range(len(indices)):
flag, r = False, 0. flag, rew = False, 0.
real_step_n = nstep real_step_n = nstep
for n in range(nstep): for n in range(nstep):
idx = (indices[i] + n) % buf_len idx = (indices[i] + n) % buf_len
r += buffer.rew[idx] * gamma**n rew += buffer.rew[idx] * gamma**n
if buffer.done[idx]: if buffer.done[idx]:
if not ( if not (
hasattr(buffer, 'info') and buffer.info['TimeLimit.truncated'][idx] hasattr(buffer, 'info') and buffer.info['TimeLimit.truncated'][idx]
@ -156,8 +156,8 @@ def compute_nstep_return_base(nstep, gamma, buffer, indices):
break break
if not flag: if not flag:
idx = (indices[i] + real_step_n - 1) % buf_len idx = (indices[i] + real_step_n - 1) % buf_len
r += to_numpy(target_q_fn(buffer, idx)) * gamma**real_step_n rew += to_numpy(target_q_fn(buffer, idx)) * gamma**real_step_n
returns[i] = r returns[i] = rew
return returns return returns

View File

@ -41,7 +41,7 @@ def gomoku(args=get_args()):
return TicTacToeEnv(args.board_size, args.win_size) return TicTacToeEnv(args.board_size, args.win_size)
test_envs = DummyVectorEnv([env_func for _ in range(args.test_num)]) test_envs = DummyVectorEnv([env_func for _ in range(args.test_num)])
for r in range(args.self_play_round): for round in range(args.self_play_round):
rews = [] rews = []
agent_learn.set_eps(0.0) agent_learn.set_eps(0.0)
# compute the reward over previous learner # compute the reward over previous learner
@ -66,12 +66,12 @@ def gomoku(args=get_args()):
# previous learner can only be used for forward # previous learner can only be used for forward
agent.forward = opponent.forward agent.forward = opponent.forward
args.model_save_path = os.path.join( args.model_save_path = os.path.join(
args.logdir, 'Gomoku', 'dqn', f'policy_round_{r}_epoch_{epoch}.pth' args.logdir, 'Gomoku', 'dqn', f'policy_round_{round}_epoch_{epoch}.pth'
) )
result, agent_learn = train_agent( result, agent_learn = train_agent(
args, agent_learn=agent_learn, agent_opponent=agent, optim=optim args, agent_learn=agent_learn, agent_opponent=agent, optim=optim
) )
print(f'round_{r}_epoch_{epoch}') print(f'round_{round}_epoch_{epoch}')
pprint.pprint(result) pprint.pprint(result)
learnt_agent = deepcopy(agent_learn) learnt_agent = deepcopy(agent_learn)
learnt_agent.set_eps(0.0) learnt_agent.set_eps(0.0)

View File

@ -11,17 +11,18 @@ import torch
IndexType = Union[slice, int, np.ndarray, List[int]] IndexType = Union[slice, int, np.ndarray, List[int]]
def _is_batch_set(data: Any) -> bool: def _is_batch_set(obj: Any) -> bool:
# Batch set is a list/tuple of dict/Batch objects, # Batch set is a list/tuple of dict/Batch objects,
# or 1-D np.ndarray with object type, # or 1-D np.ndarray with object type,
# where each element is a dict/Batch object # where each element is a dict/Batch object
if isinstance(data, np.ndarray): # most often case if isinstance(obj, np.ndarray): # most often case
# "for e in data" will just unpack the first dimension, # "for element in obj" will just unpack the first dimension,
# but data.tolist() will flatten ndarray of objects # but obj.tolist() will flatten ndarray of objects
# so do not use data.tolist() # so do not use obj.tolist()
return data.dtype == object and all(isinstance(e, (dict, Batch)) for e in data) return obj.dtype == object and \
elif isinstance(data, (list, tuple)): all(isinstance(element, (dict, Batch)) for element in obj)
if len(data) > 0 and all(isinstance(e, (dict, Batch)) for e in data): elif isinstance(obj, (list, tuple)):
if len(obj) > 0 and all(isinstance(element, (dict, Batch)) for element in obj):
return True return True
return False return False
@ -48,28 +49,29 @@ def _is_number(value: Any) -> bool:
return isinstance(value, (Number, np.number, np.bool_)) return isinstance(value, (Number, np.number, np.bool_))
def _to_array_with_correct_type(v: Any) -> np.ndarray: def _to_array_with_correct_type(obj: Any) -> np.ndarray:
if isinstance(v, np.ndarray) and issubclass(v.dtype.type, (np.bool_, np.number)): if isinstance(obj, np.ndarray) and \
return v # most often case issubclass(obj.dtype.type, (np.bool_, np.number)):
return obj # most often case
# convert the value to np.ndarray # convert the value to np.ndarray
# convert to object data type if neither bool nor number # convert to object obj type if neither bool nor number
# raises an exception if array's elements are tensors themselves # raises an exception if array's elements are tensors themselves
v = np.asanyarray(v) obj_array = np.asanyarray(obj)
if not issubclass(v.dtype.type, (np.bool_, np.number)): if not issubclass(obj_array.dtype.type, (np.bool_, np.number)):
v = v.astype(object) obj_array = obj_array.astype(object)
if v.dtype == object: if obj_array.dtype == object:
# scalar ndarray with object data type is very annoying # scalar ndarray with object obj type is very annoying
# a=np.array([np.array({}, dtype=object), np.array({}, dtype=object)]) # a=np.array([np.array({}, dtype=object), np.array({}, dtype=object)])
# a is not array([{}, {}], dtype=object), and a[0]={} results in # a is not array([{}, {}], dtype=object), and a[0]={} results in
# something very strange: # something very strange:
# array([{}, array({}, dtype=object)], dtype=object) # array([{}, array({}, dtype=object)], dtype=object)
if not v.shape: if not obj_array.shape:
v = v.item(0) obj_array = obj_array.item(0)
elif all(isinstance(e, np.ndarray) for e in v.reshape(-1)): elif all(isinstance(arr, np.ndarray) for arr in obj_array.reshape(-1)):
return v # various length, np.array([[1], [2, 3], [4, 5, 6]]) return obj_array # various length, np.array([[1], [2, 3], [4, 5, 6]])
elif any(isinstance(e, torch.Tensor) for e in v.reshape(-1)): elif any(isinstance(arr, torch.Tensor) for arr in obj_array.reshape(-1)):
raise ValueError("Numpy arrays of tensors are not supported yet.") raise ValueError("Numpy arrays of tensors are not supported yet.")
return v return obj_array
def _create_value( def _create_value(
@ -113,44 +115,45 @@ def _create_value(
def _assert_type_keys(keys: Iterable[str]) -> None: def _assert_type_keys(keys: Iterable[str]) -> None:
assert all(isinstance(e, str) for e in keys), \ assert all(isinstance(key, str) for key in keys), \
f"keys should all be string, but got {keys}" f"keys should all be string, but got {keys}"
def _parse_value(v: Any) -> Optional[Union["Batch", np.ndarray, torch.Tensor]]: def _parse_value(obj: Any) -> Optional[Union["Batch", np.ndarray, torch.Tensor]]:
if isinstance(v, Batch): # most often case if isinstance(obj, Batch): # most often case
return v return obj
elif (isinstance(v, np.ndarray) and elif (isinstance(obj, np.ndarray) and
issubclass(v.dtype.type, (np.bool_, np.number))) or \ issubclass(obj.dtype.type, (np.bool_, np.number))) or \
isinstance(v, torch.Tensor) or v is None: # third often case isinstance(obj, torch.Tensor) or obj is None: # third often case
return v return obj
elif _is_number(v): # second often case, but it is more time-consuming elif _is_number(obj): # second often case, but it is more time-consuming
return np.asanyarray(v) return np.asanyarray(obj)
elif isinstance(v, dict): elif isinstance(obj, dict):
return Batch(v) return Batch(obj)
else: else:
if not isinstance(v, np.ndarray) and isinstance(v, Collection) and \ if not isinstance(obj, np.ndarray) and \
len(v) > 0 and all(isinstance(e, torch.Tensor) for e in v): isinstance(obj, Collection) and len(obj) > 0 and \
all(isinstance(element, torch.Tensor) for element in obj):
try: try:
return torch.stack(v) # type: ignore return torch.stack(obj) # type: ignore
except RuntimeError as e: except RuntimeError as exception:
raise TypeError( raise TypeError(
"Batch does not support non-stackable iterable" "Batch does not support non-stackable iterable"
" of torch.Tensor as unique value yet." " of torch.Tensor as unique value yet."
) from e ) from exception
if _is_batch_set(v): if _is_batch_set(obj):
v = Batch(v) # list of dict / Batch obj = Batch(obj) # list of dict / Batch
else: else:
# None, scalar, normal data list (main case) # None, scalar, normal obj list (main case)
# or an actual list of objects # or an actual list of objects
try: try:
v = _to_array_with_correct_type(v) obj = _to_array_with_correct_type(obj)
except ValueError as e: except ValueError as exception:
raise TypeError( raise TypeError(
"Batch does not support heterogeneous list/" "Batch does not support heterogeneous list/"
"tuple of tensors as unique value yet." "tuple of tensors as unique value yet."
) from e ) from exception
return v return obj
def _alloc_by_keys_diff( def _alloc_by_keys_diff(
@ -189,8 +192,8 @@ class Batch:
if batch_dict is not None: if batch_dict is not None:
if isinstance(batch_dict, (dict, Batch)): if isinstance(batch_dict, (dict, Batch)):
_assert_type_keys(batch_dict.keys()) _assert_type_keys(batch_dict.keys())
for k, v in batch_dict.items(): for batch_key, obj in batch_dict.items():
self.__dict__[k] = _parse_value(v) self.__dict__[batch_key] = _parse_value(obj)
elif _is_batch_set(batch_dict): elif _is_batch_set(batch_dict):
self.stack_(batch_dict) # type: ignore self.stack_(batch_dict) # type: ignore
if len(kwargs) > 0: if len(kwargs) > 0:
@ -214,10 +217,10 @@ class Batch:
Only the actual data are serialized for both efficiency and simplicity. Only the actual data are serialized for both efficiency and simplicity.
""" """
state = {} state = {}
for k, v in self.items(): for batch_key, obj in self.items():
if isinstance(v, Batch): if isinstance(obj, Batch):
v = v.__getstate__() obj = obj.__getstate__()
state[k] = v state[batch_key] = obj
return state return state
def __setstate__(self, state: Dict[str, Any]) -> None: def __setstate__(self, state: Dict[str, Any]) -> None:
@ -234,13 +237,13 @@ class Batch:
return self.__dict__[index] return self.__dict__[index]
batch_items = self.items() batch_items = self.items()
if len(batch_items) > 0: if len(batch_items) > 0:
b = Batch() new_batch = Batch()
for k, v in batch_items: for batch_key, obj in batch_items:
if isinstance(v, Batch) and v.is_empty(): if isinstance(obj, Batch) and obj.is_empty():
b.__dict__[k] = Batch() new_batch.__dict__[batch_key] = Batch()
else: else:
b.__dict__[k] = v[index] new_batch.__dict__[batch_key] = obj[index]
return b return new_batch
else: else:
raise IndexError("Cannot access item from empty Batch object.") raise IndexError("Cannot access item from empty Batch object.")
@ -273,20 +276,20 @@ class Batch:
def __iadd__(self, other: Union["Batch", Number, np.number]) -> "Batch": def __iadd__(self, other: Union["Batch", Number, np.number]) -> "Batch":
"""Algebraic addition with another Batch instance in-place.""" """Algebraic addition with another Batch instance in-place."""
if isinstance(other, Batch): if isinstance(other, Batch):
for (k, r), v in zip( for (batch_key, obj), value in zip(
self.__dict__.items(), other.__dict__.values() self.__dict__.items(), other.__dict__.values()
): # TODO are keys consistent? ): # TODO are keys consistent?
if isinstance(r, Batch) and r.is_empty(): if isinstance(obj, Batch) and obj.is_empty():
continue continue
else: else:
self.__dict__[k] += v self.__dict__[batch_key] += value
return self return self
elif _is_number(other): elif _is_number(other):
for k, r in self.items(): for batch_key, obj in self.items():
if isinstance(r, Batch) and r.is_empty(): if isinstance(obj, Batch) and obj.is_empty():
continue continue
else: else:
self.__dict__[k] += other self.__dict__[batch_key] += other
return self return self
else: else:
raise TypeError("Only addition of Batch or number is supported.") raise TypeError("Only addition of Batch or number is supported.")
@ -295,54 +298,54 @@ class Batch:
"""Algebraic addition with another Batch instance out-of-place.""" """Algebraic addition with another Batch instance out-of-place."""
return deepcopy(self).__iadd__(other) return deepcopy(self).__iadd__(other)
def __imul__(self, val: Union[Number, np.number]) -> "Batch": def __imul__(self, value: Union[Number, np.number]) -> "Batch":
"""Algebraic multiplication with a scalar value in-place.""" """Algebraic multiplication with a scalar value in-place."""
assert _is_number(val), "Only multiplication by a number is supported." assert _is_number(value), "Only multiplication by a number is supported."
for k, r in self.__dict__.items(): for batch_key, obj in self.__dict__.items():
if isinstance(r, Batch) and r.is_empty(): if isinstance(obj, Batch) and obj.is_empty():
continue continue
self.__dict__[k] *= val self.__dict__[batch_key] *= value
return self return self
def __mul__(self, val: Union[Number, np.number]) -> "Batch": def __mul__(self, value: Union[Number, np.number]) -> "Batch":
"""Algebraic multiplication with a scalar value out-of-place.""" """Algebraic multiplication with a scalar value out-of-place."""
return deepcopy(self).__imul__(val) return deepcopy(self).__imul__(value)
def __itruediv__(self, val: Union[Number, np.number]) -> "Batch": def __itruediv__(self, value: Union[Number, np.number]) -> "Batch":
"""Algebraic division with a scalar value in-place.""" """Algebraic division with a scalar value in-place."""
assert _is_number(val), "Only division by a number is supported." assert _is_number(value), "Only division by a number is supported."
for k, r in self.__dict__.items(): for batch_key, obj in self.__dict__.items():
if isinstance(r, Batch) and r.is_empty(): if isinstance(obj, Batch) and obj.is_empty():
continue continue
self.__dict__[k] /= val self.__dict__[batch_key] /= value
return self return self
def __truediv__(self, val: Union[Number, np.number]) -> "Batch": def __truediv__(self, value: Union[Number, np.number]) -> "Batch":
"""Algebraic division with a scalar value out-of-place.""" """Algebraic division with a scalar value out-of-place."""
return deepcopy(self).__itruediv__(val) return deepcopy(self).__itruediv__(value)
def __repr__(self) -> str: def __repr__(self) -> str:
"""Return str(self).""" """Return str(self)."""
s = self.__class__.__name__ + "(\n" self_str = self.__class__.__name__ + "(\n"
flag = False flag = False
for k, v in self.__dict__.items(): for batch_key, obj in self.__dict__.items():
rpl = "\n" + " " * (6 + len(k)) rpl = "\n" + " " * (6 + len(batch_key))
obj = pprint.pformat(v).replace("\n", rpl) obj_name = pprint.pformat(obj).replace("\n", rpl)
s += f" {k}: {obj},\n" self_str += f" {batch_key}: {obj_name},\n"
flag = True flag = True
if flag: if flag:
s += ")" self_str += ")"
else: else:
s = self.__class__.__name__ + "()" self_str = self.__class__.__name__ + "()"
return s return self_str
def to_numpy(self) -> None: def to_numpy(self) -> None:
"""Change all torch.Tensor to numpy.ndarray in-place.""" """Change all torch.Tensor to numpy.ndarray in-place."""
for k, v in self.items(): for batch_key, obj in self.items():
if isinstance(v, torch.Tensor): if isinstance(obj, torch.Tensor):
self.__dict__[k] = v.detach().cpu().numpy() self.__dict__[batch_key] = obj.detach().cpu().numpy()
elif isinstance(v, Batch): elif isinstance(obj, Batch):
v.to_numpy() obj.to_numpy()
def to_torch( def to_torch(
self, self,
@ -353,24 +356,24 @@ class Batch:
if not isinstance(device, torch.device): if not isinstance(device, torch.device):
device = torch.device(device) device = torch.device(device)
for k, v in self.items(): for batch_key, obj in self.items():
if isinstance(v, torch.Tensor): if isinstance(obj, torch.Tensor):
if dtype is not None and v.dtype != dtype or \ if dtype is not None and obj.dtype != dtype or \
v.device.type != device.type or \ obj.device.type != device.type or \
device.index != v.device.index: device.index != obj.device.index:
if dtype is not None: if dtype is not None:
v = v.type(dtype) obj = obj.type(dtype)
self.__dict__[k] = v.to(device) self.__dict__[batch_key] = obj.to(device)
elif isinstance(v, Batch): elif isinstance(obj, Batch):
v.to_torch(dtype, device) obj.to_torch(dtype, device)
else: else:
# ndarray or scalar # ndarray or scalar
if not isinstance(v, np.ndarray): if not isinstance(obj, np.ndarray):
v = np.asanyarray(v) obj = np.asanyarray(obj)
v = torch.from_numpy(v).to(device) obj = torch.from_numpy(obj).to(device)
if dtype is not None: if dtype is not None:
v = v.type(dtype) obj = obj.type(dtype)
self.__dict__[k] = v self.__dict__[batch_key] = obj
def __cat(self, batches: Sequence[Union[dict, "Batch"]], lens: List[int]) -> None: def __cat(self, batches: Sequence[Union[dict, "Batch"]], lens: List[int]) -> None:
"""Private method for Batch.cat_. """Private method for Batch.cat_.
@ -395,50 +398,51 @@ class Batch:
# partial keys will be padded by zeros # partial keys will be padded by zeros
# with the shape of [len, rest_shape] # with the shape of [len, rest_shape]
sum_lens = [0] sum_lens = [0]
for x in lens: for len_ in lens:
sum_lens.append(sum_lens[-1] + x) sum_lens.append(sum_lens[-1] + len_)
# collect non-empty keys # collect non-empty keys
keys_map = [ keys_map = [
set( set(
k for k, v in batch.items() batch_key for batch_key, obj in batch.items()
if not (isinstance(v, Batch) and v.is_empty()) if not (isinstance(obj, Batch) and obj.is_empty())
) for batch in batches ) for batch in batches
] ]
keys_shared = set.intersection(*keys_map) keys_shared = set.intersection(*keys_map)
values_shared = [[e[k] for e in batches] for k in keys_shared] values_shared = [[batch[key] for batch in batches] for key in keys_shared]
for k, v in zip(keys_shared, values_shared): for key, shared_value in zip(keys_shared, values_shared):
if all(isinstance(e, (dict, Batch)) for e in v): if all(isinstance(element, (dict, Batch)) for element in shared_value):
batch_holder = Batch() batch_holder = Batch()
batch_holder.__cat(v, lens=lens) batch_holder.__cat(shared_value, lens=lens)
self.__dict__[k] = batch_holder self.__dict__[key] = batch_holder
elif all(isinstance(e, torch.Tensor) for e in v): elif all(isinstance(element, torch.Tensor) for element in shared_value):
self.__dict__[k] = torch.cat(v) self.__dict__[key] = torch.cat(shared_value)
else: else:
# cat Batch(a=np.zeros((3, 4))) and Batch(a=Batch(b=Batch())) # cat Batch(a=np.zeros((3, 4))) and Batch(a=Batch(b=Batch()))
# will fail here # will fail here
v = np.concatenate(v) shared_value = np.concatenate(shared_value)
self.__dict__[k] = _to_array_with_correct_type(v) self.__dict__[key] = _to_array_with_correct_type(shared_value)
keys_total = set.union(*[set(b.keys()) for b in batches]) keys_total = set.union(*[set(batch.keys()) for batch in batches])
keys_reserve_or_partial = set.difference(keys_total, keys_shared) keys_reserve_or_partial = set.difference(keys_total, keys_shared)
# keys that are reserved in all batches # keys that are reserved in all batches
keys_reserve = set.difference(keys_total, set.union(*keys_map)) keys_reserve = set.difference(keys_total, set.union(*keys_map))
# keys that occur only in some batches, but not all # keys that occur only in some batches, but not all
keys_partial = keys_reserve_or_partial.difference(keys_reserve) keys_partial = keys_reserve_or_partial.difference(keys_reserve)
for k in keys_reserve: for key in keys_reserve:
# reserved keys # reserved keys
self.__dict__[k] = Batch() self.__dict__[key] = Batch()
for k in keys_partial: for key in keys_partial:
for i, e in enumerate(batches): for i, batch in enumerate(batches):
if k not in e.__dict__: if key not in batch.__dict__:
continue continue
val = e.get(k) value = batch.get(key)
if isinstance(val, Batch) and val.is_empty(): if isinstance(value, Batch) and value.is_empty():
continue continue
try: try:
self.__dict__[k][sum_lens[i]:sum_lens[i + 1]] = val self.__dict__[key][sum_lens[i]:sum_lens[i + 1]] = value
except KeyError: except KeyError:
self.__dict__[k] = _create_value(val, sum_lens[-1], stack=False) self.__dict__[key] = \
self.__dict__[k][sum_lens[i]:sum_lens[i + 1]] = val _create_value(value, sum_lens[-1], stack=False)
self.__dict__[key][sum_lens[i]:sum_lens[i + 1]] = value
def cat_(self, batches: Union["Batch", Sequence[Union[dict, "Batch"]]]) -> None: def cat_(self, batches: Union["Batch", Sequence[Union[dict, "Batch"]]]) -> None:
"""Concatenate a list of (or one) Batch objects into current batch.""" """Concatenate a list of (or one) Batch objects into current batch."""
@ -446,16 +450,16 @@ class Batch:
batches = [batches] batches = [batches]
# check input format # check input format
batch_list = [] batch_list = []
for b in batches: for batch in batches:
if isinstance(b, dict): if isinstance(batch, dict):
if len(b) > 0: if len(batch) > 0:
batch_list.append(Batch(b)) batch_list.append(Batch(batch))
elif isinstance(b, Batch): elif isinstance(batch, Batch):
# x.is_empty() means that x is Batch() and should be ignored # x.is_empty() means that x is Batch() and should be ignored
if not b.is_empty(): if not batch.is_empty():
batch_list.append(b) batch_list.append(batch)
else: else:
raise ValueError(f"Cannot concatenate {type(b)} in Batch.cat_") raise ValueError(f"Cannot concatenate {type(batch)} in Batch.cat_")
if len(batch_list) == 0: if len(batch_list) == 0:
return return
batches = batch_list batches = batch_list
@ -463,13 +467,15 @@ class Batch:
# x.is_empty(recurse=True) here means x is a nested empty batch # x.is_empty(recurse=True) here means x is a nested empty batch
# like Batch(a=Batch), and we have to treat it as length zero and # like Batch(a=Batch), and we have to treat it as length zero and
# keep it. # keep it.
lens = [0 if x.is_empty(recurse=True) else len(x) for x in batches] lens = [
except TypeError as e: 0 if batch.is_empty(recurse=True) else len(batch) for batch in batches
]
except TypeError as exception:
raise ValueError( raise ValueError(
"Batch.cat_ meets an exception. Maybe because there is any " "Batch.cat_ meets an exception. Maybe because there is any "
f"scalar in {batches} but Batch.cat_ does not support the " f"scalar in {batches} but Batch.cat_ does not support the "
"concatenation of scalar." "concatenation of scalar."
) from e ) from exception
if not self.is_empty(): if not self.is_empty():
batches = [self] + list(batches) batches = [self] + list(batches)
lens = [0 if self.is_empty(recurse=True) else len(self)] + lens lens = [0 if self.is_empty(recurse=True) else len(self)] + lens
@ -501,16 +507,16 @@ class Batch:
"""Stack a list of Batch object into current batch.""" """Stack a list of Batch object into current batch."""
# check input format # check input format
batch_list = [] batch_list = []
for b in batches: for batch in batches:
if isinstance(b, dict): if isinstance(batch, dict):
if len(b) > 0: if len(batch) > 0:
batch_list.append(Batch(b)) batch_list.append(Batch(batch))
elif isinstance(b, Batch): elif isinstance(batch, Batch):
# x.is_empty() means that x is Batch() and should be ignored # x.is_empty() means that x is Batch() and should be ignored
if not b.is_empty(): if not batch.is_empty():
batch_list.append(b) batch_list.append(batch)
else: else:
raise ValueError(f"Cannot concatenate {type(b)} in Batch.stack_") raise ValueError(f"Cannot concatenate {type(batch)} in Batch.stack_")
if len(batch_list) == 0: if len(batch_list) == 0:
return return
batches = batch_list batches = batch_list
@ -519,28 +525,31 @@ class Batch:
# collect non-empty keys # collect non-empty keys
keys_map = [ keys_map = [
set( set(
k for k, v in batch.items() batch_key for batch_key, obj in batch.items()
if not (isinstance(v, Batch) and v.is_empty()) if not (isinstance(obj, Batch) and obj.is_empty())
) for batch in batches ) for batch in batches
] ]
keys_shared = set.intersection(*keys_map) keys_shared = set.intersection(*keys_map)
values_shared = [[e[k] for e in batches] for k in keys_shared] values_shared = [[batch[key] for batch in batches] for key in keys_shared]
for k, v in zip(keys_shared, values_shared): for shared_key, value in zip(keys_shared, values_shared):
if all(isinstance(e, torch.Tensor) for e in v): # second often # second often
self.__dict__[k] = torch.stack(v, axis) if all(isinstance(element, torch.Tensor) for element in value):
elif all(isinstance(e, (Batch, dict)) for e in v): # third often self.__dict__[shared_key] = torch.stack(value, axis)
self.__dict__[k] = Batch.stack(v, axis) # third often
elif all(isinstance(element, (Batch, dict)) for element in value):
self.__dict__[shared_key] = Batch.stack(value, axis)
else: # most often case is np.ndarray else: # most often case is np.ndarray
try: try:
self.__dict__[k] = _to_array_with_correct_type(np.stack(v, axis)) self.__dict__[shared_key] = \
_to_array_with_correct_type(np.stack(value, axis))
except ValueError: except ValueError:
warnings.warn( warnings.warn(
"You are using tensors with different shape," "You are using tensors with different shape,"
" fallback to dtype=object by default." " fallback to dtype=object by default."
) )
self.__dict__[k] = np.array(v, dtype=object) self.__dict__[shared_key] = np.array(value, dtype=object)
# all the keys # all the keys
keys_total = set.union(*[set(b.keys()) for b in batches]) keys_total = set.union(*[set(batch.keys()) for batch in batches])
# keys that are reserved in all batches # keys that are reserved in all batches
keys_reserve = set.difference(keys_total, set.union(*keys_map)) keys_reserve = set.difference(keys_total, set.union(*keys_map))
# keys that are either partial or reserved # keys that are either partial or reserved
@ -552,21 +561,21 @@ class Batch:
f"Stack of Batch with non-shared keys {keys_partial} is only " f"Stack of Batch with non-shared keys {keys_partial} is only "
f"supported with axis=0, but got axis={axis}!" f"supported with axis=0, but got axis={axis}!"
) )
for k in keys_reserve: for key in keys_reserve:
# reserved keys # reserved keys
self.__dict__[k] = Batch() self.__dict__[key] = Batch()
for k in keys_partial: for key in keys_partial:
for i, e in enumerate(batches): for i, batch in enumerate(batches):
if k not in e.__dict__: if key not in batch.__dict__:
continue
val = e.get(k)
if isinstance(val, Batch) and val.is_empty():
continue continue
value = batch.get(key)
if isinstance(value, Batch) and value.is_empty(): # type: ignore
continue # type: ignore
try: try:
self.__dict__[k][i] = val self.__dict__[key][i] = value
except KeyError: except KeyError:
self.__dict__[k] = _create_value(val, len(batches)) self.__dict__[key] = _create_value(value, len(batches))
self.__dict__[k][i] = val self.__dict__[key][i] = value
@staticmethod @staticmethod
def stack(batches: Sequence[Union[dict, "Batch"]], axis: int = 0) -> "Batch": def stack(batches: Sequence[Union[dict, "Batch"]], axis: int = 0) -> "Batch":
@ -620,27 +629,27 @@ class Batch:
), ),
) )
""" """
for k, v in self.items(): for batch_key, obj in self.items():
if isinstance(v, torch.Tensor): # most often case if isinstance(obj, torch.Tensor): # most often case
self.__dict__[k][index] = 0 self.__dict__[batch_key][index] = 0
elif v is None: elif obj is None:
continue continue
elif isinstance(v, np.ndarray): elif isinstance(obj, np.ndarray):
if v.dtype == object: if obj.dtype == object:
self.__dict__[k][index] = None self.__dict__[batch_key][index] = None
else: else:
self.__dict__[k][index] = 0 self.__dict__[batch_key][index] = 0
elif isinstance(v, Batch): elif isinstance(obj, Batch):
self.__dict__[k].empty_(index=index) self.__dict__[batch_key].empty_(index=index)
else: # scalar value else: # scalar value
warnings.warn( warnings.warn(
"You are calling Batch.empty on a NumPy scalar, " "You are calling Batch.empty on a NumPy scalar, "
"which may cause undefined behaviors." "which may cause undefined behaviors."
) )
if _is_number(v): if _is_number(obj):
self.__dict__[k] = v.__class__(0) self.__dict__[batch_key] = obj.__class__(0)
else: else:
self.__dict__[k] = None self.__dict__[batch_key] = None
return self return self
@staticmethod @staticmethod
@ -658,26 +667,26 @@ class Batch:
if batch is None: if batch is None:
self.update(kwargs) self.update(kwargs)
return return
for k, v in batch.items(): for batch_key, obj in batch.items():
self.__dict__[k] = _parse_value(v) self.__dict__[batch_key] = _parse_value(obj)
if kwargs: if kwargs:
self.update(kwargs) self.update(kwargs)
def __len__(self) -> int: def __len__(self) -> int:
"""Return len(self).""" """Return len(self)."""
r = [] lens = []
for v in self.__dict__.values(): for obj in self.__dict__.values():
if isinstance(v, Batch) and v.is_empty(recurse=True): if isinstance(obj, Batch) and obj.is_empty(recurse=True):
continue continue
elif hasattr(v, "__len__") and (isinstance(v, Batch) or v.ndim > 0): elif hasattr(obj, "__len__") and (isinstance(obj, Batch) or obj.ndim > 0):
r.append(len(v)) lens.append(len(obj))
else: else:
raise TypeError(f"Object {v} in {self} has no len()") raise TypeError(f"Object {obj} in {self} has no len()")
if len(r) == 0: if len(lens) == 0:
# empty batch has the shape of any, like the tensorflow '?' shape. # empty batch has the shape of any, like the tensorflow '?' shape.
# So it has no length. # So it has no length.
raise TypeError(f"Object {self} has no len()") raise TypeError(f"Object {self} has no len()")
return min(r) return min(lens)
def is_empty(self, recurse: bool = False) -> bool: def is_empty(self, recurse: bool = False) -> bool:
"""Test if a Batch is empty. """Test if a Batch is empty.
@ -710,8 +719,8 @@ class Batch:
if not recurse: if not recurse:
return False return False
return all( return all(
False if not isinstance(x, Batch) else x.is_empty(recurse=True) False if not isinstance(obj, Batch) else obj.is_empty(recurse=True)
for x in self.values() for obj in self.values()
) )
@property @property
@ -721,9 +730,9 @@ class Batch:
return [] return []
else: else:
data_shape = [] data_shape = []
for v in self.__dict__.values(): for obj in self.__dict__.values():
try: try:
data_shape.append(list(v.shape)) data_shape.append(list(obj.shape))
except AttributeError: except AttributeError:
data_shape.append([]) data_shape.append([])
return list(map(min, zip(*data_shape))) if len(data_shape) > 1 \ return list(map(min, zip(*data_shape))) if len(data_shape) > 1 \

View File

@ -69,8 +69,8 @@ class ReplayBuffer:
"""Return self.key.""" """Return self.key."""
try: try:
return self._meta[key] return self._meta[key]
except KeyError as e: except KeyError as exception:
raise AttributeError from e raise AttributeError from exception
def __setstate__(self, state: Dict[str, Any]) -> None: def __setstate__(self, state: Dict[str, Any]) -> None:
"""Unpickling interface. """Unpickling interface.
@ -198,10 +198,10 @@ class ReplayBuffer:
episode_reward is 0. episode_reward is 0.
""" """
# preprocess batch # preprocess batch
b = Batch() new_batch = Batch()
for key in set(self._reserved_keys).intersection(batch.keys()): for key in set(self._reserved_keys).intersection(batch.keys()):
b.__dict__[key] = batch[key] new_batch.__dict__[key] = batch[key]
batch = b batch = new_batch
assert set(["obs", "act", "rew", "done"]).issubset(batch.keys()) assert set(["obs", "act", "rew", "done"]).issubset(batch.keys())
stacked_batch = buffer_ids is not None stacked_batch = buffer_ids is not None
if stacked_batch: if stacked_batch:
@ -315,9 +315,9 @@ class ReplayBuffer:
return Batch.stack(stack, axis=indices.ndim) return Batch.stack(stack, axis=indices.ndim)
else: else:
return np.stack(stack, axis=indices.ndim) return np.stack(stack, axis=indices.ndim)
except IndexError as e: except IndexError as exception:
if not (isinstance(val, Batch) and val.is_empty()): if not (isinstance(val, Batch) and val.is_empty()):
raise e # val != Batch() raise exception # val != Batch()
return Batch() return Batch()
def __getitem__(self, index: Union[slice, int, List[int], np.ndarray]) -> Batch: def __getitem__(self, index: Union[slice, int, List[int], np.ndarray]) -> Batch:

View File

@ -114,10 +114,10 @@ class ReplayBufferManager(ReplayBuffer):
episode_reward is 0. episode_reward is 0.
""" """
# preprocess batch # preprocess batch
b = Batch() new_batch = Batch()
for key in set(self._reserved_keys).intersection(batch.keys()): for key in set(self._reserved_keys).intersection(batch.keys()):
b.__dict__[key] = batch[key] new_batch.__dict__[key] = batch[key]
batch = b batch = new_batch
assert set(["obs", "act", "rew", "done"]).issubset(batch.keys()) assert set(["obs", "act", "rew", "done"]).issubset(batch.keys())
if self._save_only_last_obs: if self._save_only_last_obs:
batch.obs = batch.obs[:, -1] batch.obs = batch.obs[:, -1]

View File

@ -111,11 +111,11 @@ def to_hdf5(x: Hdf5ConvertibleType, y: h5py.Group) -> None:
# and possibly in other cases like structured arrays. # and possibly in other cases like structured arrays.
try: try:
to_hdf5_via_pickle(v, y, k) to_hdf5_via_pickle(v, y, k)
except Exception as e: except Exception as exception:
raise RuntimeError( raise RuntimeError(
f"Attempted to pickle {v.__class__.__name__} due to " f"Attempted to pickle {v.__class__.__name__} due to "
"data type not supported by HDF5 and failed." "data type not supported by HDF5 and failed."
) from e ) from exception
y[k].attrs["__data_type__"] = "pickled_ndarray" y[k].attrs["__data_type__"] = "pickled_ndarray"
elif isinstance(v, (int, float)): elif isinstance(v, (int, float)):
# ints and floats are stored as attributes of groups # ints and floats are stored as attributes of groups
@ -123,11 +123,11 @@ def to_hdf5(x: Hdf5ConvertibleType, y: h5py.Group) -> None:
else: # resort to pickle for any other type of object else: # resort to pickle for any other type of object
try: try:
to_hdf5_via_pickle(v, y, k) to_hdf5_via_pickle(v, y, k)
except Exception as e: except Exception as exception:
raise NotImplementedError( raise NotImplementedError(
f"No conversion to HDF5 for object of type '{type(v)}' " f"No conversion to HDF5 for object of type '{type(v)}' "
"implemented and fallback to pickle failed." "implemented and fallback to pickle failed."
) from e ) from exception
y[k].attrs["__data_type__"] = v.__class__.__name__ y[k].attrs["__data_type__"] = v.__class__.__name__

View File

@ -15,8 +15,8 @@ class MultiAgentEnv(ABC, gym.Env):
env = MultiAgentEnv(...) env = MultiAgentEnv(...)
# obs is a dict containing obs, agent_id, and mask # obs is a dict containing obs, agent_id, and mask
obs = env.reset() obs = env.reset()
action = policy(obs) act = policy(obs)
obs, rew, done, info = env.step(action) obs, rew, done, info = env.step(act)
env.close() env.close()
The available action's mask is set to 1, otherwise it is set to 0. Further The available action's mask is set to 1, otherwise it is set to 0. Further

View File

@ -412,10 +412,10 @@ class RayVectorEnv(BaseVectorEnv):
def __init__(self, env_fns: List[Callable[[], gym.Env]], **kwargs: Any) -> None: def __init__(self, env_fns: List[Callable[[], gym.Env]], **kwargs: Any) -> None:
try: try:
import ray import ray
except ImportError as e: except ImportError as exception:
raise ImportError( raise ImportError(
"Please install ray to support RayVectorEnv: pip install ray" "Please install ray to support RayVectorEnv: pip install ray"
) from e ) from exception
if not ray.is_initialized(): if not ray.is_initialized():
ray.init() ray.init()
super().__init__(env_fns, RayEnvWorker, **kwargs) super().__init__(env_fns, RayEnvWorker, **kwargs)

View File

@ -98,6 +98,12 @@ class BasePolicy(ABC, nn.Module):
""" """
return act return act
def soft_update(self, tgt: nn.Module, src: nn.Module, tau: float) -> None:
"""Softly update the parameters of target module towards the parameters \
of source module."""
for tgt_param, src_param in zip(tgt.parameters(), src.parameters()):
tgt_param.data.copy_(tau * src_param.data + (1 - tau) * tgt_param.data)
@abstractmethod @abstractmethod
def forward( def forward(
self, self,
@ -387,10 +393,10 @@ def _gae_return(
) -> np.ndarray: ) -> np.ndarray:
returns = np.zeros(rew.shape) returns = np.zeros(rew.shape)
delta = rew + v_s_ * gamma - v_s delta = rew + v_s_ * gamma - v_s
m = (1.0 - end_flag) * (gamma * gae_lambda) discount = (1.0 - end_flag) * (gamma * gae_lambda)
gae = 0.0 gae = 0.0
for i in range(len(rew) - 1, -1, -1): for i in range(len(rew) - 1, -1, -1):
gae = delta[i] + m[i] * gae gae = delta[i] + discount[i] * gae
returns[i] = gae returns[i] = gae
return returns return returns

View File

@ -40,23 +40,23 @@ class ImitationPolicy(BasePolicy):
state: Optional[Union[dict, Batch, np.ndarray]] = None, state: Optional[Union[dict, Batch, np.ndarray]] = None,
**kwargs: Any, **kwargs: Any,
) -> Batch: ) -> Batch:
logits, h = self.model(batch.obs, state=state, info=batch.info) logits, hidden = self.model(batch.obs, state=state, info=batch.info)
if self.action_type == "discrete": if self.action_type == "discrete":
a = logits.max(dim=1)[1] act = logits.max(dim=1)[1]
else: else:
a = logits act = logits
return Batch(logits=logits, act=a, state=h) return Batch(logits=logits, act=act, state=hidden)
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
self.optim.zero_grad() self.optim.zero_grad()
if self.action_type == "continuous": # regression if self.action_type == "continuous": # regression
a = self(batch).act act = self(batch).act
a_ = to_torch(batch.act, dtype=torch.float32, device=a.device) act_target = to_torch(batch.act, dtype=torch.float32, device=act.device)
loss = F.mse_loss(a, a_) # type: ignore loss = F.mse_loss(act, act_target) # type: ignore
elif self.action_type == "discrete": # classification elif self.action_type == "discrete": # classification
a = F.log_softmax(self(batch).logits, dim=-1) act = F.log_softmax(self(batch).logits, dim=-1)
a_ = to_torch(batch.act, dtype=torch.long, device=a.device) act_target = to_torch(batch.act, dtype=torch.long, device=act.device)
loss = F.nll_loss(a, a_) # type: ignore loss = F.nll_loss(act, act_target) # type: ignore
loss.backward() loss.backward()
self.optim.step() self.optim.step()
return {"loss": loss.item()} return {"loss": loss.item()}

View File

@ -105,32 +105,27 @@ class BCQPolicy(BasePolicy):
obs_group: torch.Tensor = to_torch( # type: ignore obs_group: torch.Tensor = to_torch( # type: ignore
batch.obs, device=self.device batch.obs, device=self.device
) )
act = [] act_group = []
for obs in obs_group: for obs in obs_group:
# now obs is (state_dim) # now obs is (state_dim)
obs = (obs.reshape(1, -1)).repeat(self.forward_sampled_times, 1) obs = (obs.reshape(1, -1)).repeat(self.forward_sampled_times, 1)
# now obs is (forward_sampled_times, state_dim) # now obs is (forward_sampled_times, state_dim)
# decode(obs) generates action and actor perturbs it # decode(obs) generates action and actor perturbs it
action = self.actor(obs, self.vae.decode(obs)) act = self.actor(obs, self.vae.decode(obs))
# now action is (forward_sampled_times, action_dim) # now action is (forward_sampled_times, action_dim)
q1 = self.critic1(obs, action) q1 = self.critic1(obs, act)
# q1 is (forward_sampled_times, 1) # q1 is (forward_sampled_times, 1)
ind = q1.argmax(0) max_indice = q1.argmax(0)
act.append(action[ind].cpu().data.numpy().flatten()) act_group.append(act[max_indice].cpu().data.numpy().flatten())
act = np.array(act) act_group = np.array(act_group)
return Batch(act=act) return Batch(act=act_group)
def sync_weight(self) -> None: def sync_weight(self) -> None:
"""Soft-update the weight for the target network.""" """Soft-update the weight for the target network."""
for net, net_target in [ self.soft_update(self.critic1_target, self.critic1, self.tau)
[self.critic1, self.critic1_target], [self.critic2, self.critic2_target], self.soft_update(self.critic2_target, self.critic2, self.tau)
[self.actor, self.actor_target] self.soft_update(self.actor_target, self.actor, self.tau)
]:
for param, target_param in zip(net.parameters(), net_target.parameters()):
target_param.data.copy_(
self.tau * param.data + (1 - self.tau) * target_param.data
)
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
# batch: obs, act, rew, done, obs_next. (numpy array) # batch: obs, act, rew, done, obs_next. (numpy array)

View File

@ -113,13 +113,8 @@ class CQLPolicy(SACPolicy):
def sync_weight(self) -> None: def sync_weight(self) -> None:
"""Soft-update the weight for the target network.""" """Soft-update the weight for the target network."""
for net, net_old in [ self.soft_update(self.critic1_old, self.critic1, self.tau)
[self.critic1, self.critic1_old], [self.critic2, self.critic2_old] self.soft_update(self.critic2_old, self.critic2, self.tau)
]:
for param, target_param in zip(net.parameters(), net_old.parameters()):
target_param.data.copy_(
self._tau * param.data + (1 - self._tau) * target_param.data
)
def actor_pred(self, obs: torch.Tensor) -> \ def actor_pred(self, obs: torch.Tensor) -> \
Tuple[torch.Tensor, torch.Tensor]: Tuple[torch.Tensor, torch.Tensor]:

View File

@ -94,13 +94,10 @@ class DiscreteBCQPolicy(DQNPolicy):
# mask actions for argmax # mask actions for argmax
ratio = imitation_logits - imitation_logits.max(dim=-1, keepdim=True).values ratio = imitation_logits - imitation_logits.max(dim=-1, keepdim=True).values
mask = (ratio < self._log_tau).float() mask = (ratio < self._log_tau).float()
action = (q_value - np.inf * mask).argmax(dim=-1) act = (q_value - np.inf * mask).argmax(dim=-1)
return Batch( return Batch(
act=action, act=act, state=state, q_value=q_value, imitation_logits=imitation_logits
state=state,
q_value=q_value,
imitation_logits=imitation_logits
) )
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:

View File

@ -57,15 +57,15 @@ class DiscreteCQLPolicy(QRDQNPolicy):
curr_dist = all_dist[np.arange(len(act)), act, :].unsqueeze(2) curr_dist = all_dist[np.arange(len(act)), act, :].unsqueeze(2)
target_dist = batch.returns.unsqueeze(1) target_dist = batch.returns.unsqueeze(1)
# calculate each element's difference between curr_dist and target_dist # calculate each element's difference between curr_dist and target_dist
u = F.smooth_l1_loss(target_dist, curr_dist, reduction="none") dist_diff = F.smooth_l1_loss(target_dist, curr_dist, reduction="none")
huber_loss = ( huber_loss = (
u * (self.tau_hat - dist_diff *
(target_dist - curr_dist).detach().le(0.).float()).abs() (self.tau_hat - (target_dist - curr_dist).detach().le(0.).float()).abs()
).sum(-1).mean(1) ).sum(-1).mean(1)
qr_loss = (huber_loss * weight).mean() qr_loss = (huber_loss * weight).mean()
# ref: https://github.com/ku2482/fqf-iqn-qrdqn.pytorch/ # ref: https://github.com/ku2482/fqf-iqn-qrdqn.pytorch/
# blob/master/fqf_iqn_qrdqn/agent/qrdqn_agent.py L130 # blob/master/fqf_iqn_qrdqn/agent/qrdqn_agent.py L130
batch.weight = u.detach().abs().sum(-1).mean(1) # prio-buffer batch.weight = dist_diff.detach().abs().sum(-1).mean(1) # prio-buffer
# add CQL loss # add CQL loss
q = self.compute_q_value(all_dist, None) q = self.compute_q_value(all_dist, None)
dataset_expec = q.gather(1, act.unsqueeze(1)).mean() dataset_expec = q.gather(1, act.unsqueeze(1)).mean()

View File

@ -97,9 +97,9 @@ class DiscreteCRRPolicy(PGPolicy):
target = rew.unsqueeze(1) + self._gamma * expected_target_q target = rew.unsqueeze(1) + self._gamma * expected_target_q
critic_loss = 0.5 * F.mse_loss(qa_t, target) critic_loss = 0.5 * F.mse_loss(qa_t, target)
# Actor loss # Actor loss
a_t, _ = self.actor(batch.obs) act_target, _ = self.actor(batch.obs)
m = Categorical(logits=a_t) dist = Categorical(logits=act_target)
expected_policy_q = (q_t * m.probs).sum(-1, keepdim=True) expected_policy_q = (q_t * dist.probs).sum(-1, keepdim=True)
advantage = qa_t - expected_policy_q advantage = qa_t - expected_policy_q
if self._policy_improvement_mode == "binary": if self._policy_improvement_mode == "binary":
actor_loss_coef = (advantage > 0).float() actor_loss_coef = (advantage > 0).float()
@ -109,7 +109,7 @@ class DiscreteCRRPolicy(PGPolicy):
) )
else: else:
actor_loss_coef = 1.0 # effectively behavior cloning actor_loss_coef = 1.0 # effectively behavior cloning
actor_loss = (-m.log_prob(act) * actor_loss_coef).mean() actor_loss = (-dist.log_prob(act) * actor_loss_coef).mean()
# CQL loss/regularizer # CQL loss/regularizer
min_q_loss = (q_t.logsumexp(1) - qa_t).mean() min_q_loss = (q_t.logsumexp(1) - qa_t).mean()
loss = actor_loss + critic_loss + self._min_q_weight * min_q_loss loss = actor_loss + critic_loss + self._min_q_weight * min_q_loss

View File

@ -202,13 +202,13 @@ class PSRLPolicy(BasePolicy):
rew_sum = np.zeros((n_s, n_a)) rew_sum = np.zeros((n_s, n_a))
rew_square_sum = np.zeros((n_s, n_a)) rew_square_sum = np.zeros((n_s, n_a))
rew_count = np.zeros((n_s, n_a)) rew_count = np.zeros((n_s, n_a))
for b in batch.split(size=1): for minibatch in batch.split(size=1):
obs, act, obs_next = b.obs, b.act, b.obs_next obs, act, obs_next = minibatch.obs, minibatch.act, minibatch.obs_next
trans_count[obs, act, obs_next] += 1 trans_count[obs, act, obs_next] += 1
rew_sum[obs, act] += b.rew rew_sum[obs, act] += minibatch.rew
rew_square_sum[obs, act] += b.rew**2 rew_square_sum[obs, act] += minibatch.rew**2
rew_count[obs, act] += 1 rew_count[obs, act] += 1
if self._add_done_loop and b.done: if self._add_done_loop and minibatch.done:
# special operation for terminal states: add a self-loop # special operation for terminal states: add a self-loop
trans_count[obs_next, :, obs_next] += 1 trans_count[obs_next, :, obs_next] += 1
rew_count[obs_next, :] += 1 rew_count[obs_next, :] += 1

View File

@ -85,9 +85,9 @@ class A2CPolicy(PGPolicy):
) -> Batch: ) -> Batch:
v_s, v_s_ = [], [] v_s, v_s_ = [], []
with torch.no_grad(): with torch.no_grad():
for b in batch.split(self._batch, shuffle=False, merge_last=True): for minibatch in batch.split(self._batch, shuffle=False, merge_last=True):
v_s.append(self.critic(b.obs)) v_s.append(self.critic(minibatch.obs))
v_s_.append(self.critic(b.obs_next)) v_s_.append(self.critic(minibatch.obs_next))
batch.v_s = torch.cat(v_s, dim=0).flatten() # old value batch.v_s = torch.cat(v_s, dim=0).flatten() # old value
v_s = batch.v_s.cpu().numpy() v_s = batch.v_s.cpu().numpy()
v_s_ = torch.cat(v_s_, dim=0).flatten().cpu().numpy() v_s_ = torch.cat(v_s_, dim=0).flatten().cpu().numpy()
@ -122,14 +122,15 @@ class A2CPolicy(PGPolicy):
) -> Dict[str, List[float]]: ) -> Dict[str, List[float]]:
losses, actor_losses, vf_losses, ent_losses = [], [], [], [] losses, actor_losses, vf_losses, ent_losses = [], [], [], []
for _ in range(repeat): for _ in range(repeat):
for b in batch.split(batch_size, merge_last=True): for minibatch in batch.split(batch_size, merge_last=True):
# calculate loss for actor # calculate loss for actor
dist = self(b).dist dist = self(minibatch).dist
log_prob = dist.log_prob(b.act).reshape(len(b.adv), -1).transpose(0, 1) log_prob = dist.log_prob(minibatch.act)
actor_loss = -(log_prob * b.adv).mean() log_prob = log_prob.reshape(len(minibatch.adv), -1).transpose(0, 1)
actor_loss = -(log_prob * minibatch.adv).mean()
# calculate loss for critic # calculate loss for critic
value = self.critic(b.obs).flatten() value = self.critic(minibatch.obs).flatten()
vf_loss = F.mse_loss(b.returns, value) vf_loss = F.mse_loss(minibatch.returns, value)
# calculate regularization and overall loss # calculate regularization and overall loss
ent_loss = dist.entropy().mean() ent_loss = dist.entropy().mean()
loss = actor_loss + self._weight_vf * vf_loss \ loss = actor_loss + self._weight_vf * vf_loss \

View File

@ -70,13 +70,13 @@ class C51Policy(DQNPolicy):
def _target_dist(self, batch: Batch) -> torch.Tensor: def _target_dist(self, batch: Batch) -> torch.Tensor:
if self._target: if self._target:
a = self(batch, input="obs_next").act act = self(batch, input="obs_next").act
next_dist = self(batch, model="model_old", input="obs_next").logits next_dist = self(batch, model="model_old", input="obs_next").logits
else: else:
next_b = self(batch, input="obs_next") next_batch = self(batch, input="obs_next")
a = next_b.act act = next_batch.act
next_dist = next_b.logits next_dist = next_batch.logits
next_dist = next_dist[np.arange(len(a)), a, :] next_dist = next_dist[np.arange(len(act)), act, :]
target_support = batch.returns.clamp(self._v_min, self._v_max) target_support = batch.returns.clamp(self._v_min, self._v_max)
# An amazing trick for calculating the projection gracefully. # An amazing trick for calculating the projection gracefully.
# ref: https://github.com/ShangtongZhang/DeepRL # ref: https://github.com/ShangtongZhang/DeepRL

View File

@ -73,7 +73,7 @@ class DDPGPolicy(BasePolicy):
self.critic_old.eval() self.critic_old.eval()
self.critic_optim: torch.optim.Optimizer = critic_optim self.critic_optim: torch.optim.Optimizer = critic_optim
assert 0.0 <= tau <= 1.0, "tau should be in [0, 1]" assert 0.0 <= tau <= 1.0, "tau should be in [0, 1]"
self._tau = tau self.tau = tau
assert 0.0 <= gamma <= 1.0, "gamma should be in [0, 1]" assert 0.0 <= gamma <= 1.0, "gamma should be in [0, 1]"
self._gamma = gamma self._gamma = gamma
self._noise = exploration_noise self._noise = exploration_noise
@ -95,10 +95,8 @@ class DDPGPolicy(BasePolicy):
def sync_weight(self) -> None: def sync_weight(self) -> None:
"""Soft-update the weight for the target network.""" """Soft-update the weight for the target network."""
for o, n in zip(self.actor_old.parameters(), self.actor.parameters()): self.soft_update(self.actor_old, self.actor, self.tau)
o.data.copy_(o.data * (1.0 - self._tau) + n.data * self._tau) self.soft_update(self.critic_old, self.critic, self.tau)
for o, n in zip(self.critic_old.parameters(), self.critic.parameters()):
o.data.copy_(o.data * (1.0 - self._tau) + n.data * self._tau)
def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor:
batch = buffer[indices] # batch.obs_next: s_{t+n} batch = buffer[indices] # batch.obs_next: s_{t+n}
@ -139,8 +137,8 @@ class DDPGPolicy(BasePolicy):
""" """
model = getattr(self, model) model = getattr(self, model)
obs = batch[input] obs = batch[input]
actions, h = model(obs, state=state, info=batch.info) actions, hidden = model(obs, state=state, info=batch.info)
return Batch(act=actions, state=h) return Batch(act=actions, state=hidden)
@staticmethod @staticmethod
def _mse_optimizer( def _mse_optimizer(
@ -163,8 +161,7 @@ class DDPGPolicy(BasePolicy):
td, critic_loss = self._mse_optimizer(batch, self.critic, self.critic_optim) td, critic_loss = self._mse_optimizer(batch, self.critic, self.critic_optim)
batch.weight = td # prio-buffer batch.weight = td # prio-buffer
# actor # actor
action = self(batch).act actor_loss = -self.critic(batch.obs, self(batch).act).mean()
actor_loss = -self.critic(batch.obs, action).mean()
self.actor_optim.zero_grad() self.actor_optim.zero_grad()
actor_loss.backward() actor_loss.backward()
self.actor_optim.step() self.actor_optim.step()

View File

@ -76,10 +76,10 @@ class DiscreteSACPolicy(SACPolicy):
**kwargs: Any, **kwargs: Any,
) -> Batch: ) -> Batch:
obs = batch[input] obs = batch[input]
logits, h = self.actor(obs, state=state, info=batch.info) logits, hidden = self.actor(obs, state=state, info=batch.info)
dist = Categorical(logits=logits) dist = Categorical(logits=logits)
act = dist.sample() act = dist.sample()
return Batch(logits=logits, act=act, state=h, dist=dist) return Batch(logits=logits, act=act, state=hidden, dist=dist)
def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor:
batch = buffer[indices] # batch.obs: s_{t+n} batch = buffer[indices] # batch.obs: s_{t+n}

View File

@ -151,13 +151,13 @@ class DQNPolicy(BasePolicy):
""" """
model = getattr(self, model) model = getattr(self, model)
obs = batch[input] obs = batch[input]
obs_ = obs.obs if hasattr(obs, "obs") else obs obs_next = obs.obs if hasattr(obs, "obs") else obs
logits, h = model(obs_, state=state, info=batch.info) logits, hidden = model(obs_next, state=state, info=batch.info)
q = self.compute_q_value(logits, getattr(obs, "mask", None)) q = self.compute_q_value(logits, getattr(obs, "mask", None))
if not hasattr(self, "max_action_num"): if not hasattr(self, "max_action_num"):
self.max_action_num = q.shape[1] self.max_action_num = q.shape[1]
act = to_numpy(q.max(dim=1)[1]) act = to_numpy(q.max(dim=1)[1])
return Batch(logits=logits, act=act, state=h) return Batch(logits=logits, act=act, state=hidden)
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
if self._target and self._iter % self._freq == 0: if self._target and self._iter % self._freq == 0:
@ -166,10 +166,10 @@ class DQNPolicy(BasePolicy):
weight = batch.pop("weight", 1.0) weight = batch.pop("weight", 1.0)
q = self(batch).logits q = self(batch).logits
q = q[np.arange(len(q)), batch.act] q = q[np.arange(len(q)), batch.act]
r = to_torch_as(batch.returns.flatten(), q) returns = to_torch_as(batch.returns.flatten(), q)
td = r - q td_error = returns - q
loss = (td.pow(2) * weight).mean() loss = (td_error.pow(2) * weight).mean()
batch.weight = td # prio-buffer batch.weight = td_error # prio-buffer
loss.backward() loss.backward()
self.optim.step() self.optim.step()
self._iter += 1 self._iter += 1

View File

@ -60,15 +60,15 @@ class FQFPolicy(QRDQNPolicy):
batch = buffer[indices] # batch.obs_next: s_{t+n} batch = buffer[indices] # batch.obs_next: s_{t+n}
if self._target: if self._target:
result = self(batch, input="obs_next") result = self(batch, input="obs_next")
a, fractions = result.act, result.fractions act, fractions = result.act, result.fractions
next_dist = self( next_dist = self(
batch, model="model_old", input="obs_next", fractions=fractions batch, model="model_old", input="obs_next", fractions=fractions
).logits ).logits
else: else:
next_b = self(batch, input="obs_next") next_batch = self(batch, input="obs_next")
a = next_b.act act = next_batch.act
next_dist = next_b.logits next_dist = next_batch.logits
next_dist = next_dist[np.arange(len(a)), a, :] next_dist = next_dist[np.arange(len(act)), act, :]
return next_dist # shape: [bsz, num_quantiles] return next_dist # shape: [bsz, num_quantiles]
def forward( def forward(
@ -82,14 +82,17 @@ class FQFPolicy(QRDQNPolicy):
) -> Batch: ) -> Batch:
model = getattr(self, model) model = getattr(self, model)
obs = batch[input] obs = batch[input]
obs_ = obs.obs if hasattr(obs, "obs") else obs obs_next = obs.obs if hasattr(obs, "obs") else obs
if fractions is None: if fractions is None:
(logits, fractions, quantiles_tau), h = model( (logits, fractions, quantiles_tau), hidden = model(
obs_, propose_model=self.propose_model, state=state, info=batch.info obs_next,
propose_model=self.propose_model,
state=state,
info=batch.info
) )
else: else:
(logits, _, quantiles_tau), h = model( (logits, _, quantiles_tau), hidden = model(
obs_, obs_next,
propose_model=self.propose_model, propose_model=self.propose_model,
fractions=fractions, fractions=fractions,
state=state, state=state,
@ -106,7 +109,7 @@ class FQFPolicy(QRDQNPolicy):
return Batch( return Batch(
logits=logits, logits=logits,
act=act, act=act,
state=h, state=hidden,
fractions=fractions, fractions=fractions,
quantiles_tau=quantiles_tau quantiles_tau=quantiles_tau
) )
@ -122,9 +125,9 @@ class FQFPolicy(QRDQNPolicy):
curr_dist = curr_dist_orig[np.arange(len(act)), act, :].unsqueeze(2) curr_dist = curr_dist_orig[np.arange(len(act)), act, :].unsqueeze(2)
target_dist = batch.returns.unsqueeze(1) target_dist = batch.returns.unsqueeze(1)
# calculate each element's difference between curr_dist and target_dist # calculate each element's difference between curr_dist and target_dist
u = F.smooth_l1_loss(target_dist, curr_dist, reduction="none") dist_diff = F.smooth_l1_loss(target_dist, curr_dist, reduction="none")
huber_loss = ( huber_loss = (
u * ( dist_diff * (
tau_hats.unsqueeze(2) - tau_hats.unsqueeze(2) -
(target_dist - curr_dist).detach().le(0.).float() (target_dist - curr_dist).detach().le(0.).float()
).abs() ).abs()
@ -132,7 +135,7 @@ class FQFPolicy(QRDQNPolicy):
quantile_loss = (huber_loss * weight).mean() quantile_loss = (huber_loss * weight).mean()
# ref: https://github.com/ku2482/fqf-iqn-qrdqn.pytorch/ # ref: https://github.com/ku2482/fqf-iqn-qrdqn.pytorch/
# blob/master/fqf_iqn_qrdqn/agent/qrdqn_agent.py L130 # blob/master/fqf_iqn_qrdqn/agent/qrdqn_agent.py L130
batch.weight = u.detach().abs().sum(-1).mean(1) # prio-buffer batch.weight = dist_diff.detach().abs().sum(-1).mean(1) # prio-buffer
# calculate fraction loss # calculate fraction loss
with torch.no_grad(): with torch.no_grad():
sa_quantile_hats = curr_dist_orig[np.arange(len(act)), act, :] sa_quantile_hats = curr_dist_orig[np.arange(len(act)), act, :]

View File

@ -73,36 +73,37 @@ class IQNPolicy(QRDQNPolicy):
sample_size = self._sample_size sample_size = self._sample_size
model = getattr(self, model) model = getattr(self, model)
obs = batch[input] obs = batch[input]
obs_ = obs.obs if hasattr(obs, "obs") else obs obs_next = obs.obs if hasattr(obs, "obs") else obs
(logits, (logits, taus), hidden = model(
taus), h = model(obs_, sample_size=sample_size, state=state, info=batch.info) obs_next, sample_size=sample_size, state=state, info=batch.info
)
q = self.compute_q_value(logits, getattr(obs, "mask", None)) q = self.compute_q_value(logits, getattr(obs, "mask", None))
if not hasattr(self, "max_action_num"): if not hasattr(self, "max_action_num"):
self.max_action_num = q.shape[1] self.max_action_num = q.shape[1]
act = to_numpy(q.max(dim=1)[1]) act = to_numpy(q.max(dim=1)[1])
return Batch(logits=logits, act=act, state=h, taus=taus) return Batch(logits=logits, act=act, state=hidden, taus=taus)
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
if self._target and self._iter % self._freq == 0: if self._target and self._iter % self._freq == 0:
self.sync_weight() self.sync_weight()
self.optim.zero_grad() self.optim.zero_grad()
weight = batch.pop("weight", 1.0) weight = batch.pop("weight", 1.0)
out = self(batch) action_batch = self(batch)
curr_dist, taus = out.logits, out.taus curr_dist, taus = action_batch.logits, action_batch.taus
act = batch.act act = batch.act
curr_dist = curr_dist[np.arange(len(act)), act, :].unsqueeze(2) curr_dist = curr_dist[np.arange(len(act)), act, :].unsqueeze(2)
target_dist = batch.returns.unsqueeze(1) target_dist = batch.returns.unsqueeze(1)
# calculate each element's difference between curr_dist and target_dist # calculate each element's difference between curr_dist and target_dist
u = F.smooth_l1_loss(target_dist, curr_dist, reduction="none") dist_diff = F.smooth_l1_loss(target_dist, curr_dist, reduction="none")
huber_loss = ( huber_loss = (
u * dist_diff *
(taus.unsqueeze(2) - (taus.unsqueeze(2) -
(target_dist - curr_dist).detach().le(0.).float()).abs() (target_dist - curr_dist).detach().le(0.).float()).abs()
).sum(-1).mean(1) ).sum(-1).mean(1)
loss = (huber_loss * weight).mean() loss = (huber_loss * weight).mean()
# ref: https://github.com/ku2482/fqf-iqn-qrdqn.pytorch/ # ref: https://github.com/ku2482/fqf-iqn-qrdqn.pytorch/
# blob/master/fqf_iqn_qrdqn/agent/qrdqn_agent.py L130 # blob/master/fqf_iqn_qrdqn/agent/qrdqn_agent.py L130
batch.weight = u.detach().abs().sum(-1).mean(1) # prio-buffer batch.weight = dist_diff.detach().abs().sum(-1).mean(1) # prio-buffer
loss.backward() loss.backward()
self.optim.step() self.optim.step()
self._iter += 1 self._iter += 1

View File

@ -71,8 +71,8 @@ class NPGPolicy(A2CPolicy):
batch = super().process_fn(batch, buffer, indices) batch = super().process_fn(batch, buffer, indices)
old_log_prob = [] old_log_prob = []
with torch.no_grad(): with torch.no_grad():
for b in batch.split(self._batch, shuffle=False, merge_last=True): for minibatch in batch.split(self._batch, shuffle=False, merge_last=True):
old_log_prob.append(self(b).dist.log_prob(b.act)) old_log_prob.append(self(minibatch).dist.log_prob(minibatch.act))
batch.logp_old = torch.cat(old_log_prob, dim=0) batch.logp_old = torch.cat(old_log_prob, dim=0)
if self._norm_adv: if self._norm_adv:
batch.adv = (batch.adv - batch.adv.mean()) / batch.adv.std() batch.adv = (batch.adv - batch.adv.mean()) / batch.adv.std()
@ -83,20 +83,20 @@ class NPGPolicy(A2CPolicy):
) -> Dict[str, List[float]]: ) -> Dict[str, List[float]]:
actor_losses, vf_losses, kls = [], [], [] actor_losses, vf_losses, kls = [], [], []
for _ in range(repeat): for _ in range(repeat):
for b in batch.split(batch_size, merge_last=True): for minibatch in batch.split(batch_size, merge_last=True):
# optimize actor # optimize actor
# direction: calculate villia gradient # direction: calculate villia gradient
dist = self(b).dist dist = self(minibatch).dist
log_prob = dist.log_prob(b.act) log_prob = dist.log_prob(minibatch.act)
log_prob = log_prob.reshape(log_prob.size(0), -1).transpose(0, 1) log_prob = log_prob.reshape(log_prob.size(0), -1).transpose(0, 1)
actor_loss = -(log_prob * b.adv).mean() actor_loss = -(log_prob * minibatch.adv).mean()
flat_grads = self._get_flat_grad( flat_grads = self._get_flat_grad(
actor_loss, self.actor, retain_graph=True actor_loss, self.actor, retain_graph=True
).detach() ).detach()
# direction: calculate natural gradient # direction: calculate natural gradient
with torch.no_grad(): with torch.no_grad():
old_dist = self(b).dist old_dist = self(minibatch).dist
kl = kl_divergence(old_dist, dist).mean() kl = kl_divergence(old_dist, dist).mean()
# calculate first order gradient of kl with respect to theta # calculate first order gradient of kl with respect to theta
@ -112,13 +112,13 @@ class NPGPolicy(A2CPolicy):
) )
new_flat_params = flat_params + self._step_size * search_direction new_flat_params = flat_params + self._step_size * search_direction
self._set_from_flat_params(self.actor, new_flat_params) self._set_from_flat_params(self.actor, new_flat_params)
new_dist = self(b).dist new_dist = self(minibatch).dist
kl = kl_divergence(old_dist, new_dist).mean() kl = kl_divergence(old_dist, new_dist).mean()
# optimize citirc # optimize citirc
for _ in range(self._optim_critic_iters): for _ in range(self._optim_critic_iters):
value = self.critic(b.obs).flatten() value = self.critic(minibatch.obs).flatten()
vf_loss = F.mse_loss(b.returns, value) vf_loss = F.mse_loss(minibatch.returns, value)
self.optim.zero_grad() self.optim.zero_grad()
vf_loss.backward() vf_loss.backward()
self.optim.step() self.optim.step()
@ -147,14 +147,14 @@ class NPGPolicy(A2CPolicy):
def _conjugate_gradients( def _conjugate_gradients(
self, self,
b: torch.Tensor, minibatch: torch.Tensor,
flat_kl_grad: torch.Tensor, flat_kl_grad: torch.Tensor,
nsteps: int = 10, nsteps: int = 10,
residual_tol: float = 1e-10 residual_tol: float = 1e-10
) -> torch.Tensor: ) -> torch.Tensor:
x = torch.zeros_like(b) x = torch.zeros_like(minibatch)
r, p = b.clone(), b.clone() r, p = minibatch.clone(), minibatch.clone()
# Note: should be 'r, p = b - MVP(x)', but for x=0, MVP(x)=0. # Note: should be 'r, p = minibatch - MVP(x)', but for x=0, MVP(x)=0.
# Change if doing warm start. # Change if doing warm start.
rdotr = r.dot(r) rdotr = r.dot(r)
for _ in range(nsteps): for _ in range(nsteps):

View File

@ -107,7 +107,7 @@ class PGPolicy(BasePolicy):
Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for
more detailed explanation. more detailed explanation.
""" """
logits, h = self.actor(batch.obs, state=state) logits, hidden = self.actor(batch.obs, state=state)
if isinstance(logits, tuple): if isinstance(logits, tuple):
dist = self.dist_fn(*logits) dist = self.dist_fn(*logits)
else: else:
@ -119,20 +119,20 @@ class PGPolicy(BasePolicy):
act = logits[0] act = logits[0]
else: else:
act = dist.sample() act = dist.sample()
return Batch(logits=logits, act=act, state=h, dist=dist) return Batch(logits=logits, act=act, state=hidden, dist=dist)
def learn( # type: ignore def learn( # type: ignore
self, batch: Batch, batch_size: int, repeat: int, **kwargs: Any self, batch: Batch, batch_size: int, repeat: int, **kwargs: Any
) -> Dict[str, List[float]]: ) -> Dict[str, List[float]]:
losses = [] losses = []
for _ in range(repeat): for _ in range(repeat):
for b in batch.split(batch_size, merge_last=True): for minibatch in batch.split(batch_size, merge_last=True):
self.optim.zero_grad() self.optim.zero_grad()
result = self(b) result = self(minibatch)
dist = result.dist dist = result.dist
a = to_torch_as(b.act, result.act) act = to_torch_as(minibatch.act, result.act)
ret = to_torch_as(b.returns, result.act) ret = to_torch_as(minibatch.returns, result.act)
log_prob = dist.log_prob(a).reshape(len(ret), -1).transpose(0, 1) log_prob = dist.log_prob(act).reshape(len(ret), -1).transpose(0, 1)
loss = -(log_prob * ret).mean() loss = -(log_prob * ret).mean()
loss.backward() loss.backward()
self.optim.step() self.optim.step()

View File

@ -96,8 +96,8 @@ class PPOPolicy(A2CPolicy):
batch.act = to_torch_as(batch.act, batch.v_s) batch.act = to_torch_as(batch.act, batch.v_s)
old_log_prob = [] old_log_prob = []
with torch.no_grad(): with torch.no_grad():
for b in batch.split(self._batch, shuffle=False, merge_last=True): for minibatch in batch.split(self._batch, shuffle=False, merge_last=True):
old_log_prob.append(self(b).dist.log_prob(b.act)) old_log_prob.append(self(minibatch).dist.log_prob(minibatch.act))
batch.logp_old = torch.cat(old_log_prob, dim=0) batch.logp_old = torch.cat(old_log_prob, dim=0)
return batch return batch
@ -108,32 +108,35 @@ class PPOPolicy(A2CPolicy):
for step in range(repeat): for step in range(repeat):
if self._recompute_adv and step > 0: if self._recompute_adv and step > 0:
batch = self._compute_returns(batch, self._buffer, self._indices) batch = self._compute_returns(batch, self._buffer, self._indices)
for b in batch.split(batch_size, merge_last=True): for minibatch in batch.split(batch_size, merge_last=True):
# calculate loss for actor # calculate loss for actor
dist = self(b).dist dist = self(minibatch).dist
if self._norm_adv: if self._norm_adv:
mean, std = b.adv.mean(), b.adv.std() mean, std = minibatch.adv.mean(), minibatch.adv.std()
b.adv = (b.adv - mean) / std # per-batch norm minibatch.adv = (minibatch.adv - mean) / std # per-batch norm
ratio = (dist.log_prob(b.act) - b.logp_old).exp().float() ratio = (dist.log_prob(minibatch.act) -
minibatch.logp_old).exp().float()
ratio = ratio.reshape(ratio.size(0), -1).transpose(0, 1) ratio = ratio.reshape(ratio.size(0), -1).transpose(0, 1)
surr1 = ratio * b.adv surr1 = ratio * minibatch.adv
surr2 = ratio.clamp(1.0 - self._eps_clip, 1.0 + self._eps_clip) * b.adv surr2 = ratio.clamp(
1.0 - self._eps_clip, 1.0 + self._eps_clip
) * minibatch.adv
if self._dual_clip: if self._dual_clip:
clip1 = torch.min(surr1, surr2) clip1 = torch.min(surr1, surr2)
clip2 = torch.max(clip1, self._dual_clip * b.adv) clip2 = torch.max(clip1, self._dual_clip * minibatch.adv)
clip_loss = -torch.where(b.adv < 0, clip2, clip1).mean() clip_loss = -torch.where(minibatch.adv < 0, clip2, clip1).mean()
else: else:
clip_loss = -torch.min(surr1, surr2).mean() clip_loss = -torch.min(surr1, surr2).mean()
# calculate loss for critic # calculate loss for critic
value = self.critic(b.obs).flatten() value = self.critic(minibatch.obs).flatten()
if self._value_clip: if self._value_clip:
v_clip = b.v_s + (value - v_clip = minibatch.v_s + \
b.v_s).clamp(-self._eps_clip, self._eps_clip) (value - minibatch.v_s).clamp(-self._eps_clip, self._eps_clip)
vf1 = (b.returns - value).pow(2) vf1 = (minibatch.returns - value).pow(2)
vf2 = (b.returns - v_clip).pow(2) vf2 = (minibatch.returns - v_clip).pow(2)
vf_loss = torch.max(vf1, vf2).mean() vf_loss = torch.max(vf1, vf2).mean()
else: else:
vf_loss = (b.returns - value).pow(2).mean() vf_loss = (minibatch.returns - value).pow(2).mean()
# calculate regularization and overall loss # calculate regularization and overall loss
ent_loss = dist.entropy().mean() ent_loss = dist.entropy().mean()
loss = clip_loss + self._weight_vf * vf_loss \ loss = clip_loss + self._weight_vf * vf_loss \

View File

@ -56,13 +56,13 @@ class QRDQNPolicy(DQNPolicy):
def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor:
batch = buffer[indices] # batch.obs_next: s_{t+n} batch = buffer[indices] # batch.obs_next: s_{t+n}
if self._target: if self._target:
a = self(batch, input="obs_next").act act = self(batch, input="obs_next").act
next_dist = self(batch, model="model_old", input="obs_next").logits next_dist = self(batch, model="model_old", input="obs_next").logits
else: else:
next_b = self(batch, input="obs_next") next_batch = self(batch, input="obs_next")
a = next_b.act act = next_batch.act
next_dist = next_b.logits next_dist = next_batch.logits
next_dist = next_dist[np.arange(len(a)), a, :] next_dist = next_dist[np.arange(len(act)), act, :]
return next_dist # shape: [bsz, num_quantiles] return next_dist # shape: [bsz, num_quantiles]
def compute_q_value( def compute_q_value(
@ -80,15 +80,15 @@ class QRDQNPolicy(DQNPolicy):
curr_dist = curr_dist[np.arange(len(act)), act, :].unsqueeze(2) curr_dist = curr_dist[np.arange(len(act)), act, :].unsqueeze(2)
target_dist = batch.returns.unsqueeze(1) target_dist = batch.returns.unsqueeze(1)
# calculate each element's difference between curr_dist and target_dist # calculate each element's difference between curr_dist and target_dist
u = F.smooth_l1_loss(target_dist, curr_dist, reduction="none") dist_diff = F.smooth_l1_loss(target_dist, curr_dist, reduction="none")
huber_loss = ( huber_loss = (
u * (self.tau_hat - dist_diff *
(target_dist - curr_dist).detach().le(0.).float()).abs() (self.tau_hat - (target_dist - curr_dist).detach().le(0.).float()).abs()
).sum(-1).mean(1) ).sum(-1).mean(1)
loss = (huber_loss * weight).mean() loss = (huber_loss * weight).mean()
# ref: https://github.com/ku2482/fqf-iqn-qrdqn.pytorch/ # ref: https://github.com/ku2482/fqf-iqn-qrdqn.pytorch/
# blob/master/fqf_iqn_qrdqn/agent/qrdqn_agent.py L130 # blob/master/fqf_iqn_qrdqn/agent/qrdqn_agent.py L130
batch.weight = u.detach().abs().sum(-1).mean(1) # prio-buffer batch.weight = dist_diff.detach().abs().sum(-1).mean(1) # prio-buffer
loss.backward() loss.backward()
self.optim.step() self.optim.step()
self._iter += 1 self._iter += 1

View File

@ -99,10 +99,8 @@ class SACPolicy(DDPGPolicy):
return self return self
def sync_weight(self) -> None: def sync_weight(self) -> None:
for o, n in zip(self.critic1_old.parameters(), self.critic1.parameters()): self.soft_update(self.critic1_old, self.critic1, self.tau)
o.data.copy_(o.data * (1.0 - self._tau) + n.data * self._tau) self.soft_update(self.critic2_old, self.critic2, self.tau)
for o, n in zip(self.critic2_old.parameters(), self.critic2.parameters()):
o.data.copy_(o.data * (1.0 - self._tau) + n.data * self._tau)
def forward( # type: ignore def forward( # type: ignore
self, self,
@ -112,7 +110,7 @@ class SACPolicy(DDPGPolicy):
**kwargs: Any, **kwargs: Any,
) -> Batch: ) -> Batch:
obs = batch[input] obs = batch[input]
logits, h = self.actor(obs, state=state, info=batch.info) logits, hidden = self.actor(obs, state=state, info=batch.info)
assert isinstance(logits, tuple) assert isinstance(logits, tuple)
dist = Independent(Normal(*logits), 1) dist = Independent(Normal(*logits), 1)
if self._deterministic_eval and not self.training: if self._deterministic_eval and not self.training:
@ -134,16 +132,20 @@ class SACPolicy(DDPGPolicy):
action_scale * (1 - squashed_action.pow(2)) + self.__eps action_scale * (1 - squashed_action.pow(2)) + self.__eps
).sum(-1, keepdim=True) ).sum(-1, keepdim=True)
return Batch( return Batch(
logits=logits, act=squashed_action, state=h, dist=dist, log_prob=log_prob logits=logits,
act=squashed_action,
state=hidden,
dist=dist,
log_prob=log_prob
) )
def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor:
batch = buffer[indices] # batch.obs: s_{t+n} batch = buffer[indices] # batch.obs: s_{t+n}
obs_next_result = self(batch, input='obs_next') obs_next_result = self(batch, input="obs_next")
a_ = obs_next_result.act act_ = obs_next_result.act
target_q = torch.min( target_q = torch.min(
self.critic1_old(batch.obs_next, a_), self.critic1_old(batch.obs_next, act_),
self.critic2_old(batch.obs_next, a_), self.critic2_old(batch.obs_next, act_),
) - self._alpha * obs_next_result.log_prob ) - self._alpha * obs_next_result.log_prob
return target_q return target_q
@ -159,9 +161,9 @@ class SACPolicy(DDPGPolicy):
# actor # actor
obs_result = self(batch) obs_result = self(batch)
a = obs_result.act act = obs_result.act
current_q1a = self.critic1(batch.obs, a).flatten() current_q1a = self.critic1(batch.obs, act).flatten()
current_q2a = self.critic2(batch.obs, a).flatten() current_q2a = self.critic2(batch.obs, act).flatten()
actor_loss = ( actor_loss = (
self._alpha * obs_result.log_prob.flatten() - self._alpha * obs_result.log_prob.flatten() -
torch.min(current_q1a, current_q2a) torch.min(current_q1a, current_q2a)

View File

@ -89,23 +89,20 @@ class TD3Policy(DDPGPolicy):
return self return self
def sync_weight(self) -> None: def sync_weight(self) -> None:
for o, n in zip(self.actor_old.parameters(), self.actor.parameters()): self.soft_update(self.critic1_old, self.critic1, self.tau)
o.data.copy_(o.data * (1.0 - self._tau) + n.data * self._tau) self.soft_update(self.critic2_old, self.critic2, self.tau)
for o, n in zip(self.critic1_old.parameters(), self.critic1.parameters()): self.soft_update(self.actor_old, self.actor, self.tau)
o.data.copy_(o.data * (1.0 - self._tau) + n.data * self._tau)
for o, n in zip(self.critic2_old.parameters(), self.critic2.parameters()):
o.data.copy_(o.data * (1.0 - self._tau) + n.data * self._tau)
def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor:
batch = buffer[indices] # batch.obs: s_{t+n} batch = buffer[indices] # batch.obs: s_{t+n}
a_ = self(batch, model="actor_old", input="obs_next").act act_ = self(batch, model="actor_old", input="obs_next").act
dev = a_.device noise = torch.randn(size=act_.shape, device=act_.device) * self._policy_noise
noise = torch.randn(size=a_.shape, device=dev) * self._policy_noise
if self._noise_clip > 0.0: if self._noise_clip > 0.0:
noise = noise.clamp(-self._noise_clip, self._noise_clip) noise = noise.clamp(-self._noise_clip, self._noise_clip)
a_ += noise act_ += noise
target_q = torch.min( target_q = torch.min(
self.critic1_old(batch.obs_next, a_), self.critic2_old(batch.obs_next, a_) self.critic1_old(batch.obs_next, act_),
self.critic2_old(batch.obs_next, act_),
) )
return target_q return target_q

View File

@ -71,20 +71,21 @@ class TRPOPolicy(NPGPolicy):
) -> Dict[str, List[float]]: ) -> Dict[str, List[float]]:
actor_losses, vf_losses, step_sizes, kls = [], [], [], [] actor_losses, vf_losses, step_sizes, kls = [], [], [], []
for _ in range(repeat): for _ in range(repeat):
for b in batch.split(batch_size, merge_last=True): for minibatch in batch.split(batch_size, merge_last=True):
# optimize actor # optimize actor
# direction: calculate villia gradient # direction: calculate villia gradient
dist = self(b).dist # TODO could come from batch dist = self(minibatch).dist # TODO could come from batch
ratio = (dist.log_prob(b.act) - b.logp_old).exp().float() ratio = (dist.log_prob(minibatch.act) -
minibatch.logp_old).exp().float()
ratio = ratio.reshape(ratio.size(0), -1).transpose(0, 1) ratio = ratio.reshape(ratio.size(0), -1).transpose(0, 1)
actor_loss = -(ratio * b.adv).mean() actor_loss = -(ratio * minibatch.adv).mean()
flat_grads = self._get_flat_grad( flat_grads = self._get_flat_grad(
actor_loss, self.actor, retain_graph=True actor_loss, self.actor, retain_graph=True
).detach() ).detach()
# direction: calculate natural gradient # direction: calculate natural gradient
with torch.no_grad(): with torch.no_grad():
old_dist = self(b).dist old_dist = self(minibatch).dist
kl = kl_divergence(old_dist, dist).mean() kl = kl_divergence(old_dist, dist).mean()
# calculate first order gradient of kl with respect to theta # calculate first order gradient of kl with respect to theta
@ -109,12 +110,13 @@ class TRPOPolicy(NPGPolicy):
new_flat_params = flat_params + step_size * search_direction new_flat_params = flat_params + step_size * search_direction
self._set_from_flat_params(self.actor, new_flat_params) self._set_from_flat_params(self.actor, new_flat_params)
# calculate kl and if in bound, loss actually down # calculate kl and if in bound, loss actually down
new_dist = self(b).dist new_dist = self(minibatch).dist
new_dratio = (new_dist.log_prob(b.act) - new_dratio = (
b.logp_old).exp().float() new_dist.log_prob(minibatch.act) - minibatch.logp_old
).exp().float()
new_dratio = new_dratio.reshape(new_dratio.size(0), new_dratio = new_dratio.reshape(new_dratio.size(0),
-1).transpose(0, 1) -1).transpose(0, 1)
new_actor_loss = -(new_dratio * b.adv).mean() new_actor_loss = -(new_dratio * minibatch.adv).mean()
kl = kl_divergence(old_dist, new_dist).mean() kl = kl_divergence(old_dist, new_dist).mean()
if kl < self._delta and new_actor_loss < actor_loss: if kl < self._delta and new_actor_loss < actor_loss:
@ -133,8 +135,8 @@ class TRPOPolicy(NPGPolicy):
# optimize citirc # optimize citirc
for _ in range(self._optim_critic_iters): for _ in range(self._optim_critic_iters):
value = self.critic(b.obs).flatten() value = self.critic(minibatch.obs).flatten()
vf_loss = F.mse_loss(b.returns, value) vf_loss = F.mse_loss(minibatch.returns, value)
self.optim.zero_grad() self.optim.zero_grad()
vf_loss.backward() vf_loss.backward()
self.optim.step() self.optim.step()

View File

@ -87,14 +87,14 @@ class MLP(nn.Module):
self.output_dim = output_dim or hidden_sizes[-1] self.output_dim = output_dim or hidden_sizes[-1]
self.model = nn.Sequential(*model) self.model = nn.Sequential(*model)
def forward(self, s: Union[np.ndarray, torch.Tensor]) -> torch.Tensor: def forward(self, obs: Union[np.ndarray, torch.Tensor]) -> torch.Tensor:
if self.device is not None: if self.device is not None:
s = torch.as_tensor( obs = torch.as_tensor(
s, obs,
device=self.device, # type: ignore device=self.device, # type: ignore
dtype=torch.float32, dtype=torch.float32,
) )
return self.model(s.flatten(1)) # type: ignore return self.model(obs.flatten(1)) # type: ignore
class Net(nn.Module): class Net(nn.Module):
@ -187,12 +187,12 @@ class Net(nn.Module):
def forward( def forward(
self, self,
s: Union[np.ndarray, torch.Tensor], obs: Union[np.ndarray, torch.Tensor],
state: Any = None, state: Any = None,
info: Dict[str, Any] = {}, info: Dict[str, Any] = {},
) -> Tuple[torch.Tensor, Any]: ) -> Tuple[torch.Tensor, Any]:
"""Mapping: s -> flatten (inside MLP)-> logits.""" """Mapping: obs -> flatten (inside MLP)-> logits."""
logits = self.model(s) logits = self.model(obs)
bsz = logits.shape[0] bsz = logits.shape[0]
if self.use_dueling: # Dueling DQN if self.use_dueling: # Dueling DQN
q, v = self.Q(logits), self.V(logits) q, v = self.Q(logits), self.V(logits)
@ -235,38 +235,45 @@ class Recurrent(nn.Module):
def forward( def forward(
self, self,
s: Union[np.ndarray, torch.Tensor], obs: Union[np.ndarray, torch.Tensor],
state: Optional[Dict[str, torch.Tensor]] = None, state: Optional[Dict[str, torch.Tensor]] = None,
info: Dict[str, Any] = {}, info: Dict[str, Any] = {},
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
"""Mapping: s -> flatten -> logits. """Mapping: obs -> flatten -> logits.
In the evaluation mode, s should be with shape ``[bsz, dim]``; in the In the evaluation mode, `obs` should be with shape ``[bsz, dim]``; in the
training mode, s should be with shape ``[bsz, len, dim]``. See the code training mode, `obs` should be with shape ``[bsz, len, dim]``. See the code
and comment for more detail. and comment for more detail.
""" """
s = torch.as_tensor(s, device=self.device, dtype=torch.float32) # type: ignore obs = torch.as_tensor(
# s [bsz, len, dim] (training) or [bsz, dim] (evaluation) obs,
device=self.device, # type: ignore
dtype=torch.float32,
)
# obs [bsz, len, dim] (training) or [bsz, dim] (evaluation)
# In short, the tensor's shape in training phase is longer than which # In short, the tensor's shape in training phase is longer than which
# in evaluation phase. # in evaluation phase.
if len(s.shape) == 2: if len(obs.shape) == 2:
s = s.unsqueeze(-2) obs = obs.unsqueeze(-2)
s = self.fc1(s) obs = self.fc1(obs)
self.nn.flatten_parameters() self.nn.flatten_parameters()
if state is None: if state is None:
s, (h, c) = self.nn(s) obs, (hidden, cell) = self.nn(obs)
else: else:
# we store the stack data in [bsz, len, ...] format # we store the stack data in [bsz, len, ...] format
# but pytorch rnn needs [len, bsz, ...] # but pytorch rnn needs [len, bsz, ...]
s, (h, c) = self.nn( obs, (hidden, cell) = self.nn(
s, ( obs, (
state["h"].transpose(0, 1).contiguous(), state["hidden"].transpose(0, 1).contiguous(),
state["c"].transpose(0, 1).contiguous() state["cell"].transpose(0, 1).contiguous()
) )
) )
s = self.fc2(s[:, -1]) obs = self.fc2(obs[:, -1])
# please ensure the first dim is batch size: [bsz, len, ...] # please ensure the first dim is batch size: [bsz, len, ...]
return s, {"h": h.transpose(0, 1).detach(), "c": c.transpose(0, 1).detach()} return obs, {
"hidden": hidden.transpose(0, 1).detach(),
"cell": cell.transpose(0, 1).detach()
}
class ActorCritic(nn.Module): class ActorCritic(nn.Module):
@ -299,8 +306,8 @@ class DataParallelNet(nn.Module):
super().__init__() super().__init__()
self.net = nn.DataParallel(net) self.net = nn.DataParallel(net)
def forward(self, s: Union[np.ndarray, torch.Tensor], *args: Any, def forward(self, obs: Union[np.ndarray, torch.Tensor], *args: Any,
**kwargs: Any) -> Tuple[Any, Any]: **kwargs: Any) -> Tuple[Any, Any]:
if not isinstance(s, torch.Tensor): if not isinstance(obs, torch.Tensor):
s = torch.as_tensor(s, dtype=torch.float32) obs = torch.as_tensor(obs, dtype=torch.float32)
return self.net(s=s.cuda(), *args, **kwargs) return self.net(obs=obs.cuda(), *args, **kwargs)

View File

@ -58,14 +58,14 @@ class Actor(nn.Module):
def forward( def forward(
self, self,
s: Union[np.ndarray, torch.Tensor], obs: Union[np.ndarray, torch.Tensor],
state: Any = None, state: Any = None,
info: Dict[str, Any] = {}, info: Dict[str, Any] = {},
) -> Tuple[torch.Tensor, Any]: ) -> Tuple[torch.Tensor, Any]:
"""Mapping: s -> logits -> action.""" """Mapping: obs -> logits -> action."""
logits, h = self.preprocess(s, state) logits, hidden = self.preprocess(obs, state)
logits = self._max * torch.tanh(self.last(logits)) logits = self._max * torch.tanh(self.last(logits))
return logits, h return logits, hidden
class Critic(nn.Module): class Critic(nn.Module):
@ -110,24 +110,24 @@ class Critic(nn.Module):
def forward( def forward(
self, self,
s: Union[np.ndarray, torch.Tensor], obs: Union[np.ndarray, torch.Tensor],
a: Optional[Union[np.ndarray, torch.Tensor]] = None, act: Optional[Union[np.ndarray, torch.Tensor]] = None,
info: Dict[str, Any] = {}, info: Dict[str, Any] = {},
) -> torch.Tensor: ) -> torch.Tensor:
"""Mapping: (s, a) -> logits -> Q(s, a).""" """Mapping: (s, a) -> logits -> Q(s, a)."""
s = torch.as_tensor( obs = torch.as_tensor(
s, obs,
device=self.device, # type: ignore device=self.device, # type: ignore
dtype=torch.float32, dtype=torch.float32,
).flatten(1) ).flatten(1)
if a is not None: if act is not None:
a = torch.as_tensor( act = torch.as_tensor(
a, act,
device=self.device, # type: ignore device=self.device, # type: ignore
dtype=torch.float32, dtype=torch.float32,
).flatten(1) ).flatten(1)
s = torch.cat([s, a], dim=1) obs = torch.cat([obs, act], dim=1)
logits, h = self.preprocess(s) logits, hidden = self.preprocess(obs)
logits = self.last(logits) logits = self.last(logits)
return logits return logits
@ -196,12 +196,12 @@ class ActorProb(nn.Module):
def forward( def forward(
self, self,
s: Union[np.ndarray, torch.Tensor], obs: Union[np.ndarray, torch.Tensor],
state: Any = None, state: Any = None,
info: Dict[str, Any] = {}, info: Dict[str, Any] = {},
) -> Tuple[Tuple[torch.Tensor, torch.Tensor], Any]: ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], Any]:
"""Mapping: s -> logits -> (mu, sigma).""" """Mapping: obs -> logits -> (mu, sigma)."""
logits, h = self.preprocess(s, state) logits, hidden = self.preprocess(obs, state)
mu = self.mu(logits) mu = self.mu(logits)
if not self._unbounded: if not self._unbounded:
mu = self._max * torch.tanh(mu) mu = self._max * torch.tanh(mu)
@ -252,30 +252,34 @@ class RecurrentActorProb(nn.Module):
def forward( def forward(
self, self,
s: Union[np.ndarray, torch.Tensor], obs: Union[np.ndarray, torch.Tensor],
state: Optional[Dict[str, torch.Tensor]] = None, state: Optional[Dict[str, torch.Tensor]] = None,
info: Dict[str, Any] = {}, info: Dict[str, Any] = {},
) -> Tuple[Tuple[torch.Tensor, torch.Tensor], Dict[str, torch.Tensor]]: ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], Dict[str, torch.Tensor]]:
"""Almost the same as :class:`~tianshou.utils.net.common.Recurrent`.""" """Almost the same as :class:`~tianshou.utils.net.common.Recurrent`."""
s = torch.as_tensor(s, device=self.device, dtype=torch.float32) # type: ignore obs = torch.as_tensor(
# s [bsz, len, dim] (training) or [bsz, dim] (evaluation) obs,
device=self.device, # type: ignore
dtype=torch.float32,
)
# obs [bsz, len, dim] (training) or [bsz, dim] (evaluation)
# In short, the tensor's shape in training phase is longer than which # In short, the tensor's shape in training phase is longer than which
# in evaluation phase. # in evaluation phase.
if len(s.shape) == 2: if len(obs.shape) == 2:
s = s.unsqueeze(-2) obs = obs.unsqueeze(-2)
self.nn.flatten_parameters() self.nn.flatten_parameters()
if state is None: if state is None:
s, (h, c) = self.nn(s) obs, (hidden, cell) = self.nn(obs)
else: else:
# we store the stack data in [bsz, len, ...] format # we store the stack data in [bsz, len, ...] format
# but pytorch rnn needs [len, bsz, ...] # but pytorch rnn needs [len, bsz, ...]
s, (h, c) = self.nn( obs, (hidden, cell) = self.nn(
s, ( obs, (
state["h"].transpose(0, 1).contiguous(), state["hidden"].transpose(0, 1).contiguous(),
state["c"].transpose(0, 1).contiguous() state["cell"].transpose(0, 1).contiguous()
) )
) )
logits = s[:, -1] logits = obs[:, -1]
mu = self.mu(logits) mu = self.mu(logits)
if not self._unbounded: if not self._unbounded:
mu = self._max * torch.tanh(mu) mu = self._max * torch.tanh(mu)
@ -287,8 +291,8 @@ class RecurrentActorProb(nn.Module):
sigma = (self.sigma_param.view(shape) + torch.zeros_like(mu)).exp() sigma = (self.sigma_param.view(shape) + torch.zeros_like(mu)).exp()
# please ensure the first dim is batch size: [bsz, len, ...] # please ensure the first dim is batch size: [bsz, len, ...]
return (mu, sigma), { return (mu, sigma), {
"h": h.transpose(0, 1).detach(), "hidden": hidden.transpose(0, 1).detach(),
"c": c.transpose(0, 1).detach() "cell": cell.transpose(0, 1).detach()
} }
@ -321,28 +325,32 @@ class RecurrentCritic(nn.Module):
def forward( def forward(
self, self,
s: Union[np.ndarray, torch.Tensor], obs: Union[np.ndarray, torch.Tensor],
a: Optional[Union[np.ndarray, torch.Tensor]] = None, act: Optional[Union[np.ndarray, torch.Tensor]] = None,
info: Dict[str, Any] = {}, info: Dict[str, Any] = {},
) -> torch.Tensor: ) -> torch.Tensor:
"""Almost the same as :class:`~tianshou.utils.net.common.Recurrent`.""" """Almost the same as :class:`~tianshou.utils.net.common.Recurrent`."""
s = torch.as_tensor(s, device=self.device, dtype=torch.float32) # type: ignore obs = torch.as_tensor(
# s [bsz, len, dim] (training) or [bsz, dim] (evaluation) obs,
device=self.device, # type: ignore
dtype=torch.float32,
)
# obs [bsz, len, dim] (training) or [bsz, dim] (evaluation)
# In short, the tensor's shape in training phase is longer than which # In short, the tensor's shape in training phase is longer than which
# in evaluation phase. # in evaluation phase.
assert len(s.shape) == 3 assert len(obs.shape) == 3
self.nn.flatten_parameters() self.nn.flatten_parameters()
s, (h, c) = self.nn(s) obs, (hidden, cell) = self.nn(obs)
s = s[:, -1] obs = obs[:, -1]
if a is not None: if act is not None:
a = torch.as_tensor( act = torch.as_tensor(
a, act,
device=self.device, # type: ignore device=self.device, # type: ignore
dtype=torch.float32, dtype=torch.float32,
) )
s = torch.cat([s, a], dim=1) obs = torch.cat([obs, act], dim=1)
s = self.fc2(s) obs = self.fc2(obs)
return s return obs
class Perturbation(nn.Module): class Perturbation(nn.Module):
@ -381,9 +389,9 @@ class Perturbation(nn.Module):
def forward(self, state: torch.Tensor, action: torch.Tensor) -> torch.Tensor: def forward(self, state: torch.Tensor, action: torch.Tensor) -> torch.Tensor:
# preprocess_net # preprocess_net
logits = self.preprocess_net(torch.cat([state, action], -1))[0] logits = self.preprocess_net(torch.cat([state, action], -1))[0]
a = self.phi * self.max_action * torch.tanh(logits) noise = self.phi * self.max_action * torch.tanh(logits)
# clip to [-max_action, max_action] # clip to [-max_action, max_action]
return (a + action).clamp(-self.max_action, self.max_action) return (noise + action).clamp(-self.max_action, self.max_action)
class VAE(nn.Module): class VAE(nn.Module):
@ -434,31 +442,32 @@ class VAE(nn.Module):
self, state: torch.Tensor, action: torch.Tensor self, state: torch.Tensor, action: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# [state, action] -> z , [state, z] -> action # [state, action] -> z , [state, z] -> action
z = self.encoder(torch.cat([state, action], -1)) latent_z = self.encoder(torch.cat([state, action], -1))
# shape of z: (state.shape[:-1], hidden_dim) # shape of z: (state.shape[:-1], hidden_dim)
mean = self.mean(z) mean = self.mean(latent_z)
# Clamped for numerical stability # Clamped for numerical stability
log_std = self.log_std(z).clamp(-4, 15) log_std = self.log_std(latent_z).clamp(-4, 15)
std = torch.exp(log_std) std = torch.exp(log_std)
# shape of mean, std: (state.shape[:-1], latent_dim) # shape of mean, std: (state.shape[:-1], latent_dim)
z = mean + std * torch.randn_like(std) # (state.shape[:-1], latent_dim) latent_z = mean + std * torch.randn_like(std) # (state.shape[:-1], latent_dim)
u = self.decode(state, z) # (state.shape[:-1], action_dim) reconstruction = self.decode(state, latent_z) # (state.shape[:-1], action_dim)
return u, mean, std return reconstruction, mean, std
def decode( def decode(
self, self,
state: torch.Tensor, state: torch.Tensor,
z: Union[torch.Tensor, None] = None latent_z: Union[torch.Tensor, None] = None
) -> torch.Tensor: ) -> torch.Tensor:
# decode(state) -> action # decode(state) -> action
if z is None: if latent_z is None:
# state.shape[0] may be batch_size # state.shape[0] may be batch_size
# latent vector clipped to [-0.5, 0.5] # latent vector clipped to [-0.5, 0.5]
z = torch.randn(state.shape[:-1] + (self.latent_dim, )) \ latent_z = torch.randn(state.shape[:-1] + (self.latent_dim, )) \
.to(self.device).clamp(-0.5, 0.5) .to(self.device).clamp(-0.5, 0.5)
# decode z with state! # decode z with state!
return self.max_action * torch.tanh(self.decoder(torch.cat([state, z], -1))) return self.max_action * \
torch.tanh(self.decoder(torch.cat([state, latent_z], -1)))

View File

@ -59,16 +59,16 @@ class Actor(nn.Module):
def forward( def forward(
self, self,
s: Union[np.ndarray, torch.Tensor], obs: Union[np.ndarray, torch.Tensor],
state: Any = None, state: Any = None,
info: Dict[str, Any] = {}, info: Dict[str, Any] = {},
) -> Tuple[torch.Tensor, Any]: ) -> Tuple[torch.Tensor, Any]:
r"""Mapping: s -> Q(s, \*).""" r"""Mapping: s -> Q(s, \*)."""
logits, h = self.preprocess(s, state) logits, hidden = self.preprocess(obs, state)
logits = self.last(logits) logits = self.last(logits)
if self.softmax_output: if self.softmax_output:
logits = F.softmax(logits, dim=-1) logits = F.softmax(logits, dim=-1)
return logits, h return logits, hidden
class Critic(nn.Module): class Critic(nn.Module):
@ -114,10 +114,10 @@ class Critic(nn.Module):
) )
def forward( def forward(
self, s: Union[np.ndarray, torch.Tensor], **kwargs: Any self, obs: Union[np.ndarray, torch.Tensor], **kwargs: Any
) -> torch.Tensor: ) -> torch.Tensor:
"""Mapping: s -> V(s).""" """Mapping: s -> V(s)."""
logits, _ = self.preprocess(s, state=kwargs.get("state", None)) logits, _ = self.preprocess(obs, state=kwargs.get("state", None))
return self.last(logits) return self.last(logits)
@ -199,10 +199,10 @@ class ImplicitQuantileNetwork(Critic):
).to(device) ).to(device)
def forward( # type: ignore def forward( # type: ignore
self, s: Union[np.ndarray, torch.Tensor], sample_size: int, **kwargs: Any self, obs: Union[np.ndarray, torch.Tensor], sample_size: int, **kwargs: Any
) -> Tuple[Any, torch.Tensor]: ) -> Tuple[Any, torch.Tensor]:
r"""Mapping: s -> Q(s, \*).""" r"""Mapping: s -> Q(s, \*)."""
logits, h = self.preprocess(s, state=kwargs.get("state", None)) logits, hidden = self.preprocess(obs, state=kwargs.get("state", None))
# Sample fractions. # Sample fractions.
batch_size = logits.size(0) batch_size = logits.size(0)
taus = torch.rand( taus = torch.rand(
@ -211,7 +211,7 @@ class ImplicitQuantileNetwork(Critic):
embedding = (logits.unsqueeze(1) * embedding = (logits.unsqueeze(1) *
self.embed_model(taus)).view(batch_size * sample_size, -1) self.embed_model(taus)).view(batch_size * sample_size, -1)
out = self.last(embedding).view(batch_size, sample_size, -1).transpose(1, 2) out = self.last(embedding).view(batch_size, sample_size, -1).transpose(1, 2)
return (out, taus), h return (out, taus), hidden
class FractionProposalNetwork(nn.Module): class FractionProposalNetwork(nn.Module):
@ -235,17 +235,17 @@ class FractionProposalNetwork(nn.Module):
self.embedding_dim = embedding_dim self.embedding_dim = embedding_dim
def forward( def forward(
self, state_embeddings: torch.Tensor self, obs_embeddings: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# Calculate (log of) probabilities q_i in the paper. # Calculate (log of) probabilities q_i in the paper.
m = torch.distributions.Categorical(logits=self.net(state_embeddings)) dist = torch.distributions.Categorical(logits=self.net(obs_embeddings))
taus_1_N = torch.cumsum(m.probs, dim=1) taus_1_N = torch.cumsum(dist.probs, dim=1)
# Calculate \tau_i (i=0,...,N). # Calculate \tau_i (i=0,...,N).
taus = F.pad(taus_1_N, (1, 0)) taus = F.pad(taus_1_N, (1, 0))
# Calculate \hat \tau_i (i=0,...,N-1). # Calculate \hat \tau_i (i=0,...,N-1).
tau_hats = (taus[:, :-1] + taus[:, 1:]).detach() / 2.0 tau_hats = (taus[:, :-1] + taus[:, 1:]).detach() / 2.0
# Calculate entropies of value distributions. # Calculate entropies of value distributions.
entropies = m.entropy() entropies = dist.entropy()
return taus, tau_hats, entropies return taus, tau_hats, entropies
@ -294,13 +294,13 @@ class FullQuantileFunction(ImplicitQuantileNetwork):
return quantiles return quantiles
def forward( # type: ignore def forward( # type: ignore
self, s: Union[np.ndarray, torch.Tensor], self, obs: Union[np.ndarray, torch.Tensor],
propose_model: FractionProposalNetwork, propose_model: FractionProposalNetwork,
fractions: Optional[Batch] = None, fractions: Optional[Batch] = None,
**kwargs: Any **kwargs: Any
) -> Tuple[Any, torch.Tensor]: ) -> Tuple[Any, torch.Tensor]:
r"""Mapping: s -> Q(s, \*).""" r"""Mapping: s -> Q(s, \*)."""
logits, h = self.preprocess(s, state=kwargs.get("state", None)) logits, hidden = self.preprocess(obs, state=kwargs.get("state", None))
# Propose fractions # Propose fractions
if fractions is None: if fractions is None:
taus, tau_hats, entropies = propose_model(logits.detach()) taus, tau_hats, entropies = propose_model(logits.detach())
@ -313,7 +313,7 @@ class FullQuantileFunction(ImplicitQuantileNetwork):
if self.training: if self.training:
with torch.no_grad(): with torch.no_grad():
quantiles_tau = self._compute_quantiles(logits, taus[:, 1:-1]) quantiles_tau = self._compute_quantiles(logits, taus[:, 1:-1])
return (quantiles, fractions, quantiles_tau), h return (quantiles, fractions, quantiles_tau), hidden
class NoisyLinear(nn.Module): class NoisyLinear(nn.Module):

View File

@ -31,20 +31,20 @@ class MovAvg(object):
self.banned = [np.inf, np.nan, -np.inf] self.banned = [np.inf, np.nan, -np.inf]
def add( def add(
self, x: Union[Number, np.number, list, np.ndarray, torch.Tensor] self, data_array: Union[Number, np.number, list, np.ndarray, torch.Tensor]
) -> float: ) -> float:
"""Add a scalar into :class:`MovAvg`. """Add a scalar into :class:`MovAvg`.
You can add ``torch.Tensor`` with only one element, a python scalar, or You can add ``torch.Tensor`` with only one element, a python scalar, or
a list of python scalar. a list of python scalar.
""" """
if isinstance(x, torch.Tensor): if isinstance(data_array, torch.Tensor):
x = x.flatten().cpu().numpy() data_array = data_array.flatten().cpu().numpy()
if np.isscalar(x): if np.isscalar(data_array):
x = [x] data_array = [data_array]
for i in x: # type: ignore for number in data_array: # type: ignore
if i not in self.banned: if number not in self.banned:
self.cache.append(i) self.cache.append(number)
if self.size > 0 and len(self.cache) > self.size: if self.size > 0 and len(self.cache) > self.size:
self.cache = self.cache[-self.size:] self.cache = self.cache[-self.size:]
return self.get() return self.get()
@ -80,10 +80,10 @@ class RunningMeanStd(object):
self.mean, self.var = mean, std self.mean, self.var = mean, std
self.count = 0 self.count = 0
def update(self, x: np.ndarray) -> None: def update(self, data_array: np.ndarray) -> None:
"""Add a batch of item into RMS with the same shape, modify mean/var/count.""" """Add a batch of item into RMS with the same shape, modify mean/var/count."""
batch_mean, batch_var = np.mean(x, axis=0), np.var(x, axis=0) batch_mean, batch_var = np.mean(data_array, axis=0), np.var(data_array, axis=0)
batch_count = len(x) batch_count = len(data_array)
delta = batch_mean - self.mean delta = batch_mean - self.mean
total_count = self.count + batch_count total_count = self.count + batch_count