Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

✨ use shared buffers for multi processing #2324

Open
joocer opened this issue Jan 26, 2025 · 0 comments
Open

✨ use shared buffers for multi processing #2324

joocer opened this issue Jan 26, 2025 · 0 comments

Comments

@joocer
Copy link
Contributor

joocer commented Jan 26, 2025

import pyarrow as pa
import multiprocessing.shared_memory as shm
import numpy as np
import multiprocessing

def create_shared_arrow_array():
    """Creates a large Arrow array and stores it in shared memory."""
    data = np.arange(1_000_000, dtype=np.int32)  # Example large dataset
    array = pa.array(data)

    # Extract raw buffer and store in shared memory
    buf = array.buffers()[1]  # First buffer is null bitmap, second is data
    shared_mem = shm.SharedMemory(create=True, size=buf.size)
    shared_mem.buf[:buf.size] = buf  # Copy data to shared memory

    return shared_mem.name, array.type, array.length  # Return details for reconstruction

def worker(shared_mem_name, dtype, start, end):
    """Worker function to attach to shared memory and slice data."""
    existing_shm = shm.SharedMemory(name=shared_mem_name)
    buf = pa.py_buffer(existing_shm.buf)  # Wrap in Arrow buffer

    # Reconstruct Arrow array from shared buffer
    array = pa.Array.from_buffers(dtype, end - start, [None, buf], offset=start)
    
    result = array.to_numpy().sum()  # Example computation
    existing_shm.close()
    return result

def main():
    shared_mem_name, dtype, length = create_shared_arrow_array()
    num_workers = 4
    chunk_size = length // num_workers

    # Define process pool
    with multiprocessing.Pool(num_workers) as pool:
        results = pool.starmap(worker, [
            (shared_mem_name, dtype, i * chunk_size, (i + 1) * chunk_size)
            for i in range(num_workers)
        ])

    print("Results:", results)
    print("Final Sum:", sum(results))

    # Cleanup shared memory
    shm.SharedMemory(name=shared_mem_name).unlink()

if __name__ == "__main__":
    main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant