diffusion_policy/tests/test_timestamp_accumulator.py
2023-03-07 16:07:15 -05:00

152 lines
4.2 KiB
Python

import sys
import os
ROOT_DIR = os.path.dirname(os.path.dirname(__file__))
sys.path.append(ROOT_DIR)
os.chdir(ROOT_DIR)
import numpy as np
import time
from diffusion_policy.common.timestamp_accumulator import (
get_accumulate_timestamp_idxs,
TimestampObsAccumulator,
TimestampActionAccumulator
)
def test_index():
buffer = np.zeros(16)
start_time = 0.0
dt = 1/10
timestamps = np.linspace(0,1,100)
gi = list()
next_global_idx = 0
local_idxs, global_idxs, next_global_idx = get_accumulate_timestamp_idxs(timestamps,
start_time=start_time, dt=dt, next_global_idx=next_global_idx)
assert local_idxs[0] == 0
assert global_idxs[0] == 0
# print(local_idxs)
# print(global_idxs)
# print(timestamps[local_idxs])
buffer[global_idxs] = timestamps[local_idxs]
gi.extend(global_idxs)
timestamps = np.linspace(0.5,1.5,100)
local_idxs, global_idxs, next_global_idx = get_accumulate_timestamp_idxs(timestamps,
start_time=start_time, dt=dt, next_global_idx = next_global_idx)
# print(local_idxs)
# print(global_idxs)
# print(timestamps[local_idxs])
# import pdb; pdb.set_trace()
buffer[global_idxs] = timestamps[local_idxs]
gi.extend(global_idxs)
assert np.all(buffer[1:] > buffer[:-1])
assert np.all(np.array(gi) == np.array(list(range(len(gi)))))
# print(buffer)
# start over
next_global_idx = 0
timestamps = np.linspace(0,1,3)
local_idxs, global_idxs, next_global_idx = get_accumulate_timestamp_idxs(timestamps,
start_time=start_time, dt=dt, next_global_idx = next_global_idx)
assert local_idxs[0] == 0
assert local_idxs[-1] == 2
# print(local_idxs)
# print(global_idxs)
# print(timestamps[local_idxs])
# test numerical error issue
# this becomes a problem when eps <= 1e-7
start_time = time.time()
next_global_idx = 0
timestamps = np.arange(100000) * dt + start_time
local_idxs, global_idxs, next_global_idx = get_accumulate_timestamp_idxs(timestamps,
start_time=start_time, dt=dt, next_global_idx = next_global_idx)
assert local_idxs == global_idxs
# print(local_idxs)
# print(global_idxs)
# print(timestamps[local_idxs])
def test_obs_accumulator():
dt = 1/10
ddt = 1/100
n = 100
d = 6
start_time = time.time()
toa = TimestampObsAccumulator(start_time, dt)
poses = np.arange(n).reshape((n,1))
poses = np.repeat(poses, d, axis=1)
timestamps = np.arange(n) * ddt + start_time
toa.put({
'pose': poses,
'timestamp': timestamps
}, timestamps)
assert np.all(toa.data['pose'][:,0] == np.arange(10)*10)
assert len(toa) == 10
# add the same thing, result shouldn't change
toa.put({
'pose': poses,
'timestamp': timestamps
}, timestamps)
assert np.all(toa.data['pose'][:,0] == np.arange(10)*10)
assert len(toa) == 10
# add lower than desired freuquency to test fill_in
dt = 1/10
ddt = 1/5
n = 10
d = 6
start_time = time.time()
toa = TimestampObsAccumulator(start_time, dt)
poses = np.arange(n).reshape((n,1))
poses = np.repeat(poses, d, axis=1)
timestamps = np.arange(n) * ddt + start_time
toa.put({
'pose': poses,
'timestamp': timestamps
}, timestamps)
assert len(toa) == 1 + (n-1) * 2
timestamps = (np.arange(n) + 2) * ddt + start_time
toa.put({
'pose': poses,
'timestamp': timestamps
}, timestamps)
assert len(toa) == 1 + (n-1) * 2 + 4
def test_action_accumulator():
dt = 1/10
n = 10
d = 6
start_time = time.time()
taa = TimestampActionAccumulator(start_time, dt)
actions = np.arange(n).reshape((n,1))
actions = np.repeat(actions, d, axis=1)
timestamps = np.arange(n) * dt + start_time
taa.put(actions, timestamps)
assert np.all(taa.actions == actions)
assert np.all(taa.timestamps == timestamps)
# add another round
taa.put(actions-5, timestamps-0.5)
assert np.allclose(taa.timestamps, timestamps)
# add another round
taa.put(actions+5, timestamps+0.5)
assert len(taa) == 15
assert np.all(taa.actions[:,0] == np.arange(15))
if __name__ == '__main__':
test_action_accumulator()