Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[pthread] init async gpu -> cpu #49

Merged
merged 8 commits into from
Nov 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 18 additions & 3 deletions csrc/aio.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
#include <stdexcept>
#include <memory>
#include "aio.h"

AIOAsyncIO::AIOAsyncIO(unsigned int n_entries)
Expand Down Expand Up @@ -126,4 +124,21 @@ void AIOAsyncIO::readv(int fd, const iovec *iov, unsigned int iovcnt, unsigned l
io_submit(this->io_ctx, 1, &iocbs); /* 提交这个I/O不会堵塞 */

this->n_read_events++;
}
}

void AIOAsyncIO::write_tensor(int fd, torch::Tensor t, unsigned long long offset, callback_t callback, std::optional<torch::Tensor> pinned) {
if (t.is_cuda()) {
if (pinned.has_value()) {
pinned.value().copy_(t);
t = pinned.value();
} else {
t = t.to(torch::kCPU);
}
}
void *buffer = t.data_ptr();
size_t n_bytes = t.numel() * t.element_size();
this->write(fd, buffer, n_bytes, offset, callback);
}

void AIOAsyncIO::register_h2d(unsigned int num_tensors) {}
void AIOAsyncIO::sync_h2d() {}
15 changes: 12 additions & 3 deletions csrc/async_file_io.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
#include "asyncio.h"
#include "async_file_io.h"
#include "backend.h"
#include <stdexcept>

AsyncFileWriter::AsyncFileWriter(int fd, unsigned int n_entries, const std::string &backend) : fd(fd), aio(create_asyncio(n_entries, backend)) {}

Expand All @@ -11,6 +8,18 @@ void AsyncFileWriter::write(size_t buffer, size_t n_bytes, unsigned long long of
this->aio->write(this->fd, ptr, n_bytes, offset, callback);
}

void AsyncFileWriter::write_tensor(torch::Tensor tensor, unsigned long long offset, callback_t callback, std::optional<torch::Tensor> pinned) {
this->aio->write_tensor(this->fd, tensor, offset, callback, pinned);
}

void AsyncFileWriter::register_h2d(unsigned int num_tensors) {
this->aio->register_h2d(num_tensors);
}

void AsyncFileWriter::sync_h2d() {
this->aio->sync_h2d();
}

void AsyncFileWriter::synchronize()
{
this->aio->synchronize();
Expand Down
47 changes: 46 additions & 1 deletion csrc/pthread_backend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,4 +76,49 @@ void PthreadAsyncIO::synchronize() {
this->get_event(WAIT);
}

void PthreadAsyncIO::register_file(int fd) {}
void PthreadAsyncIO::register_file(int fd) {}

void PthreadAsyncIO::register_h2d(unsigned int num_tensors) {
this->h2d_in_progress.store(num_tensors); // register tensors to write for this run
}

void PthreadAsyncIO::sync_h2d() {
std::unique_lock<std::mutex> lock(this->mtx);
this->cv.wait(lock, [this] { return this->h2d_in_progress == 0; }); // block until all in-progress h2d are completed
}

void PthreadAsyncIO::write_tensor(int fd, torch::Tensor t, unsigned long long offset, callback_t callback, std::optional<torch::Tensor> pinned) {
auto stream = c10::cuda::getCurrentCUDAStream();
if (!t.is_cuda()) {
this->h2d_in_progress.fetch_sub(1); // already moved to cpu
if (this->h2d_in_progress.load() == 0) { // notify when all h2d are completed and safe to optimizer.step()
std::lock_guard<std::mutex> lock(this->mtx);
cv.notify_one();
}
}
auto fut = this->pool.submit_task(
[this, fd, t, offset, pinned, stream] {
torch::Tensor cpu_tensor;
if (t.is_cuda()) {
at::cuda::CUDAStreamGuard guard(stream); // https://pytorch.org/cppdocs/notes/tensor_cuda_stream.html
if (pinned.has_value()) {
pinned.value().copy_(t, /*non_blocking*/ false);
cpu_tensor = pinned.value();
} else {
cpu_tensor = t.to(t.options().device(c10::DeviceType::CPU), /*non_blocking*/ false, /*copy*/ false); // modified from torch::Tensor::cpu()
}
this->h2d_in_progress.fetch_sub(1);
if (this->h2d_in_progress.load() == 0) { // notify when all h2d are completed and safe to optimizer.step()
std::lock_guard<std::mutex> lock(this->mtx);
cv.notify_one();
}
} else {
cpu_tensor = t;
}
void *buf = cpu_tensor.data_ptr();
size_t n_bytes = cpu_tensor.numel() * cpu_tensor.element_size();
return pwrite(fd, buf, n_bytes, offset);
}
);
this->write_fut.push_back(std::make_tuple(std::move(fut), callback));
}
5 changes: 4 additions & 1 deletion csrc/py_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
py::class_<AsyncFileWriter>(m, "AsyncFileWriter")
.def(py::init<int, unsigned int, const std::string &>(), py::arg("fd"), py::arg("n_entries"), py::arg("backend") = "aio")
.def("write", &AsyncFileWriter::write, py::arg("buffer"), py::arg("n_bytes"), py::arg("offset"), py::arg("callback") = py::none())
.def("synchronize", &AsyncFileWriter::synchronize);
.def("write_tensor", &AsyncFileWriter::write_tensor, py::arg("tensor"), py::arg("offset"), py::arg("callback") = py::none(), py::arg("pinned") = py::none())
.def("synchronize", &AsyncFileWriter::synchronize)
.def("sync_h2d", &AsyncFileWriter::sync_h2d)
.def("register_h2d", &AsyncFileWriter::register_h2d, py::arg("num_tensors"));
}
19 changes: 18 additions & 1 deletion csrc/uring.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,4 +97,21 @@ void UringAsyncIO::readv(int fd, const iovec *iov, unsigned int iovcnt, unsigned
io_uring_sqe_set_data(sqe, data);
io_uring_submit(&this->ring);
this->n_read_events++;
}
}

void UringAsyncIO::write_tensor(int fd, torch::Tensor t, unsigned long long offset, callback_t callback, std::optional<torch::Tensor> pinned) {
if (t.is_cuda()) {
if (pinned.has_value()) {
pinned.value().copy_(t);
t = pinned.value();
} else {
t = t.to(torch::kCPU);
}
}
void *buffer = t.data_ptr<float>();
size_t n_bytes = t.numel() * t.element_size();
this->write(fd, buffer, n_bytes, offset, callback);
}

void UringAsyncIO::register_h2d(unsigned int num_tensors) {}
void UringAsyncIO::sync_h2d() {}
6 changes: 6 additions & 0 deletions include/aio.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
#pragma once

#include <libaio.h>
#include <torch/torch.h>
#include <stdexcept>
#include <memory>
#include "asyncio.h"

class AIOAsyncIO : public AsyncIO
Expand All @@ -24,9 +27,12 @@ class AIOAsyncIO : public AsyncIO
void writev(int fd, const iovec *iov, unsigned int iovcnt, unsigned long long offset, callback_t callback);
void readv(int fd, const iovec *iov, unsigned int iovcnt, unsigned long long offset, callback_t callback);

void register_h2d(unsigned int num_tensors);
void sync_h2d();
void sync_write_events();
void sync_read_events();
void synchronize();

void register_file(int fd);
void write_tensor(int fd, torch::Tensor t, unsigned long long offset, callback_t callback, std::optional<torch::Tensor> pinned);
};
9 changes: 9 additions & 0 deletions include/async_file_io.h
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
#pragma once
#include <string>
#include <torch/torch.h>
#include <optional>

#include "asyncio.h"
#include "backend.h"

#ifndef DISABLE_URING
#include "uring.h"
#endif

#ifndef DISABLE_AIO
#include "aio.h"
#endif
Expand All @@ -13,7 +19,10 @@ class AsyncFileWriter
public:
AsyncFileWriter(int fd, unsigned int n_entries, const std::string &backend);
void write(size_t buffer, size_t n_bytes, unsigned long long offset, callback_t callback);
void write_tensor(torch::Tensor tensor, unsigned long long offset, callback_t callback, std::optional<torch::Tensor> pinned);
void synchronize();
void register_h2d(unsigned int num_tensors);
void sync_h2d();
~AsyncFileWriter();

private:
Expand Down
4 changes: 4 additions & 0 deletions include/asyncio.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include <fcntl.h>
#include <functional>
#include <torch/torch.h>

using callback_t = std::function<void()>;

Expand Down Expand Up @@ -44,7 +45,10 @@ class AsyncIO
virtual void get_event(WaitType wt) = 0;
virtual void sync_write_events() = 0;
virtual void sync_read_events() = 0;
virtual void register_h2d(unsigned int num_tensors) = 0;
virtual void sync_h2d() = 0;
virtual void synchronize() = 0;

virtual void register_file(int fd) = 0;
virtual void write_tensor(int fd, torch::Tensor t, unsigned long long offset, callback_t callback, std::optional<torch::Tensor> pinned) = 0;
};
15 changes: 14 additions & 1 deletion include/pthread_backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@
#include <queue>
#include <tuple>
#include <functional>
#include <iostream>
#include <c10/cuda/CUDAStream.h>
#include <c10/cuda/CUDAGuard.h>
#include <atomic>
#include <condition_variable>
#include <mutex>

#include "asyncio.h"
#include "threadpool.hpp"
Expand All @@ -18,12 +24,15 @@ class PthreadAsyncIO : public AsyncIO
{
private:
BS::thread_pool pool;
std::atomic<unsigned int> h2d_in_progress;
std::condition_variable cv;
std::mutex mtx;
std::deque<std::tuple<std::future<ssize_t>, callback_t>> write_fut;
std::deque<std::tuple<std::future<ssize_t>, callback_t>> read_fut;

public:
PthreadAsyncIO(unsigned int n_entries)
: pool(n_entries) {}
: pool(n_entries), h2d_in_progress(0) {}

~PthreadAsyncIO() {}

Expand All @@ -35,7 +44,11 @@ class PthreadAsyncIO : public AsyncIO
void get_event(WaitType wt);
void sync_write_events();
void sync_read_events();
void register_h2d(unsigned int num_tensors);
void sync_h2d();
void synchronize();

void register_file(int fd);

void write_tensor(int fd, torch::Tensor t, unsigned long long offset, callback_t callback, std::optional<torch::Tensor> pinned);
};
3 changes: 3 additions & 0 deletions include/uring.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,12 @@ class UringAsyncIO : public AsyncIO
void writev(int fd, const iovec *iov, unsigned int iovcnt, unsigned long long offset, callback_t callback);
void readv(int fd, const iovec *iov, unsigned int iovcnt, unsigned long long offset, callback_t callback);

void register_h2d(unsigned int num_tensors);
void sync_h2d();
void sync_write_events();
void sync_read_events();
void synchronize();

void register_file(int fd);
void write_tensor(int fd, torch::Tensor t, unsigned long long offset, callback_t callback, std::optional<torch::Tensor> pinned);
};
1 change: 1 addition & 0 deletions tensornvme/_C/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,5 @@ def probe_backend(backend: str) -> bool: ...
class AsyncFileWriter:
def __init__(self, fd: int, n_entries: int, backend: str = "aio") -> None: ...
def write(self, buffer: int, n_bytes: int, offset: int, callback: Optional[Callable[[], None]] = None) -> None: ...
def write_tensor(self, tensor: Tensor, offset: int, callback: Optional[Callable[[], None]] = None, pinned: Optional[Tensor] = None) -> None: ...
def synchronize(self) -> None: ...
18 changes: 16 additions & 2 deletions tensornvme/async_file_io.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import ctypes
import torch
from functools import partial

from typing import List
from torch import Tensor
from typing import List, Optional
from io import IOBase
from tensornvme._C import AsyncFileWriter as AsyncFileWriterC

Expand All @@ -16,6 +17,7 @@ def __init__(self, fp: IOBase, n_entries: int = 16, backend=None) -> None:
self.offset = 0
# must ensure the data is not garbage collected
self.buffers = []
self.comm_stream = torch.cuda.Stream()

def write(self, data: bytes) -> int:
ptr = ctypes.cast(data, ctypes.POINTER(ctypes.c_char))
Expand All @@ -31,6 +33,18 @@ def write_raw(self, py_ref: object, buffer: int, n_bytes: int, offset: int) -> N
self.io.write(buffer, n_bytes, offset, partial(AsyncFileWriter.gc_callback, self.buffers, len(self.buffers) - 1))
self.offset += n_bytes

def write_tensor(self, tensor: Tensor, pinned: Optional[Tensor] = None) -> None:
with torch.cuda.stream(self.comm_stream):
self.buffers.append(tensor) # append before callback is called
self.io.write_tensor(tensor, self.offset, partial(AsyncFileWriter.gc_callback, self.buffers, len(self.buffers) - 1), pinned)
self.offset += tensor.numel() * tensor.element_size()

def register_h2d(self, num_tensors: int) -> None:
self.io.register_h2d(num_tensors)

def sync_before_step(self):
self.io.sync_h2d()

@staticmethod
def gc_callback(listt: List, idx: int) -> None:
listt[idx] = None
Expand Down