Formalize variable names (#509)
Co-authored-by: Jiayi Weng <trinkle23897@gmail.com>
This commit is contained in:
parent
bc53ead273
commit
c25926dd8f
@ -350,7 +350,7 @@ But the state stored in the buffer may be a shallow-copy. To make sure each of y
|
||||
|
||||
def reset():
|
||||
return copy.deepcopy(self.graph)
|
||||
def step(a):
|
||||
def step(action):
|
||||
...
|
||||
return copy.deepcopy(self.graph), reward, done, {}
|
||||
|
||||
@ -391,13 +391,13 @@ In addition, legal actions in multi-agent RL often vary with timestep (just like
|
||||
The above description gives rise to the following formulation of multi-agent RL:
|
||||
::
|
||||
|
||||
action = policy(state, agent_id, mask)
|
||||
(next_state, next_agent_id, next_mask), reward = env.step(action)
|
||||
act = policy(state, agent_id, mask)
|
||||
(next_state, next_agent_id, next_mask), reward = env.step(act)
|
||||
|
||||
By constructing a new state ``state_ = (state, agent_id, mask)``, essentially we can return to the typical formulation of RL:
|
||||
::
|
||||
|
||||
action = policy(state_)
|
||||
next_state_, reward = env.step(action)
|
||||
act = policy(state_)
|
||||
next_state_, reward = env.step(act)
|
||||
|
||||
Following this idea, we write a tiny example of playing `Tic Tac Toe <https://en.wikipedia.org/wiki/Tic-tac-toe>`_ against a random player by using a Q-learning algorithm. The tutorial is at :doc:`/tutorials/tictactoe`.
|
||||
|
@ -298,14 +298,14 @@ where :math:`\gamma` is the discount factor, :math:`\gamma \in [0, 1]`. Here is
|
||||
::
|
||||
|
||||
# pseudocode, cannot work
|
||||
s = env.reset()
|
||||
obs = env.reset()
|
||||
buffer = Buffer(size=10000)
|
||||
agent = DQN()
|
||||
for i in range(int(1e6)):
|
||||
a = agent.compute_action(s)
|
||||
s_, r, d, _ = env.step(a)
|
||||
buffer.store(s, a, s_, r, d)
|
||||
s = s_
|
||||
act = agent.compute_action(obs)
|
||||
obs_next, rew, done, _ = env.step(act)
|
||||
buffer.store(obs, act, obs_next, rew, done)
|
||||
obs = obs_next
|
||||
if i % 1000 == 0:
|
||||
b_s, b_a, b_s_, b_r, b_d = buffer.get(size=64)
|
||||
# compute 2-step returns. How?
|
||||
@ -390,14 +390,14 @@ We give a high-level explanation through the pseudocode used in section :ref:`pr
|
||||
::
|
||||
|
||||
# pseudocode, cannot work # methods in tianshou
|
||||
s = env.reset()
|
||||
obs = env.reset()
|
||||
buffer = Buffer(size=10000) # buffer = tianshou.data.ReplayBuffer(size=10000)
|
||||
agent = DQN() # policy.__init__(...)
|
||||
for i in range(int(1e6)): # done in trainer
|
||||
a = agent.compute_action(s) # act = policy(batch, ...).act
|
||||
s_, r, d, _ = env.step(a) # collector.collect(...)
|
||||
buffer.store(s, a, s_, r, d) # collector.collect(...)
|
||||
s = s_ # collector.collect(...)
|
||||
act = agent.compute_action(obs) # act = policy(batch, ...).act
|
||||
obs_next, rew, done, _ = env.step(act) # collector.collect(...)
|
||||
buffer.store(obs, act, obs_next, rew, done) # collector.collect(...)
|
||||
obs = obs_next # collector.collect(...)
|
||||
if i % 1000 == 0: # done in trainer
|
||||
# the following is done in policy.update(batch_size, buffer)
|
||||
b_s, b_a, b_s_, b_r, b_d = buffer.get(size=64) # batch, indices = buffer.sample(batch_size)
|
||||
|
@ -42,13 +42,13 @@ class DQN(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: Union[np.ndarray, torch.Tensor],
|
||||
obs: Union[np.ndarray, torch.Tensor],
|
||||
state: Optional[Any] = None,
|
||||
info: Dict[str, Any] = {},
|
||||
) -> Tuple[torch.Tensor, Any]:
|
||||
r"""Mapping: x -> Q(x, \*)."""
|
||||
x = torch.as_tensor(x, device=self.device, dtype=torch.float32)
|
||||
return self.net(x), state
|
||||
r"""Mapping: s -> Q(s, \*)."""
|
||||
obs = torch.as_tensor(obs, device=self.device, dtype=torch.float32)
|
||||
return self.net(obs), state
|
||||
|
||||
|
||||
class C51(DQN):
|
||||
@ -73,15 +73,15 @@ class C51(DQN):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: Union[np.ndarray, torch.Tensor],
|
||||
obs: Union[np.ndarray, torch.Tensor],
|
||||
state: Optional[Any] = None,
|
||||
info: Dict[str, Any] = {},
|
||||
) -> Tuple[torch.Tensor, Any]:
|
||||
r"""Mapping: x -> Z(x, \*)."""
|
||||
x, state = super().forward(x)
|
||||
x = x.view(-1, self.num_atoms).softmax(dim=-1)
|
||||
x = x.view(-1, self.action_num, self.num_atoms)
|
||||
return x, state
|
||||
obs, state = super().forward(obs)
|
||||
obs = obs.view(-1, self.num_atoms).softmax(dim=-1)
|
||||
obs = obs.view(-1, self.action_num, self.num_atoms)
|
||||
return obs, state
|
||||
|
||||
|
||||
class Rainbow(DQN):
|
||||
@ -127,22 +127,22 @@ class Rainbow(DQN):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: Union[np.ndarray, torch.Tensor],
|
||||
obs: Union[np.ndarray, torch.Tensor],
|
||||
state: Optional[Any] = None,
|
||||
info: Dict[str, Any] = {},
|
||||
) -> Tuple[torch.Tensor, Any]:
|
||||
r"""Mapping: x -> Z(x, \*)."""
|
||||
x, state = super().forward(x)
|
||||
q = self.Q(x)
|
||||
obs, state = super().forward(obs)
|
||||
q = self.Q(obs)
|
||||
q = q.view(-1, self.action_num, self.num_atoms)
|
||||
if self._is_dueling:
|
||||
v = self.V(x)
|
||||
v = self.V(obs)
|
||||
v = v.view(-1, 1, self.num_atoms)
|
||||
logits = q - q.mean(dim=1, keepdim=True) + v
|
||||
else:
|
||||
logits = q
|
||||
y = logits.softmax(dim=2)
|
||||
return y, state
|
||||
probs = logits.softmax(dim=2)
|
||||
return probs, state
|
||||
|
||||
|
||||
class QRDQN(DQN):
|
||||
@ -168,11 +168,11 @@ class QRDQN(DQN):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: Union[np.ndarray, torch.Tensor],
|
||||
obs: Union[np.ndarray, torch.Tensor],
|
||||
state: Optional[Any] = None,
|
||||
info: Dict[str, Any] = {},
|
||||
) -> Tuple[torch.Tensor, Any]:
|
||||
r"""Mapping: x -> Z(x, \*)."""
|
||||
x, state = super().forward(x)
|
||||
x = x.view(-1, self.action_num, self.num_quantiles)
|
||||
return x, state
|
||||
obs, state = super().forward(obs)
|
||||
obs = obs.view(-1, self.action_num, self.num_quantiles)
|
||||
return obs, state
|
||||
|
@ -56,16 +56,16 @@ class Wrapper(gym.Wrapper):
|
||||
self.rm_done = rm_done
|
||||
|
||||
def step(self, action):
|
||||
r = 0.0
|
||||
rew_sum = 0.0
|
||||
for _ in range(self.action_repeat):
|
||||
obs, reward, done, info = self.env.step(action)
|
||||
obs, rew, done, info = self.env.step(action)
|
||||
# remove done reward penalty
|
||||
if not done or not self.rm_done:
|
||||
r = r + reward
|
||||
rew_sum = rew_sum + rew
|
||||
if done:
|
||||
break
|
||||
# scale reward
|
||||
return obs, self.reward_scale * r, done, info
|
||||
return obs, self.reward_scale * rew_sum, done, info
|
||||
|
||||
|
||||
def test_sac_bipedal(args=get_args()):
|
||||
|
@ -86,10 +86,10 @@ class MyTestEnv(gym.Env):
|
||||
|
||||
def _get_reward(self):
|
||||
"""Generate a non-scalar reward if ma_rew is True."""
|
||||
x = int(self.done)
|
||||
end_flag = int(self.done)
|
||||
if self.ma_rew > 0:
|
||||
return [x] * self.ma_rew
|
||||
return x
|
||||
return [end_flag] * self.ma_rew
|
||||
return end_flag
|
||||
|
||||
def _get_state(self):
|
||||
"""Generate state(observation) of MyTestEnv"""
|
||||
|
@ -32,10 +32,12 @@ def test_replaybuffer(size=10, bufsize=20):
|
||||
assert str(buf) == buf.__class__.__name__ + '()'
|
||||
obs = env.reset()
|
||||
action_list = [1] * 5 + [0] * 10 + [1] * 10
|
||||
for i, a in enumerate(action_list):
|
||||
obs_next, rew, done, info = env.step(a)
|
||||
for i, act in enumerate(action_list):
|
||||
obs_next, rew, done, info = env.step(act)
|
||||
buf.add(
|
||||
Batch(obs=obs, act=[a], rew=rew, done=done, obs_next=obs_next, info=info)
|
||||
Batch(
|
||||
obs=obs, act=[act], rew=rew, done=done, obs_next=obs_next, info=info
|
||||
)
|
||||
)
|
||||
obs = obs_next
|
||||
assert len(buf) == min(bufsize, i + 1)
|
||||
@ -220,11 +222,11 @@ def test_priortized_replaybuffer(size=32, bufsize=15):
|
||||
buf2 = PrioritizedVectorReplayBuffer(bufsize, buffer_num=3, alpha=0.5, beta=0.5)
|
||||
obs = env.reset()
|
||||
action_list = [1] * 5 + [0] * 10 + [1] * 10
|
||||
for i, a in enumerate(action_list):
|
||||
obs_next, rew, done, info = env.step(a)
|
||||
for i, act in enumerate(action_list):
|
||||
obs_next, rew, done, info = env.step(act)
|
||||
batch = Batch(
|
||||
obs=obs,
|
||||
act=a,
|
||||
act=act,
|
||||
rew=rew,
|
||||
done=done,
|
||||
obs_next=obs_next,
|
||||
|
@ -331,20 +331,20 @@ def test_collector_with_ma():
|
||||
policy = MyPolicy()
|
||||
c0 = Collector(policy, env, ReplayBuffer(size=100), Logger.single_preprocess_fn)
|
||||
# n_step=3 will collect a full episode
|
||||
r = c0.collect(n_step=3)['rews']
|
||||
assert len(r) == 0
|
||||
r = c0.collect(n_episode=2)['rews']
|
||||
assert r.shape == (2, 4) and np.all(r == 1)
|
||||
rew = c0.collect(n_step=3)['rews']
|
||||
assert len(rew) == 0
|
||||
rew = c0.collect(n_episode=2)['rews']
|
||||
assert rew.shape == (2, 4) and np.all(rew == 1)
|
||||
env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0, ma_rew=4) for i in [2, 3, 4, 5]]
|
||||
envs = DummyVectorEnv(env_fns)
|
||||
c1 = Collector(
|
||||
policy, envs, VectorReplayBuffer(total_size=100, buffer_num=4),
|
||||
Logger.single_preprocess_fn
|
||||
)
|
||||
r = c1.collect(n_step=12)['rews']
|
||||
assert r.shape == (2, 4) and np.all(r == 1), r
|
||||
r = c1.collect(n_episode=8)['rews']
|
||||
assert r.shape == (8, 4) and np.all(r == 1)
|
||||
rew = c1.collect(n_step=12)['rews']
|
||||
assert rew.shape == (2, 4) and np.all(rew == 1), rew
|
||||
rew = c1.collect(n_episode=8)['rews']
|
||||
assert rew.shape == (8, 4) and np.all(rew == 1)
|
||||
batch, _ = c1.buffer.sample(10)
|
||||
print(batch)
|
||||
c0.buffer.update(c1.buffer)
|
||||
@ -446,8 +446,8 @@ def test_collector_with_ma():
|
||||
policy, envs, VectorReplayBuffer(total_size=100, buffer_num=4, stack_num=4),
|
||||
Logger.single_preprocess_fn
|
||||
)
|
||||
r = c2.collect(n_episode=10)['rews']
|
||||
assert r.shape == (10, 4) and np.all(r == 1)
|
||||
rew = c2.collect(n_episode=10)['rews']
|
||||
assert rew.shape == (10, 4) and np.all(rew == 1)
|
||||
batch, _ = c2.buffer.sample(10)
|
||||
|
||||
|
||||
|
@ -58,22 +58,22 @@ def test_async_env(size=10000, num=8, sleep=0.1):
|
||||
# should be smaller
|
||||
action_list = [1] * num + [0] * (num * 2) + [1] * (num * 4)
|
||||
current_idx_start = 0
|
||||
action = action_list[:num]
|
||||
act = action_list[:num]
|
||||
env_ids = list(range(num))
|
||||
o = []
|
||||
spent_time = time.time()
|
||||
while current_idx_start < len(action_list):
|
||||
A, B, C, D = v.step(action=action, id=env_ids)
|
||||
A, B, C, D = v.step(action=act, id=env_ids)
|
||||
b = Batch({'obs': A, 'rew': B, 'done': C, 'info': D})
|
||||
env_ids = b.info.env_id
|
||||
o.append(b)
|
||||
current_idx_start += len(action)
|
||||
current_idx_start += len(act)
|
||||
# len of action may be smaller than len(A) in the end
|
||||
action = action_list[current_idx_start:current_idx_start + len(A)]
|
||||
act = action_list[current_idx_start:current_idx_start + len(A)]
|
||||
# truncate env_ids with the first terms
|
||||
# typically len(env_ids) == len(A) == len(action), except for the
|
||||
# last batch when actions are not enough
|
||||
env_ids = env_ids[:len(action)]
|
||||
env_ids = env_ids[:len(act)]
|
||||
spent_time = time.time() - spent_time
|
||||
Batch.cat(o)
|
||||
v.close()
|
||||
|
@ -142,11 +142,11 @@ def compute_nstep_return_base(nstep, gamma, buffer, indices):
|
||||
returns = np.zeros_like(indices, dtype=float)
|
||||
buf_len = len(buffer)
|
||||
for i in range(len(indices)):
|
||||
flag, r = False, 0.
|
||||
flag, rew = False, 0.
|
||||
real_step_n = nstep
|
||||
for n in range(nstep):
|
||||
idx = (indices[i] + n) % buf_len
|
||||
r += buffer.rew[idx] * gamma**n
|
||||
rew += buffer.rew[idx] * gamma**n
|
||||
if buffer.done[idx]:
|
||||
if not (
|
||||
hasattr(buffer, 'info') and buffer.info['TimeLimit.truncated'][idx]
|
||||
@ -156,8 +156,8 @@ def compute_nstep_return_base(nstep, gamma, buffer, indices):
|
||||
break
|
||||
if not flag:
|
||||
idx = (indices[i] + real_step_n - 1) % buf_len
|
||||
r += to_numpy(target_q_fn(buffer, idx)) * gamma**real_step_n
|
||||
returns[i] = r
|
||||
rew += to_numpy(target_q_fn(buffer, idx)) * gamma**real_step_n
|
||||
returns[i] = rew
|
||||
return returns
|
||||
|
||||
|
||||
|
@ -41,7 +41,7 @@ def gomoku(args=get_args()):
|
||||
return TicTacToeEnv(args.board_size, args.win_size)
|
||||
|
||||
test_envs = DummyVectorEnv([env_func for _ in range(args.test_num)])
|
||||
for r in range(args.self_play_round):
|
||||
for round in range(args.self_play_round):
|
||||
rews = []
|
||||
agent_learn.set_eps(0.0)
|
||||
# compute the reward over previous learner
|
||||
@ -66,12 +66,12 @@ def gomoku(args=get_args()):
|
||||
# previous learner can only be used for forward
|
||||
agent.forward = opponent.forward
|
||||
args.model_save_path = os.path.join(
|
||||
args.logdir, 'Gomoku', 'dqn', f'policy_round_{r}_epoch_{epoch}.pth'
|
||||
args.logdir, 'Gomoku', 'dqn', f'policy_round_{round}_epoch_{epoch}.pth'
|
||||
)
|
||||
result, agent_learn = train_agent(
|
||||
args, agent_learn=agent_learn, agent_opponent=agent, optim=optim
|
||||
)
|
||||
print(f'round_{r}_epoch_{epoch}')
|
||||
print(f'round_{round}_epoch_{epoch}')
|
||||
pprint.pprint(result)
|
||||
learnt_agent = deepcopy(agent_learn)
|
||||
learnt_agent.set_eps(0.0)
|
||||
|
@ -11,17 +11,18 @@ import torch
|
||||
IndexType = Union[slice, int, np.ndarray, List[int]]
|
||||
|
||||
|
||||
def _is_batch_set(data: Any) -> bool:
|
||||
def _is_batch_set(obj: Any) -> bool:
|
||||
# Batch set is a list/tuple of dict/Batch objects,
|
||||
# or 1-D np.ndarray with object type,
|
||||
# where each element is a dict/Batch object
|
||||
if isinstance(data, np.ndarray): # most often case
|
||||
# "for e in data" will just unpack the first dimension,
|
||||
# but data.tolist() will flatten ndarray of objects
|
||||
# so do not use data.tolist()
|
||||
return data.dtype == object and all(isinstance(e, (dict, Batch)) for e in data)
|
||||
elif isinstance(data, (list, tuple)):
|
||||
if len(data) > 0 and all(isinstance(e, (dict, Batch)) for e in data):
|
||||
if isinstance(obj, np.ndarray): # most often case
|
||||
# "for element in obj" will just unpack the first dimension,
|
||||
# but obj.tolist() will flatten ndarray of objects
|
||||
# so do not use obj.tolist()
|
||||
return obj.dtype == object and \
|
||||
all(isinstance(element, (dict, Batch)) for element in obj)
|
||||
elif isinstance(obj, (list, tuple)):
|
||||
if len(obj) > 0 and all(isinstance(element, (dict, Batch)) for element in obj):
|
||||
return True
|
||||
return False
|
||||
|
||||
@ -48,28 +49,29 @@ def _is_number(value: Any) -> bool:
|
||||
return isinstance(value, (Number, np.number, np.bool_))
|
||||
|
||||
|
||||
def _to_array_with_correct_type(v: Any) -> np.ndarray:
|
||||
if isinstance(v, np.ndarray) and issubclass(v.dtype.type, (np.bool_, np.number)):
|
||||
return v # most often case
|
||||
def _to_array_with_correct_type(obj: Any) -> np.ndarray:
|
||||
if isinstance(obj, np.ndarray) and \
|
||||
issubclass(obj.dtype.type, (np.bool_, np.number)):
|
||||
return obj # most often case
|
||||
# convert the value to np.ndarray
|
||||
# convert to object data type if neither bool nor number
|
||||
# convert to object obj type if neither bool nor number
|
||||
# raises an exception if array's elements are tensors themselves
|
||||
v = np.asanyarray(v)
|
||||
if not issubclass(v.dtype.type, (np.bool_, np.number)):
|
||||
v = v.astype(object)
|
||||
if v.dtype == object:
|
||||
# scalar ndarray with object data type is very annoying
|
||||
obj_array = np.asanyarray(obj)
|
||||
if not issubclass(obj_array.dtype.type, (np.bool_, np.number)):
|
||||
obj_array = obj_array.astype(object)
|
||||
if obj_array.dtype == object:
|
||||
# scalar ndarray with object obj type is very annoying
|
||||
# a=np.array([np.array({}, dtype=object), np.array({}, dtype=object)])
|
||||
# a is not array([{}, {}], dtype=object), and a[0]={} results in
|
||||
# something very strange:
|
||||
# array([{}, array({}, dtype=object)], dtype=object)
|
||||
if not v.shape:
|
||||
v = v.item(0)
|
||||
elif all(isinstance(e, np.ndarray) for e in v.reshape(-1)):
|
||||
return v # various length, np.array([[1], [2, 3], [4, 5, 6]])
|
||||
elif any(isinstance(e, torch.Tensor) for e in v.reshape(-1)):
|
||||
if not obj_array.shape:
|
||||
obj_array = obj_array.item(0)
|
||||
elif all(isinstance(arr, np.ndarray) for arr in obj_array.reshape(-1)):
|
||||
return obj_array # various length, np.array([[1], [2, 3], [4, 5, 6]])
|
||||
elif any(isinstance(arr, torch.Tensor) for arr in obj_array.reshape(-1)):
|
||||
raise ValueError("Numpy arrays of tensors are not supported yet.")
|
||||
return v
|
||||
return obj_array
|
||||
|
||||
|
||||
def _create_value(
|
||||
@ -113,44 +115,45 @@ def _create_value(
|
||||
|
||||
|
||||
def _assert_type_keys(keys: Iterable[str]) -> None:
|
||||
assert all(isinstance(e, str) for e in keys), \
|
||||
assert all(isinstance(key, str) for key in keys), \
|
||||
f"keys should all be string, but got {keys}"
|
||||
|
||||
|
||||
def _parse_value(v: Any) -> Optional[Union["Batch", np.ndarray, torch.Tensor]]:
|
||||
if isinstance(v, Batch): # most often case
|
||||
return v
|
||||
elif (isinstance(v, np.ndarray) and
|
||||
issubclass(v.dtype.type, (np.bool_, np.number))) or \
|
||||
isinstance(v, torch.Tensor) or v is None: # third often case
|
||||
return v
|
||||
elif _is_number(v): # second often case, but it is more time-consuming
|
||||
return np.asanyarray(v)
|
||||
elif isinstance(v, dict):
|
||||
return Batch(v)
|
||||
def _parse_value(obj: Any) -> Optional[Union["Batch", np.ndarray, torch.Tensor]]:
|
||||
if isinstance(obj, Batch): # most often case
|
||||
return obj
|
||||
elif (isinstance(obj, np.ndarray) and
|
||||
issubclass(obj.dtype.type, (np.bool_, np.number))) or \
|
||||
isinstance(obj, torch.Tensor) or obj is None: # third often case
|
||||
return obj
|
||||
elif _is_number(obj): # second often case, but it is more time-consuming
|
||||
return np.asanyarray(obj)
|
||||
elif isinstance(obj, dict):
|
||||
return Batch(obj)
|
||||
else:
|
||||
if not isinstance(v, np.ndarray) and isinstance(v, Collection) and \
|
||||
len(v) > 0 and all(isinstance(e, torch.Tensor) for e in v):
|
||||
if not isinstance(obj, np.ndarray) and \
|
||||
isinstance(obj, Collection) and len(obj) > 0 and \
|
||||
all(isinstance(element, torch.Tensor) for element in obj):
|
||||
try:
|
||||
return torch.stack(v) # type: ignore
|
||||
except RuntimeError as e:
|
||||
return torch.stack(obj) # type: ignore
|
||||
except RuntimeError as exception:
|
||||
raise TypeError(
|
||||
"Batch does not support non-stackable iterable"
|
||||
" of torch.Tensor as unique value yet."
|
||||
) from e
|
||||
if _is_batch_set(v):
|
||||
v = Batch(v) # list of dict / Batch
|
||||
) from exception
|
||||
if _is_batch_set(obj):
|
||||
obj = Batch(obj) # list of dict / Batch
|
||||
else:
|
||||
# None, scalar, normal data list (main case)
|
||||
# None, scalar, normal obj list (main case)
|
||||
# or an actual list of objects
|
||||
try:
|
||||
v = _to_array_with_correct_type(v)
|
||||
except ValueError as e:
|
||||
obj = _to_array_with_correct_type(obj)
|
||||
except ValueError as exception:
|
||||
raise TypeError(
|
||||
"Batch does not support heterogeneous list/"
|
||||
"tuple of tensors as unique value yet."
|
||||
) from e
|
||||
return v
|
||||
) from exception
|
||||
return obj
|
||||
|
||||
|
||||
def _alloc_by_keys_diff(
|
||||
@ -189,8 +192,8 @@ class Batch:
|
||||
if batch_dict is not None:
|
||||
if isinstance(batch_dict, (dict, Batch)):
|
||||
_assert_type_keys(batch_dict.keys())
|
||||
for k, v in batch_dict.items():
|
||||
self.__dict__[k] = _parse_value(v)
|
||||
for batch_key, obj in batch_dict.items():
|
||||
self.__dict__[batch_key] = _parse_value(obj)
|
||||
elif _is_batch_set(batch_dict):
|
||||
self.stack_(batch_dict) # type: ignore
|
||||
if len(kwargs) > 0:
|
||||
@ -214,10 +217,10 @@ class Batch:
|
||||
Only the actual data are serialized for both efficiency and simplicity.
|
||||
"""
|
||||
state = {}
|
||||
for k, v in self.items():
|
||||
if isinstance(v, Batch):
|
||||
v = v.__getstate__()
|
||||
state[k] = v
|
||||
for batch_key, obj in self.items():
|
||||
if isinstance(obj, Batch):
|
||||
obj = obj.__getstate__()
|
||||
state[batch_key] = obj
|
||||
return state
|
||||
|
||||
def __setstate__(self, state: Dict[str, Any]) -> None:
|
||||
@ -234,13 +237,13 @@ class Batch:
|
||||
return self.__dict__[index]
|
||||
batch_items = self.items()
|
||||
if len(batch_items) > 0:
|
||||
b = Batch()
|
||||
for k, v in batch_items:
|
||||
if isinstance(v, Batch) and v.is_empty():
|
||||
b.__dict__[k] = Batch()
|
||||
new_batch = Batch()
|
||||
for batch_key, obj in batch_items:
|
||||
if isinstance(obj, Batch) and obj.is_empty():
|
||||
new_batch.__dict__[batch_key] = Batch()
|
||||
else:
|
||||
b.__dict__[k] = v[index]
|
||||
return b
|
||||
new_batch.__dict__[batch_key] = obj[index]
|
||||
return new_batch
|
||||
else:
|
||||
raise IndexError("Cannot access item from empty Batch object.")
|
||||
|
||||
@ -273,20 +276,20 @@ class Batch:
|
||||
def __iadd__(self, other: Union["Batch", Number, np.number]) -> "Batch":
|
||||
"""Algebraic addition with another Batch instance in-place."""
|
||||
if isinstance(other, Batch):
|
||||
for (k, r), v in zip(
|
||||
for (batch_key, obj), value in zip(
|
||||
self.__dict__.items(), other.__dict__.values()
|
||||
): # TODO are keys consistent?
|
||||
if isinstance(r, Batch) and r.is_empty():
|
||||
if isinstance(obj, Batch) and obj.is_empty():
|
||||
continue
|
||||
else:
|
||||
self.__dict__[k] += v
|
||||
self.__dict__[batch_key] += value
|
||||
return self
|
||||
elif _is_number(other):
|
||||
for k, r in self.items():
|
||||
if isinstance(r, Batch) and r.is_empty():
|
||||
for batch_key, obj in self.items():
|
||||
if isinstance(obj, Batch) and obj.is_empty():
|
||||
continue
|
||||
else:
|
||||
self.__dict__[k] += other
|
||||
self.__dict__[batch_key] += other
|
||||
return self
|
||||
else:
|
||||
raise TypeError("Only addition of Batch or number is supported.")
|
||||
@ -295,54 +298,54 @@ class Batch:
|
||||
"""Algebraic addition with another Batch instance out-of-place."""
|
||||
return deepcopy(self).__iadd__(other)
|
||||
|
||||
def __imul__(self, val: Union[Number, np.number]) -> "Batch":
|
||||
def __imul__(self, value: Union[Number, np.number]) -> "Batch":
|
||||
"""Algebraic multiplication with a scalar value in-place."""
|
||||
assert _is_number(val), "Only multiplication by a number is supported."
|
||||
for k, r in self.__dict__.items():
|
||||
if isinstance(r, Batch) and r.is_empty():
|
||||
assert _is_number(value), "Only multiplication by a number is supported."
|
||||
for batch_key, obj in self.__dict__.items():
|
||||
if isinstance(obj, Batch) and obj.is_empty():
|
||||
continue
|
||||
self.__dict__[k] *= val
|
||||
self.__dict__[batch_key] *= value
|
||||
return self
|
||||
|
||||
def __mul__(self, val: Union[Number, np.number]) -> "Batch":
|
||||
def __mul__(self, value: Union[Number, np.number]) -> "Batch":
|
||||
"""Algebraic multiplication with a scalar value out-of-place."""
|
||||
return deepcopy(self).__imul__(val)
|
||||
return deepcopy(self).__imul__(value)
|
||||
|
||||
def __itruediv__(self, val: Union[Number, np.number]) -> "Batch":
|
||||
def __itruediv__(self, value: Union[Number, np.number]) -> "Batch":
|
||||
"""Algebraic division with a scalar value in-place."""
|
||||
assert _is_number(val), "Only division by a number is supported."
|
||||
for k, r in self.__dict__.items():
|
||||
if isinstance(r, Batch) and r.is_empty():
|
||||
assert _is_number(value), "Only division by a number is supported."
|
||||
for batch_key, obj in self.__dict__.items():
|
||||
if isinstance(obj, Batch) and obj.is_empty():
|
||||
continue
|
||||
self.__dict__[k] /= val
|
||||
self.__dict__[batch_key] /= value
|
||||
return self
|
||||
|
||||
def __truediv__(self, val: Union[Number, np.number]) -> "Batch":
|
||||
def __truediv__(self, value: Union[Number, np.number]) -> "Batch":
|
||||
"""Algebraic division with a scalar value out-of-place."""
|
||||
return deepcopy(self).__itruediv__(val)
|
||||
return deepcopy(self).__itruediv__(value)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""Return str(self)."""
|
||||
s = self.__class__.__name__ + "(\n"
|
||||
self_str = self.__class__.__name__ + "(\n"
|
||||
flag = False
|
||||
for k, v in self.__dict__.items():
|
||||
rpl = "\n" + " " * (6 + len(k))
|
||||
obj = pprint.pformat(v).replace("\n", rpl)
|
||||
s += f" {k}: {obj},\n"
|
||||
for batch_key, obj in self.__dict__.items():
|
||||
rpl = "\n" + " " * (6 + len(batch_key))
|
||||
obj_name = pprint.pformat(obj).replace("\n", rpl)
|
||||
self_str += f" {batch_key}: {obj_name},\n"
|
||||
flag = True
|
||||
if flag:
|
||||
s += ")"
|
||||
self_str += ")"
|
||||
else:
|
||||
s = self.__class__.__name__ + "()"
|
||||
return s
|
||||
self_str = self.__class__.__name__ + "()"
|
||||
return self_str
|
||||
|
||||
def to_numpy(self) -> None:
|
||||
"""Change all torch.Tensor to numpy.ndarray in-place."""
|
||||
for k, v in self.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
self.__dict__[k] = v.detach().cpu().numpy()
|
||||
elif isinstance(v, Batch):
|
||||
v.to_numpy()
|
||||
for batch_key, obj in self.items():
|
||||
if isinstance(obj, torch.Tensor):
|
||||
self.__dict__[batch_key] = obj.detach().cpu().numpy()
|
||||
elif isinstance(obj, Batch):
|
||||
obj.to_numpy()
|
||||
|
||||
def to_torch(
|
||||
self,
|
||||
@ -353,24 +356,24 @@ class Batch:
|
||||
if not isinstance(device, torch.device):
|
||||
device = torch.device(device)
|
||||
|
||||
for k, v in self.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
if dtype is not None and v.dtype != dtype or \
|
||||
v.device.type != device.type or \
|
||||
device.index != v.device.index:
|
||||
for batch_key, obj in self.items():
|
||||
if isinstance(obj, torch.Tensor):
|
||||
if dtype is not None and obj.dtype != dtype or \
|
||||
obj.device.type != device.type or \
|
||||
device.index != obj.device.index:
|
||||
if dtype is not None:
|
||||
v = v.type(dtype)
|
||||
self.__dict__[k] = v.to(device)
|
||||
elif isinstance(v, Batch):
|
||||
v.to_torch(dtype, device)
|
||||
obj = obj.type(dtype)
|
||||
self.__dict__[batch_key] = obj.to(device)
|
||||
elif isinstance(obj, Batch):
|
||||
obj.to_torch(dtype, device)
|
||||
else:
|
||||
# ndarray or scalar
|
||||
if not isinstance(v, np.ndarray):
|
||||
v = np.asanyarray(v)
|
||||
v = torch.from_numpy(v).to(device)
|
||||
if not isinstance(obj, np.ndarray):
|
||||
obj = np.asanyarray(obj)
|
||||
obj = torch.from_numpy(obj).to(device)
|
||||
if dtype is not None:
|
||||
v = v.type(dtype)
|
||||
self.__dict__[k] = v
|
||||
obj = obj.type(dtype)
|
||||
self.__dict__[batch_key] = obj
|
||||
|
||||
def __cat(self, batches: Sequence[Union[dict, "Batch"]], lens: List[int]) -> None:
|
||||
"""Private method for Batch.cat_.
|
||||
@ -395,50 +398,51 @@ class Batch:
|
||||
# partial keys will be padded by zeros
|
||||
# with the shape of [len, rest_shape]
|
||||
sum_lens = [0]
|
||||
for x in lens:
|
||||
sum_lens.append(sum_lens[-1] + x)
|
||||
for len_ in lens:
|
||||
sum_lens.append(sum_lens[-1] + len_)
|
||||
# collect non-empty keys
|
||||
keys_map = [
|
||||
set(
|
||||
k for k, v in batch.items()
|
||||
if not (isinstance(v, Batch) and v.is_empty())
|
||||
batch_key for batch_key, obj in batch.items()
|
||||
if not (isinstance(obj, Batch) and obj.is_empty())
|
||||
) for batch in batches
|
||||
]
|
||||
keys_shared = set.intersection(*keys_map)
|
||||
values_shared = [[e[k] for e in batches] for k in keys_shared]
|
||||
for k, v in zip(keys_shared, values_shared):
|
||||
if all(isinstance(e, (dict, Batch)) for e in v):
|
||||
values_shared = [[batch[key] for batch in batches] for key in keys_shared]
|
||||
for key, shared_value in zip(keys_shared, values_shared):
|
||||
if all(isinstance(element, (dict, Batch)) for element in shared_value):
|
||||
batch_holder = Batch()
|
||||
batch_holder.__cat(v, lens=lens)
|
||||
self.__dict__[k] = batch_holder
|
||||
elif all(isinstance(e, torch.Tensor) for e in v):
|
||||
self.__dict__[k] = torch.cat(v)
|
||||
batch_holder.__cat(shared_value, lens=lens)
|
||||
self.__dict__[key] = batch_holder
|
||||
elif all(isinstance(element, torch.Tensor) for element in shared_value):
|
||||
self.__dict__[key] = torch.cat(shared_value)
|
||||
else:
|
||||
# cat Batch(a=np.zeros((3, 4))) and Batch(a=Batch(b=Batch()))
|
||||
# will fail here
|
||||
v = np.concatenate(v)
|
||||
self.__dict__[k] = _to_array_with_correct_type(v)
|
||||
keys_total = set.union(*[set(b.keys()) for b in batches])
|
||||
shared_value = np.concatenate(shared_value)
|
||||
self.__dict__[key] = _to_array_with_correct_type(shared_value)
|
||||
keys_total = set.union(*[set(batch.keys()) for batch in batches])
|
||||
keys_reserve_or_partial = set.difference(keys_total, keys_shared)
|
||||
# keys that are reserved in all batches
|
||||
keys_reserve = set.difference(keys_total, set.union(*keys_map))
|
||||
# keys that occur only in some batches, but not all
|
||||
keys_partial = keys_reserve_or_partial.difference(keys_reserve)
|
||||
for k in keys_reserve:
|
||||
for key in keys_reserve:
|
||||
# reserved keys
|
||||
self.__dict__[k] = Batch()
|
||||
for k in keys_partial:
|
||||
for i, e in enumerate(batches):
|
||||
if k not in e.__dict__:
|
||||
self.__dict__[key] = Batch()
|
||||
for key in keys_partial:
|
||||
for i, batch in enumerate(batches):
|
||||
if key not in batch.__dict__:
|
||||
continue
|
||||
val = e.get(k)
|
||||
if isinstance(val, Batch) and val.is_empty():
|
||||
value = batch.get(key)
|
||||
if isinstance(value, Batch) and value.is_empty():
|
||||
continue
|
||||
try:
|
||||
self.__dict__[k][sum_lens[i]:sum_lens[i + 1]] = val
|
||||
self.__dict__[key][sum_lens[i]:sum_lens[i + 1]] = value
|
||||
except KeyError:
|
||||
self.__dict__[k] = _create_value(val, sum_lens[-1], stack=False)
|
||||
self.__dict__[k][sum_lens[i]:sum_lens[i + 1]] = val
|
||||
self.__dict__[key] = \
|
||||
_create_value(value, sum_lens[-1], stack=False)
|
||||
self.__dict__[key][sum_lens[i]:sum_lens[i + 1]] = value
|
||||
|
||||
def cat_(self, batches: Union["Batch", Sequence[Union[dict, "Batch"]]]) -> None:
|
||||
"""Concatenate a list of (or one) Batch objects into current batch."""
|
||||
@ -446,16 +450,16 @@ class Batch:
|
||||
batches = [batches]
|
||||
# check input format
|
||||
batch_list = []
|
||||
for b in batches:
|
||||
if isinstance(b, dict):
|
||||
if len(b) > 0:
|
||||
batch_list.append(Batch(b))
|
||||
elif isinstance(b, Batch):
|
||||
for batch in batches:
|
||||
if isinstance(batch, dict):
|
||||
if len(batch) > 0:
|
||||
batch_list.append(Batch(batch))
|
||||
elif isinstance(batch, Batch):
|
||||
# x.is_empty() means that x is Batch() and should be ignored
|
||||
if not b.is_empty():
|
||||
batch_list.append(b)
|
||||
if not batch.is_empty():
|
||||
batch_list.append(batch)
|
||||
else:
|
||||
raise ValueError(f"Cannot concatenate {type(b)} in Batch.cat_")
|
||||
raise ValueError(f"Cannot concatenate {type(batch)} in Batch.cat_")
|
||||
if len(batch_list) == 0:
|
||||
return
|
||||
batches = batch_list
|
||||
@ -463,13 +467,15 @@ class Batch:
|
||||
# x.is_empty(recurse=True) here means x is a nested empty batch
|
||||
# like Batch(a=Batch), and we have to treat it as length zero and
|
||||
# keep it.
|
||||
lens = [0 if x.is_empty(recurse=True) else len(x) for x in batches]
|
||||
except TypeError as e:
|
||||
lens = [
|
||||
0 if batch.is_empty(recurse=True) else len(batch) for batch in batches
|
||||
]
|
||||
except TypeError as exception:
|
||||
raise ValueError(
|
||||
"Batch.cat_ meets an exception. Maybe because there is any "
|
||||
f"scalar in {batches} but Batch.cat_ does not support the "
|
||||
"concatenation of scalar."
|
||||
) from e
|
||||
) from exception
|
||||
if not self.is_empty():
|
||||
batches = [self] + list(batches)
|
||||
lens = [0 if self.is_empty(recurse=True) else len(self)] + lens
|
||||
@ -501,16 +507,16 @@ class Batch:
|
||||
"""Stack a list of Batch object into current batch."""
|
||||
# check input format
|
||||
batch_list = []
|
||||
for b in batches:
|
||||
if isinstance(b, dict):
|
||||
if len(b) > 0:
|
||||
batch_list.append(Batch(b))
|
||||
elif isinstance(b, Batch):
|
||||
for batch in batches:
|
||||
if isinstance(batch, dict):
|
||||
if len(batch) > 0:
|
||||
batch_list.append(Batch(batch))
|
||||
elif isinstance(batch, Batch):
|
||||
# x.is_empty() means that x is Batch() and should be ignored
|
||||
if not b.is_empty():
|
||||
batch_list.append(b)
|
||||
if not batch.is_empty():
|
||||
batch_list.append(batch)
|
||||
else:
|
||||
raise ValueError(f"Cannot concatenate {type(b)} in Batch.stack_")
|
||||
raise ValueError(f"Cannot concatenate {type(batch)} in Batch.stack_")
|
||||
if len(batch_list) == 0:
|
||||
return
|
||||
batches = batch_list
|
||||
@ -519,28 +525,31 @@ class Batch:
|
||||
# collect non-empty keys
|
||||
keys_map = [
|
||||
set(
|
||||
k for k, v in batch.items()
|
||||
if not (isinstance(v, Batch) and v.is_empty())
|
||||
batch_key for batch_key, obj in batch.items()
|
||||
if not (isinstance(obj, Batch) and obj.is_empty())
|
||||
) for batch in batches
|
||||
]
|
||||
keys_shared = set.intersection(*keys_map)
|
||||
values_shared = [[e[k] for e in batches] for k in keys_shared]
|
||||
for k, v in zip(keys_shared, values_shared):
|
||||
if all(isinstance(e, torch.Tensor) for e in v): # second often
|
||||
self.__dict__[k] = torch.stack(v, axis)
|
||||
elif all(isinstance(e, (Batch, dict)) for e in v): # third often
|
||||
self.__dict__[k] = Batch.stack(v, axis)
|
||||
values_shared = [[batch[key] for batch in batches] for key in keys_shared]
|
||||
for shared_key, value in zip(keys_shared, values_shared):
|
||||
# second often
|
||||
if all(isinstance(element, torch.Tensor) for element in value):
|
||||
self.__dict__[shared_key] = torch.stack(value, axis)
|
||||
# third often
|
||||
elif all(isinstance(element, (Batch, dict)) for element in value):
|
||||
self.__dict__[shared_key] = Batch.stack(value, axis)
|
||||
else: # most often case is np.ndarray
|
||||
try:
|
||||
self.__dict__[k] = _to_array_with_correct_type(np.stack(v, axis))
|
||||
self.__dict__[shared_key] = \
|
||||
_to_array_with_correct_type(np.stack(value, axis))
|
||||
except ValueError:
|
||||
warnings.warn(
|
||||
"You are using tensors with different shape,"
|
||||
" fallback to dtype=object by default."
|
||||
)
|
||||
self.__dict__[k] = np.array(v, dtype=object)
|
||||
self.__dict__[shared_key] = np.array(value, dtype=object)
|
||||
# all the keys
|
||||
keys_total = set.union(*[set(b.keys()) for b in batches])
|
||||
keys_total = set.union(*[set(batch.keys()) for batch in batches])
|
||||
# keys that are reserved in all batches
|
||||
keys_reserve = set.difference(keys_total, set.union(*keys_map))
|
||||
# keys that are either partial or reserved
|
||||
@ -552,21 +561,21 @@ class Batch:
|
||||
f"Stack of Batch with non-shared keys {keys_partial} is only "
|
||||
f"supported with axis=0, but got axis={axis}!"
|
||||
)
|
||||
for k in keys_reserve:
|
||||
for key in keys_reserve:
|
||||
# reserved keys
|
||||
self.__dict__[k] = Batch()
|
||||
for k in keys_partial:
|
||||
for i, e in enumerate(batches):
|
||||
if k not in e.__dict__:
|
||||
continue
|
||||
val = e.get(k)
|
||||
if isinstance(val, Batch) and val.is_empty():
|
||||
self.__dict__[key] = Batch()
|
||||
for key in keys_partial:
|
||||
for i, batch in enumerate(batches):
|
||||
if key not in batch.__dict__:
|
||||
continue
|
||||
value = batch.get(key)
|
||||
if isinstance(value, Batch) and value.is_empty(): # type: ignore
|
||||
continue # type: ignore
|
||||
try:
|
||||
self.__dict__[k][i] = val
|
||||
self.__dict__[key][i] = value
|
||||
except KeyError:
|
||||
self.__dict__[k] = _create_value(val, len(batches))
|
||||
self.__dict__[k][i] = val
|
||||
self.__dict__[key] = _create_value(value, len(batches))
|
||||
self.__dict__[key][i] = value
|
||||
|
||||
@staticmethod
|
||||
def stack(batches: Sequence[Union[dict, "Batch"]], axis: int = 0) -> "Batch":
|
||||
@ -620,27 +629,27 @@ class Batch:
|
||||
),
|
||||
)
|
||||
"""
|
||||
for k, v in self.items():
|
||||
if isinstance(v, torch.Tensor): # most often case
|
||||
self.__dict__[k][index] = 0
|
||||
elif v is None:
|
||||
for batch_key, obj in self.items():
|
||||
if isinstance(obj, torch.Tensor): # most often case
|
||||
self.__dict__[batch_key][index] = 0
|
||||
elif obj is None:
|
||||
continue
|
||||
elif isinstance(v, np.ndarray):
|
||||
if v.dtype == object:
|
||||
self.__dict__[k][index] = None
|
||||
elif isinstance(obj, np.ndarray):
|
||||
if obj.dtype == object:
|
||||
self.__dict__[batch_key][index] = None
|
||||
else:
|
||||
self.__dict__[k][index] = 0
|
||||
elif isinstance(v, Batch):
|
||||
self.__dict__[k].empty_(index=index)
|
||||
self.__dict__[batch_key][index] = 0
|
||||
elif isinstance(obj, Batch):
|
||||
self.__dict__[batch_key].empty_(index=index)
|
||||
else: # scalar value
|
||||
warnings.warn(
|
||||
"You are calling Batch.empty on a NumPy scalar, "
|
||||
"which may cause undefined behaviors."
|
||||
)
|
||||
if _is_number(v):
|
||||
self.__dict__[k] = v.__class__(0)
|
||||
if _is_number(obj):
|
||||
self.__dict__[batch_key] = obj.__class__(0)
|
||||
else:
|
||||
self.__dict__[k] = None
|
||||
self.__dict__[batch_key] = None
|
||||
return self
|
||||
|
||||
@staticmethod
|
||||
@ -658,26 +667,26 @@ class Batch:
|
||||
if batch is None:
|
||||
self.update(kwargs)
|
||||
return
|
||||
for k, v in batch.items():
|
||||
self.__dict__[k] = _parse_value(v)
|
||||
for batch_key, obj in batch.items():
|
||||
self.__dict__[batch_key] = _parse_value(obj)
|
||||
if kwargs:
|
||||
self.update(kwargs)
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Return len(self)."""
|
||||
r = []
|
||||
for v in self.__dict__.values():
|
||||
if isinstance(v, Batch) and v.is_empty(recurse=True):
|
||||
lens = []
|
||||
for obj in self.__dict__.values():
|
||||
if isinstance(obj, Batch) and obj.is_empty(recurse=True):
|
||||
continue
|
||||
elif hasattr(v, "__len__") and (isinstance(v, Batch) or v.ndim > 0):
|
||||
r.append(len(v))
|
||||
elif hasattr(obj, "__len__") and (isinstance(obj, Batch) or obj.ndim > 0):
|
||||
lens.append(len(obj))
|
||||
else:
|
||||
raise TypeError(f"Object {v} in {self} has no len()")
|
||||
if len(r) == 0:
|
||||
raise TypeError(f"Object {obj} in {self} has no len()")
|
||||
if len(lens) == 0:
|
||||
# empty batch has the shape of any, like the tensorflow '?' shape.
|
||||
# So it has no length.
|
||||
raise TypeError(f"Object {self} has no len()")
|
||||
return min(r)
|
||||
return min(lens)
|
||||
|
||||
def is_empty(self, recurse: bool = False) -> bool:
|
||||
"""Test if a Batch is empty.
|
||||
@ -710,8 +719,8 @@ class Batch:
|
||||
if not recurse:
|
||||
return False
|
||||
return all(
|
||||
False if not isinstance(x, Batch) else x.is_empty(recurse=True)
|
||||
for x in self.values()
|
||||
False if not isinstance(obj, Batch) else obj.is_empty(recurse=True)
|
||||
for obj in self.values()
|
||||
)
|
||||
|
||||
@property
|
||||
@ -721,9 +730,9 @@ class Batch:
|
||||
return []
|
||||
else:
|
||||
data_shape = []
|
||||
for v in self.__dict__.values():
|
||||
for obj in self.__dict__.values():
|
||||
try:
|
||||
data_shape.append(list(v.shape))
|
||||
data_shape.append(list(obj.shape))
|
||||
except AttributeError:
|
||||
data_shape.append([])
|
||||
return list(map(min, zip(*data_shape))) if len(data_shape) > 1 \
|
||||
|
@ -69,8 +69,8 @@ class ReplayBuffer:
|
||||
"""Return self.key."""
|
||||
try:
|
||||
return self._meta[key]
|
||||
except KeyError as e:
|
||||
raise AttributeError from e
|
||||
except KeyError as exception:
|
||||
raise AttributeError from exception
|
||||
|
||||
def __setstate__(self, state: Dict[str, Any]) -> None:
|
||||
"""Unpickling interface.
|
||||
@ -198,10 +198,10 @@ class ReplayBuffer:
|
||||
episode_reward is 0.
|
||||
"""
|
||||
# preprocess batch
|
||||
b = Batch()
|
||||
new_batch = Batch()
|
||||
for key in set(self._reserved_keys).intersection(batch.keys()):
|
||||
b.__dict__[key] = batch[key]
|
||||
batch = b
|
||||
new_batch.__dict__[key] = batch[key]
|
||||
batch = new_batch
|
||||
assert set(["obs", "act", "rew", "done"]).issubset(batch.keys())
|
||||
stacked_batch = buffer_ids is not None
|
||||
if stacked_batch:
|
||||
@ -315,9 +315,9 @@ class ReplayBuffer:
|
||||
return Batch.stack(stack, axis=indices.ndim)
|
||||
else:
|
||||
return np.stack(stack, axis=indices.ndim)
|
||||
except IndexError as e:
|
||||
except IndexError as exception:
|
||||
if not (isinstance(val, Batch) and val.is_empty()):
|
||||
raise e # val != Batch()
|
||||
raise exception # val != Batch()
|
||||
return Batch()
|
||||
|
||||
def __getitem__(self, index: Union[slice, int, List[int], np.ndarray]) -> Batch:
|
||||
|
@ -114,10 +114,10 @@ class ReplayBufferManager(ReplayBuffer):
|
||||
episode_reward is 0.
|
||||
"""
|
||||
# preprocess batch
|
||||
b = Batch()
|
||||
new_batch = Batch()
|
||||
for key in set(self._reserved_keys).intersection(batch.keys()):
|
||||
b.__dict__[key] = batch[key]
|
||||
batch = b
|
||||
new_batch.__dict__[key] = batch[key]
|
||||
batch = new_batch
|
||||
assert set(["obs", "act", "rew", "done"]).issubset(batch.keys())
|
||||
if self._save_only_last_obs:
|
||||
batch.obs = batch.obs[:, -1]
|
||||
|
@ -111,11 +111,11 @@ def to_hdf5(x: Hdf5ConvertibleType, y: h5py.Group) -> None:
|
||||
# and possibly in other cases like structured arrays.
|
||||
try:
|
||||
to_hdf5_via_pickle(v, y, k)
|
||||
except Exception as e:
|
||||
except Exception as exception:
|
||||
raise RuntimeError(
|
||||
f"Attempted to pickle {v.__class__.__name__} due to "
|
||||
"data type not supported by HDF5 and failed."
|
||||
) from e
|
||||
) from exception
|
||||
y[k].attrs["__data_type__"] = "pickled_ndarray"
|
||||
elif isinstance(v, (int, float)):
|
||||
# ints and floats are stored as attributes of groups
|
||||
@ -123,11 +123,11 @@ def to_hdf5(x: Hdf5ConvertibleType, y: h5py.Group) -> None:
|
||||
else: # resort to pickle for any other type of object
|
||||
try:
|
||||
to_hdf5_via_pickle(v, y, k)
|
||||
except Exception as e:
|
||||
except Exception as exception:
|
||||
raise NotImplementedError(
|
||||
f"No conversion to HDF5 for object of type '{type(v)}' "
|
||||
"implemented and fallback to pickle failed."
|
||||
) from e
|
||||
) from exception
|
||||
y[k].attrs["__data_type__"] = v.__class__.__name__
|
||||
|
||||
|
||||
|
4
tianshou/env/maenv.py
vendored
4
tianshou/env/maenv.py
vendored
@ -15,8 +15,8 @@ class MultiAgentEnv(ABC, gym.Env):
|
||||
env = MultiAgentEnv(...)
|
||||
# obs is a dict containing obs, agent_id, and mask
|
||||
obs = env.reset()
|
||||
action = policy(obs)
|
||||
obs, rew, done, info = env.step(action)
|
||||
act = policy(obs)
|
||||
obs, rew, done, info = env.step(act)
|
||||
env.close()
|
||||
|
||||
The available action's mask is set to 1, otherwise it is set to 0. Further
|
||||
|
4
tianshou/env/venvs.py
vendored
4
tianshou/env/venvs.py
vendored
@ -412,10 +412,10 @@ class RayVectorEnv(BaseVectorEnv):
|
||||
def __init__(self, env_fns: List[Callable[[], gym.Env]], **kwargs: Any) -> None:
|
||||
try:
|
||||
import ray
|
||||
except ImportError as e:
|
||||
except ImportError as exception:
|
||||
raise ImportError(
|
||||
"Please install ray to support RayVectorEnv: pip install ray"
|
||||
) from e
|
||||
) from exception
|
||||
if not ray.is_initialized():
|
||||
ray.init()
|
||||
super().__init__(env_fns, RayEnvWorker, **kwargs)
|
||||
|
@ -98,6 +98,12 @@ class BasePolicy(ABC, nn.Module):
|
||||
"""
|
||||
return act
|
||||
|
||||
def soft_update(self, tgt: nn.Module, src: nn.Module, tau: float) -> None:
|
||||
"""Softly update the parameters of target module towards the parameters \
|
||||
of source module."""
|
||||
for tgt_param, src_param in zip(tgt.parameters(), src.parameters()):
|
||||
tgt_param.data.copy_(tau * src_param.data + (1 - tau) * tgt_param.data)
|
||||
|
||||
@abstractmethod
|
||||
def forward(
|
||||
self,
|
||||
@ -387,10 +393,10 @@ def _gae_return(
|
||||
) -> np.ndarray:
|
||||
returns = np.zeros(rew.shape)
|
||||
delta = rew + v_s_ * gamma - v_s
|
||||
m = (1.0 - end_flag) * (gamma * gae_lambda)
|
||||
discount = (1.0 - end_flag) * (gamma * gae_lambda)
|
||||
gae = 0.0
|
||||
for i in range(len(rew) - 1, -1, -1):
|
||||
gae = delta[i] + m[i] * gae
|
||||
gae = delta[i] + discount[i] * gae
|
||||
returns[i] = gae
|
||||
return returns
|
||||
|
||||
|
@ -40,23 +40,23 @@ class ImitationPolicy(BasePolicy):
|
||||
state: Optional[Union[dict, Batch, np.ndarray]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Batch:
|
||||
logits, h = self.model(batch.obs, state=state, info=batch.info)
|
||||
logits, hidden = self.model(batch.obs, state=state, info=batch.info)
|
||||
if self.action_type == "discrete":
|
||||
a = logits.max(dim=1)[1]
|
||||
act = logits.max(dim=1)[1]
|
||||
else:
|
||||
a = logits
|
||||
return Batch(logits=logits, act=a, state=h)
|
||||
act = logits
|
||||
return Batch(logits=logits, act=act, state=hidden)
|
||||
|
||||
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
|
||||
self.optim.zero_grad()
|
||||
if self.action_type == "continuous": # regression
|
||||
a = self(batch).act
|
||||
a_ = to_torch(batch.act, dtype=torch.float32, device=a.device)
|
||||
loss = F.mse_loss(a, a_) # type: ignore
|
||||
act = self(batch).act
|
||||
act_target = to_torch(batch.act, dtype=torch.float32, device=act.device)
|
||||
loss = F.mse_loss(act, act_target) # type: ignore
|
||||
elif self.action_type == "discrete": # classification
|
||||
a = F.log_softmax(self(batch).logits, dim=-1)
|
||||
a_ = to_torch(batch.act, dtype=torch.long, device=a.device)
|
||||
loss = F.nll_loss(a, a_) # type: ignore
|
||||
act = F.log_softmax(self(batch).logits, dim=-1)
|
||||
act_target = to_torch(batch.act, dtype=torch.long, device=act.device)
|
||||
loss = F.nll_loss(act, act_target) # type: ignore
|
||||
loss.backward()
|
||||
self.optim.step()
|
||||
return {"loss": loss.item()}
|
||||
|
@ -105,32 +105,27 @@ class BCQPolicy(BasePolicy):
|
||||
obs_group: torch.Tensor = to_torch( # type: ignore
|
||||
batch.obs, device=self.device
|
||||
)
|
||||
act = []
|
||||
act_group = []
|
||||
for obs in obs_group:
|
||||
# now obs is (state_dim)
|
||||
obs = (obs.reshape(1, -1)).repeat(self.forward_sampled_times, 1)
|
||||
# now obs is (forward_sampled_times, state_dim)
|
||||
|
||||
# decode(obs) generates action and actor perturbs it
|
||||
action = self.actor(obs, self.vae.decode(obs))
|
||||
act = self.actor(obs, self.vae.decode(obs))
|
||||
# now action is (forward_sampled_times, action_dim)
|
||||
q1 = self.critic1(obs, action)
|
||||
q1 = self.critic1(obs, act)
|
||||
# q1 is (forward_sampled_times, 1)
|
||||
ind = q1.argmax(0)
|
||||
act.append(action[ind].cpu().data.numpy().flatten())
|
||||
act = np.array(act)
|
||||
return Batch(act=act)
|
||||
max_indice = q1.argmax(0)
|
||||
act_group.append(act[max_indice].cpu().data.numpy().flatten())
|
||||
act_group = np.array(act_group)
|
||||
return Batch(act=act_group)
|
||||
|
||||
def sync_weight(self) -> None:
|
||||
"""Soft-update the weight for the target network."""
|
||||
for net, net_target in [
|
||||
[self.critic1, self.critic1_target], [self.critic2, self.critic2_target],
|
||||
[self.actor, self.actor_target]
|
||||
]:
|
||||
for param, target_param in zip(net.parameters(), net_target.parameters()):
|
||||
target_param.data.copy_(
|
||||
self.tau * param.data + (1 - self.tau) * target_param.data
|
||||
)
|
||||
self.soft_update(self.critic1_target, self.critic1, self.tau)
|
||||
self.soft_update(self.critic2_target, self.critic2, self.tau)
|
||||
self.soft_update(self.actor_target, self.actor, self.tau)
|
||||
|
||||
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
|
||||
# batch: obs, act, rew, done, obs_next. (numpy array)
|
||||
|
@ -113,13 +113,8 @@ class CQLPolicy(SACPolicy):
|
||||
|
||||
def sync_weight(self) -> None:
|
||||
"""Soft-update the weight for the target network."""
|
||||
for net, net_old in [
|
||||
[self.critic1, self.critic1_old], [self.critic2, self.critic2_old]
|
||||
]:
|
||||
for param, target_param in zip(net.parameters(), net_old.parameters()):
|
||||
target_param.data.copy_(
|
||||
self._tau * param.data + (1 - self._tau) * target_param.data
|
||||
)
|
||||
self.soft_update(self.critic1_old, self.critic1, self.tau)
|
||||
self.soft_update(self.critic2_old, self.critic2, self.tau)
|
||||
|
||||
def actor_pred(self, obs: torch.Tensor) -> \
|
||||
Tuple[torch.Tensor, torch.Tensor]:
|
||||
|
@ -94,13 +94,10 @@ class DiscreteBCQPolicy(DQNPolicy):
|
||||
# mask actions for argmax
|
||||
ratio = imitation_logits - imitation_logits.max(dim=-1, keepdim=True).values
|
||||
mask = (ratio < self._log_tau).float()
|
||||
action = (q_value - np.inf * mask).argmax(dim=-1)
|
||||
act = (q_value - np.inf * mask).argmax(dim=-1)
|
||||
|
||||
return Batch(
|
||||
act=action,
|
||||
state=state,
|
||||
q_value=q_value,
|
||||
imitation_logits=imitation_logits
|
||||
act=act, state=state, q_value=q_value, imitation_logits=imitation_logits
|
||||
)
|
||||
|
||||
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
|
||||
|
@ -57,15 +57,15 @@ class DiscreteCQLPolicy(QRDQNPolicy):
|
||||
curr_dist = all_dist[np.arange(len(act)), act, :].unsqueeze(2)
|
||||
target_dist = batch.returns.unsqueeze(1)
|
||||
# calculate each element's difference between curr_dist and target_dist
|
||||
u = F.smooth_l1_loss(target_dist, curr_dist, reduction="none")
|
||||
dist_diff = F.smooth_l1_loss(target_dist, curr_dist, reduction="none")
|
||||
huber_loss = (
|
||||
u * (self.tau_hat -
|
||||
(target_dist - curr_dist).detach().le(0.).float()).abs()
|
||||
dist_diff *
|
||||
(self.tau_hat - (target_dist - curr_dist).detach().le(0.).float()).abs()
|
||||
).sum(-1).mean(1)
|
||||
qr_loss = (huber_loss * weight).mean()
|
||||
# ref: https://github.com/ku2482/fqf-iqn-qrdqn.pytorch/
|
||||
# blob/master/fqf_iqn_qrdqn/agent/qrdqn_agent.py L130
|
||||
batch.weight = u.detach().abs().sum(-1).mean(1) # prio-buffer
|
||||
batch.weight = dist_diff.detach().abs().sum(-1).mean(1) # prio-buffer
|
||||
# add CQL loss
|
||||
q = self.compute_q_value(all_dist, None)
|
||||
dataset_expec = q.gather(1, act.unsqueeze(1)).mean()
|
||||
|
@ -97,9 +97,9 @@ class DiscreteCRRPolicy(PGPolicy):
|
||||
target = rew.unsqueeze(1) + self._gamma * expected_target_q
|
||||
critic_loss = 0.5 * F.mse_loss(qa_t, target)
|
||||
# Actor loss
|
||||
a_t, _ = self.actor(batch.obs)
|
||||
m = Categorical(logits=a_t)
|
||||
expected_policy_q = (q_t * m.probs).sum(-1, keepdim=True)
|
||||
act_target, _ = self.actor(batch.obs)
|
||||
dist = Categorical(logits=act_target)
|
||||
expected_policy_q = (q_t * dist.probs).sum(-1, keepdim=True)
|
||||
advantage = qa_t - expected_policy_q
|
||||
if self._policy_improvement_mode == "binary":
|
||||
actor_loss_coef = (advantage > 0).float()
|
||||
@ -109,7 +109,7 @@ class DiscreteCRRPolicy(PGPolicy):
|
||||
)
|
||||
else:
|
||||
actor_loss_coef = 1.0 # effectively behavior cloning
|
||||
actor_loss = (-m.log_prob(act) * actor_loss_coef).mean()
|
||||
actor_loss = (-dist.log_prob(act) * actor_loss_coef).mean()
|
||||
# CQL loss/regularizer
|
||||
min_q_loss = (q_t.logsumexp(1) - qa_t).mean()
|
||||
loss = actor_loss + critic_loss + self._min_q_weight * min_q_loss
|
||||
|
@ -202,13 +202,13 @@ class PSRLPolicy(BasePolicy):
|
||||
rew_sum = np.zeros((n_s, n_a))
|
||||
rew_square_sum = np.zeros((n_s, n_a))
|
||||
rew_count = np.zeros((n_s, n_a))
|
||||
for b in batch.split(size=1):
|
||||
obs, act, obs_next = b.obs, b.act, b.obs_next
|
||||
for minibatch in batch.split(size=1):
|
||||
obs, act, obs_next = minibatch.obs, minibatch.act, minibatch.obs_next
|
||||
trans_count[obs, act, obs_next] += 1
|
||||
rew_sum[obs, act] += b.rew
|
||||
rew_square_sum[obs, act] += b.rew**2
|
||||
rew_sum[obs, act] += minibatch.rew
|
||||
rew_square_sum[obs, act] += minibatch.rew**2
|
||||
rew_count[obs, act] += 1
|
||||
if self._add_done_loop and b.done:
|
||||
if self._add_done_loop and minibatch.done:
|
||||
# special operation for terminal states: add a self-loop
|
||||
trans_count[obs_next, :, obs_next] += 1
|
||||
rew_count[obs_next, :] += 1
|
||||
|
@ -85,9 +85,9 @@ class A2CPolicy(PGPolicy):
|
||||
) -> Batch:
|
||||
v_s, v_s_ = [], []
|
||||
with torch.no_grad():
|
||||
for b in batch.split(self._batch, shuffle=False, merge_last=True):
|
||||
v_s.append(self.critic(b.obs))
|
||||
v_s_.append(self.critic(b.obs_next))
|
||||
for minibatch in batch.split(self._batch, shuffle=False, merge_last=True):
|
||||
v_s.append(self.critic(minibatch.obs))
|
||||
v_s_.append(self.critic(minibatch.obs_next))
|
||||
batch.v_s = torch.cat(v_s, dim=0).flatten() # old value
|
||||
v_s = batch.v_s.cpu().numpy()
|
||||
v_s_ = torch.cat(v_s_, dim=0).flatten().cpu().numpy()
|
||||
@ -122,14 +122,15 @@ class A2CPolicy(PGPolicy):
|
||||
) -> Dict[str, List[float]]:
|
||||
losses, actor_losses, vf_losses, ent_losses = [], [], [], []
|
||||
for _ in range(repeat):
|
||||
for b in batch.split(batch_size, merge_last=True):
|
||||
for minibatch in batch.split(batch_size, merge_last=True):
|
||||
# calculate loss for actor
|
||||
dist = self(b).dist
|
||||
log_prob = dist.log_prob(b.act).reshape(len(b.adv), -1).transpose(0, 1)
|
||||
actor_loss = -(log_prob * b.adv).mean()
|
||||
dist = self(minibatch).dist
|
||||
log_prob = dist.log_prob(minibatch.act)
|
||||
log_prob = log_prob.reshape(len(minibatch.adv), -1).transpose(0, 1)
|
||||
actor_loss = -(log_prob * minibatch.adv).mean()
|
||||
# calculate loss for critic
|
||||
value = self.critic(b.obs).flatten()
|
||||
vf_loss = F.mse_loss(b.returns, value)
|
||||
value = self.critic(minibatch.obs).flatten()
|
||||
vf_loss = F.mse_loss(minibatch.returns, value)
|
||||
# calculate regularization and overall loss
|
||||
ent_loss = dist.entropy().mean()
|
||||
loss = actor_loss + self._weight_vf * vf_loss \
|
||||
|
@ -70,13 +70,13 @@ class C51Policy(DQNPolicy):
|
||||
|
||||
def _target_dist(self, batch: Batch) -> torch.Tensor:
|
||||
if self._target:
|
||||
a = self(batch, input="obs_next").act
|
||||
act = self(batch, input="obs_next").act
|
||||
next_dist = self(batch, model="model_old", input="obs_next").logits
|
||||
else:
|
||||
next_b = self(batch, input="obs_next")
|
||||
a = next_b.act
|
||||
next_dist = next_b.logits
|
||||
next_dist = next_dist[np.arange(len(a)), a, :]
|
||||
next_batch = self(batch, input="obs_next")
|
||||
act = next_batch.act
|
||||
next_dist = next_batch.logits
|
||||
next_dist = next_dist[np.arange(len(act)), act, :]
|
||||
target_support = batch.returns.clamp(self._v_min, self._v_max)
|
||||
# An amazing trick for calculating the projection gracefully.
|
||||
# ref: https://github.com/ShangtongZhang/DeepRL
|
||||
|
@ -73,7 +73,7 @@ class DDPGPolicy(BasePolicy):
|
||||
self.critic_old.eval()
|
||||
self.critic_optim: torch.optim.Optimizer = critic_optim
|
||||
assert 0.0 <= tau <= 1.0, "tau should be in [0, 1]"
|
||||
self._tau = tau
|
||||
self.tau = tau
|
||||
assert 0.0 <= gamma <= 1.0, "gamma should be in [0, 1]"
|
||||
self._gamma = gamma
|
||||
self._noise = exploration_noise
|
||||
@ -95,10 +95,8 @@ class DDPGPolicy(BasePolicy):
|
||||
|
||||
def sync_weight(self) -> None:
|
||||
"""Soft-update the weight for the target network."""
|
||||
for o, n in zip(self.actor_old.parameters(), self.actor.parameters()):
|
||||
o.data.copy_(o.data * (1.0 - self._tau) + n.data * self._tau)
|
||||
for o, n in zip(self.critic_old.parameters(), self.critic.parameters()):
|
||||
o.data.copy_(o.data * (1.0 - self._tau) + n.data * self._tau)
|
||||
self.soft_update(self.actor_old, self.actor, self.tau)
|
||||
self.soft_update(self.critic_old, self.critic, self.tau)
|
||||
|
||||
def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor:
|
||||
batch = buffer[indices] # batch.obs_next: s_{t+n}
|
||||
@ -139,8 +137,8 @@ class DDPGPolicy(BasePolicy):
|
||||
"""
|
||||
model = getattr(self, model)
|
||||
obs = batch[input]
|
||||
actions, h = model(obs, state=state, info=batch.info)
|
||||
return Batch(act=actions, state=h)
|
||||
actions, hidden = model(obs, state=state, info=batch.info)
|
||||
return Batch(act=actions, state=hidden)
|
||||
|
||||
@staticmethod
|
||||
def _mse_optimizer(
|
||||
@ -163,8 +161,7 @@ class DDPGPolicy(BasePolicy):
|
||||
td, critic_loss = self._mse_optimizer(batch, self.critic, self.critic_optim)
|
||||
batch.weight = td # prio-buffer
|
||||
# actor
|
||||
action = self(batch).act
|
||||
actor_loss = -self.critic(batch.obs, action).mean()
|
||||
actor_loss = -self.critic(batch.obs, self(batch).act).mean()
|
||||
self.actor_optim.zero_grad()
|
||||
actor_loss.backward()
|
||||
self.actor_optim.step()
|
||||
|
@ -76,10 +76,10 @@ class DiscreteSACPolicy(SACPolicy):
|
||||
**kwargs: Any,
|
||||
) -> Batch:
|
||||
obs = batch[input]
|
||||
logits, h = self.actor(obs, state=state, info=batch.info)
|
||||
logits, hidden = self.actor(obs, state=state, info=batch.info)
|
||||
dist = Categorical(logits=logits)
|
||||
act = dist.sample()
|
||||
return Batch(logits=logits, act=act, state=h, dist=dist)
|
||||
return Batch(logits=logits, act=act, state=hidden, dist=dist)
|
||||
|
||||
def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor:
|
||||
batch = buffer[indices] # batch.obs: s_{t+n}
|
||||
|
@ -151,13 +151,13 @@ class DQNPolicy(BasePolicy):
|
||||
"""
|
||||
model = getattr(self, model)
|
||||
obs = batch[input]
|
||||
obs_ = obs.obs if hasattr(obs, "obs") else obs
|
||||
logits, h = model(obs_, state=state, info=batch.info)
|
||||
obs_next = obs.obs if hasattr(obs, "obs") else obs
|
||||
logits, hidden = model(obs_next, state=state, info=batch.info)
|
||||
q = self.compute_q_value(logits, getattr(obs, "mask", None))
|
||||
if not hasattr(self, "max_action_num"):
|
||||
self.max_action_num = q.shape[1]
|
||||
act = to_numpy(q.max(dim=1)[1])
|
||||
return Batch(logits=logits, act=act, state=h)
|
||||
return Batch(logits=logits, act=act, state=hidden)
|
||||
|
||||
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
|
||||
if self._target and self._iter % self._freq == 0:
|
||||
@ -166,10 +166,10 @@ class DQNPolicy(BasePolicy):
|
||||
weight = batch.pop("weight", 1.0)
|
||||
q = self(batch).logits
|
||||
q = q[np.arange(len(q)), batch.act]
|
||||
r = to_torch_as(batch.returns.flatten(), q)
|
||||
td = r - q
|
||||
loss = (td.pow(2) * weight).mean()
|
||||
batch.weight = td # prio-buffer
|
||||
returns = to_torch_as(batch.returns.flatten(), q)
|
||||
td_error = returns - q
|
||||
loss = (td_error.pow(2) * weight).mean()
|
||||
batch.weight = td_error # prio-buffer
|
||||
loss.backward()
|
||||
self.optim.step()
|
||||
self._iter += 1
|
||||
|
@ -60,15 +60,15 @@ class FQFPolicy(QRDQNPolicy):
|
||||
batch = buffer[indices] # batch.obs_next: s_{t+n}
|
||||
if self._target:
|
||||
result = self(batch, input="obs_next")
|
||||
a, fractions = result.act, result.fractions
|
||||
act, fractions = result.act, result.fractions
|
||||
next_dist = self(
|
||||
batch, model="model_old", input="obs_next", fractions=fractions
|
||||
).logits
|
||||
else:
|
||||
next_b = self(batch, input="obs_next")
|
||||
a = next_b.act
|
||||
next_dist = next_b.logits
|
||||
next_dist = next_dist[np.arange(len(a)), a, :]
|
||||
next_batch = self(batch, input="obs_next")
|
||||
act = next_batch.act
|
||||
next_dist = next_batch.logits
|
||||
next_dist = next_dist[np.arange(len(act)), act, :]
|
||||
return next_dist # shape: [bsz, num_quantiles]
|
||||
|
||||
def forward(
|
||||
@ -82,14 +82,17 @@ class FQFPolicy(QRDQNPolicy):
|
||||
) -> Batch:
|
||||
model = getattr(self, model)
|
||||
obs = batch[input]
|
||||
obs_ = obs.obs if hasattr(obs, "obs") else obs
|
||||
obs_next = obs.obs if hasattr(obs, "obs") else obs
|
||||
if fractions is None:
|
||||
(logits, fractions, quantiles_tau), h = model(
|
||||
obs_, propose_model=self.propose_model, state=state, info=batch.info
|
||||
(logits, fractions, quantiles_tau), hidden = model(
|
||||
obs_next,
|
||||
propose_model=self.propose_model,
|
||||
state=state,
|
||||
info=batch.info
|
||||
)
|
||||
else:
|
||||
(logits, _, quantiles_tau), h = model(
|
||||
obs_,
|
||||
(logits, _, quantiles_tau), hidden = model(
|
||||
obs_next,
|
||||
propose_model=self.propose_model,
|
||||
fractions=fractions,
|
||||
state=state,
|
||||
@ -106,7 +109,7 @@ class FQFPolicy(QRDQNPolicy):
|
||||
return Batch(
|
||||
logits=logits,
|
||||
act=act,
|
||||
state=h,
|
||||
state=hidden,
|
||||
fractions=fractions,
|
||||
quantiles_tau=quantiles_tau
|
||||
)
|
||||
@ -122,9 +125,9 @@ class FQFPolicy(QRDQNPolicy):
|
||||
curr_dist = curr_dist_orig[np.arange(len(act)), act, :].unsqueeze(2)
|
||||
target_dist = batch.returns.unsqueeze(1)
|
||||
# calculate each element's difference between curr_dist and target_dist
|
||||
u = F.smooth_l1_loss(target_dist, curr_dist, reduction="none")
|
||||
dist_diff = F.smooth_l1_loss(target_dist, curr_dist, reduction="none")
|
||||
huber_loss = (
|
||||
u * (
|
||||
dist_diff * (
|
||||
tau_hats.unsqueeze(2) -
|
||||
(target_dist - curr_dist).detach().le(0.).float()
|
||||
).abs()
|
||||
@ -132,7 +135,7 @@ class FQFPolicy(QRDQNPolicy):
|
||||
quantile_loss = (huber_loss * weight).mean()
|
||||
# ref: https://github.com/ku2482/fqf-iqn-qrdqn.pytorch/
|
||||
# blob/master/fqf_iqn_qrdqn/agent/qrdqn_agent.py L130
|
||||
batch.weight = u.detach().abs().sum(-1).mean(1) # prio-buffer
|
||||
batch.weight = dist_diff.detach().abs().sum(-1).mean(1) # prio-buffer
|
||||
# calculate fraction loss
|
||||
with torch.no_grad():
|
||||
sa_quantile_hats = curr_dist_orig[np.arange(len(act)), act, :]
|
||||
|
@ -73,36 +73,37 @@ class IQNPolicy(QRDQNPolicy):
|
||||
sample_size = self._sample_size
|
||||
model = getattr(self, model)
|
||||
obs = batch[input]
|
||||
obs_ = obs.obs if hasattr(obs, "obs") else obs
|
||||
(logits,
|
||||
taus), h = model(obs_, sample_size=sample_size, state=state, info=batch.info)
|
||||
obs_next = obs.obs if hasattr(obs, "obs") else obs
|
||||
(logits, taus), hidden = model(
|
||||
obs_next, sample_size=sample_size, state=state, info=batch.info
|
||||
)
|
||||
q = self.compute_q_value(logits, getattr(obs, "mask", None))
|
||||
if not hasattr(self, "max_action_num"):
|
||||
self.max_action_num = q.shape[1]
|
||||
act = to_numpy(q.max(dim=1)[1])
|
||||
return Batch(logits=logits, act=act, state=h, taus=taus)
|
||||
return Batch(logits=logits, act=act, state=hidden, taus=taus)
|
||||
|
||||
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
|
||||
if self._target and self._iter % self._freq == 0:
|
||||
self.sync_weight()
|
||||
self.optim.zero_grad()
|
||||
weight = batch.pop("weight", 1.0)
|
||||
out = self(batch)
|
||||
curr_dist, taus = out.logits, out.taus
|
||||
action_batch = self(batch)
|
||||
curr_dist, taus = action_batch.logits, action_batch.taus
|
||||
act = batch.act
|
||||
curr_dist = curr_dist[np.arange(len(act)), act, :].unsqueeze(2)
|
||||
target_dist = batch.returns.unsqueeze(1)
|
||||
# calculate each element's difference between curr_dist and target_dist
|
||||
u = F.smooth_l1_loss(target_dist, curr_dist, reduction="none")
|
||||
dist_diff = F.smooth_l1_loss(target_dist, curr_dist, reduction="none")
|
||||
huber_loss = (
|
||||
u *
|
||||
dist_diff *
|
||||
(taus.unsqueeze(2) -
|
||||
(target_dist - curr_dist).detach().le(0.).float()).abs()
|
||||
).sum(-1).mean(1)
|
||||
loss = (huber_loss * weight).mean()
|
||||
# ref: https://github.com/ku2482/fqf-iqn-qrdqn.pytorch/
|
||||
# blob/master/fqf_iqn_qrdqn/agent/qrdqn_agent.py L130
|
||||
batch.weight = u.detach().abs().sum(-1).mean(1) # prio-buffer
|
||||
batch.weight = dist_diff.detach().abs().sum(-1).mean(1) # prio-buffer
|
||||
loss.backward()
|
||||
self.optim.step()
|
||||
self._iter += 1
|
||||
|
@ -71,8 +71,8 @@ class NPGPolicy(A2CPolicy):
|
||||
batch = super().process_fn(batch, buffer, indices)
|
||||
old_log_prob = []
|
||||
with torch.no_grad():
|
||||
for b in batch.split(self._batch, shuffle=False, merge_last=True):
|
||||
old_log_prob.append(self(b).dist.log_prob(b.act))
|
||||
for minibatch in batch.split(self._batch, shuffle=False, merge_last=True):
|
||||
old_log_prob.append(self(minibatch).dist.log_prob(minibatch.act))
|
||||
batch.logp_old = torch.cat(old_log_prob, dim=0)
|
||||
if self._norm_adv:
|
||||
batch.adv = (batch.adv - batch.adv.mean()) / batch.adv.std()
|
||||
@ -83,20 +83,20 @@ class NPGPolicy(A2CPolicy):
|
||||
) -> Dict[str, List[float]]:
|
||||
actor_losses, vf_losses, kls = [], [], []
|
||||
for _ in range(repeat):
|
||||
for b in batch.split(batch_size, merge_last=True):
|
||||
for minibatch in batch.split(batch_size, merge_last=True):
|
||||
# optimize actor
|
||||
# direction: calculate villia gradient
|
||||
dist = self(b).dist
|
||||
log_prob = dist.log_prob(b.act)
|
||||
dist = self(minibatch).dist
|
||||
log_prob = dist.log_prob(minibatch.act)
|
||||
log_prob = log_prob.reshape(log_prob.size(0), -1).transpose(0, 1)
|
||||
actor_loss = -(log_prob * b.adv).mean()
|
||||
actor_loss = -(log_prob * minibatch.adv).mean()
|
||||
flat_grads = self._get_flat_grad(
|
||||
actor_loss, self.actor, retain_graph=True
|
||||
).detach()
|
||||
|
||||
# direction: calculate natural gradient
|
||||
with torch.no_grad():
|
||||
old_dist = self(b).dist
|
||||
old_dist = self(minibatch).dist
|
||||
|
||||
kl = kl_divergence(old_dist, dist).mean()
|
||||
# calculate first order gradient of kl with respect to theta
|
||||
@ -112,13 +112,13 @@ class NPGPolicy(A2CPolicy):
|
||||
)
|
||||
new_flat_params = flat_params + self._step_size * search_direction
|
||||
self._set_from_flat_params(self.actor, new_flat_params)
|
||||
new_dist = self(b).dist
|
||||
new_dist = self(minibatch).dist
|
||||
kl = kl_divergence(old_dist, new_dist).mean()
|
||||
|
||||
# optimize citirc
|
||||
for _ in range(self._optim_critic_iters):
|
||||
value = self.critic(b.obs).flatten()
|
||||
vf_loss = F.mse_loss(b.returns, value)
|
||||
value = self.critic(minibatch.obs).flatten()
|
||||
vf_loss = F.mse_loss(minibatch.returns, value)
|
||||
self.optim.zero_grad()
|
||||
vf_loss.backward()
|
||||
self.optim.step()
|
||||
@ -147,14 +147,14 @@ class NPGPolicy(A2CPolicy):
|
||||
|
||||
def _conjugate_gradients(
|
||||
self,
|
||||
b: torch.Tensor,
|
||||
minibatch: torch.Tensor,
|
||||
flat_kl_grad: torch.Tensor,
|
||||
nsteps: int = 10,
|
||||
residual_tol: float = 1e-10
|
||||
) -> torch.Tensor:
|
||||
x = torch.zeros_like(b)
|
||||
r, p = b.clone(), b.clone()
|
||||
# Note: should be 'r, p = b - MVP(x)', but for x=0, MVP(x)=0.
|
||||
x = torch.zeros_like(minibatch)
|
||||
r, p = minibatch.clone(), minibatch.clone()
|
||||
# Note: should be 'r, p = minibatch - MVP(x)', but for x=0, MVP(x)=0.
|
||||
# Change if doing warm start.
|
||||
rdotr = r.dot(r)
|
||||
for _ in range(nsteps):
|
||||
|
@ -107,7 +107,7 @@ class PGPolicy(BasePolicy):
|
||||
Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for
|
||||
more detailed explanation.
|
||||
"""
|
||||
logits, h = self.actor(batch.obs, state=state)
|
||||
logits, hidden = self.actor(batch.obs, state=state)
|
||||
if isinstance(logits, tuple):
|
||||
dist = self.dist_fn(*logits)
|
||||
else:
|
||||
@ -119,20 +119,20 @@ class PGPolicy(BasePolicy):
|
||||
act = logits[0]
|
||||
else:
|
||||
act = dist.sample()
|
||||
return Batch(logits=logits, act=act, state=h, dist=dist)
|
||||
return Batch(logits=logits, act=act, state=hidden, dist=dist)
|
||||
|
||||
def learn( # type: ignore
|
||||
self, batch: Batch, batch_size: int, repeat: int, **kwargs: Any
|
||||
) -> Dict[str, List[float]]:
|
||||
losses = []
|
||||
for _ in range(repeat):
|
||||
for b in batch.split(batch_size, merge_last=True):
|
||||
for minibatch in batch.split(batch_size, merge_last=True):
|
||||
self.optim.zero_grad()
|
||||
result = self(b)
|
||||
result = self(minibatch)
|
||||
dist = result.dist
|
||||
a = to_torch_as(b.act, result.act)
|
||||
ret = to_torch_as(b.returns, result.act)
|
||||
log_prob = dist.log_prob(a).reshape(len(ret), -1).transpose(0, 1)
|
||||
act = to_torch_as(minibatch.act, result.act)
|
||||
ret = to_torch_as(minibatch.returns, result.act)
|
||||
log_prob = dist.log_prob(act).reshape(len(ret), -1).transpose(0, 1)
|
||||
loss = -(log_prob * ret).mean()
|
||||
loss.backward()
|
||||
self.optim.step()
|
||||
|
@ -96,8 +96,8 @@ class PPOPolicy(A2CPolicy):
|
||||
batch.act = to_torch_as(batch.act, batch.v_s)
|
||||
old_log_prob = []
|
||||
with torch.no_grad():
|
||||
for b in batch.split(self._batch, shuffle=False, merge_last=True):
|
||||
old_log_prob.append(self(b).dist.log_prob(b.act))
|
||||
for minibatch in batch.split(self._batch, shuffle=False, merge_last=True):
|
||||
old_log_prob.append(self(minibatch).dist.log_prob(minibatch.act))
|
||||
batch.logp_old = torch.cat(old_log_prob, dim=0)
|
||||
return batch
|
||||
|
||||
@ -108,32 +108,35 @@ class PPOPolicy(A2CPolicy):
|
||||
for step in range(repeat):
|
||||
if self._recompute_adv and step > 0:
|
||||
batch = self._compute_returns(batch, self._buffer, self._indices)
|
||||
for b in batch.split(batch_size, merge_last=True):
|
||||
for minibatch in batch.split(batch_size, merge_last=True):
|
||||
# calculate loss for actor
|
||||
dist = self(b).dist
|
||||
dist = self(minibatch).dist
|
||||
if self._norm_adv:
|
||||
mean, std = b.adv.mean(), b.adv.std()
|
||||
b.adv = (b.adv - mean) / std # per-batch norm
|
||||
ratio = (dist.log_prob(b.act) - b.logp_old).exp().float()
|
||||
mean, std = minibatch.adv.mean(), minibatch.adv.std()
|
||||
minibatch.adv = (minibatch.adv - mean) / std # per-batch norm
|
||||
ratio = (dist.log_prob(minibatch.act) -
|
||||
minibatch.logp_old).exp().float()
|
||||
ratio = ratio.reshape(ratio.size(0), -1).transpose(0, 1)
|
||||
surr1 = ratio * b.adv
|
||||
surr2 = ratio.clamp(1.0 - self._eps_clip, 1.0 + self._eps_clip) * b.adv
|
||||
surr1 = ratio * minibatch.adv
|
||||
surr2 = ratio.clamp(
|
||||
1.0 - self._eps_clip, 1.0 + self._eps_clip
|
||||
) * minibatch.adv
|
||||
if self._dual_clip:
|
||||
clip1 = torch.min(surr1, surr2)
|
||||
clip2 = torch.max(clip1, self._dual_clip * b.adv)
|
||||
clip_loss = -torch.where(b.adv < 0, clip2, clip1).mean()
|
||||
clip2 = torch.max(clip1, self._dual_clip * minibatch.adv)
|
||||
clip_loss = -torch.where(minibatch.adv < 0, clip2, clip1).mean()
|
||||
else:
|
||||
clip_loss = -torch.min(surr1, surr2).mean()
|
||||
# calculate loss for critic
|
||||
value = self.critic(b.obs).flatten()
|
||||
value = self.critic(minibatch.obs).flatten()
|
||||
if self._value_clip:
|
||||
v_clip = b.v_s + (value -
|
||||
b.v_s).clamp(-self._eps_clip, self._eps_clip)
|
||||
vf1 = (b.returns - value).pow(2)
|
||||
vf2 = (b.returns - v_clip).pow(2)
|
||||
v_clip = minibatch.v_s + \
|
||||
(value - minibatch.v_s).clamp(-self._eps_clip, self._eps_clip)
|
||||
vf1 = (minibatch.returns - value).pow(2)
|
||||
vf2 = (minibatch.returns - v_clip).pow(2)
|
||||
vf_loss = torch.max(vf1, vf2).mean()
|
||||
else:
|
||||
vf_loss = (b.returns - value).pow(2).mean()
|
||||
vf_loss = (minibatch.returns - value).pow(2).mean()
|
||||
# calculate regularization and overall loss
|
||||
ent_loss = dist.entropy().mean()
|
||||
loss = clip_loss + self._weight_vf * vf_loss \
|
||||
|
@ -56,13 +56,13 @@ class QRDQNPolicy(DQNPolicy):
|
||||
def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor:
|
||||
batch = buffer[indices] # batch.obs_next: s_{t+n}
|
||||
if self._target:
|
||||
a = self(batch, input="obs_next").act
|
||||
act = self(batch, input="obs_next").act
|
||||
next_dist = self(batch, model="model_old", input="obs_next").logits
|
||||
else:
|
||||
next_b = self(batch, input="obs_next")
|
||||
a = next_b.act
|
||||
next_dist = next_b.logits
|
||||
next_dist = next_dist[np.arange(len(a)), a, :]
|
||||
next_batch = self(batch, input="obs_next")
|
||||
act = next_batch.act
|
||||
next_dist = next_batch.logits
|
||||
next_dist = next_dist[np.arange(len(act)), act, :]
|
||||
return next_dist # shape: [bsz, num_quantiles]
|
||||
|
||||
def compute_q_value(
|
||||
@ -80,15 +80,15 @@ class QRDQNPolicy(DQNPolicy):
|
||||
curr_dist = curr_dist[np.arange(len(act)), act, :].unsqueeze(2)
|
||||
target_dist = batch.returns.unsqueeze(1)
|
||||
# calculate each element's difference between curr_dist and target_dist
|
||||
u = F.smooth_l1_loss(target_dist, curr_dist, reduction="none")
|
||||
dist_diff = F.smooth_l1_loss(target_dist, curr_dist, reduction="none")
|
||||
huber_loss = (
|
||||
u * (self.tau_hat -
|
||||
(target_dist - curr_dist).detach().le(0.).float()).abs()
|
||||
dist_diff *
|
||||
(self.tau_hat - (target_dist - curr_dist).detach().le(0.).float()).abs()
|
||||
).sum(-1).mean(1)
|
||||
loss = (huber_loss * weight).mean()
|
||||
# ref: https://github.com/ku2482/fqf-iqn-qrdqn.pytorch/
|
||||
# blob/master/fqf_iqn_qrdqn/agent/qrdqn_agent.py L130
|
||||
batch.weight = u.detach().abs().sum(-1).mean(1) # prio-buffer
|
||||
batch.weight = dist_diff.detach().abs().sum(-1).mean(1) # prio-buffer
|
||||
loss.backward()
|
||||
self.optim.step()
|
||||
self._iter += 1
|
||||
|
@ -99,10 +99,8 @@ class SACPolicy(DDPGPolicy):
|
||||
return self
|
||||
|
||||
def sync_weight(self) -> None:
|
||||
for o, n in zip(self.critic1_old.parameters(), self.critic1.parameters()):
|
||||
o.data.copy_(o.data * (1.0 - self._tau) + n.data * self._tau)
|
||||
for o, n in zip(self.critic2_old.parameters(), self.critic2.parameters()):
|
||||
o.data.copy_(o.data * (1.0 - self._tau) + n.data * self._tau)
|
||||
self.soft_update(self.critic1_old, self.critic1, self.tau)
|
||||
self.soft_update(self.critic2_old, self.critic2, self.tau)
|
||||
|
||||
def forward( # type: ignore
|
||||
self,
|
||||
@ -112,7 +110,7 @@ class SACPolicy(DDPGPolicy):
|
||||
**kwargs: Any,
|
||||
) -> Batch:
|
||||
obs = batch[input]
|
||||
logits, h = self.actor(obs, state=state, info=batch.info)
|
||||
logits, hidden = self.actor(obs, state=state, info=batch.info)
|
||||
assert isinstance(logits, tuple)
|
||||
dist = Independent(Normal(*logits), 1)
|
||||
if self._deterministic_eval and not self.training:
|
||||
@ -134,16 +132,20 @@ class SACPolicy(DDPGPolicy):
|
||||
action_scale * (1 - squashed_action.pow(2)) + self.__eps
|
||||
).sum(-1, keepdim=True)
|
||||
return Batch(
|
||||
logits=logits, act=squashed_action, state=h, dist=dist, log_prob=log_prob
|
||||
logits=logits,
|
||||
act=squashed_action,
|
||||
state=hidden,
|
||||
dist=dist,
|
||||
log_prob=log_prob
|
||||
)
|
||||
|
||||
def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor:
|
||||
batch = buffer[indices] # batch.obs: s_{t+n}
|
||||
obs_next_result = self(batch, input='obs_next')
|
||||
a_ = obs_next_result.act
|
||||
obs_next_result = self(batch, input="obs_next")
|
||||
act_ = obs_next_result.act
|
||||
target_q = torch.min(
|
||||
self.critic1_old(batch.obs_next, a_),
|
||||
self.critic2_old(batch.obs_next, a_),
|
||||
self.critic1_old(batch.obs_next, act_),
|
||||
self.critic2_old(batch.obs_next, act_),
|
||||
) - self._alpha * obs_next_result.log_prob
|
||||
return target_q
|
||||
|
||||
@ -159,9 +161,9 @@ class SACPolicy(DDPGPolicy):
|
||||
|
||||
# actor
|
||||
obs_result = self(batch)
|
||||
a = obs_result.act
|
||||
current_q1a = self.critic1(batch.obs, a).flatten()
|
||||
current_q2a = self.critic2(batch.obs, a).flatten()
|
||||
act = obs_result.act
|
||||
current_q1a = self.critic1(batch.obs, act).flatten()
|
||||
current_q2a = self.critic2(batch.obs, act).flatten()
|
||||
actor_loss = (
|
||||
self._alpha * obs_result.log_prob.flatten() -
|
||||
torch.min(current_q1a, current_q2a)
|
||||
|
@ -89,23 +89,20 @@ class TD3Policy(DDPGPolicy):
|
||||
return self
|
||||
|
||||
def sync_weight(self) -> None:
|
||||
for o, n in zip(self.actor_old.parameters(), self.actor.parameters()):
|
||||
o.data.copy_(o.data * (1.0 - self._tau) + n.data * self._tau)
|
||||
for o, n in zip(self.critic1_old.parameters(), self.critic1.parameters()):
|
||||
o.data.copy_(o.data * (1.0 - self._tau) + n.data * self._tau)
|
||||
for o, n in zip(self.critic2_old.parameters(), self.critic2.parameters()):
|
||||
o.data.copy_(o.data * (1.0 - self._tau) + n.data * self._tau)
|
||||
self.soft_update(self.critic1_old, self.critic1, self.tau)
|
||||
self.soft_update(self.critic2_old, self.critic2, self.tau)
|
||||
self.soft_update(self.actor_old, self.actor, self.tau)
|
||||
|
||||
def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor:
|
||||
batch = buffer[indices] # batch.obs: s_{t+n}
|
||||
a_ = self(batch, model="actor_old", input="obs_next").act
|
||||
dev = a_.device
|
||||
noise = torch.randn(size=a_.shape, device=dev) * self._policy_noise
|
||||
act_ = self(batch, model="actor_old", input="obs_next").act
|
||||
noise = torch.randn(size=act_.shape, device=act_.device) * self._policy_noise
|
||||
if self._noise_clip > 0.0:
|
||||
noise = noise.clamp(-self._noise_clip, self._noise_clip)
|
||||
a_ += noise
|
||||
act_ += noise
|
||||
target_q = torch.min(
|
||||
self.critic1_old(batch.obs_next, a_), self.critic2_old(batch.obs_next, a_)
|
||||
self.critic1_old(batch.obs_next, act_),
|
||||
self.critic2_old(batch.obs_next, act_),
|
||||
)
|
||||
return target_q
|
||||
|
||||
|
@ -71,20 +71,21 @@ class TRPOPolicy(NPGPolicy):
|
||||
) -> Dict[str, List[float]]:
|
||||
actor_losses, vf_losses, step_sizes, kls = [], [], [], []
|
||||
for _ in range(repeat):
|
||||
for b in batch.split(batch_size, merge_last=True):
|
||||
for minibatch in batch.split(batch_size, merge_last=True):
|
||||
# optimize actor
|
||||
# direction: calculate villia gradient
|
||||
dist = self(b).dist # TODO could come from batch
|
||||
ratio = (dist.log_prob(b.act) - b.logp_old).exp().float()
|
||||
dist = self(minibatch).dist # TODO could come from batch
|
||||
ratio = (dist.log_prob(minibatch.act) -
|
||||
minibatch.logp_old).exp().float()
|
||||
ratio = ratio.reshape(ratio.size(0), -1).transpose(0, 1)
|
||||
actor_loss = -(ratio * b.adv).mean()
|
||||
actor_loss = -(ratio * minibatch.adv).mean()
|
||||
flat_grads = self._get_flat_grad(
|
||||
actor_loss, self.actor, retain_graph=True
|
||||
).detach()
|
||||
|
||||
# direction: calculate natural gradient
|
||||
with torch.no_grad():
|
||||
old_dist = self(b).dist
|
||||
old_dist = self(minibatch).dist
|
||||
|
||||
kl = kl_divergence(old_dist, dist).mean()
|
||||
# calculate first order gradient of kl with respect to theta
|
||||
@ -109,12 +110,13 @@ class TRPOPolicy(NPGPolicy):
|
||||
new_flat_params = flat_params + step_size * search_direction
|
||||
self._set_from_flat_params(self.actor, new_flat_params)
|
||||
# calculate kl and if in bound, loss actually down
|
||||
new_dist = self(b).dist
|
||||
new_dratio = (new_dist.log_prob(b.act) -
|
||||
b.logp_old).exp().float()
|
||||
new_dist = self(minibatch).dist
|
||||
new_dratio = (
|
||||
new_dist.log_prob(minibatch.act) - minibatch.logp_old
|
||||
).exp().float()
|
||||
new_dratio = new_dratio.reshape(new_dratio.size(0),
|
||||
-1).transpose(0, 1)
|
||||
new_actor_loss = -(new_dratio * b.adv).mean()
|
||||
new_actor_loss = -(new_dratio * minibatch.adv).mean()
|
||||
kl = kl_divergence(old_dist, new_dist).mean()
|
||||
|
||||
if kl < self._delta and new_actor_loss < actor_loss:
|
||||
@ -133,8 +135,8 @@ class TRPOPolicy(NPGPolicy):
|
||||
|
||||
# optimize citirc
|
||||
for _ in range(self._optim_critic_iters):
|
||||
value = self.critic(b.obs).flatten()
|
||||
vf_loss = F.mse_loss(b.returns, value)
|
||||
value = self.critic(minibatch.obs).flatten()
|
||||
vf_loss = F.mse_loss(minibatch.returns, value)
|
||||
self.optim.zero_grad()
|
||||
vf_loss.backward()
|
||||
self.optim.step()
|
||||
|
@ -87,14 +87,14 @@ class MLP(nn.Module):
|
||||
self.output_dim = output_dim or hidden_sizes[-1]
|
||||
self.model = nn.Sequential(*model)
|
||||
|
||||
def forward(self, s: Union[np.ndarray, torch.Tensor]) -> torch.Tensor:
|
||||
def forward(self, obs: Union[np.ndarray, torch.Tensor]) -> torch.Tensor:
|
||||
if self.device is not None:
|
||||
s = torch.as_tensor(
|
||||
s,
|
||||
obs = torch.as_tensor(
|
||||
obs,
|
||||
device=self.device, # type: ignore
|
||||
dtype=torch.float32,
|
||||
)
|
||||
return self.model(s.flatten(1)) # type: ignore
|
||||
return self.model(obs.flatten(1)) # type: ignore
|
||||
|
||||
|
||||
class Net(nn.Module):
|
||||
@ -187,12 +187,12 @@ class Net(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
s: Union[np.ndarray, torch.Tensor],
|
||||
obs: Union[np.ndarray, torch.Tensor],
|
||||
state: Any = None,
|
||||
info: Dict[str, Any] = {},
|
||||
) -> Tuple[torch.Tensor, Any]:
|
||||
"""Mapping: s -> flatten (inside MLP)-> logits."""
|
||||
logits = self.model(s)
|
||||
"""Mapping: obs -> flatten (inside MLP)-> logits."""
|
||||
logits = self.model(obs)
|
||||
bsz = logits.shape[0]
|
||||
if self.use_dueling: # Dueling DQN
|
||||
q, v = self.Q(logits), self.V(logits)
|
||||
@ -235,38 +235,45 @@ class Recurrent(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
s: Union[np.ndarray, torch.Tensor],
|
||||
obs: Union[np.ndarray, torch.Tensor],
|
||||
state: Optional[Dict[str, torch.Tensor]] = None,
|
||||
info: Dict[str, Any] = {},
|
||||
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
|
||||
"""Mapping: s -> flatten -> logits.
|
||||
"""Mapping: obs -> flatten -> logits.
|
||||
|
||||
In the evaluation mode, s should be with shape ``[bsz, dim]``; in the
|
||||
training mode, s should be with shape ``[bsz, len, dim]``. See the code
|
||||
In the evaluation mode, `obs` should be with shape ``[bsz, dim]``; in the
|
||||
training mode, `obs` should be with shape ``[bsz, len, dim]``. See the code
|
||||
and comment for more detail.
|
||||
"""
|
||||
s = torch.as_tensor(s, device=self.device, dtype=torch.float32) # type: ignore
|
||||
# s [bsz, len, dim] (training) or [bsz, dim] (evaluation)
|
||||
obs = torch.as_tensor(
|
||||
obs,
|
||||
device=self.device, # type: ignore
|
||||
dtype=torch.float32,
|
||||
)
|
||||
# obs [bsz, len, dim] (training) or [bsz, dim] (evaluation)
|
||||
# In short, the tensor's shape in training phase is longer than which
|
||||
# in evaluation phase.
|
||||
if len(s.shape) == 2:
|
||||
s = s.unsqueeze(-2)
|
||||
s = self.fc1(s)
|
||||
if len(obs.shape) == 2:
|
||||
obs = obs.unsqueeze(-2)
|
||||
obs = self.fc1(obs)
|
||||
self.nn.flatten_parameters()
|
||||
if state is None:
|
||||
s, (h, c) = self.nn(s)
|
||||
obs, (hidden, cell) = self.nn(obs)
|
||||
else:
|
||||
# we store the stack data in [bsz, len, ...] format
|
||||
# but pytorch rnn needs [len, bsz, ...]
|
||||
s, (h, c) = self.nn(
|
||||
s, (
|
||||
state["h"].transpose(0, 1).contiguous(),
|
||||
state["c"].transpose(0, 1).contiguous()
|
||||
obs, (hidden, cell) = self.nn(
|
||||
obs, (
|
||||
state["hidden"].transpose(0, 1).contiguous(),
|
||||
state["cell"].transpose(0, 1).contiguous()
|
||||
)
|
||||
)
|
||||
s = self.fc2(s[:, -1])
|
||||
obs = self.fc2(obs[:, -1])
|
||||
# please ensure the first dim is batch size: [bsz, len, ...]
|
||||
return s, {"h": h.transpose(0, 1).detach(), "c": c.transpose(0, 1).detach()}
|
||||
return obs, {
|
||||
"hidden": hidden.transpose(0, 1).detach(),
|
||||
"cell": cell.transpose(0, 1).detach()
|
||||
}
|
||||
|
||||
|
||||
class ActorCritic(nn.Module):
|
||||
@ -299,8 +306,8 @@ class DataParallelNet(nn.Module):
|
||||
super().__init__()
|
||||
self.net = nn.DataParallel(net)
|
||||
|
||||
def forward(self, s: Union[np.ndarray, torch.Tensor], *args: Any,
|
||||
def forward(self, obs: Union[np.ndarray, torch.Tensor], *args: Any,
|
||||
**kwargs: Any) -> Tuple[Any, Any]:
|
||||
if not isinstance(s, torch.Tensor):
|
||||
s = torch.as_tensor(s, dtype=torch.float32)
|
||||
return self.net(s=s.cuda(), *args, **kwargs)
|
||||
if not isinstance(obs, torch.Tensor):
|
||||
obs = torch.as_tensor(obs, dtype=torch.float32)
|
||||
return self.net(obs=obs.cuda(), *args, **kwargs)
|
||||
|
@ -58,14 +58,14 @@ class Actor(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
s: Union[np.ndarray, torch.Tensor],
|
||||
obs: Union[np.ndarray, torch.Tensor],
|
||||
state: Any = None,
|
||||
info: Dict[str, Any] = {},
|
||||
) -> Tuple[torch.Tensor, Any]:
|
||||
"""Mapping: s -> logits -> action."""
|
||||
logits, h = self.preprocess(s, state)
|
||||
"""Mapping: obs -> logits -> action."""
|
||||
logits, hidden = self.preprocess(obs, state)
|
||||
logits = self._max * torch.tanh(self.last(logits))
|
||||
return logits, h
|
||||
return logits, hidden
|
||||
|
||||
|
||||
class Critic(nn.Module):
|
||||
@ -110,24 +110,24 @@ class Critic(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
s: Union[np.ndarray, torch.Tensor],
|
||||
a: Optional[Union[np.ndarray, torch.Tensor]] = None,
|
||||
obs: Union[np.ndarray, torch.Tensor],
|
||||
act: Optional[Union[np.ndarray, torch.Tensor]] = None,
|
||||
info: Dict[str, Any] = {},
|
||||
) -> torch.Tensor:
|
||||
"""Mapping: (s, a) -> logits -> Q(s, a)."""
|
||||
s = torch.as_tensor(
|
||||
s,
|
||||
obs = torch.as_tensor(
|
||||
obs,
|
||||
device=self.device, # type: ignore
|
||||
dtype=torch.float32,
|
||||
).flatten(1)
|
||||
if a is not None:
|
||||
a = torch.as_tensor(
|
||||
a,
|
||||
if act is not None:
|
||||
act = torch.as_tensor(
|
||||
act,
|
||||
device=self.device, # type: ignore
|
||||
dtype=torch.float32,
|
||||
).flatten(1)
|
||||
s = torch.cat([s, a], dim=1)
|
||||
logits, h = self.preprocess(s)
|
||||
obs = torch.cat([obs, act], dim=1)
|
||||
logits, hidden = self.preprocess(obs)
|
||||
logits = self.last(logits)
|
||||
return logits
|
||||
|
||||
@ -196,12 +196,12 @@ class ActorProb(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
s: Union[np.ndarray, torch.Tensor],
|
||||
obs: Union[np.ndarray, torch.Tensor],
|
||||
state: Any = None,
|
||||
info: Dict[str, Any] = {},
|
||||
) -> Tuple[Tuple[torch.Tensor, torch.Tensor], Any]:
|
||||
"""Mapping: s -> logits -> (mu, sigma)."""
|
||||
logits, h = self.preprocess(s, state)
|
||||
"""Mapping: obs -> logits -> (mu, sigma)."""
|
||||
logits, hidden = self.preprocess(obs, state)
|
||||
mu = self.mu(logits)
|
||||
if not self._unbounded:
|
||||
mu = self._max * torch.tanh(mu)
|
||||
@ -252,30 +252,34 @@ class RecurrentActorProb(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
s: Union[np.ndarray, torch.Tensor],
|
||||
obs: Union[np.ndarray, torch.Tensor],
|
||||
state: Optional[Dict[str, torch.Tensor]] = None,
|
||||
info: Dict[str, Any] = {},
|
||||
) -> Tuple[Tuple[torch.Tensor, torch.Tensor], Dict[str, torch.Tensor]]:
|
||||
"""Almost the same as :class:`~tianshou.utils.net.common.Recurrent`."""
|
||||
s = torch.as_tensor(s, device=self.device, dtype=torch.float32) # type: ignore
|
||||
# s [bsz, len, dim] (training) or [bsz, dim] (evaluation)
|
||||
obs = torch.as_tensor(
|
||||
obs,
|
||||
device=self.device, # type: ignore
|
||||
dtype=torch.float32,
|
||||
)
|
||||
# obs [bsz, len, dim] (training) or [bsz, dim] (evaluation)
|
||||
# In short, the tensor's shape in training phase is longer than which
|
||||
# in evaluation phase.
|
||||
if len(s.shape) == 2:
|
||||
s = s.unsqueeze(-2)
|
||||
if len(obs.shape) == 2:
|
||||
obs = obs.unsqueeze(-2)
|
||||
self.nn.flatten_parameters()
|
||||
if state is None:
|
||||
s, (h, c) = self.nn(s)
|
||||
obs, (hidden, cell) = self.nn(obs)
|
||||
else:
|
||||
# we store the stack data in [bsz, len, ...] format
|
||||
# but pytorch rnn needs [len, bsz, ...]
|
||||
s, (h, c) = self.nn(
|
||||
s, (
|
||||
state["h"].transpose(0, 1).contiguous(),
|
||||
state["c"].transpose(0, 1).contiguous()
|
||||
obs, (hidden, cell) = self.nn(
|
||||
obs, (
|
||||
state["hidden"].transpose(0, 1).contiguous(),
|
||||
state["cell"].transpose(0, 1).contiguous()
|
||||
)
|
||||
)
|
||||
logits = s[:, -1]
|
||||
logits = obs[:, -1]
|
||||
mu = self.mu(logits)
|
||||
if not self._unbounded:
|
||||
mu = self._max * torch.tanh(mu)
|
||||
@ -287,8 +291,8 @@ class RecurrentActorProb(nn.Module):
|
||||
sigma = (self.sigma_param.view(shape) + torch.zeros_like(mu)).exp()
|
||||
# please ensure the first dim is batch size: [bsz, len, ...]
|
||||
return (mu, sigma), {
|
||||
"h": h.transpose(0, 1).detach(),
|
||||
"c": c.transpose(0, 1).detach()
|
||||
"hidden": hidden.transpose(0, 1).detach(),
|
||||
"cell": cell.transpose(0, 1).detach()
|
||||
}
|
||||
|
||||
|
||||
@ -321,28 +325,32 @@ class RecurrentCritic(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
s: Union[np.ndarray, torch.Tensor],
|
||||
a: Optional[Union[np.ndarray, torch.Tensor]] = None,
|
||||
obs: Union[np.ndarray, torch.Tensor],
|
||||
act: Optional[Union[np.ndarray, torch.Tensor]] = None,
|
||||
info: Dict[str, Any] = {},
|
||||
) -> torch.Tensor:
|
||||
"""Almost the same as :class:`~tianshou.utils.net.common.Recurrent`."""
|
||||
s = torch.as_tensor(s, device=self.device, dtype=torch.float32) # type: ignore
|
||||
# s [bsz, len, dim] (training) or [bsz, dim] (evaluation)
|
||||
obs = torch.as_tensor(
|
||||
obs,
|
||||
device=self.device, # type: ignore
|
||||
dtype=torch.float32,
|
||||
)
|
||||
# obs [bsz, len, dim] (training) or [bsz, dim] (evaluation)
|
||||
# In short, the tensor's shape in training phase is longer than which
|
||||
# in evaluation phase.
|
||||
assert len(s.shape) == 3
|
||||
assert len(obs.shape) == 3
|
||||
self.nn.flatten_parameters()
|
||||
s, (h, c) = self.nn(s)
|
||||
s = s[:, -1]
|
||||
if a is not None:
|
||||
a = torch.as_tensor(
|
||||
a,
|
||||
obs, (hidden, cell) = self.nn(obs)
|
||||
obs = obs[:, -1]
|
||||
if act is not None:
|
||||
act = torch.as_tensor(
|
||||
act,
|
||||
device=self.device, # type: ignore
|
||||
dtype=torch.float32,
|
||||
)
|
||||
s = torch.cat([s, a], dim=1)
|
||||
s = self.fc2(s)
|
||||
return s
|
||||
obs = torch.cat([obs, act], dim=1)
|
||||
obs = self.fc2(obs)
|
||||
return obs
|
||||
|
||||
|
||||
class Perturbation(nn.Module):
|
||||
@ -381,9 +389,9 @@ class Perturbation(nn.Module):
|
||||
def forward(self, state: torch.Tensor, action: torch.Tensor) -> torch.Tensor:
|
||||
# preprocess_net
|
||||
logits = self.preprocess_net(torch.cat([state, action], -1))[0]
|
||||
a = self.phi * self.max_action * torch.tanh(logits)
|
||||
noise = self.phi * self.max_action * torch.tanh(logits)
|
||||
# clip to [-max_action, max_action]
|
||||
return (a + action).clamp(-self.max_action, self.max_action)
|
||||
return (noise + action).clamp(-self.max_action, self.max_action)
|
||||
|
||||
|
||||
class VAE(nn.Module):
|
||||
@ -434,31 +442,32 @@ class VAE(nn.Module):
|
||||
self, state: torch.Tensor, action: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
# [state, action] -> z , [state, z] -> action
|
||||
z = self.encoder(torch.cat([state, action], -1))
|
||||
latent_z = self.encoder(torch.cat([state, action], -1))
|
||||
# shape of z: (state.shape[:-1], hidden_dim)
|
||||
|
||||
mean = self.mean(z)
|
||||
mean = self.mean(latent_z)
|
||||
# Clamped for numerical stability
|
||||
log_std = self.log_std(z).clamp(-4, 15)
|
||||
log_std = self.log_std(latent_z).clamp(-4, 15)
|
||||
std = torch.exp(log_std)
|
||||
# shape of mean, std: (state.shape[:-1], latent_dim)
|
||||
|
||||
z = mean + std * torch.randn_like(std) # (state.shape[:-1], latent_dim)
|
||||
latent_z = mean + std * torch.randn_like(std) # (state.shape[:-1], latent_dim)
|
||||
|
||||
u = self.decode(state, z) # (state.shape[:-1], action_dim)
|
||||
return u, mean, std
|
||||
reconstruction = self.decode(state, latent_z) # (state.shape[:-1], action_dim)
|
||||
return reconstruction, mean, std
|
||||
|
||||
def decode(
|
||||
self,
|
||||
state: torch.Tensor,
|
||||
z: Union[torch.Tensor, None] = None
|
||||
latent_z: Union[torch.Tensor, None] = None
|
||||
) -> torch.Tensor:
|
||||
# decode(state) -> action
|
||||
if z is None:
|
||||
if latent_z is None:
|
||||
# state.shape[0] may be batch_size
|
||||
# latent vector clipped to [-0.5, 0.5]
|
||||
z = torch.randn(state.shape[:-1] + (self.latent_dim, )) \
|
||||
latent_z = torch.randn(state.shape[:-1] + (self.latent_dim, )) \
|
||||
.to(self.device).clamp(-0.5, 0.5)
|
||||
|
||||
# decode z with state!
|
||||
return self.max_action * torch.tanh(self.decoder(torch.cat([state, z], -1)))
|
||||
return self.max_action * \
|
||||
torch.tanh(self.decoder(torch.cat([state, latent_z], -1)))
|
||||
|
@ -59,16 +59,16 @@ class Actor(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
s: Union[np.ndarray, torch.Tensor],
|
||||
obs: Union[np.ndarray, torch.Tensor],
|
||||
state: Any = None,
|
||||
info: Dict[str, Any] = {},
|
||||
) -> Tuple[torch.Tensor, Any]:
|
||||
r"""Mapping: s -> Q(s, \*)."""
|
||||
logits, h = self.preprocess(s, state)
|
||||
logits, hidden = self.preprocess(obs, state)
|
||||
logits = self.last(logits)
|
||||
if self.softmax_output:
|
||||
logits = F.softmax(logits, dim=-1)
|
||||
return logits, h
|
||||
return logits, hidden
|
||||
|
||||
|
||||
class Critic(nn.Module):
|
||||
@ -114,10 +114,10 @@ class Critic(nn.Module):
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, s: Union[np.ndarray, torch.Tensor], **kwargs: Any
|
||||
self, obs: Union[np.ndarray, torch.Tensor], **kwargs: Any
|
||||
) -> torch.Tensor:
|
||||
"""Mapping: s -> V(s)."""
|
||||
logits, _ = self.preprocess(s, state=kwargs.get("state", None))
|
||||
logits, _ = self.preprocess(obs, state=kwargs.get("state", None))
|
||||
return self.last(logits)
|
||||
|
||||
|
||||
@ -199,10 +199,10 @@ class ImplicitQuantileNetwork(Critic):
|
||||
).to(device)
|
||||
|
||||
def forward( # type: ignore
|
||||
self, s: Union[np.ndarray, torch.Tensor], sample_size: int, **kwargs: Any
|
||||
self, obs: Union[np.ndarray, torch.Tensor], sample_size: int, **kwargs: Any
|
||||
) -> Tuple[Any, torch.Tensor]:
|
||||
r"""Mapping: s -> Q(s, \*)."""
|
||||
logits, h = self.preprocess(s, state=kwargs.get("state", None))
|
||||
logits, hidden = self.preprocess(obs, state=kwargs.get("state", None))
|
||||
# Sample fractions.
|
||||
batch_size = logits.size(0)
|
||||
taus = torch.rand(
|
||||
@ -211,7 +211,7 @@ class ImplicitQuantileNetwork(Critic):
|
||||
embedding = (logits.unsqueeze(1) *
|
||||
self.embed_model(taus)).view(batch_size * sample_size, -1)
|
||||
out = self.last(embedding).view(batch_size, sample_size, -1).transpose(1, 2)
|
||||
return (out, taus), h
|
||||
return (out, taus), hidden
|
||||
|
||||
|
||||
class FractionProposalNetwork(nn.Module):
|
||||
@ -235,17 +235,17 @@ class FractionProposalNetwork(nn.Module):
|
||||
self.embedding_dim = embedding_dim
|
||||
|
||||
def forward(
|
||||
self, state_embeddings: torch.Tensor
|
||||
self, obs_embeddings: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
# Calculate (log of) probabilities q_i in the paper.
|
||||
m = torch.distributions.Categorical(logits=self.net(state_embeddings))
|
||||
taus_1_N = torch.cumsum(m.probs, dim=1)
|
||||
dist = torch.distributions.Categorical(logits=self.net(obs_embeddings))
|
||||
taus_1_N = torch.cumsum(dist.probs, dim=1)
|
||||
# Calculate \tau_i (i=0,...,N).
|
||||
taus = F.pad(taus_1_N, (1, 0))
|
||||
# Calculate \hat \tau_i (i=0,...,N-1).
|
||||
tau_hats = (taus[:, :-1] + taus[:, 1:]).detach() / 2.0
|
||||
# Calculate entropies of value distributions.
|
||||
entropies = m.entropy()
|
||||
entropies = dist.entropy()
|
||||
return taus, tau_hats, entropies
|
||||
|
||||
|
||||
@ -294,13 +294,13 @@ class FullQuantileFunction(ImplicitQuantileNetwork):
|
||||
return quantiles
|
||||
|
||||
def forward( # type: ignore
|
||||
self, s: Union[np.ndarray, torch.Tensor],
|
||||
self, obs: Union[np.ndarray, torch.Tensor],
|
||||
propose_model: FractionProposalNetwork,
|
||||
fractions: Optional[Batch] = None,
|
||||
**kwargs: Any
|
||||
) -> Tuple[Any, torch.Tensor]:
|
||||
r"""Mapping: s -> Q(s, \*)."""
|
||||
logits, h = self.preprocess(s, state=kwargs.get("state", None))
|
||||
logits, hidden = self.preprocess(obs, state=kwargs.get("state", None))
|
||||
# Propose fractions
|
||||
if fractions is None:
|
||||
taus, tau_hats, entropies = propose_model(logits.detach())
|
||||
@ -313,7 +313,7 @@ class FullQuantileFunction(ImplicitQuantileNetwork):
|
||||
if self.training:
|
||||
with torch.no_grad():
|
||||
quantiles_tau = self._compute_quantiles(logits, taus[:, 1:-1])
|
||||
return (quantiles, fractions, quantiles_tau), h
|
||||
return (quantiles, fractions, quantiles_tau), hidden
|
||||
|
||||
|
||||
class NoisyLinear(nn.Module):
|
||||
|
@ -31,20 +31,20 @@ class MovAvg(object):
|
||||
self.banned = [np.inf, np.nan, -np.inf]
|
||||
|
||||
def add(
|
||||
self, x: Union[Number, np.number, list, np.ndarray, torch.Tensor]
|
||||
self, data_array: Union[Number, np.number, list, np.ndarray, torch.Tensor]
|
||||
) -> float:
|
||||
"""Add a scalar into :class:`MovAvg`.
|
||||
|
||||
You can add ``torch.Tensor`` with only one element, a python scalar, or
|
||||
a list of python scalar.
|
||||
"""
|
||||
if isinstance(x, torch.Tensor):
|
||||
x = x.flatten().cpu().numpy()
|
||||
if np.isscalar(x):
|
||||
x = [x]
|
||||
for i in x: # type: ignore
|
||||
if i not in self.banned:
|
||||
self.cache.append(i)
|
||||
if isinstance(data_array, torch.Tensor):
|
||||
data_array = data_array.flatten().cpu().numpy()
|
||||
if np.isscalar(data_array):
|
||||
data_array = [data_array]
|
||||
for number in data_array: # type: ignore
|
||||
if number not in self.banned:
|
||||
self.cache.append(number)
|
||||
if self.size > 0 and len(self.cache) > self.size:
|
||||
self.cache = self.cache[-self.size:]
|
||||
return self.get()
|
||||
@ -80,10 +80,10 @@ class RunningMeanStd(object):
|
||||
self.mean, self.var = mean, std
|
||||
self.count = 0
|
||||
|
||||
def update(self, x: np.ndarray) -> None:
|
||||
def update(self, data_array: np.ndarray) -> None:
|
||||
"""Add a batch of item into RMS with the same shape, modify mean/var/count."""
|
||||
batch_mean, batch_var = np.mean(x, axis=0), np.var(x, axis=0)
|
||||
batch_count = len(x)
|
||||
batch_mean, batch_var = np.mean(data_array, axis=0), np.var(data_array, axis=0)
|
||||
batch_count = len(data_array)
|
||||
|
||||
delta = batch_mean - self.mean
|
||||
total_count = self.count + batch_count
|
||||
|
Loading…
x
Reference in New Issue
Block a user