Robust conversion from/to numpy/pytorch (#63)
* Enable to convert Batch data back to torch. * Add torch converter to collector. * Fix * Move to_numpy/to_torch convert in dedicated utils.py. * Use to_numpy/to_torch to convert arrays. * fix lint * fix * Add unit test to check Batch from/to numpy. * Fix Batch over Batch. Co-authored-by: Alexis Duburcq <alexis.duburcq@wandercraft.eu>
This commit is contained in:
parent
b5093ecb56
commit
8af7196a9a
@ -1,4 +1,5 @@
|
|||||||
import pytest
|
import pytest
|
||||||
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from tianshou.data import Batch
|
from tianshou.data import Batch
|
||||||
@ -29,6 +30,20 @@ def test_batch_over_batch():
|
|||||||
assert batch2[-1].b.b == 0
|
assert batch2[-1].b.b == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_batch_from_to_numpy_without_copy():
|
||||||
|
batch = Batch(a=np.ones((1,)), b=Batch(c=np.ones((1,))))
|
||||||
|
a_mem_addr_orig = batch["a"].__array_interface__['data'][0]
|
||||||
|
c_mem_addr_orig = batch["b"]["c"].__array_interface__['data'][0]
|
||||||
|
batch.to_torch()
|
||||||
|
assert isinstance(batch["a"], torch.Tensor)
|
||||||
|
assert isinstance(batch["b"]["c"], torch.Tensor)
|
||||||
|
batch.to_numpy()
|
||||||
|
a_mem_addr_new = batch["a"].__array_interface__['data'][0]
|
||||||
|
c_mem_addr_new = batch["b"]["c"].__array_interface__['data'][0]
|
||||||
|
assert a_mem_addr_new == a_mem_addr_orig
|
||||||
|
assert c_mem_addr_new == c_mem_addr_orig
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test_batch()
|
test_batch()
|
||||||
test_batch_over_batch()
|
test_batch_over_batch()
|
||||||
|
@ -1,12 +1,15 @@
|
|||||||
from tianshou.data.batch import Batch
|
from tianshou.data.batch import Batch
|
||||||
|
from tianshou.data.utils import to_numpy, to_torch
|
||||||
from tianshou.data.buffer import ReplayBuffer, \
|
from tianshou.data.buffer import ReplayBuffer, \
|
||||||
ListReplayBuffer, PrioritizedReplayBuffer
|
ListReplayBuffer, PrioritizedReplayBuffer
|
||||||
from tianshou.data.collector import Collector
|
from tianshou.data.collector import Collector
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'Batch',
|
'Batch',
|
||||||
|
'to_numpy',
|
||||||
|
'to_torch',
|
||||||
'ReplayBuffer',
|
'ReplayBuffer',
|
||||||
'ListReplayBuffer',
|
'ListReplayBuffer',
|
||||||
'PrioritizedReplayBuffer',
|
'PrioritizedReplayBuffer',
|
||||||
'Collector',
|
'Collector'
|
||||||
]
|
]
|
||||||
|
@ -141,13 +141,31 @@ class Batch:
|
|||||||
return self.__getattr__(k)
|
return self.__getattr__(k)
|
||||||
return d
|
return d
|
||||||
|
|
||||||
def to_numpy(self) -> np.ndarray:
|
def to_numpy(self) -> None:
|
||||||
"""Change all torch.Tensor to numpy.ndarray. This is an inplace
|
"""Change all torch.Tensor to numpy.ndarray. This is an inplace
|
||||||
operation.
|
operation.
|
||||||
"""
|
"""
|
||||||
for k, v in self.__dict__.items():
|
for k, v in self.__dict__.items():
|
||||||
if isinstance(v, torch.Tensor):
|
if isinstance(v, torch.Tensor):
|
||||||
self.__dict__[k] = v.cpu().numpy()
|
self.__dict__[k] = v.cpu().numpy()
|
||||||
|
elif isinstance(v, Batch):
|
||||||
|
v.to_numpy()
|
||||||
|
|
||||||
|
def to_torch(self,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
device: Union[str, int] = 'cpu'
|
||||||
|
) -> None:
|
||||||
|
"""Change all numpy.ndarray to torch.Tensor. This is an inplace
|
||||||
|
operation.
|
||||||
|
"""
|
||||||
|
for k, v in self.__dict__.items():
|
||||||
|
if isinstance(v, np.ndarray):
|
||||||
|
v = torch.from_numpy(v).to(device)
|
||||||
|
if dtype is not None:
|
||||||
|
v = v.type(dtype)
|
||||||
|
self.__dict__[k] = v
|
||||||
|
elif isinstance(v, Batch):
|
||||||
|
v.to_torch()
|
||||||
|
|
||||||
def append(self, batch: 'Batch') -> None:
|
def append(self, batch: 'Batch') -> None:
|
||||||
"""Append a :class:`~tianshou.data.Batch` object to current batch."""
|
"""Append a :class:`~tianshou.data.Batch` object to current batch."""
|
||||||
|
@ -8,7 +8,7 @@ from typing import Any, Dict, List, Union, Optional, Callable
|
|||||||
from tianshou.utils import MovAvg
|
from tianshou.utils import MovAvg
|
||||||
from tianshou.env import BaseVectorEnv
|
from tianshou.env import BaseVectorEnv
|
||||||
from tianshou.policy import BasePolicy
|
from tianshou.policy import BasePolicy
|
||||||
from tianshou.data import Batch, ReplayBuffer, ListReplayBuffer
|
from tianshou.data import Batch, ReplayBuffer, ListReplayBuffer, to_numpy
|
||||||
|
|
||||||
|
|
||||||
class Collector(object):
|
class Collector(object):
|
||||||
@ -201,21 +201,6 @@ class Collector(object):
|
|||||||
elif isinstance(self.state, (torch.Tensor, np.ndarray)):
|
elif isinstance(self.state, (torch.Tensor, np.ndarray)):
|
||||||
self.state[id] = 0
|
self.state[id] = 0
|
||||||
|
|
||||||
def _to_numpy(self, x: Union[
|
|
||||||
torch.Tensor, dict, Batch, np.ndarray]) -> None:
|
|
||||||
"""Return an object without torch.Tensor."""
|
|
||||||
if isinstance(x, torch.Tensor):
|
|
||||||
return x.cpu().numpy()
|
|
||||||
elif isinstance(x, dict):
|
|
||||||
for k in x:
|
|
||||||
if isinstance(x[k], torch.Tensor):
|
|
||||||
x[k] = x[k].cpu().numpy()
|
|
||||||
return x
|
|
||||||
elif isinstance(x, Batch):
|
|
||||||
x.to_numpy()
|
|
||||||
return x
|
|
||||||
return x
|
|
||||||
|
|
||||||
def collect(self,
|
def collect(self,
|
||||||
n_step: int = 0,
|
n_step: int = 0,
|
||||||
n_episode: Union[int, List[int]] = 0,
|
n_episode: Union[int, List[int]] = 0,
|
||||||
@ -270,9 +255,9 @@ class Collector(object):
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
result = self.policy(batch, self.state)
|
result = self.policy(batch, self.state)
|
||||||
self.state = result.get('state', None)
|
self.state = result.get('state', None)
|
||||||
self._policy = self._to_numpy(result.policy) \
|
self._policy = to_numpy(result.policy) \
|
||||||
if hasattr(result, 'policy') else [{}] * self.env_num
|
if hasattr(result, 'policy') else [{}] * self.env_num
|
||||||
self._act = self._to_numpy(result.act)
|
self._act = to_numpy(result.act)
|
||||||
obs_next, self._rew, self._done, self._info = self.env.step(
|
obs_next, self._rew, self._done, self._info = self.env.step(
|
||||||
self._act if self._multi_env else self._act[0])
|
self._act if self._multi_env else self._act[0])
|
||||||
if not self._multi_env:
|
if not self._multi_env:
|
||||||
|
36
tianshou/data/utils.py
Normal file
36
tianshou/data/utils.py
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
from typing import Union, Optional
|
||||||
|
|
||||||
|
from tianshou.data import Batch
|
||||||
|
|
||||||
|
|
||||||
|
def to_numpy(x: Union[
|
||||||
|
torch.Tensor, dict, Batch, np.ndarray]) -> Union[
|
||||||
|
dict, Batch, np.ndarray]:
|
||||||
|
"""Return an object without torch.Tensor."""
|
||||||
|
if isinstance(x, torch.Tensor):
|
||||||
|
x = x.detach().cpu().numpy()
|
||||||
|
elif isinstance(x, dict):
|
||||||
|
for k, v in x.items():
|
||||||
|
x[k] = to_numpy(v)
|
||||||
|
elif isinstance(x, Batch):
|
||||||
|
x.to_numpy()
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def to_torch(x: Union[torch.Tensor, dict, Batch, np.ndarray],
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
device: Union[str, int] = 'cpu'
|
||||||
|
) -> Union[dict, Batch, torch.Tensor]:
|
||||||
|
"""Return an object without np.ndarray."""
|
||||||
|
if isinstance(x, np.ndarray):
|
||||||
|
x = torch.from_numpy(x).to(device)
|
||||||
|
if dtype is not None:
|
||||||
|
x = x.type(dtype)
|
||||||
|
elif isinstance(x, dict):
|
||||||
|
for k, v in x.items():
|
||||||
|
x[k] = to_torch(v, dtype, device)
|
||||||
|
elif isinstance(x, Batch):
|
||||||
|
x.to_torch()
|
||||||
|
return x
|
@ -3,7 +3,7 @@ import numpy as np
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from typing import Dict, Union, Optional
|
from typing import Dict, Union, Optional
|
||||||
|
|
||||||
from tianshou.data import Batch
|
from tianshou.data import Batch, to_torch
|
||||||
from tianshou.policy import BasePolicy
|
from tianshou.policy import BasePolicy
|
||||||
|
|
||||||
|
|
||||||
@ -46,11 +46,11 @@ class ImitationPolicy(BasePolicy):
|
|||||||
self.optim.zero_grad()
|
self.optim.zero_grad()
|
||||||
if self.mode == 'continuous':
|
if self.mode == 'continuous':
|
||||||
a = self(batch).act
|
a = self(batch).act
|
||||||
a_ = torch.tensor(batch.act, dtype=torch.float, device=a.device)
|
a_ = to_torch(batch.act, dtype=torch.float, device=a.device)
|
||||||
loss = F.mse_loss(a, a_)
|
loss = F.mse_loss(a, a_)
|
||||||
elif self.mode == 'discrete': # classification
|
elif self.mode == 'discrete': # classification
|
||||||
a = self(batch).logits
|
a = self(batch).logits
|
||||||
a_ = torch.tensor(batch.act, dtype=torch.long, device=a.device)
|
a_ = to_torch(batch.act, dtype=torch.long, device=a.device)
|
||||||
loss = F.nll_loss(a, a_)
|
loss = F.nll_loss(a, a_)
|
||||||
loss.backward()
|
loss.backward()
|
||||||
self.optim.step()
|
self.optim.step()
|
||||||
|
@ -5,7 +5,7 @@ import torch.nn.functional as F
|
|||||||
from typing import Dict, List, Union, Optional
|
from typing import Dict, List, Union, Optional
|
||||||
|
|
||||||
from tianshou.policy import PGPolicy
|
from tianshou.policy import PGPolicy
|
||||||
from tianshou.data import Batch, ReplayBuffer
|
from tianshou.data import Batch, ReplayBuffer, to_torch, to_numpy
|
||||||
|
|
||||||
|
|
||||||
class A2CPolicy(PGPolicy):
|
class A2CPolicy(PGPolicy):
|
||||||
@ -64,7 +64,7 @@ class A2CPolicy(PGPolicy):
|
|||||||
v_ = []
|
v_ = []
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for b in batch.split(self._batch, shuffle=False):
|
for b in batch.split(self._batch, shuffle=False):
|
||||||
v_.append(self.critic(b.obs_next).detach().cpu().numpy())
|
v_.append(to_numpy(self.critic(b.obs_next)))
|
||||||
v_ = np.concatenate(v_, axis=0)
|
v_ = np.concatenate(v_, axis=0)
|
||||||
return self.compute_episodic_return(
|
return self.compute_episodic_return(
|
||||||
batch, v_, gamma=self._gamma, gae_lambda=self._lambda)
|
batch, v_, gamma=self._gamma, gae_lambda=self._lambda)
|
||||||
@ -106,8 +106,8 @@ class A2CPolicy(PGPolicy):
|
|||||||
self.optim.zero_grad()
|
self.optim.zero_grad()
|
||||||
dist = self(b).dist
|
dist = self(b).dist
|
||||||
v = self.critic(b.obs)
|
v = self.critic(b.obs)
|
||||||
a = torch.tensor(b.act, device=v.device)
|
a = to_torch(b.act, device=v.device)
|
||||||
r = torch.tensor(b.returns, device=v.device)
|
r = to_torch(b.returns, device=v.device)
|
||||||
a_loss = -(dist.log_prob(a) * (r - v).detach()).mean()
|
a_loss = -(dist.log_prob(a) * (r - v).detach()).mean()
|
||||||
vf_loss = F.mse_loss(r[:, None], v)
|
vf_loss = F.mse_loss(r[:, None], v)
|
||||||
ent_loss = dist.entropy().mean()
|
ent_loss = dist.entropy().mean()
|
||||||
|
@ -6,7 +6,7 @@ from typing import Dict, Tuple, Union, Optional
|
|||||||
|
|
||||||
from tianshou.policy import BasePolicy
|
from tianshou.policy import BasePolicy
|
||||||
# from tianshou.exploration import OUNoise
|
# from tianshou.exploration import OUNoise
|
||||||
from tianshou.data import Batch, ReplayBuffer
|
from tianshou.data import Batch, ReplayBuffer, to_torch
|
||||||
|
|
||||||
|
|
||||||
class DDPGPolicy(BasePolicy):
|
class DDPGPolicy(BasePolicy):
|
||||||
@ -135,7 +135,7 @@ class DDPGPolicy(BasePolicy):
|
|||||||
eps = self._eps
|
eps = self._eps
|
||||||
if eps > 0:
|
if eps > 0:
|
||||||
# noise = np.random.normal(0, eps, size=logits.shape)
|
# noise = np.random.normal(0, eps, size=logits.shape)
|
||||||
# logits += torch.tensor(noise, device=logits.device)
|
# logits += to_torch(noise, device=logits.device)
|
||||||
# noise = self.noise(logits.shape, eps)
|
# noise = self.noise(logits.shape, eps)
|
||||||
logits += torch.randn(
|
logits += torch.randn(
|
||||||
size=logits.shape, device=logits.device) * eps
|
size=logits.shape, device=logits.device) * eps
|
||||||
@ -147,10 +147,10 @@ class DDPGPolicy(BasePolicy):
|
|||||||
target_q = self.critic_old(batch.obs_next, self(
|
target_q = self.critic_old(batch.obs_next, self(
|
||||||
batch, model='actor_old', input='obs_next', eps=0).act)
|
batch, model='actor_old', input='obs_next', eps=0).act)
|
||||||
dev = target_q.device
|
dev = target_q.device
|
||||||
rew = torch.tensor(batch.rew,
|
rew = to_torch(batch.rew,
|
||||||
dtype=torch.float, device=dev)[:, None]
|
dtype=torch.float, device=dev)[:, None]
|
||||||
done = torch.tensor(batch.done,
|
done = to_torch(batch.done,
|
||||||
dtype=torch.float, device=dev)[:, None]
|
dtype=torch.float, device=dev)[:, None]
|
||||||
target_q = (rew + (1. - done) * self._gamma * target_q)
|
target_q = (rew + (1. - done) * self._gamma * target_q)
|
||||||
current_q = self.critic(batch.obs, batch.act)
|
current_q = self.critic(batch.obs, batch.act)
|
||||||
critic_loss = F.mse_loss(current_q, target_q)
|
critic_loss = F.mse_loss(current_q, target_q)
|
||||||
|
@ -5,7 +5,8 @@ import torch.nn.functional as F
|
|||||||
from typing import Dict, Union, Optional
|
from typing import Dict, Union, Optional
|
||||||
|
|
||||||
from tianshou.policy import BasePolicy
|
from tianshou.policy import BasePolicy
|
||||||
from tianshou.data import Batch, ReplayBuffer, PrioritizedReplayBuffer
|
from tianshou.data import Batch, ReplayBuffer, PrioritizedReplayBuffer, \
|
||||||
|
to_torch, to_numpy
|
||||||
|
|
||||||
|
|
||||||
class DQNPolicy(BasePolicy):
|
class DQNPolicy(BasePolicy):
|
||||||
@ -96,12 +97,12 @@ class DQNPolicy(BasePolicy):
|
|||||||
target_q = self(
|
target_q = self(
|
||||||
terminal_data, model='model_old', input='obs_next').logits
|
terminal_data, model='model_old', input='obs_next').logits
|
||||||
if isinstance(target_q, torch.Tensor):
|
if isinstance(target_q, torch.Tensor):
|
||||||
target_q = target_q.detach().cpu().numpy()
|
target_q = to_numpy(target_q)
|
||||||
target_q = target_q[np.arange(len(a)), a]
|
target_q = target_q[np.arange(len(a)), a]
|
||||||
else:
|
else:
|
||||||
target_q = self(terminal_data, input='obs_next').logits
|
target_q = self(terminal_data, input='obs_next').logits
|
||||||
if isinstance(target_q, torch.Tensor):
|
if isinstance(target_q, torch.Tensor):
|
||||||
target_q = target_q.detach().cpu().numpy()
|
target_q = to_numpy(target_q)
|
||||||
target_q = target_q.max(axis=1)
|
target_q = target_q.max(axis=1)
|
||||||
target_q[gammas != self._n_step] = 0
|
target_q[gammas != self._n_step] = 0
|
||||||
returns += (self._gamma ** gammas) * target_q
|
returns += (self._gamma ** gammas) * target_q
|
||||||
@ -111,11 +112,11 @@ class DQNPolicy(BasePolicy):
|
|||||||
q = q[np.arange(len(q)), batch.act]
|
q = q[np.arange(len(q)), batch.act]
|
||||||
r = batch.returns
|
r = batch.returns
|
||||||
if isinstance(r, np.ndarray):
|
if isinstance(r, np.ndarray):
|
||||||
r = torch.tensor(r, device=q.device, dtype=q.dtype)
|
r = to_torch(r, device=q.device, dtype=q.dtype)
|
||||||
td = r - q
|
td = r - q
|
||||||
buffer.update_weight(indice, td.detach().cpu().numpy())
|
buffer.update_weight(indice, to_numpy(td))
|
||||||
impt_weight = torch.tensor(batch.impt_weight,
|
impt_weight = to_torch(batch.impt_weight,
|
||||||
device=q.device, dtype=torch.float)
|
device=q.device, dtype=torch.float)
|
||||||
loss = (td.pow(2) * impt_weight).mean()
|
loss = (td.pow(2) * impt_weight).mean()
|
||||||
if not hasattr(batch, 'loss'):
|
if not hasattr(batch, 'loss'):
|
||||||
batch.loss = loss
|
batch.loss = loss
|
||||||
@ -147,7 +148,7 @@ class DQNPolicy(BasePolicy):
|
|||||||
model = getattr(self, model)
|
model = getattr(self, model)
|
||||||
obs = getattr(batch, input)
|
obs = getattr(batch, input)
|
||||||
q, h = model(obs, state=state, info=batch.info)
|
q, h = model(obs, state=state, info=batch.info)
|
||||||
act = q.max(dim=1)[1].detach().cpu().numpy()
|
act = to_numpy(q.max(dim=1)[1])
|
||||||
# add eps to act
|
# add eps to act
|
||||||
if eps is None:
|
if eps is None:
|
||||||
eps = self.eps
|
eps = self.eps
|
||||||
@ -168,7 +169,7 @@ class DQNPolicy(BasePolicy):
|
|||||||
q = q[np.arange(len(q)), batch.act]
|
q = q[np.arange(len(q)), batch.act]
|
||||||
r = batch.returns
|
r = batch.returns
|
||||||
if isinstance(r, np.ndarray):
|
if isinstance(r, np.ndarray):
|
||||||
r = torch.tensor(r, device=q.device, dtype=q.dtype)
|
r = to_torch(r, device=q.device, dtype=q.dtype)
|
||||||
loss = F.mse_loss(q, r)
|
loss = F.mse_loss(q, r)
|
||||||
loss.backward()
|
loss.backward()
|
||||||
self.optim.step()
|
self.optim.step()
|
||||||
|
@ -3,7 +3,7 @@ import numpy as np
|
|||||||
from typing import Dict, List, Union, Optional
|
from typing import Dict, List, Union, Optional
|
||||||
|
|
||||||
from tianshou.policy import BasePolicy
|
from tianshou.policy import BasePolicy
|
||||||
from tianshou.data import Batch, ReplayBuffer
|
from tianshou.data import Batch, ReplayBuffer, to_torch
|
||||||
|
|
||||||
|
|
||||||
class PGPolicy(BasePolicy):
|
class PGPolicy(BasePolicy):
|
||||||
@ -88,8 +88,8 @@ class PGPolicy(BasePolicy):
|
|||||||
for b in batch.split(batch_size):
|
for b in batch.split(batch_size):
|
||||||
self.optim.zero_grad()
|
self.optim.zero_grad()
|
||||||
dist = self(b).dist
|
dist = self(b).dist
|
||||||
a = torch.tensor(b.act, device=dist.logits.device)
|
a = to_torch(b.act, device=dist.logits.device)
|
||||||
r = torch.tensor(b.returns, device=dist.logits.device)
|
r = to_torch(b.returns, device=dist.logits.device)
|
||||||
loss = -(dist.log_prob(a) * r).sum()
|
loss = -(dist.log_prob(a) * r).sum()
|
||||||
loss.backward()
|
loss.backward()
|
||||||
self.optim.step()
|
self.optim.step()
|
||||||
|
@ -4,7 +4,7 @@ from torch import nn
|
|||||||
from typing import Dict, List, Tuple, Union, Optional
|
from typing import Dict, List, Tuple, Union, Optional
|
||||||
|
|
||||||
from tianshou.policy import PGPolicy
|
from tianshou.policy import PGPolicy
|
||||||
from tianshou.data import Batch, ReplayBuffer
|
from tianshou.data import Batch, ReplayBuffer, to_numpy, to_torch
|
||||||
|
|
||||||
|
|
||||||
class PPOPolicy(PGPolicy):
|
class PPOPolicy(PGPolicy):
|
||||||
@ -88,7 +88,7 @@ class PPOPolicy(PGPolicy):
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for b in batch.split(self._batch, shuffle=False):
|
for b in batch.split(self._batch, shuffle=False):
|
||||||
v_.append(self.critic(b.obs_next))
|
v_.append(self.critic(b.obs_next))
|
||||||
v_ = torch.cat(v_, dim=0).cpu().numpy()
|
v_ = to_numpy(torch.cat(v_, dim=0))
|
||||||
return self.compute_episodic_return(
|
return self.compute_episodic_return(
|
||||||
batch, v_, gamma=self._gamma, gae_lambda=self._lambda)
|
batch, v_, gamma=self._gamma, gae_lambda=self._lambda)
|
||||||
|
|
||||||
@ -129,12 +129,12 @@ class PPOPolicy(PGPolicy):
|
|||||||
for b in batch.split(batch_size, shuffle=False):
|
for b in batch.split(batch_size, shuffle=False):
|
||||||
v.append(self.critic(b.obs))
|
v.append(self.critic(b.obs))
|
||||||
old_log_prob.append(self(b).dist.log_prob(
|
old_log_prob.append(self(b).dist.log_prob(
|
||||||
torch.tensor(b.act, device=v[0].device)))
|
to_torch(b.act, device=v[0].device)))
|
||||||
batch.v = torch.cat(v, dim=0) # old value
|
batch.v = torch.cat(v, dim=0) # old value
|
||||||
dev = batch.v.device
|
dev = batch.v.device
|
||||||
batch.act = torch.tensor(batch.act, dtype=torch.float, device=dev)
|
batch.act = to_torch(batch.act, dtype=torch.float, device=dev)
|
||||||
batch.logp_old = torch.cat(old_log_prob, dim=0)
|
batch.logp_old = torch.cat(old_log_prob, dim=0)
|
||||||
batch.returns = torch.tensor(
|
batch.returns = to_torch(
|
||||||
batch.returns, dtype=torch.float, device=dev
|
batch.returns, dtype=torch.float, device=dev
|
||||||
).reshape(batch.v.shape)
|
).reshape(batch.v.shape)
|
||||||
if self._rew_norm:
|
if self._rew_norm:
|
||||||
|
@ -4,7 +4,7 @@ from copy import deepcopy
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from typing import Dict, Tuple, Union, Optional
|
from typing import Dict, Tuple, Union, Optional
|
||||||
|
|
||||||
from tianshou.data import Batch
|
from tianshou.data import Batch, to_torch
|
||||||
from tianshou.policy import DDPGPolicy
|
from tianshou.policy import DDPGPolicy
|
||||||
from tianshou.policy.dist import DiagGaussian
|
from tianshou.policy.dist import DiagGaussian
|
||||||
|
|
||||||
@ -110,15 +110,15 @@ class SACPolicy(DDPGPolicy):
|
|||||||
obs_next_result = self(batch, input='obs_next')
|
obs_next_result = self(batch, input='obs_next')
|
||||||
a_ = obs_next_result.act
|
a_ = obs_next_result.act
|
||||||
dev = a_.device
|
dev = a_.device
|
||||||
batch.act = torch.tensor(batch.act, dtype=torch.float, device=dev)
|
batch.act = to_torch(batch.act, dtype=torch.float, device=dev)
|
||||||
target_q = torch.min(
|
target_q = torch.min(
|
||||||
self.critic1_old(batch.obs_next, a_),
|
self.critic1_old(batch.obs_next, a_),
|
||||||
self.critic2_old(batch.obs_next, a_),
|
self.critic2_old(batch.obs_next, a_),
|
||||||
) - self._alpha * obs_next_result.log_prob
|
) - self._alpha * obs_next_result.log_prob
|
||||||
rew = torch.tensor(batch.rew,
|
rew = to_torch(batch.rew,
|
||||||
dtype=torch.float, device=dev)[:, None]
|
dtype=torch.float, device=dev)[:, None]
|
||||||
done = torch.tensor(batch.done,
|
done = to_torch(batch.done,
|
||||||
dtype=torch.float, device=dev)[:, None]
|
dtype=torch.float, device=dev)[:, None]
|
||||||
target_q = (rew + (1. - done) * self._gamma * target_q)
|
target_q = (rew + (1. - done) * self._gamma * target_q)
|
||||||
# critic 1
|
# critic 1
|
||||||
current_q1 = self.critic1(batch.obs, batch.act)
|
current_q1 = self.critic1(batch.obs, batch.act)
|
||||||
|
@ -3,7 +3,7 @@ from copy import deepcopy
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from typing import Dict, Tuple, Optional
|
from typing import Dict, Tuple, Optional
|
||||||
|
|
||||||
from tianshou.data import Batch
|
from tianshou.data import Batch, to_torch
|
||||||
from tianshou.policy import DDPGPolicy
|
from tianshou.policy import DDPGPolicy
|
||||||
|
|
||||||
|
|
||||||
@ -112,10 +112,10 @@ class TD3Policy(DDPGPolicy):
|
|||||||
target_q = torch.min(
|
target_q = torch.min(
|
||||||
self.critic1_old(batch.obs_next, a_),
|
self.critic1_old(batch.obs_next, a_),
|
||||||
self.critic2_old(batch.obs_next, a_))
|
self.critic2_old(batch.obs_next, a_))
|
||||||
rew = torch.tensor(batch.rew,
|
rew = to_torch(batch.rew,
|
||||||
dtype=torch.float, device=dev)[:, None]
|
dtype=torch.float, device=dev)[:, None]
|
||||||
done = torch.tensor(batch.done,
|
done = to_torch(batch.done,
|
||||||
dtype=torch.float, device=dev)[:, None]
|
dtype=torch.float, device=dev)[:, None]
|
||||||
target_q = (rew + (1. - done) * self._gamma * target_q)
|
target_q = (rew + (1. - done) * self._gamma * target_q)
|
||||||
# critic 1
|
# critic 1
|
||||||
current_q1 = self.critic1(batch.obs, batch.act)
|
current_q1 = self.critic1(batch.obs, batch.act)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user