""" Training Strategy supervised learning, imitation learning, testing, rollout """ import time from copy import deepcopy import os import random import cv2 import numpy as np import torch as th from torch.nn import functional as F from stable_baselines3.common.type_aliases import RolloutReturn, TrainFreq, TrainFrequencyUnit from stable_baselines3.common.utils import should_collect_more_steps, get_schedule_fn, configure_logger, update_learning_rate from stable_baselines3.common.vec_env import VecEnv from stable_baselines3.common.utils import get_device # ----------- from flightpolicy.yopo.yopo_policy import YopoPolicy from flightpolicy.yopo.dataloader import YopoDataset from torch.utils.data import DataLoader from flightpolicy.yopo.primitive_utils import transform, rotate, transform_inv, rotate_inv from flightpolicy.yopo.primitive_utils import LatticeParam, LatticePrimitive from flightpolicy.yopo.buffers import ReplayBuffer from ruamel.yaml import YAML class YopoAlgorithm: def __init__( self, env=None, learning_rate=0.001, is_imitation=False, buffer_size=1_000_000, learning_starts=100, batch_size=256, unselect=0.0, loss_weight=[], train_freq=(1, "step"), change_env_freq=-1, gradient_steps=1, policy_kwargs=None, tensorboard_log=None, verbose=0, max_grad_norm=10, ): # env self.observation_dim = env.observation_dim self.action_dim = env.action_dim self.n_envs = env.num_envs self.env = env # training self.learning_rate = learning_rate self.batch_size = batch_size self.max_grad_norm = max_grad_norm self.unselect = unselect self.loss_weight = loss_weight self.device = get_device('auto') self.policy_kwargs = {} if policy_kwargs is None else policy_kwargs # imitation learning self.is_imitation = is_imitation self.buffer_size = buffer_size self.train_freq = train_freq self.change_env_freq = change_env_freq self.learning_starts = learning_starts self.gradient_steps = gradient_steps self.freq_reset = False self.replay_buffer = None # logger self.verbose = verbose self.tensorboard_log = tensorboard_log self.logger = configure_logger(self.verbose, self.tensorboard_log, "YOPO") # trajectory cfg = YAML().load(open(os.environ["FLIGHTMARE_PATH"] + "/flightlib/configs/traj_opt.yaml", 'r')) self.lattice_space = LatticeParam(cfg) self.lattice_primitive = LatticePrimitive(self.lattice_space) self._setup_model() def _setup_model(self): self.lr_schedule = get_schedule_fn(self.learning_rate) # buffer: pos, quat, vel, acc, depth if self.replay_buffer is None and self.is_imitation: self.replay_buffer = ReplayBuffer( self.buffer_size, self.observation_dim, (self.env.network_width, self.env.network_height), device=self.device, n_envs=self.n_envs, ) print("Loading Network...") self.policy = YopoPolicy( observation_dim=self.observation_dim, action_dim=self.action_dim, lattice_space=self.lattice_space, lattice_primitive=self.lattice_primitive, lr_schedule=self.lr_schedule, train_env=self.env, device=self.device, **self.policy_kwargs ) self.policy = self.policy.to(self.device) print("Network Loaded!") if self.is_imitation: self._convert_train_freq() def supervised_learning(self, epoch, log_interval): self.policy.set_training_mode(True) data_loader = DataLoader(YopoDataset(), batch_size=self.batch_size, shuffle=True, num_workers=0) n_updates = 0 start_time = time.time() for epoch_ in range(epoch): cost_losses = [] # Performance (score) of prediction score_losses = [] # Accuracy of the predicted score for step, (depth, pos, quat, obs_b, map_id) in enumerate(data_loader): # obs: body frame if depth.shape[0] != self.batch_size: # batch size == num of env continue n_updates = n_updates + 1 depth = depth.to(self.device) obs_b = obs_b.numpy() goal_dir = obs_b[:, 6:9] goal_w = transform(quat.numpy(), pos.numpy(), 10 * goal_dir) # Rwb * g_b + t_wb vel_w = rotate(quat.numpy(), obs_b[:, 0:3]) acc_w = rotate(quat.numpy(), obs_b[:, 3:6]) self.env.setState(pos.numpy(), vel_w, acc_w, quat.numpy()) self.env.setGoal(goal_w) self.env.setMapID(map_id.numpy()) obs_b[:, 0:6] = self.normalize_obs(obs_b[:, 0:6]) obs_norm_input = self.prapare_input_observation(obs_b) obs_norm_input = obs_norm_input.to(self.device) endstate_score_predictions, cost_labels = self.policy.inference(depth, obs_norm_input) score_labels = cost_labels.clone().detach() cost_labels_record = th.mean(cost_labels) cost_labels_filtered = self.cost_filter(cost_labels) cost_loss = th.mean(cost_labels_filtered) score_loss = F.smooth_l1_loss(endstate_score_predictions[:, 9, :], score_labels) loss = self.loss_weight[0] * cost_loss + self.loss_weight[1] * score_loss cost_losses.append(self.loss_weight[0] * cost_labels_record.item()) score_losses.append(self.loss_weight[1] * score_loss.item()) # Optimize the policy self.policy.optimizer.zero_grad() loss.backward() # Clip gradient norm th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm) self.policy.optimizer.step() if log_interval is not None and n_updates % log_interval[0] == 0: self.logger.record("time/epoch", epoch_, exclude="tensorboard") self.logger.record("time/steps", n_updates, exclude="tensorboard") self.logger.record("time/batch_fps", log_interval[0] / (time.time() - start_time), exclude="tensorboard") self.logger.record("train/trajectory_cost", np.mean(cost_losses)) self.logger.record("train/score_loss", np.mean(score_losses)) self.logger.dump(step=n_updates) cost_losses = [] score_losses = [] start_time = time.time() if log_interval is not None and n_updates % log_interval[1] == 0: policy_path = self.logger.get_dir() + "/Policy" os.makedirs(policy_path, exist_ok=True) path = policy_path + "/epoch{}_iter{}.pth".format(epoch_, step) th.save({"state_dict": self.policy.state_dict(), "data": self.policy.get_constructor_parameters()}, path) def imitation_learning( self, total_timesteps, log_interval, reset_num_timesteps=True, ): # 0. setup learn and init the first observation self._setup_learn(total_timesteps, reset_num_timesteps) while self.num_timesteps < total_timesteps: # 1. Rollout and Collect Data into Buffer rollout = self.collect_rollouts( self.env, train_freq=self.train_freq, replay_buffer=self.replay_buffer ) if rollout.continue_training is False: break # 2. Train the Policy if self.num_timesteps > 0 and self.num_timesteps > self.learning_starts: # If no `gradient_steps` is specified, do as many gradients steps as steps performed during the rollout gradient_steps = self.gradient_steps if self.gradient_steps >= 0 else rollout.episode_timesteps if gradient_steps > 0: # Special case when the user passes `gradient_steps=0` self.train(batch_size=self.batch_size, gradient_steps=gradient_steps) self.reset_state() iteration = int(self.num_timesteps / (self.train_freq.frequency * self.env.num_envs)) # 3. reset the environment if self.change_env_freq > 0 and iteration % self.change_env_freq == 0: self.env.spawnTreesAndSavePointcloud() self._map_id = self._map_id + 1 self.reset_state() # 4. print the log and save weight if log_interval is not None and iteration % log_interval[0] == 0: self._dump_logs() if log_interval is not None and iteration % log_interval[1] == 0: policy_path = self.logger.get_dir() + "/Policy" os.makedirs(policy_path, exist_ok=True) path = policy_path + "/epoch0_iter{}.pth".format(iteration) th.save({"state_dict": self.policy.state_dict(), "data": self.policy.get_constructor_parameters()}, path) def test_policy(self, num_rollouts: int = 10): max_ep_length = 400 self.policy.set_training_mode(False) for n_roll in range(num_rollouts): obs, done, ep_len = self.env.reset(), False, 0 costs = [] # Randomly initialize the position and goal on the map. random_y_goal = 20 * random.uniform(-1, 1) + 20 random_y = 20 * random.uniform(-1, 1) + 20 goal_w = np.array([[20, random_y_goal, 2]]) obs = np.array([[-20, random_y, 2, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0]]) self.env.setGoal(goal_w) self.env.setState(np.array([[-20, random_y, 2]]), np.array([[0, 0, 0]]), np.array([[0, 0, 0]]), np.array([[1, 0, 0, 0]])) self.env.render() while not (done or (ep_len >= max_ep_length)): depth = self.env.getDepthImage() depth_vis = cv2.resize(depth[0][0], (320, 180)) cv2.imshow("depth", depth_vis) cv2.waitKey(10) depth = th.from_numpy(depth).to(self.device) # transform observation to body frame quat_bw = -obs[:, 9:13] # inv of quat: [w, -x, -y, -z] quat_bw[:, 0] = -quat_bw[:, 0] goal_dir_w = (goal_w - obs[:, 0:3]) / np.linalg.norm(goal_w - obs[:, 0:3]) goal_dir_b = rotate(quat_bw, goal_dir_w) vel_acc_norm_b = self.normalize_obs(obs[:, 3:9]) obs_norm_b = np.hstack((vel_acc_norm_b, goal_dir_b)) obs_norm_input = self.prapare_input_observation(obs_norm_b) obs_norm_input = obs_norm_input.to(self.device) endstate_pred, score_pred = self.policy.predict(depth, obs_norm_input) endstate_pred = endstate_pred.cpu().numpy() # obs: p_wb, v_b, a_b, q_wb; endstate_pred: pva in body frame obs, rew, done = self.env.step(endstate_pred) costs.append(rew) ep_len += 1 print("round ", n_roll, ", total steps:", len(costs), ", avg cost:", sum(costs) / len(costs)) self.env.disconnectUnity() def train(self, gradient_steps: int, batch_size: int) -> None: """ Imitation learning: sample data from the replay buffer and train the Policy """ self.policy.set_training_mode(True) # Switch to train mode (this affects batch norm / dropout) update_learning_rate(self.policy.optimizer, self.lr_schedule(self._current_progress_remaining)) cost_losses = [] score_losses = [] # dy, dz, r, p, vx, vy, vz for _ in range(gradient_steps): # Sample replay buffer replay_data = self.replay_buffer.sample(batch_size) depth = th.from_numpy(replay_data.depths).to(self.device) pos = replay_data.observations[:, 0:3] vel_acc_b = replay_data.observations[:, 3:9] quat_wb = replay_data.observations[:, 9:13] goal_w = replay_data.goals map_id = replay_data.map_id goal_dir_w = (goal_w - pos) / np.linalg.norm(goal_w - pos, axis=1)[:, np.newaxis] goal_dir_b = rotate_inv(quat_wb, goal_dir_w) vel_w = rotate(quat_wb, vel_acc_b[:, 0:3]) acc_w = rotate(quat_wb, vel_acc_b[:, 3:6]) self.env.setState(pos, vel_w, acc_w, quat_wb) self.env.setGoal(goal_w) self.env.setMapID(map_id) vel_acc_norm_b = self.normalize_obs(vel_acc_b) obs_norm_b = np.hstack((vel_acc_norm_b, goal_dir_b)) obs_norm_input = self.prapare_input_observation(obs_norm_b) obs_norm_input = obs_norm_input.to(self.device) endstate_score_predictions, cost_labels = self.policy.inference(depth, obs_norm_input) score_labels = cost_labels.clone().detach() cost_labels_record = th.mean(cost_labels) cost_labels_filtered = self.cost_filter(cost_labels) cost_loss = th.mean(cost_labels_filtered) score_loss = F.smooth_l1_loss(endstate_score_predictions[:, 9, :], score_labels) loss = self.loss_weight[0] * cost_loss + self.loss_weight[1] * score_loss cost_losses.append(self.loss_weight[0] * cost_labels_record.item()) score_losses.append(self.loss_weight[1] * score_loss.item()) # Optimize the policy self.policy.optimizer.zero_grad() loss.backward() # Clip gradient norm th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm) self.policy.optimizer.step() # Increase update counter self._n_updates += gradient_steps self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard") self.logger.record("train/trajectory_cost", np.mean(cost_losses)) self.logger.record("train/score_loss", np.mean(score_losses)) def collect_rollouts( self, env, train_freq, replay_buffer, ) -> RolloutReturn: self.policy.set_training_mode(False) num_collected_steps, num_collected_episodes = 0, 0 assert isinstance(env, VecEnv), "You must pass a VecEnv" assert train_freq.frequency > 0, "Should at least collect one step or episode." if env.num_envs > 1: assert train_freq.unit == TrainFrequencyUnit.STEP, "You must use only one env when doing episodic training." continue_training = True while should_collect_more_steps(train_freq, num_collected_steps, num_collected_episodes): # 1. pred endstate used latest policy sampled_endstate = self._sample_action() # 2. perform action and get new observation new_obs, rewards, dones = env.step(sampled_endstate) self.num_timesteps += env.num_envs num_collected_steps += 1 # 3. store the last obs, depth, and goal self._store_transition(replay_buffer) self._current_progress_remaining = 1.0 - float(self.num_timesteps) / float(self._total_timesteps) # 4. update the obs, depth, and reset the goal for the done-env self._last_obs = new_obs self._last_depth = env.getDepthImage() for idx, done in enumerate(dones): if done: num_collected_episodes += 1 # reset goal for the 'done' env self._last_goal[idx] = self.get_random_goal(self._last_obs[idx]) return RolloutReturn(num_collected_steps * env.num_envs, num_collected_episodes, continue_training) def prapare_input_observation(self, obs): """ convert the observation from body frame to primitive frame, and then concatenate it with the depth features (to ensure the translational invariance) """ obs_return = np.ones((obs.shape[0], obs.shape[1], self.lattice_space.vertical_num, self.lattice_space.horizon_num), dtype=np.float32) id = 0 v_b, a_b, g_b = obs[:, 0:3], obs[:, 3:6], obs[:, 6:9] for i in range(self.lattice_space.vertical_num - 1, -1, -1): for j in range(self.lattice_space.horizon_num - 1, -1, -1): Rbp = self.lattice_primitive.getRotation(id) obs_return[:, 0:3, i, j] = np.dot(v_b, Rbp) # v_p obs_return[:, 3:6, i, j] = np.dot(a_b, Rbp) # a_p obs_return[:, 6:9, i, j] = np.dot(g_b, Rbp) # g_p # obs_return[:, 0:6, i, j] = self.normalize_obs(obs_return[:, 0:6, i, j]) id = id + 1 return th.from_numpy(obs_return) def unnormalize_obs(self, vel_acc_norm): vel = vel_acc_norm[:, 0:3] * self.lattice_space.vel_max acc = vel_acc_norm[:, 3:6] * self.lattice_space.acc_max return np.hstack((vel, acc)) def normalize_obs(self, vel_acc): vel_norm = vel_acc[:, 0:3] / self.lattice_space.vel_max acc_norm = vel_acc[:, 3:6] / self.lattice_space.acc_max return np.hstack((vel_norm, acc_norm)) def cost_filter(self, costs_): # costs_ = costs.clone() # NOTE: numpy.ndarray is reference invocation! if self.unselect <= 0 or self.unselect >= 1: return costs_ # filter the negative samples rows, cols = costs_.size() unselect = int(cols * self.unselect) for i in range(rows): row = costs_[i] _, indices = th.topk(row, unselect) costs_[i][indices] = 0.0 return costs_ def _setup_learn(self, total_timesteps, reset_num_timesteps=True): # reset the time info self.start_time = time.time() if reset_num_timesteps: self.num_timesteps = 0 # steps of sampling self._n_updates = 0 # steps of policy updating self._total_timesteps = total_timesteps self._num_timesteps_at_start = self.num_timesteps # ----------------- Init the First Observation ----------------- self._last_obs = self.env.reset() self._last_depth = self.env.getDepthImage() self._last_goal = np.zeros([self.env.num_envs, 3], dtype=np.float32) for i in range(self.env.num_envs): self._last_goal[i] = self.get_random_goal(self._last_obs[i]) self._map_id = np.zeros((self.env.num_envs, 1), dtype=np.float32) def _sample_action(self) -> np.ndarray: """ use pretrained model or current model to sample the actions (endstate) self._last_obs: last state obs [p, v, a, q] self._last_depth: last depth image """ obs = self._last_obs.copy() goal_w = self._last_goal.copy() depth = th.from_numpy(self._last_depth).to(self.device) # [w, x, y, z] inv() of quat: [w, -x, -y, -z] quat_bw = -obs[:, 9:13] quat_bw[:, 0] = -quat_bw[:, 0] vel_acc_norm_b = self.normalize_obs(obs[:, 3:9]) goal_dir_w = (goal_w - obs[:, 0:3]) / np.linalg.norm(goal_w - obs[:, 0:3], axis=1)[:, np.newaxis] goal_dir_b = rotate(quat_bw, goal_dir_w) obs_norm_b = np.hstack((vel_acc_norm_b, goal_dir_b)) obs_norm_input = self.prapare_input_observation(obs_norm_b) obs_norm_input = obs_norm_input.to(self.device) endstate_pred, score_pred = self.policy.predict(depth, obs_norm_input) endstate_pred = endstate_pred.cpu().numpy() return endstate_pred def _dump_logs(self) -> None: """ Write log. """ time_elapsed = time.time() - self.start_time fps = int((self.num_timesteps - self._num_timesteps_at_start) / (time_elapsed + 1e-8)) self.logger.record("time/fps", fps, exclude="tensorboard") self.logger.record("time/minute_elapsed", int(time_elapsed / 60), exclude="tensorboard") self.logger.record("time/total_timesteps", self.num_timesteps, exclude="tensorboard") self.logger.record("train/map_id", self._map_id[0][0], exclude="tensorboard") # Pass the number of timesteps for tensorboard self.logger.dump(step=self.num_timesteps) def _store_transition(self, replay_buffer): # Avoid modification by reference obs = deepcopy(self._last_obs) goal = deepcopy(self._last_goal) depth = deepcopy(self._last_depth) map_id = deepcopy(self._map_id) replay_buffer.add(obs, goal, depth, map_id) def get_random_goal(self, uav_state=None): world = self.env.world_box # 1. Use random goal in map if uav_state is None: world_center = np.array([world[3] + world[0], world[4] + world[1], world[5] + world[2]]) / 2 world_scale = np.array([world[3] - world[0], world[4] - world[1], 1.0]) # The goal can be out of the world, if strictly in world: np.random.uniform(-0.5, 0.5, 3) random_numbers = np.random.uniform(-1, 1, 3) random_goal = random_numbers * world_scale + world_center # 2. Use goal in front of the UAV (for better imitation learning) else: q_wb = uav_state[9:].copy() p_wb = uav_state[0:3].copy() goal = np.random.randn(3) + np.array([2, 0, 0]) goal_dir = goal / np.linalg.norm(goal) random_goal_b = 50 * goal_dir random_goal_w = transform(q_wb, p_wb, random_goal_b) random_goal_w[2] = np.random.uniform(-1, 1) * 1 + (world[5] + world[2]) / 2 random_goal = random_goal_w return random_goal def reset_state(self): """ Reset the state and map_id after every train step, because the state and map_id are manually set in training, which will affect the cost, controller, image render, and other parts for next rollout """ self.env.setMapID(-np.ones((self.env.num_envs, 1))) self._last_obs = self.env.reset() self._last_depth = self.env.getDepthImage() for i in range(self.env.num_envs): self._last_goal[i] = self.get_random_goal(self._last_obs[i]) def _convert_train_freq(self) -> None: """ Convert `train_freq` parameter (int or tuple) to a TrainFreq object. """ if not isinstance(self.train_freq, TrainFreq): train_freq = self.train_freq # The value of the train frequency will be checked later if not isinstance(train_freq, tuple): train_freq = (train_freq, "step") try: train_freq = (train_freq[0], TrainFrequencyUnit(train_freq[1])) except ValueError: raise ValueError( f"The unit of the `train_freq` must be either 'step' or 'episode' not '{train_freq[1]}'!") if not isinstance(train_freq[0], int): raise ValueError(f"The frequency of `train_freq` must be an integer and not {train_freq[0]}") self.train_freq = TrainFreq(*train_freq)