Tianshou/tianshou/data/adv_estimate.py

37 lines
833 B
Python

import numpy as np
def full_return(raw_data):
"""
naively compute full return
:param raw_data: dict of specified keys and values.
"""
obs = raw_data['obs']
acs = raw_data['acs']
rews = raw_data['rews']
news = raw_data['news']
num_timesteps = rews.shape[0]
data = {}
data['obs'] = obs
data['acs'] = acs
Gts = rews.copy()
episode_start_idx = 0
for i in range(1, num_timesteps):
if news[i] or (i == num_timesteps - 1): # found one full episode
if i < rews.shape[0] - 1:
t = i - 1
else:
t = i
Gt = 0
while t >= episode_start_idx:
Gt += rews[t]
Gts[t] = Gt
t -= 1
episode_start_idx = i
data['Gts'] = Gts
return data