diffusion_policy/tests/test_ring_buffer.py
2023-03-07 16:07:15 -05:00

189 lines
5.1 KiB
Python

import sys
import os
ROOT_DIR = os.path.dirname(os.path.dirname(__file__))
sys.path.append(ROOT_DIR)
os.chdir(ROOT_DIR)
import time
import numpy as np
import multiprocessing as mp
from multiprocessing.managers import SharedMemoryManager
from diffusion_policy.shared_memory.shared_memory_ring_buffer import (
SharedMemoryRingBuffer,
SharedAtomicCounter)
def test():
shm_manager = SharedMemoryManager()
shm_manager.start()
ring_buffer = SharedMemoryRingBuffer.create_from_examples(
shm_manager,
{'timestamp': np.array(0, dtype=np.float64)},
buffer_size=128
)
for i in range(30):
ring_buffer.put({
'timestamp': np.array(
time.perf_counter(),
dtype=np.float64)
})
print(ring_buffer.get())
def _timestamp_worker(ring_buffer, start_event, stop_event):
while not stop_event.is_set():
start_event.set()
ring_buffer.put({
'timestamp': np.array(
time.time(),
dtype=np.float64)
})
def test_mp():
shm_manager = SharedMemoryManager()
shm_manager.start()
ring_buffer = SharedMemoryRingBuffer.create_from_examples(
shm_manager,
{'timestamp': np.array(0, dtype=np.float64)},
get_max_k=1,
get_time_budget=0.01,
put_desired_frequency=1000
)
start_event = mp.Event()
stop_event = mp.Event()
worker = mp.Process(target=_timestamp_worker, args=(
ring_buffer, start_event, stop_event))
worker.start()
start_event.wait()
for i in range(1000):
t = float(ring_buffer.get()['timestamp'])
curr_t = time.time()
print('latency', curr_t - t)
stop_event.set()
worker.join()
def test_get_last_k():
shm_manager = SharedMemoryManager()
shm_manager.start()
ring_buffer = SharedMemoryRingBuffer.create_from_examples(
shm_manager,
{'counter': np.array(0, dtype=np.int64)},
buffer_size=8
)
from collections import deque
k = 4
last_k = deque(maxlen=k)
for i in range(100):
ring_buffer.put({
'counter': np.array(i, dtype=np.int64)
})
last_k.append(i)
if i > k:
result = ring_buffer.get_last_k(k)['counter']
assert np.allclose(result, last_k)
print(ring_buffer.shared_arrays['counter'].get())
result = ring_buffer.get_last_k(4)
print(result)
def test_timing():
shm_manager = SharedMemoryManager()
shm_manager.start()
ring_buffer = SharedMemoryRingBuffer.create_from_examples(
shm_manager,
{'counter': np.array(0, dtype=np.int64)},
get_max_k=8,
get_time_budget=0.1,
put_desired_frequency=100
)
# print(ring_buffer.timestamp_array.get())
print('buffer_size', ring_buffer.buffer_size)
dt = 1 / 150
t_init = time.monotonic()
for i in range(1000):
t_start = time.monotonic()
ring_buffer.put({
'counter': np.array(i, dtype=np.int64)
}, wait=False)
if (i % 10 == 0) and (i > 0):
result = ring_buffer.get_last_k(8)
t_end =time.monotonic()
desired_t = (i+1) * dt + t_init
if desired_t > t_end:
time.sleep(desired_t - t_end)
hz = 1 / (time.monotonic() - t_start)
print(f'{hz}Hz')
def _timestamp_image_worker(ring_buffer, img_shape, dt, start_event, stop_event):
i = 0
t_init = time.monotonic()
image = np.ones(img_shape, dtype=np.uint8)
while not stop_event.is_set():
t_start = time.monotonic()
start_event.set()
ring_buffer.put({
'img': image,
'timestamp': time.time(),
'counter': i
})
t_end = time.monotonic()
desired_t = (i+1) * dt + t_init
# print('alive')
if desired_t > t_end:
time.sleep(desired_t - t_end)
# hz = 1 / (time.monotonic() - t_start)
i += 1
def test_timing_mp():
shm_manager = SharedMemoryManager()
shm_manager.start()
hz = 200
img_shape = (1920,1080,3)
ring_buffer = SharedMemoryRingBuffer.create_from_examples(
shm_manager,
examples={
'img': np.zeros(img_shape, dtype=np.uint8),
'timestamp': time.time(),
'counter': 0
},
get_max_k=60,
get_time_budget=0.02,
put_desired_frequency=hz
)
start_event = mp.Event()
stop_event = mp.Event()
worker = mp.Process(target=_timestamp_image_worker, args=(
ring_buffer, img_shape, 1/hz, start_event, stop_event))
worker.start()
start_event.wait()
out = None
t_start = time.monotonic()
k = 1
for i in range(1000):
if ring_buffer.count < k:
time.sleep(0)
continue
out = ring_buffer.get_last_k(k=k, out=out)
t = float(out['timestamp'][-1])
curr_t = time.time()
print('latency', curr_t - t)
t_end = time.monotonic()
print('Get Hz', 1/(t_end-t_start)*1000)
stop_event.set()
worker.join()
if __name__ == '__main__':
# test_mp()
test_timing_mp()