2024-10-20 17:01:07 +08:00
|
|
|
"""
|
|
|
|
|
Training Strategy
|
|
|
|
|
supervised learning, imitation learning, testing, rollout
|
|
|
|
|
"""
|
|
|
|
|
import time
|
|
|
|
|
from copy import deepcopy
|
|
|
|
|
import os
|
|
|
|
|
import random
|
|
|
|
|
import cv2
|
|
|
|
|
import numpy as np
|
|
|
|
|
import torch as th
|
|
|
|
|
from torch.nn import functional as F
|
|
|
|
|
from stable_baselines3.common.type_aliases import RolloutReturn, TrainFreq, TrainFrequencyUnit
|
2024-10-20 23:40:36 +08:00
|
|
|
from stable_baselines3.common.utils import should_collect_more_steps, get_schedule_fn, configure_logger, update_learning_rate
|
2024-10-20 17:01:07 +08:00
|
|
|
from stable_baselines3.common.vec_env import VecEnv
|
|
|
|
|
from stable_baselines3.common.utils import get_device
|
|
|
|
|
|
|
|
|
|
# -----------
|
|
|
|
|
from flightpolicy.yopo.yopo_policy import YopoPolicy
|
|
|
|
|
from flightpolicy.yopo.dataloader import YopoDataset
|
|
|
|
|
from torch.utils.data import DataLoader
|
|
|
|
|
from flightpolicy.yopo.primitive_utils import transform, rotate, transform_inv, rotate_inv
|
|
|
|
|
from flightpolicy.yopo.primitive_utils import LatticeParam, LatticePrimitive
|
|
|
|
|
from flightpolicy.yopo.buffers import ReplayBuffer
|
|
|
|
|
from ruamel.yaml import YAML
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class YopoAlgorithm:
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
env=None,
|
|
|
|
|
learning_rate=0.001,
|
|
|
|
|
is_imitation=False,
|
|
|
|
|
buffer_size=1_000_000,
|
|
|
|
|
learning_starts=100,
|
|
|
|
|
batch_size=256,
|
|
|
|
|
unselect=0.0,
|
|
|
|
|
loss_weight=[],
|
|
|
|
|
train_freq=(1, "step"),
|
|
|
|
|
change_env_freq=-1,
|
|
|
|
|
gradient_steps=1,
|
|
|
|
|
policy_kwargs=None,
|
|
|
|
|
tensorboard_log=None,
|
|
|
|
|
verbose=0,
|
|
|
|
|
max_grad_norm=10,
|
|
|
|
|
):
|
|
|
|
|
# env
|
|
|
|
|
self.observation_dim = env.observation_dim
|
|
|
|
|
self.action_dim = env.action_dim
|
|
|
|
|
self.n_envs = env.num_envs
|
|
|
|
|
self.env = env
|
|
|
|
|
# training
|
|
|
|
|
self.learning_rate = learning_rate
|
|
|
|
|
self.batch_size = batch_size
|
|
|
|
|
self.max_grad_norm = max_grad_norm
|
|
|
|
|
self.unselect = unselect
|
|
|
|
|
self.loss_weight = loss_weight
|
|
|
|
|
self.device = get_device('auto')
|
|
|
|
|
self.policy_kwargs = {} if policy_kwargs is None else policy_kwargs
|
|
|
|
|
# imitation learning
|
|
|
|
|
self.is_imitation = is_imitation
|
|
|
|
|
self.buffer_size = buffer_size
|
|
|
|
|
self.train_freq = train_freq
|
|
|
|
|
self.change_env_freq = change_env_freq
|
|
|
|
|
self.learning_starts = learning_starts
|
|
|
|
|
self.gradient_steps = gradient_steps
|
|
|
|
|
self.freq_reset = False
|
|
|
|
|
self.replay_buffer = None
|
|
|
|
|
# logger
|
|
|
|
|
self.verbose = verbose
|
|
|
|
|
self.tensorboard_log = tensorboard_log
|
|
|
|
|
self.logger = configure_logger(self.verbose, self.tensorboard_log, "YOPO")
|
|
|
|
|
# trajectory
|
|
|
|
|
cfg = YAML().load(open(os.environ["FLIGHTMARE_PATH"] + "/flightlib/configs/traj_opt.yaml", 'r'))
|
|
|
|
|
self.lattice_space = LatticeParam(cfg)
|
|
|
|
|
self.lattice_primitive = LatticePrimitive(self.lattice_space)
|
|
|
|
|
|
|
|
|
|
self._setup_model()
|
|
|
|
|
|
|
|
|
|
def _setup_model(self):
|
|
|
|
|
self.lr_schedule = get_schedule_fn(self.learning_rate)
|
|
|
|
|
|
|
|
|
|
# buffer: pos, quat, vel, acc, depth
|
|
|
|
|
if self.replay_buffer is None and self.is_imitation:
|
|
|
|
|
self.replay_buffer = ReplayBuffer(
|
|
|
|
|
self.buffer_size,
|
|
|
|
|
self.observation_dim,
|
|
|
|
|
(self.env.network_width, self.env.network_height),
|
|
|
|
|
device=self.device,
|
|
|
|
|
n_envs=self.n_envs,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
print("Loading Network...")
|
|
|
|
|
|
|
|
|
|
self.policy = YopoPolicy(
|
|
|
|
|
observation_dim=self.observation_dim,
|
|
|
|
|
action_dim=self.action_dim,
|
|
|
|
|
lattice_space=self.lattice_space,
|
|
|
|
|
lattice_primitive=self.lattice_primitive,
|
|
|
|
|
lr_schedule=self.lr_schedule,
|
|
|
|
|
train_env=self.env,
|
|
|
|
|
device=self.device,
|
|
|
|
|
**self.policy_kwargs
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
self.policy = self.policy.to(self.device)
|
|
|
|
|
print("Network Loaded!")
|
|
|
|
|
|
|
|
|
|
if self.is_imitation:
|
|
|
|
|
self._convert_train_freq()
|
|
|
|
|
|
|
|
|
|
def supervised_learning(self, epoch, log_interval):
|
|
|
|
|
self.policy.set_training_mode(True)
|
|
|
|
|
data_loader = DataLoader(YopoDataset(), batch_size=self.batch_size, shuffle=True, num_workers=0)
|
|
|
|
|
|
|
|
|
|
n_updates = 0
|
|
|
|
|
start_time = time.time()
|
|
|
|
|
for epoch_ in range(epoch):
|
|
|
|
|
cost_losses = [] # Performance (score) of prediction
|
|
|
|
|
score_losses = [] # Accuracy of the predicted score
|
|
|
|
|
for step, (depth, pos, quat, obs_b, map_id) in enumerate(data_loader): # obs: body frame
|
2024-10-20 23:40:36 +08:00
|
|
|
if depth.shape[0] != self.batch_size: # batch size == num of env
|
2024-10-20 17:01:07 +08:00
|
|
|
continue
|
|
|
|
|
n_updates = n_updates + 1
|
|
|
|
|
depth = depth.to(self.device)
|
|
|
|
|
obs_b = obs_b.numpy()
|
|
|
|
|
|
|
|
|
|
goal_dir = obs_b[:, 6:9]
|
|
|
|
|
goal_w = transform(quat.numpy(), pos.numpy(), 10 * goal_dir) # Rwb * g_b + t_wb
|
|
|
|
|
vel_w = rotate(quat.numpy(), obs_b[:, 0:3])
|
|
|
|
|
acc_w = rotate(quat.numpy(), obs_b[:, 3:6])
|
|
|
|
|
self.env.setState(pos.numpy(), vel_w, acc_w, quat.numpy())
|
|
|
|
|
self.env.setGoal(goal_w)
|
|
|
|
|
self.env.setMapID(map_id.numpy())
|
|
|
|
|
|
|
|
|
|
obs_b[:, 0:6] = self.normalize_obs(obs_b[:, 0:6])
|
|
|
|
|
obs_norm_input = self.prapare_input_observation(obs_b)
|
|
|
|
|
obs_norm_input = obs_norm_input.to(self.device)
|
|
|
|
|
endstate_score_predictions, cost_labels = self.policy.inference(depth, obs_norm_input)
|
|
|
|
|
score_labels = cost_labels.clone().detach()
|
|
|
|
|
cost_labels_record = th.mean(cost_labels)
|
|
|
|
|
cost_labels_filtered = self.cost_filter(cost_labels)
|
|
|
|
|
|
|
|
|
|
cost_loss = th.mean(cost_labels_filtered)
|
|
|
|
|
score_loss = F.smooth_l1_loss(endstate_score_predictions[:, 9, :], score_labels)
|
|
|
|
|
loss = self.loss_weight[0] * cost_loss + self.loss_weight[1] * score_loss
|
|
|
|
|
cost_losses.append(self.loss_weight[0] * cost_labels_record.item())
|
|
|
|
|
score_losses.append(self.loss_weight[1] * score_loss.item())
|
|
|
|
|
|
|
|
|
|
# Optimize the policy
|
|
|
|
|
self.policy.optimizer.zero_grad()
|
|
|
|
|
loss.backward()
|
|
|
|
|
# Clip gradient norm
|
|
|
|
|
th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
|
|
|
|
|
self.policy.optimizer.step()
|
|
|
|
|
|
|
|
|
|
if log_interval is not None and n_updates % log_interval[0] == 0:
|
|
|
|
|
self.logger.record("time/epoch", epoch_, exclude="tensorboard")
|
|
|
|
|
self.logger.record("time/steps", n_updates, exclude="tensorboard")
|
|
|
|
|
self.logger.record("time/batch_fps", log_interval[0] / (time.time() - start_time),
|
|
|
|
|
exclude="tensorboard")
|
|
|
|
|
self.logger.record("train/trajectory_cost", np.mean(cost_losses))
|
|
|
|
|
self.logger.record("train/score_loss", np.mean(score_losses))
|
|
|
|
|
self.logger.dump(step=n_updates)
|
|
|
|
|
cost_losses = []
|
|
|
|
|
score_losses = []
|
|
|
|
|
start_time = time.time()
|
|
|
|
|
|
|
|
|
|
if log_interval is not None and n_updates % log_interval[1] == 0:
|
|
|
|
|
policy_path = self.logger.get_dir() + "/Policy"
|
|
|
|
|
os.makedirs(policy_path, exist_ok=True)
|
|
|
|
|
path = policy_path + "/epoch{}_iter{}.pth".format(epoch_, step)
|
|
|
|
|
th.save({"state_dict": self.policy.state_dict(), "data": self.policy.get_constructor_parameters()}, path)
|
|
|
|
|
|
|
|
|
|
def imitation_learning(
|
|
|
|
|
self,
|
|
|
|
|
total_timesteps,
|
2024-10-20 23:40:36 +08:00
|
|
|
log_interval,
|
2024-10-20 17:01:07 +08:00
|
|
|
reset_num_timesteps=True,
|
|
|
|
|
):
|
|
|
|
|
|
2024-10-20 23:40:36 +08:00
|
|
|
# 0. setup learn and init the first observation
|
|
|
|
|
self._setup_learn(total_timesteps, reset_num_timesteps)
|
2024-10-20 17:01:07 +08:00
|
|
|
|
|
|
|
|
while self.num_timesteps < total_timesteps:
|
2024-10-20 23:40:36 +08:00
|
|
|
# 1. Rollout and Collect Data into Buffer
|
2024-10-20 17:01:07 +08:00
|
|
|
rollout = self.collect_rollouts(
|
|
|
|
|
self.env,
|
|
|
|
|
train_freq=self.train_freq,
|
2024-10-20 23:40:36 +08:00
|
|
|
replay_buffer=self.replay_buffer
|
2024-10-20 17:01:07 +08:00
|
|
|
)
|
|
|
|
|
if rollout.continue_training is False:
|
|
|
|
|
break
|
|
|
|
|
|
2024-10-20 23:40:36 +08:00
|
|
|
# 2. Train the Policy
|
2024-10-20 17:01:07 +08:00
|
|
|
if self.num_timesteps > 0 and self.num_timesteps > self.learning_starts:
|
2024-10-20 23:40:36 +08:00
|
|
|
# If no `gradient_steps` is specified, do as many gradients steps as steps performed during the rollout
|
2024-10-20 17:01:07 +08:00
|
|
|
gradient_steps = self.gradient_steps if self.gradient_steps >= 0 else rollout.episode_timesteps
|
2024-10-20 23:40:36 +08:00
|
|
|
if gradient_steps > 0: # Special case when the user passes `gradient_steps=0`
|
2024-10-20 17:01:07 +08:00
|
|
|
self.train(batch_size=self.batch_size, gradient_steps=gradient_steps)
|
|
|
|
|
self.reset_state()
|
|
|
|
|
|
|
|
|
|
iteration = int(self.num_timesteps / (self.train_freq.frequency * self.env.num_envs))
|
|
|
|
|
|
2024-10-20 23:40:36 +08:00
|
|
|
# 3. reset the environment
|
2024-10-20 17:01:07 +08:00
|
|
|
if self.change_env_freq > 0 and iteration % self.change_env_freq == 0:
|
|
|
|
|
self.env.spawnTreesAndSavePointcloud()
|
|
|
|
|
self._map_id = self._map_id + 1
|
|
|
|
|
self.reset_state()
|
|
|
|
|
|
2024-10-20 23:40:36 +08:00
|
|
|
# 4. print the log and save weight
|
2024-10-20 17:01:07 +08:00
|
|
|
if log_interval is not None and iteration % log_interval[0] == 0:
|
|
|
|
|
self._dump_logs()
|
|
|
|
|
if log_interval is not None and iteration % log_interval[1] == 0:
|
|
|
|
|
policy_path = self.logger.get_dir() + "/Policy"
|
|
|
|
|
os.makedirs(policy_path, exist_ok=True)
|
|
|
|
|
path = policy_path + "/epoch0_iter{}.pth".format(iteration)
|
|
|
|
|
th.save({"state_dict": self.policy.state_dict(), "data": self.policy.get_constructor_parameters()}, path)
|
|
|
|
|
|
|
|
|
|
def test_policy(self, num_rollouts: int = 10):
|
|
|
|
|
max_ep_length = 400
|
|
|
|
|
self.policy.set_training_mode(False)
|
|
|
|
|
|
|
|
|
|
for n_roll in range(num_rollouts):
|
|
|
|
|
obs, done, ep_len = self.env.reset(), False, 0
|
|
|
|
|
costs = []
|
|
|
|
|
# Randomly initialize the position and goal on the map.
|
|
|
|
|
random_y_goal = 20 * random.uniform(-1, 1) + 20
|
|
|
|
|
random_y = 20 * random.uniform(-1, 1) + 20
|
|
|
|
|
goal_w = np.array([[20, random_y_goal, 2]])
|
|
|
|
|
obs = np.array([[-20, random_y, 2, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0]])
|
|
|
|
|
self.env.setGoal(goal_w)
|
|
|
|
|
self.env.setState(np.array([[-20, random_y, 2]]), np.array([[0, 0, 0]]),
|
|
|
|
|
np.array([[0, 0, 0]]), np.array([[1, 0, 0, 0]]))
|
|
|
|
|
self.env.render()
|
|
|
|
|
|
|
|
|
|
while not (done or (ep_len >= max_ep_length)):
|
|
|
|
|
depth = self.env.getDepthImage()
|
|
|
|
|
depth_vis = cv2.resize(depth[0][0], (320, 180))
|
|
|
|
|
cv2.imshow("depth", depth_vis)
|
|
|
|
|
cv2.waitKey(10)
|
|
|
|
|
depth = th.from_numpy(depth).to(self.device)
|
|
|
|
|
|
|
|
|
|
# transform observation to body frame
|
|
|
|
|
quat_bw = -obs[:, 9:13] # inv of quat: [w, -x, -y, -z]
|
|
|
|
|
quat_bw[:, 0] = -quat_bw[:, 0]
|
|
|
|
|
goal_dir_w = (goal_w - obs[:, 0:3]) / np.linalg.norm(goal_w - obs[:, 0:3])
|
|
|
|
|
goal_dir_b = rotate(quat_bw, goal_dir_w)
|
|
|
|
|
vel_acc_norm_b = self.normalize_obs(obs[:, 3:9])
|
|
|
|
|
obs_norm_b = np.hstack((vel_acc_norm_b, goal_dir_b))
|
|
|
|
|
|
|
|
|
|
obs_norm_input = self.prapare_input_observation(obs_norm_b)
|
|
|
|
|
obs_norm_input = obs_norm_input.to(self.device)
|
|
|
|
|
|
|
|
|
|
endstate_pred, score_pred = self.policy.predict(depth, obs_norm_input)
|
|
|
|
|
endstate_pred = endstate_pred.cpu().numpy()
|
|
|
|
|
# obs: p_wb, v_b, a_b, q_wb; endstate_pred: pva in body frame
|
|
|
|
|
obs, rew, done = self.env.step(endstate_pred)
|
|
|
|
|
|
|
|
|
|
costs.append(rew)
|
|
|
|
|
ep_len += 1
|
|
|
|
|
print("round ", n_roll, ", total steps:", len(costs), ", avg cost:", sum(costs) / len(costs))
|
|
|
|
|
|
|
|
|
|
def train(self, gradient_steps: int, batch_size: int) -> None:
|
|
|
|
|
"""
|
2024-10-20 23:40:36 +08:00
|
|
|
Imitation learning: sample data from the replay buffer and train the Policy
|
2024-10-20 17:01:07 +08:00
|
|
|
"""
|
2024-10-20 23:40:36 +08:00
|
|
|
self.policy.set_training_mode(True) # Switch to train mode (this affects batch norm / dropout)
|
|
|
|
|
update_learning_rate(self.policy.optimizer, self.lr_schedule(self._current_progress_remaining))
|
2024-10-20 17:01:07 +08:00
|
|
|
|
|
|
|
|
cost_losses = []
|
|
|
|
|
score_losses = [] # dy, dz, r, p, vx, vy, vz
|
|
|
|
|
for _ in range(gradient_steps):
|
|
|
|
|
# Sample replay buffer
|
2024-10-20 23:40:36 +08:00
|
|
|
replay_data = self.replay_buffer.sample(batch_size)
|
2024-10-20 17:01:07 +08:00
|
|
|
depth = th.from_numpy(replay_data.depths).to(self.device)
|
|
|
|
|
pos = replay_data.observations[:, 0:3]
|
|
|
|
|
vel_acc_b = replay_data.observations[:, 3:9]
|
|
|
|
|
quat_wb = replay_data.observations[:, 9:13]
|
|
|
|
|
goal_w = replay_data.goals
|
|
|
|
|
map_id = replay_data.map_id
|
|
|
|
|
|
|
|
|
|
goal_dir_w = (goal_w - pos) / np.linalg.norm(goal_w - pos, axis=1)[:, np.newaxis]
|
|
|
|
|
goal_dir_b = rotate_inv(quat_wb, goal_dir_w)
|
|
|
|
|
vel_w = rotate(quat_wb, vel_acc_b[:, 0:3])
|
|
|
|
|
acc_w = rotate(quat_wb, vel_acc_b[:, 3:6])
|
|
|
|
|
self.env.setState(pos, vel_w, acc_w, quat_wb)
|
|
|
|
|
self.env.setGoal(goal_w)
|
|
|
|
|
self.env.setMapID(map_id)
|
|
|
|
|
|
|
|
|
|
vel_acc_norm_b = self.normalize_obs(vel_acc_b)
|
|
|
|
|
obs_norm_b = np.hstack((vel_acc_norm_b, goal_dir_b))
|
|
|
|
|
obs_norm_input = self.prapare_input_observation(obs_norm_b)
|
|
|
|
|
obs_norm_input = obs_norm_input.to(self.device)
|
|
|
|
|
endstate_score_predictions, cost_labels = self.policy.inference(depth, obs_norm_input)
|
|
|
|
|
score_labels = cost_labels.clone().detach()
|
|
|
|
|
|
|
|
|
|
cost_labels_record = th.mean(cost_labels)
|
|
|
|
|
cost_labels_filtered = self.cost_filter(cost_labels)
|
|
|
|
|
|
|
|
|
|
cost_loss = th.mean(cost_labels_filtered)
|
|
|
|
|
score_loss = F.smooth_l1_loss(endstate_score_predictions[:, 9, :], score_labels)
|
|
|
|
|
loss = self.loss_weight[0] * cost_loss + self.loss_weight[1] * score_loss
|
|
|
|
|
cost_losses.append(self.loss_weight[0] * cost_labels_record.item())
|
|
|
|
|
score_losses.append(self.loss_weight[1] * score_loss.item())
|
|
|
|
|
|
|
|
|
|
# Optimize the policy
|
|
|
|
|
self.policy.optimizer.zero_grad()
|
|
|
|
|
loss.backward()
|
|
|
|
|
# Clip gradient norm
|
|
|
|
|
th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
|
|
|
|
|
self.policy.optimizer.step()
|
|
|
|
|
|
|
|
|
|
# Increase update counter
|
|
|
|
|
self._n_updates += gradient_steps
|
|
|
|
|
self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard")
|
|
|
|
|
self.logger.record("train/trajectory_cost", np.mean(cost_losses))
|
|
|
|
|
self.logger.record("train/score_loss", np.mean(score_losses))
|
|
|
|
|
|
|
|
|
|
def collect_rollouts(
|
|
|
|
|
self,
|
|
|
|
|
env,
|
|
|
|
|
train_freq,
|
|
|
|
|
replay_buffer,
|
|
|
|
|
) -> RolloutReturn:
|
|
|
|
|
|
|
|
|
|
self.policy.set_training_mode(False)
|
|
|
|
|
num_collected_steps, num_collected_episodes = 0, 0
|
|
|
|
|
assert isinstance(env, VecEnv), "You must pass a VecEnv"
|
|
|
|
|
assert train_freq.frequency > 0, "Should at least collect one step or episode."
|
|
|
|
|
if env.num_envs > 1:
|
|
|
|
|
assert train_freq.unit == TrainFrequencyUnit.STEP, "You must use only one env when doing episodic training."
|
|
|
|
|
continue_training = True
|
|
|
|
|
|
|
|
|
|
while should_collect_more_steps(train_freq, num_collected_steps, num_collected_episodes):
|
|
|
|
|
|
2024-10-20 23:40:36 +08:00
|
|
|
# 1. pred endstate used latest policy
|
|
|
|
|
sampled_endstate = self._sample_action()
|
2024-10-20 17:01:07 +08:00
|
|
|
|
2024-10-20 23:40:36 +08:00
|
|
|
# 2. perform action and get new observation
|
2024-10-20 17:01:07 +08:00
|
|
|
new_obs, rewards, dones = env.step(sampled_endstate)
|
|
|
|
|
|
|
|
|
|
self.num_timesteps += env.num_envs
|
|
|
|
|
num_collected_steps += 1
|
|
|
|
|
|
|
|
|
|
# 3. store the last obs, depth, and goal
|
|
|
|
|
self._store_transition(replay_buffer)
|
2024-10-20 23:40:36 +08:00
|
|
|
self._current_progress_remaining = 1.0 - float(self.num_timesteps) / float(self._total_timesteps)
|
2024-10-20 17:01:07 +08:00
|
|
|
|
2024-10-20 23:40:36 +08:00
|
|
|
# 4. update the obs, depth, and reset the goal for the done-env
|
2024-10-20 17:01:07 +08:00
|
|
|
self._last_obs = new_obs
|
|
|
|
|
self._last_depth = env.getDepthImage()
|
|
|
|
|
|
|
|
|
|
for idx, done in enumerate(dones):
|
|
|
|
|
if done:
|
|
|
|
|
num_collected_episodes += 1
|
|
|
|
|
# reset goal for the 'done' env
|
|
|
|
|
self._last_goal[idx] = self.get_random_goal(self._last_obs[idx])
|
|
|
|
|
|
|
|
|
|
return RolloutReturn(num_collected_steps * env.num_envs, num_collected_episodes, continue_training)
|
|
|
|
|
|
|
|
|
|
def prapare_input_observation(self, obs):
|
|
|
|
|
"""
|
|
|
|
|
convert the observation from body frame to primitive frame,
|
|
|
|
|
and then concatenate it with the depth features (to ensure the translational invariance)
|
|
|
|
|
"""
|
|
|
|
|
obs_return = np.ones(
|
|
|
|
|
(obs.shape[0], self.lattice_space.vertical_num, self.lattice_space.horizon_num, obs.shape[1]),
|
|
|
|
|
dtype=np.float32)
|
|
|
|
|
id = 0
|
|
|
|
|
v_b = obs[:, 0:3]
|
|
|
|
|
a_b = obs[:, 3:6]
|
|
|
|
|
g_b = obs[:, 6:9]
|
|
|
|
|
for i in range(self.lattice_space.vertical_num - 1, -1, -1):
|
|
|
|
|
for j in range(self.lattice_space.horizon_num - 1, -1, -1):
|
|
|
|
|
Rbp = self.lattice_primitive.getRotation(id)
|
|
|
|
|
v_p = np.dot(Rbp.T, v_b.T).T
|
|
|
|
|
a_p = np.dot(Rbp.T, a_b.T).T
|
|
|
|
|
g_p = np.dot(Rbp.T, g_b.T).T
|
|
|
|
|
obs_return[:, i, j, 0:3] = v_p
|
|
|
|
|
obs_return[:, i, j, 3:6] = a_p
|
|
|
|
|
obs_return[:, i, j, 6:9] = g_p
|
|
|
|
|
# obs_return[:, i, j, 0:6] = self.normalize_obs(obs_return[:, i, j, 0:6])
|
|
|
|
|
id = id + 1
|
|
|
|
|
obs_return = np.transpose(obs_return, [0, 3, 1, 2])
|
|
|
|
|
return th.from_numpy(obs_return)
|
|
|
|
|
|
|
|
|
|
def unnormalize_obs(self, vel_acc_norm):
|
|
|
|
|
vel = vel_acc_norm[:, 0:3] * self.lattice_space.vel_max
|
|
|
|
|
acc = vel_acc_norm[:, 3:6] * self.lattice_space.acc_max
|
|
|
|
|
return np.hstack((vel, acc))
|
|
|
|
|
|
|
|
|
|
def normalize_obs(self, vel_acc):
|
|
|
|
|
vel_norm = vel_acc[:, 0:3] / self.lattice_space.vel_max
|
|
|
|
|
acc_norm = vel_acc[:, 3:6] / self.lattice_space.acc_max
|
|
|
|
|
return np.hstack((vel_norm, acc_norm))
|
|
|
|
|
|
|
|
|
|
def cost_filter(self, costs_):
|
|
|
|
|
# costs_ = costs.clone() # NOTE: numpy.ndarray is reference invocation!
|
|
|
|
|
if self.unselect <= 0 or self.unselect >= 1:
|
|
|
|
|
return costs_
|
|
|
|
|
# filter the negative samples
|
|
|
|
|
rows, cols = costs_.size()
|
|
|
|
|
unselect = int(cols * self.unselect)
|
|
|
|
|
for i in range(rows):
|
|
|
|
|
row = costs_[i]
|
|
|
|
|
_, indices = th.topk(row, unselect)
|
|
|
|
|
costs_[i][indices] = 0.0
|
|
|
|
|
return costs_
|
|
|
|
|
|
2024-10-20 23:40:36 +08:00
|
|
|
def _setup_learn(self, total_timesteps, reset_num_timesteps=True):
|
|
|
|
|
# reset the time info
|
|
|
|
|
self.start_time = time.time()
|
|
|
|
|
if reset_num_timesteps:
|
|
|
|
|
self.num_timesteps = 0 # steps of sampling
|
|
|
|
|
self._n_updates = 0 # steps of policy updating
|
|
|
|
|
self._total_timesteps = total_timesteps
|
|
|
|
|
self._num_timesteps_at_start = self.num_timesteps
|
|
|
|
|
|
2024-10-20 17:01:07 +08:00
|
|
|
# ----------------- Init the First Observation -----------------
|
2024-10-20 23:40:36 +08:00
|
|
|
self._last_obs = self.env.reset()
|
2024-10-20 17:01:07 +08:00
|
|
|
self._last_depth = self.env.getDepthImage()
|
|
|
|
|
self._last_goal = np.zeros([self.env.num_envs, 3], dtype=np.float32)
|
|
|
|
|
for i in range(0, self.env.num_envs):
|
|
|
|
|
self._last_goal[i] = self.get_random_goal(self._last_obs[i])
|
|
|
|
|
self._map_id = np.zeros((self.env.num_envs, 1), dtype=np.float32)
|
|
|
|
|
|
|
|
|
|
def _sample_action(self) -> np.ndarray:
|
|
|
|
|
"""
|
|
|
|
|
use pretrained model or current model to sample the actions (endstate)
|
|
|
|
|
self._last_obs: last state obs [p, v, a, q]
|
|
|
|
|
self._last_depth: last depth image
|
|
|
|
|
"""
|
|
|
|
|
obs = self._last_obs.copy()
|
|
|
|
|
goal_w = self._last_goal.copy()
|
|
|
|
|
depth = th.from_numpy(self._last_depth).to(self.device)
|
2024-10-20 23:40:36 +08:00
|
|
|
# [w, x, y, z] inv() of quat: [w, -x, -y, -z]
|
2024-10-20 17:01:07 +08:00
|
|
|
quat_bw = -obs[:, 9:13]
|
|
|
|
|
quat_bw[:, 0] = -quat_bw[:, 0]
|
|
|
|
|
vel_acc_norm_b = self.normalize_obs(obs[:, 3:9])
|
|
|
|
|
goal_dir_w = (goal_w - obs[:, 0:3]) / np.linalg.norm(goal_w - obs[:, 0:3], axis=1)[:, np.newaxis]
|
|
|
|
|
goal_dir_b = rotate(quat_bw, goal_dir_w)
|
|
|
|
|
obs_norm_b = np.hstack((vel_acc_norm_b, goal_dir_b))
|
|
|
|
|
|
|
|
|
|
obs_norm_input = self.prapare_input_observation(obs_norm_b)
|
|
|
|
|
obs_norm_input = obs_norm_input.to(self.device)
|
|
|
|
|
|
|
|
|
|
endstate_pred, score_pred = self.policy.predict(depth, obs_norm_input)
|
|
|
|
|
endstate_pred = endstate_pred.cpu().numpy()
|
|
|
|
|
return endstate_pred
|
|
|
|
|
|
|
|
|
|
def _dump_logs(self) -> None:
|
|
|
|
|
"""
|
|
|
|
|
Write log.
|
|
|
|
|
"""
|
|
|
|
|
time_elapsed = time.time() - self.start_time
|
|
|
|
|
fps = int((self.num_timesteps - self._num_timesteps_at_start) / (time_elapsed + 1e-8))
|
|
|
|
|
self.logger.record("time/fps", fps, exclude="tensorboard")
|
|
|
|
|
self.logger.record("time/minute_elapsed", int(time_elapsed / 60), exclude="tensorboard")
|
|
|
|
|
self.logger.record("time/total_timesteps", self.num_timesteps, exclude="tensorboard")
|
|
|
|
|
self.logger.record("train/map_id", self._map_id[0][0], exclude="tensorboard")
|
|
|
|
|
|
|
|
|
|
# Pass the number of timesteps for tensorboard
|
|
|
|
|
self.logger.dump(step=self.num_timesteps)
|
|
|
|
|
|
|
|
|
|
def _store_transition(self, replay_buffer):
|
|
|
|
|
|
|
|
|
|
# Avoid modification by reference
|
|
|
|
|
obs = deepcopy(self._last_obs)
|
|
|
|
|
goal = deepcopy(self._last_goal)
|
|
|
|
|
depth = deepcopy(self._last_depth)
|
|
|
|
|
map_id = deepcopy(self._map_id)
|
|
|
|
|
|
2024-10-20 23:40:36 +08:00
|
|
|
replay_buffer.add(obs, goal, depth, map_id)
|
2024-10-20 17:01:07 +08:00
|
|
|
|
|
|
|
|
def get_random_goal(self, uav_state=None):
|
|
|
|
|
world = self.env.world_box
|
|
|
|
|
# 1. Use random goal in map
|
|
|
|
|
if uav_state is None:
|
|
|
|
|
world_center = np.array([world[3] + world[0], world[4] + world[1], world[5] + world[2]]) / 2
|
|
|
|
|
world_scale = np.array([world[3] - world[0], world[4] - world[1], 1.0])
|
|
|
|
|
# The goal can be out of the world, if strictly in world: np.random.uniform(-0.5, 0.5, 3)
|
|
|
|
|
random_numbers = np.random.uniform(-1, 1, 3)
|
|
|
|
|
random_goal = random_numbers * world_scale + world_center
|
|
|
|
|
# 2. Use goal in front of the UAV (for better imitation learning)
|
|
|
|
|
else:
|
2024-10-20 23:40:36 +08:00
|
|
|
q_wb = uav_state[9:].copy()
|
|
|
|
|
p_wb = uav_state[0:3].copy()
|
2024-10-20 17:01:07 +08:00
|
|
|
goal = np.random.randn(3) + np.array([2, 0, 0])
|
|
|
|
|
goal_dir = goal / np.linalg.norm(goal)
|
|
|
|
|
random_goal_b = 50 * goal_dir
|
|
|
|
|
random_goal_w = transform(q_wb, p_wb, random_goal_b)
|
|
|
|
|
random_goal_w[2] = np.random.uniform(-1, 1) * 1 + (world[5] + world[2]) / 2
|
|
|
|
|
random_goal = random_goal_w
|
|
|
|
|
|
|
|
|
|
return random_goal
|
|
|
|
|
|
|
|
|
|
def reset_state(self):
|
|
|
|
|
"""
|
|
|
|
|
Reset the state and map_id after every train step, because the state and map_id are manually set in training,
|
|
|
|
|
which will affect the cost, controller, image render, and other parts for next rollout
|
|
|
|
|
"""
|
|
|
|
|
self.env.setMapID(-np.ones((self.env.num_envs, 1)))
|
|
|
|
|
self._last_obs = self.env.reset()
|
|
|
|
|
self._last_depth = self.env.getDepthImage()
|
|
|
|
|
for i in range(0, self.env.num_envs):
|
|
|
|
|
self._last_goal[i] = self.get_random_goal(self._last_obs[i])
|
|
|
|
|
|
|
|
|
|
def _convert_train_freq(self) -> None:
|
|
|
|
|
"""
|
|
|
|
|
Convert `train_freq` parameter (int or tuple)
|
|
|
|
|
to a TrainFreq object.
|
|
|
|
|
"""
|
|
|
|
|
if not isinstance(self.train_freq, TrainFreq):
|
|
|
|
|
train_freq = self.train_freq
|
|
|
|
|
|
|
|
|
|
# The value of the train frequency will be checked later
|
|
|
|
|
if not isinstance(train_freq, tuple):
|
|
|
|
|
train_freq = (train_freq, "step")
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
train_freq = (train_freq[0], TrainFrequencyUnit(train_freq[1]))
|
|
|
|
|
except ValueError:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"The unit of the `train_freq` must be either 'step' or 'episode' not '{train_freq[1]}'!")
|
|
|
|
|
|
|
|
|
|
if not isinstance(train_freq[0], int):
|
|
|
|
|
raise ValueError(f"The frequency of `train_freq` must be an integer and not {train_freq[0]}")
|
|
|
|
|
|
|
|
|
|
self.train_freq = TrainFreq(*train_freq)
|