Equidiff / equidiff /equi_diffpo /shared_memory /shared_memory_queue.py
Lillianwei's picture
mimicgen
c1f1d32
from typing import Dict, List, Union
import numbers
from queue import (Empty, Full)
from multiprocessing.managers import SharedMemoryManager
import numpy as np
from equi_diffpo.shared_memory.shared_memory_util import ArraySpec, SharedAtomicCounter
from equi_diffpo.shared_memory.shared_ndarray import SharedNDArray
class SharedMemoryQueue:
"""
A Lock-Free FIFO Shared Memory Data Structure.
Stores a sequence of dict of numpy arrays.
"""
def __init__(self,
shm_manager: SharedMemoryManager,
array_specs: List[ArraySpec],
buffer_size: int
):
# create atomic counter
write_counter = SharedAtomicCounter(shm_manager)
read_counter = SharedAtomicCounter(shm_manager)
# allocate shared memory
shared_arrays = dict()
for spec in array_specs:
key = spec.name
assert key not in shared_arrays
array = SharedNDArray.create_from_shape(
mem_mgr=shm_manager,
shape=(buffer_size,) + tuple(spec.shape),
dtype=spec.dtype)
shared_arrays[key] = array
self.buffer_size = buffer_size
self.array_specs = array_specs
self.write_counter = write_counter
self.read_counter = read_counter
self.shared_arrays = shared_arrays
@classmethod
def create_from_examples(cls,
shm_manager: SharedMemoryManager,
examples: Dict[str, Union[np.ndarray, numbers.Number]],
buffer_size: int
):
specs = list()
for key, value in examples.items():
shape = None
dtype = None
if isinstance(value, np.ndarray):
shape = value.shape
dtype = value.dtype
assert dtype != np.dtype('O')
elif isinstance(value, numbers.Number):
shape = tuple()
dtype = np.dtype(type(value))
else:
raise TypeError(f'Unsupported type {type(value)}')
spec = ArraySpec(
name=key,
shape=shape,
dtype=dtype
)
specs.append(spec)
obj = cls(
shm_manager=shm_manager,
array_specs=specs,
buffer_size=buffer_size
)
return obj
def qsize(self):
read_count = self.read_counter.load()
write_count = self.write_counter.load()
n_data = write_count - read_count
return n_data
def empty(self):
n_data = self.qsize()
return n_data <= 0
def clear(self):
self.read_counter.store(self.write_counter.load())
def put(self, data: Dict[str, Union[np.ndarray, numbers.Number]]):
read_count = self.read_counter.load()
write_count = self.write_counter.load()
n_data = write_count - read_count
if n_data >= self.buffer_size:
raise Full()
next_idx = write_count % self.buffer_size
# write to shared memory
for key, value in data.items():
arr: np.ndarray
arr = self.shared_arrays[key].get()
if isinstance(value, np.ndarray):
arr[next_idx] = value
else:
arr[next_idx] = np.array(value, dtype=arr.dtype)
# update idx
self.write_counter.add(1)
def get(self, out=None) -> Dict[str, np.ndarray]:
write_count = self.write_counter.load()
read_count = self.read_counter.load()
n_data = write_count - read_count
if n_data <= 0:
raise Empty()
if out is None:
out = self._allocate_empty()
next_idx = read_count % self.buffer_size
for key, value in self.shared_arrays.items():
arr = value.get()
np.copyto(out[key], arr[next_idx])
# update idx
self.read_counter.add(1)
return out
def get_k(self, k, out=None) -> Dict[str, np.ndarray]:
write_count = self.write_counter.load()
read_count = self.read_counter.load()
n_data = write_count - read_count
if n_data <= 0:
raise Empty()
assert k <= n_data
out = self._get_k_impl(k, read_count, out=out)
self.read_counter.add(k)
return out
def get_all(self, out=None) -> Dict[str, np.ndarray]:
write_count = self.write_counter.load()
read_count = self.read_counter.load()
n_data = write_count - read_count
if n_data <= 0:
raise Empty()
out = self._get_k_impl(n_data, read_count, out=out)
self.read_counter.add(n_data)
return out
def _get_k_impl(self, k, read_count, out=None) -> Dict[str, np.ndarray]:
if out is None:
out = self._allocate_empty(k)
curr_idx = read_count % self.buffer_size
for key, value in self.shared_arrays.items():
arr = value.get()
target = out[key]
start = curr_idx
end = min(start + k, self.buffer_size)
target_start = 0
target_end = (end - start)
target[target_start: target_end] = arr[start:end]
remainder = k - (end - start)
if remainder > 0:
# wrap around
start = 0
end = start + remainder
target_start = target_end
target_end = k
target[target_start: target_end] = arr[start:end]
return out
def _allocate_empty(self, k=None):
result = dict()
for spec in self.array_specs:
shape = spec.shape
if k is not None:
shape = (k,) + shape
result[spec.name] = np.empty(
shape=shape, dtype=spec.dtype)
return result