Skip to content

Commit

Permalink
Add tests for series extension, finish initial vtable impl
Browse files Browse the repository at this point in the history
  • Loading branch information
PThorpe92 committed Feb 5, 2025
1 parent 82abf3f commit b3cc7b5
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 22 deletions.
6 changes: 3 additions & 3 deletions core/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -270,8 +270,8 @@ impl Connection {
let sql = sql.as_ref();
trace!("Preparing: {}", sql);
let db = &self.db;
let syms: &SymbolTable = &db.syms.borrow();
let mut parser = Parser::new(sql.as_bytes());
let syms = &db.syms.borrow();
let cmd = parser.next()?;
if let Some(cmd) = cmd {
match cmd {
Expand Down Expand Up @@ -361,7 +361,7 @@ impl Connection {
pub fn execute(self: &Rc<Connection>, sql: impl AsRef<str>) -> Result<()> {
let sql = sql.as_ref();
let db = &self.db;
let syms: &SymbolTable = &db.syms.borrow_mut();
let syms: &SymbolTable = &db.syms.borrow();
let mut parser = Parser::new(sql.as_bytes());
let cmd = parser.next()?;
if let Some(cmd) = cmd {
Expand Down Expand Up @@ -552,7 +552,7 @@ impl VirtualTable {
};
match rc {
ResultCode::OK => Ok(()),
_ => Err(LimboError::ExtensionError("Filter failed".to_string())),
_ => Err(LimboError::ExtensionError(rc.to_string())),
}
}

Expand Down
2 changes: 1 addition & 1 deletion extensions/core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ pub trait VTabModule: 'static {
}

pub trait VTabCursor: Sized {
type Error;
type Error: std::fmt::Display;
fn rowid(&self) -> i64;
fn column(&self, idx: u32) -> Value;
fn eof(&self) -> bool;
Expand Down
21 changes: 6 additions & 15 deletions extensions/series/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,25 +88,16 @@ struct GenerateSeriesCursor {
error: Option<ResultCode>,
}

impl GenerateSeriesCursor {
fn next(&mut self) -> ResultCode {
let current = self.current;

// Check if we've reached the end
if (self.step > 0 && current >= self.stop) || (self.step < 0 && current <= self.stop) {
return ResultCode::EOF;
}

self.current = current.saturating_add(self.step);
ResultCode::OK
}
}

impl VTabCursor for GenerateSeriesCursor {
type Error = ResultCode;

fn next(&mut self) -> ResultCode {
GenerateSeriesCursor::next(self)
let next_val = self.current.saturating_add(self.step);
if (self.step > 0 && next_val > self.stop) || (self.step < 0 && next_val < self.stop) {
return ResultCode::EOF;
}
self.current = next_val;
ResultCode::OK
}

fn eof(&self) -> bool {
Expand Down
3 changes: 3 additions & 0 deletions macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,9 @@ pub fn derive_vtab_module(input: TokenStream) -> TokenStream {
return ::limbo_ext::Value::error(ResultCode::Error);
}
let cursor = unsafe { &mut *(cursor as *mut <#struct_name as ::limbo_ext::VTabModule>::VCursor) };
if let Some(err) = <#struct_name as ::limbo_ext::VTabModule>::VCursor::error(cursor) {
return ::limbo_ext::Value::error_with_message(err.to_string());
}
<#struct_name as ::limbo_ext::VTabModule>::column(cursor, idx)
}

Expand Down
42 changes: 39 additions & 3 deletions testing/extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,10 +115,14 @@ def validate_string_uuid(result):
return len(result) == 36 and result.count("-") == 4


def returns_error(result):
def returns_error_no_func(result):
return "error: no such function: " in result


def returns_vtable_parse_err(result):
return "Parse error: Virtual table" in result


def returns_null(result):
return result == "" or result == "\n"

Expand Down Expand Up @@ -167,7 +171,7 @@ def test_regexp(pipe):
extension_path = "./target/debug/liblimbo_regexp.so"

# before extension loads, assert no function
run_test(pipe, "SELECT regexp('a.c', 'abc');", returns_error)
run_test(pipe, "SELECT regexp('a.c', 'abc');", returns_error_no_func)
run_test(pipe, f".load {extension_path}", returns_null)
print(f"Extension {extension_path} loaded successfully.")
run_test(pipe, "SELECT regexp('a.c', 'abc');", validate_true)
Expand Down Expand Up @@ -214,7 +218,7 @@ def test_aggregates(pipe):
run_test(
pipe,
"SELECT median(1);",
returns_error,
returns_error_no_func,
"median agg function returns null when ext not loaded",
)
run_test(
Expand Down Expand Up @@ -256,12 +260,44 @@ def test_aggregates(pipe):
)


def test_series(pipe):
ext_path = "./target/debug/liblimbo_series"
run_test(
pipe,
"SELECT * FROM generate_series(1, 10);",
lambda res: "Virtual table generate_series not found" in res,
)
run_test(pipe, f".load {ext_path}", returns_null)
run_test(
pipe,
"SELECT * FROM generate_series(1, 10);",
lambda res: "Invalid Argument" in res,
)
run_test(
pipe,
"SELECT * FROM generate_series(1, 10, 2);",
lambda res: res == "1\n3\n5\n7\n9",
)
run_test(
pipe,
"SELECT * FROM generate_series(1, 10, 2, 3);",
lambda res: "Invalid Argument" in res,
)
run_test(
pipe,
"SELECT * FROM generate_series(10, 1, -2);",
lambda res: res == "10\n8\n6\n4\n2",
)


def main():
pipe = init_limbo()
try:
test_regexp(pipe)
test_uuid(pipe)
test_aggregates(pipe)
test_series(pipe)

except Exception as e:
print(f"Test FAILED: {e}")
pipe.terminate()
Expand Down

0 comments on commit b3cc7b5

Please sign in to comment.