37 lines
833 B
Python
37 lines
833 B
Python
import numpy as np
|
|
|
|
|
|
def full_return(raw_data):
|
|
"""
|
|
naively compute full return
|
|
:param raw_data: dict of specified keys and values.
|
|
"""
|
|
obs = raw_data['obs']
|
|
acs = raw_data['acs']
|
|
rews = raw_data['rews']
|
|
news = raw_data['news']
|
|
num_timesteps = rews.shape[0]
|
|
|
|
data = {}
|
|
data['obs'] = obs
|
|
data['acs'] = acs
|
|
|
|
Gts = rews.copy()
|
|
episode_start_idx = 0
|
|
for i in range(1, num_timesteps):
|
|
if news[i] or (i == num_timesteps - 1): # found one full episode
|
|
if i < rews.shape[0] - 1:
|
|
t = i - 1
|
|
else:
|
|
t = i
|
|
Gt = 0
|
|
while t >= episode_start_idx:
|
|
Gt += rews[t]
|
|
Gts[t] = Gt
|
|
t -= 1
|
|
|
|
episode_start_idx = i
|
|
|
|
data['Gts'] = Gts
|
|
|
|
return data |