diff --git a/Cargo.toml b/Cargo.toml index 91c1d9b..4ccb57e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "scyllapy" -version = "1.0.7" +version = "1.1.0" edition = "2021" [lib] @@ -14,7 +14,13 @@ chrono = "0.4.26" eq-float = "0.1.0" log = "0.4.20" openssl = { version = "0.10.56", features = ["vendored"] } -pyo3 = { version = "0.19.2", features = ["auto-initialize", "anyhow", "abi3-py38", "extension-module", "chrono"] } +pyo3 = { version = "0.19.2", features = [ + "auto-initialize", + "anyhow", + "abi3-py38", + "extension-module", + "chrono", +] } pyo3-asyncio = { version = "0.19.0", features = ["tokio-runtime"] } pyo3-log = "0.8.3" scylla = { version = "0.9.0", features = ["ssl"] } diff --git a/README.md b/README.md index f8103de..9a5cbda 100644 --- a/README.md +++ b/README.md @@ -235,4 +235,71 @@ async def execute(scylla: Scylla) -> None: "INSERT INTO table(id, name) VALUES (?, ?)", [extra_types.BigInt(1), "memelord"], ) +``` + + +# Query building + +ScyllaPy gives you ability to build queries, +instead of working with raw cql. The main advantage that it's harder to make syntax error, +while creating queries. + +Base classes for Query building can be found in `scyllapy.query_builder`. + +Usage example: + +```python +from scyllapy import Scylla +from scyllapy.query_builder import Insert, Select, Update, Delete + + +async def main(scylla: Scylla): + await scylla.execute("CREATE TABLE users(id INT PRIMARY KEY, name TEXT)") + + user_id = 1 + + # We create a user with id and name. + await Insert("users").set("id", user_id).set( + "name", "user" + ).if_not_exists().execute(scylla) + + # We update it's name to be user2 + await Update("users").set("name", "user2").where("id = ?", [user_id]).execute( + scylla + ) + + # We select all users with id = user_id; + res = await Select("users").where("id = ?", [user_id]).execute(scylla) + # Verify that it's correct. + assert res.first() == {"id": 1, "name": "user2"} + + # We delete our user. + await Delete("users").where("id = ?", [user_id]).if_exists().execute(scylla) + + res = await Select("users").where("id = ?", [user_id]).execute(scylla) + + # Verify that user is deleted. + assert not res.all() + + await scylla.execute("DROP TABLE users") + +``` + +Also, you can pass built queries into InlineBatches. You cannot use queries built with query_builder module with default batches. This constraint is exists, because we +need to use values from within your queries and should ignore all parameters passed in +`batch` method of scylla. + +Here's batch usage example. + +```python +from scyllapy import Scylla, InlineBatch +from scyllapy.query_builder import Insert + + +async def execute_batch(scylla: Scylla) -> None: + batch = InlineBatch() + for i in range(10): + Insert("users").set("id", i).set("name", "test").add_to_batch(batch) + await scylla.batch(batch) + ``` \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index a75db0b..e94027c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,6 +21,7 @@ classifiers = [ [tool.maturin] python-source = "python" module-name = "scyllapy._internal" +features = ["pyo3/extension-module"] [build-system] requires = ["maturin>=1.0,<2.0"] diff --git a/python/scyllapy/__init__.py b/python/scyllapy/__init__.py index d675d92..0e3dec8 100644 --- a/python/scyllapy/__init__.py +++ b/python/scyllapy/__init__.py @@ -6,8 +6,10 @@ PreparedQuery, Batch, BatchType, - extra_types, + QueryResult, + InlineBatch, ) + from importlib.metadata import version __version__ = version("scyllapy") @@ -21,5 +23,8 @@ "PreparedQuery", "Batch", "BatchType", + "QueryResult", "extra_types", + "InlineBatch", + "query_builder", ] diff --git a/python/scyllapy/_internal/__init__.pyi b/python/scyllapy/_internal/__init__.pyi index ef2e7f3..e14722d 100644 --- a/python/scyllapy/_internal/__init__.pyi +++ b/python/scyllapy/_internal/__init__.pyi @@ -86,7 +86,7 @@ class Scylla: """ async def batch( self, - batch: Batch, + batch: Batch | InlineBatch, params: Optional[Iterable[Iterable[Any] | dict[str, Any]]] = None, ) -> QueryResult: """ @@ -166,12 +166,6 @@ class BatchType: class Batch: """Class for batching queries together.""" - consistency: Consistency | None - serial_consistency: SerialConsistency | None - request_timeout: int | None - is_idempotent: bool | None - tracing: bool | None - def __init__( self, batch_type: BatchType = BatchType.UNLOGGED, @@ -184,6 +178,23 @@ class Batch: ) -> None: ... def add_query(self, query: Query | PreparedQuery | str) -> None: ... +class InlineBatch: + def __init__( + self, + batch_type: BatchType = BatchType.UNLOGGED, + consistency: Consistency | None = None, + serial_consistency: SerialConsistency | None = None, + request_timeout: int | None = None, + timestamp: int | None = None, + is_idempotent: bool | None = None, + tracing: bool | None = None, + ) -> None: ... + def add_query( + self, + query: Query | PreparedQuery | str, + values: list[Any] | None = None, + ) -> None: ... + class Consistency: """Consistency for query.""" diff --git a/python/scyllapy/_internal/query_builder.pyi b/python/scyllapy/_internal/query_builder.pyi new file mode 100644 index 0000000..948a655 --- /dev/null +++ b/python/scyllapy/_internal/query_builder.pyi @@ -0,0 +1,94 @@ +from typing import Any + +from scyllapy._internal import ( + Consistency, + SerialConsistency, + Scylla, + QueryResult, + InlineBatch, +) + +class Select: + def __init__(self, table: str) -> None: ... + def only(self, *columns: str) -> Select: ... + def where(self, clause: str, params: list[Any] | None = None) -> Select: ... + def group_by(self, group: str) -> Select: ... + def order_by(self, order: str, desc: bool = False) -> Select: ... + def per_partition_limit(self, per_partition_limit: int) -> Select: ... + def limit(self, limit: int) -> Select: ... + def allow_filtering(self) -> Select: ... + def distinct(self) -> Select: ... + def timeout(self, timeout: int | str) -> Select: ... + def request_params( + self, + consistency: Consistency | None = None, + serial_consistency: SerialConsistency | None = None, + request_timeout: int | None = None, + timestamp: int | None = None, + is_idempotent: bool | None = None, + tracing: bool | None = None, + ) -> Select: ... + def add_to_batch(self, batch: InlineBatch) -> None: ... + async def execute(self, scylla: Scylla) -> QueryResult: ... + +class Insert: + def __init__(self, table: str) -> None: ... + def if_not_exists(self) -> Insert: ... + def set(self, name: str, value: Any) -> Insert: ... + def timeout(self, timeout: int | str) -> Insert: ... + def timestamp(self, timestamp: int) -> Insert: ... + def ttl(self, ttl: int) -> Insert: ... + def request_params( + self, + consistency: Consistency | None = None, + serial_consistency: SerialConsistency | None = None, + request_timeout: int | None = None, + timestamp: int | None = None, + is_idempotent: bool | None = None, + tracing: bool | None = None, + ) -> Insert: ... + def add_to_batch(self, batch: InlineBatch) -> None: ... + async def execute(self, scylla: Scylla) -> QueryResult: ... + +class Delete: + def __init__(self, table: str) -> None: ... + def cols(self, *cols: str) -> Delete: ... + def where(self, clause: str, values: list[Any] | None = None) -> Delete: ... + def timeout(self, timeout: int | str) -> Delete: ... + def timestamp(self, timestamp: int) -> Delete: ... + def if_exists(self) -> Delete: ... + def if_(self, clause: str, values: list[Any] | None = None) -> Delete: ... + def request_params( + self, + consistency: Consistency | None = None, + serial_consistency: SerialConsistency | None = None, + request_timeout: int | None = None, + timestamp: int | None = None, + is_idempotent: bool | None = None, + tracing: bool | None = None, + ) -> Delete: ... + def add_to_batch(self, batch: InlineBatch) -> None: ... + async def execute(self, scylla: Scylla) -> QueryResult: ... + +class Update: + def __init__(self, table: str) -> None: ... + def set(self, name: str, value: Any) -> Update: ... + def inc(self, column: str, value: Any) -> Update: ... + def dec(self, column: str, value: Any) -> Update: ... + def where(self, clause: str, values: list[Any] | None = None) -> Update: ... + def timeout(self, timeout: int | str) -> Update: ... + def timestamp(self, timestamp: int) -> Update: ... + def ttl(self, ttl: int) -> Update: ... + def request_params( + self, + consistency: Consistency | None = None, + serial_consistency: SerialConsistency | None = None, + request_timeout: int | None = None, + timestamp: int | None = None, + is_idempotent: bool | None = None, + tracing: bool | None = None, + ) -> Update: ... + def if_exists(self) -> Update: ... + def if_(self, clause: str, values: list[Any] | None = None) -> Update: ... + def add_to_batch(self, batch: InlineBatch) -> None: ... + async def execute(self, scylla: Scylla) -> QueryResult: ... diff --git a/python/scyllapy/extra_types.py b/python/scyllapy/extra_types.py new file mode 100644 index 0000000..3c8af81 --- /dev/null +++ b/python/scyllapy/extra_types.py @@ -0,0 +1,3 @@ +from ._internal.extra_types import BigInt, Counter, Double, SmallInt, TinyInt, Unset + +__all__ = ["BigInt", "Counter", "Double", "SmallInt", "TinyInt", "Unset"] diff --git a/python/scyllapy/query_builder.py b/python/scyllapy/query_builder.py new file mode 100644 index 0000000..54606a9 --- /dev/null +++ b/python/scyllapy/query_builder.py @@ -0,0 +1,3 @@ +from ._internal.query_builder import Select, Delete, Insert, Update + +__all__ = ["Select", "Delete", "Insert", "Update"] diff --git a/python/scyllapy/types.py b/python/tests/query_builders/__init__.py similarity index 100% rename from python/scyllapy/types.py rename to python/tests/query_builders/__init__.py diff --git a/python/tests/query_builders/test_delete.py b/python/tests/query_builders/test_delete.py new file mode 100644 index 0000000..c532a0f --- /dev/null +++ b/python/tests/query_builders/test_delete.py @@ -0,0 +1,56 @@ +import pytest +from scyllapy import Scylla +from scyllapy.query_builder import Delete +from tests.utils import random_string + + +@pytest.mark.anyio +async def test_success(scylla: Scylla) -> None: + table_name = random_string(4) + await scylla.execute(f"CREATE TABLE {table_name} (id INT PRIMARY KEY, name TEXT)") + await scylla.execute( + f"INSERT INTO {table_name}(id, name) VALUES (?, ?)", [1, "meme"] + ) + await Delete(table_name).where("id = ?", [1]).execute(scylla) + res = await scylla.execute(f"SELECT * FROM {table_name}") + assert not res.all() + + +@pytest.mark.anyio +async def test_if_exists(scylla: Scylla) -> None: + table_name = random_string(4) + await scylla.execute(f"CREATE TABLE {table_name} (id INT PRIMARY KEY, name TEXT)") + await scylla.execute( + f"INSERT INTO {table_name}(id, name) VALUES (?, ?)", [1, "meme"] + ) + await Delete(table_name).where("id = ?", [1]).if_exists().execute(scylla) + res = await scylla.execute(f"SELECT * FROM {table_name}") + assert not res.all() + + +@pytest.mark.anyio +async def test_custom_if(scylla: Scylla) -> None: + table_name = random_string(4) + await scylla.execute(f"CREATE TABLE {table_name} (id INT PRIMARY KEY, name TEXT)") + await scylla.execute( + f"INSERT INTO {table_name}(id, name) VALUES (?, ?)", [1, "meme"] + ) + await Delete(table_name).where("id = ?", [1]).if_("name != ?", [None]).execute( + scylla + ) + res = await scylla.execute(f"SELECT * FROM {table_name}") + assert not res.all() + + +@pytest.mark.anyio +async def test_custom_custom_if(scylla: Scylla) -> None: + table_name = random_string(4) + await scylla.execute(f"CREATE TABLE {table_name} (id INT PRIMARY KEY, name TEXT)") + await scylla.execute( + f"INSERT INTO {table_name}(id, name) VALUES (?, ?)", [1, "meme"] + ) + await Delete(table_name).where("id = ?", [1]).if_("name != ?", [None]).execute( + scylla + ) + res = await scylla.execute(f"SELECT * FROM {table_name}") + assert not res.all() diff --git a/python/tests/query_builders/test_inserts.py b/python/tests/query_builders/test_inserts.py new file mode 100644 index 0000000..040d347 --- /dev/null +++ b/python/tests/query_builders/test_inserts.py @@ -0,0 +1,43 @@ +import pytest +from scyllapy import Scylla +from scyllapy.query_builder import Insert +from tests.utils import random_string + + +@pytest.mark.anyio +async def test_insert_success(scylla: Scylla) -> None: + table_name = random_string(4) + await scylla.execute(f"CREATE TABLE {table_name} (id INT PRIMARY KEY, name TEXT)") + await Insert(table_name).set("id", 1).set("name", "random").execute(scylla) + result = await scylla.execute(f"SELECT * FROM {table_name}") + assert result.all() == [{"id": 1, "name": "random"}] + + +@pytest.mark.anyio +async def test_insert_if_not_exists(scylla: Scylla) -> None: + table_name = random_string(4) + await scylla.execute(f"CREATE TABLE {table_name} (id INT PRIMARY KEY, name TEXT)") + await Insert(table_name).set("id", 1).set("name", "random").execute(scylla) + await Insert(table_name).set("id", 1).set( + "name", + "random2", + ).if_not_exists().execute(scylla) + res = await scylla.execute(f"SELECT * FROM {table_name}") + assert res.all() == [{"id": 1, "name": "random"}] + + +@pytest.mark.anyio +async def test_insert_request_params(scylla: Scylla) -> None: + table_name = random_string(4) + await scylla.execute(f"CREATE TABLE {table_name} (id INT PRIMARY KEY, name TEXT)") + await Insert(table_name).set("id", 1).set("name", "random").execute(scylla) + res = ( + await Insert(table_name) + .set("id", 1) + .set("name", "random2") + .request_params( + tracing=True, + ) + .execute(scylla) + ) + assert res.trace_id diff --git a/python/tests/query_builders/test_select.py b/python/tests/query_builders/test_select.py new file mode 100644 index 0000000..6c14cdb --- /dev/null +++ b/python/tests/query_builders/test_select.py @@ -0,0 +1,115 @@ +import uuid +import pytest +from scyllapy import Scylla +from scyllapy.query_builder import Select, Insert +from tests.utils import random_string + + +@pytest.mark.anyio +async def test_select_success(scylla: Scylla) -> None: + table_name = random_string(4) + await scylla.execute(f"CREATE TABLE {table_name} (id INT PRIMARY KEY, name TEXT)") + await scylla.execute( + f"INSERT INTO {table_name}(id, name) VALUES (?, ?)", [1, "meme"] + ) + res = await Select(table_name).execute(scylla) + assert res.all() == [{"id": 1, "name": "meme"}] + + +@pytest.mark.anyio +async def test_select_aliases(scylla: Scylla) -> None: + table_name = random_string(4) + await scylla.execute(f"CREATE TABLE {table_name} (id INT PRIMARY KEY, name TEXT)") + name = uuid.uuid4().hex + await scylla.execute(f"INSERT INTO {table_name}(id, name) VALUES (?, ?)", [1, name]) + res = await Select(table_name).only("name as testname").execute(scylla) + assert res.all() == [{"testname": name}] + + +@pytest.mark.anyio +async def test_select_simple_where(scylla: Scylla) -> None: + table_name = random_string(4) + await scylla.execute(f"CREATE TABLE {table_name} (id INT PRIMARY KEY, name TEXT)") + name = uuid.uuid4().hex + await scylla.execute( + f"INSERT INTO {table_name}(id, name) VALUES (?, ?)", [1, uuid.uuid4().hex] + ) + await scylla.execute(f"INSERT INTO {table_name}(id, name) VALUES (?, ?)", [2, name]) + + res = await Select(table_name).where("id = ?", [2]).execute(scylla) + assert res.all() == [{"id": 2, "name": name}] + + +@pytest.mark.anyio +async def test_select_multiple_filters(scylla: Scylla) -> None: + table_name = random_string(4) + await scylla.execute( + f"CREATE TABLE {table_name} (id INT, name TEXT, PRIMARY KEY (id, name))" + ) + name = uuid.uuid4().hex + await scylla.execute( + f"INSERT INTO {table_name}(id, name) VALUES (?, ?)", [1, uuid.uuid4().hex] + ) + await scylla.execute(f"INSERT INTO {table_name}(id, name) VALUES (?, ?)", [2, name]) + + res = ( + await Select(table_name) + .where("id = ?", [2]) + .where("name = ?", [name]) + .execute(scylla) + ) + assert res.all() == [{"id": 2, "name": name}] + + +@pytest.mark.anyio +async def test_allow_filtering(scylla: Scylla) -> None: + table_name = random_string(4) + await scylla.execute(f"CREATE TABLE {table_name} (id INT PRIMARY KEY, name TEXT)") + name = uuid.uuid4().hex + await scylla.execute( + f"INSERT INTO {table_name}(id, name) VALUES (?, ?)", [1, uuid.uuid4().hex] + ) + await scylla.execute(f"INSERT INTO {table_name}(id, name) VALUES (?, ?)", [2, name]) + + res = ( + await Select(table_name) + .where("id = ?", [2]) + .where("name = ?", [name]) + .allow_filtering() + .execute(scylla) + ) + assert res.all() == [{"id": 2, "name": name}] + + +@pytest.mark.anyio +async def test_limit(scylla: Scylla) -> None: + table_name = random_string(4) + await scylla.execute(f"CREATE TABLE {table_name} (id INT PRIMARY KEY, name TEXT)") + for i in range(10): + await scylla.execute( + f"INSERT INTO {table_name}(id, name) VALUES (?, ?)", [i, uuid.uuid4().hex] + ) + res = await Select(table_name).limit(3).execute(scylla) + assert len(res.all()) == 3 + + +@pytest.mark.anyio +async def test_order_by(scylla: Scylla) -> None: + table_name = random_string(4) + await scylla.execute( + f"CREATE TABLE {table_name} (id INT, iid INT, PRIMARY KEY(id, iid))" + ) + for i in range(10): + await scylla.execute( + f"INSERT INTO {table_name}(id, iid) VALUES (?, ?)", + [0, i], + ) + res = ( + await Select(table_name) + .only("iid") + .where("id = ?", [0]) + .order_by("iid") + .execute(scylla) + ) + ids = [row["iid"] for row in res.all()] + assert ids == list(range(10)) diff --git a/python/tests/query_builders/test_update.py b/python/tests/query_builders/test_update.py new file mode 100644 index 0000000..4e161a0 --- /dev/null +++ b/python/tests/query_builders/test_update.py @@ -0,0 +1,41 @@ +import pytest +from scyllapy import Scylla +from scyllapy.query_builder import Update +from tests.utils import random_string + + +@pytest.mark.anyio +async def test_success(scylla: Scylla) -> None: + table_name = random_string(4) + await scylla.execute(f"CREATE TABLE {table_name} (id INT PRIMARY KEY, name TEXT)") + await scylla.execute( + f"INSERT INTO {table_name}(id, name) VALUES (?, ?)", [1, "meme"] + ) + await Update(table_name).set("name", "meme2").where("id = ?", [1]).execute(scylla) + res = await scylla.execute(f"SELECT * FROM {table_name}") + assert res.all() == [{"id": 1, "name": "meme2"}] + + +@pytest.mark.anyio +async def test_ifs(scylla: Scylla) -> None: + table_name = random_string(4) + await scylla.execute(f"CREATE TABLE {table_name} (id INT PRIMARY KEY, name TEXT)") + await scylla.execute( + f"INSERT INTO {table_name}(id, name) VALUES (?, ?)", [1, "meme"] + ) + await Update(table_name).set("name", "meme2").if_("name = ?", ["meme"]).where( + "id = ?", [1] + ).execute(scylla) + res = await scylla.execute(f"SELECT * FROM {table_name}") + assert res.all() == [{"id": 1, "name": "meme2"}] + + +@pytest.mark.anyio +async def test_if_exists(scylla: Scylla) -> None: + table_name = random_string(4) + await scylla.execute(f"CREATE TABLE {table_name} (id INT PRIMARY KEY, name TEXT)") + await Update(table_name).set("name", "meme2").if_exists().where( + "id = ?", [1] + ).execute(scylla) + res = await scylla.execute(f"SELECT * FROM {table_name}") + assert res.all() == [] diff --git a/src/batches.rs b/src/batches.rs index 78cfc35..1187635 100644 --- a/src/batches.rs +++ b/src/batches.rs @@ -1,10 +1,10 @@ -use pyo3::{pyclass, pymethods}; -use scylla::batch::{Batch, BatchType}; +use pyo3::{pyclass, pymethods, types::PyDict, PyAny}; +use scylla::batch::{Batch, BatchStatement, BatchType}; use crate::{ - consistencies::{ScyllaPyConsistency, ScyllaPySerialConsistency}, - inputs::BatchQueryInput, + inputs::BatchQueryInput, queries::ScyllaPyRequestParams, utils::parse_python_query_params, }; +use scylla::frame::value::SerializedValues; #[pyclass(name = "BatchType")] #[derive(Clone, Copy, Debug, PartialEq, Eq)] @@ -18,58 +18,50 @@ pub enum ScyllaPyBatchType { #[derive(Clone)] pub struct ScyllaPyBatch { inner: Batch, - #[pyo3(get)] - pub consistency: Option, - #[pyo3(get)] - pub serial_consistency: Option, - #[pyo3(get)] - pub request_timeout: Option, - #[pyo3(get)] - pub timestamp: Option, - #[pyo3(get)] - pub is_idempotent: Option, - #[pyo3(get)] - pub tracing: Option, + request_params: ScyllaPyRequestParams, +} + +#[pyclass(name = "InlineBatch")] +#[derive(Clone)] +pub struct ScyllaPyInlineBatch { + inner: Batch, + request_params: ScyllaPyRequestParams, + values: Vec, } impl From for Batch { fn from(value: ScyllaPyBatch) -> Self { - value.inner + let mut inner = value.inner; + value.request_params.apply_to_batch(&mut inner); + inner + } +} + +impl From for (Batch, Vec) { + fn from(mut value: ScyllaPyInlineBatch) -> Self { + value.request_params.apply_to_batch(&mut value.inner); + (value.inner, value.values) } } #[pymethods] impl ScyllaPyBatch { + /// Create new batch. + /// + /// # Errors + /// + /// Can return an error in case if + /// wrong type for parameters were passed. #[new] #[pyo3(signature = ( batch_type = ScyllaPyBatchType::UNLOGGED, - consistency = None, - serial_consistency = None, - request_timeout = None, - timestamp = None, - is_idempotent = None, - tracing = None, + **params ))] - #[allow(clippy::too_many_arguments)] - #[must_use] - pub fn py_new( - batch_type: ScyllaPyBatchType, - consistency: Option, - serial_consistency: Option, - request_timeout: Option, - timestamp: Option, - is_idempotent: Option, - tracing: Option, - ) -> Self { - Self { + pub fn py_new(batch_type: ScyllaPyBatchType, params: Option<&PyDict>) -> anyhow::Result { + Ok(Self { inner: Batch::new(batch_type.into()), - consistency, - serial_consistency, - request_timeout, - timestamp, - is_idempotent, - tracing, - } + request_params: ScyllaPyRequestParams::from_dict(params)?, + }) } pub fn add_query(&mut self, query: BatchQueryInput) { @@ -77,6 +69,65 @@ impl ScyllaPyBatch { } } +impl ScyllaPyInlineBatch { + pub fn add_query_inner( + &mut self, + query: impl Into, + values: impl Into, + ) { + self.inner.append_statement(query); + self.values.push(values.into()); + } +} + +#[pymethods] +impl ScyllaPyInlineBatch { + /// Create new batch. + /// + /// # Errors + /// + /// Can return an error in case if + /// wrong type for parameters were passed. + #[new] + #[pyo3(signature = ( + batch_type = ScyllaPyBatchType::UNLOGGED, + **params + ))] + pub fn py_new(batch_type: ScyllaPyBatchType, params: Option<&PyDict>) -> anyhow::Result { + Ok(Self { + inner: Batch::new(batch_type.into()), + request_params: ScyllaPyRequestParams::from_dict(params)?, + values: vec![], + }) + } + + /// Add query to batch. + /// + /// This function appends query to batch. + /// along with values, so you don't need to + /// pass values in execute. + /// + /// # Errors + /// + /// Will result in an error, if + /// values are incorrect. + #[pyo3(signature = (query, values = None))] + pub fn add_query( + &mut self, + query: BatchQueryInput, + values: Option<&PyAny>, + ) -> anyhow::Result<()> { + self.inner.append_statement(query); + if let Some(passed_params) = values { + self.values + .push(parse_python_query_params(Some(passed_params), false)?); + } else { + self.values.push(SerializedValues::new()); + } + Ok(()) + } +} + impl From for BatchType { fn from(value: ScyllaPyBatchType) -> Self { match value { diff --git a/src/extra_types.rs b/src/extra_types.rs index 03a7cee..4bc0990 100644 --- a/src/extra_types.rs +++ b/src/extra_types.rs @@ -56,13 +56,12 @@ impl ScyllaPyUnset { /// /// May return error if module cannot be created, /// or any of classes cannot be added. -pub fn add_module<'a>(py: Python<'a>, name: &'static str) -> PyResult<&'a PyModule> { - let module = PyModule::new(py, name)?; +pub fn module_constructor(_py: Python<'_>, module: &PyModule) -> PyResult<()> { module.add_class::()?; module.add_class::()?; module.add_class::()?; module.add_class::()?; module.add_class::()?; module.add_class::()?; - Ok(module) + Ok(()) } diff --git a/src/inputs.rs b/src/inputs.rs index 0906959..c6dd271 100644 --- a/src/inputs.rs +++ b/src/inputs.rs @@ -1,6 +1,10 @@ use pyo3::FromPyObject; -use crate::{prepared_queries::ScyllaPyPreparedQuery, queries::ScyllaPyQuery}; +use crate::{ + batches::{ScyllaPyBatch, ScyllaPyInlineBatch}, + prepared_queries::ScyllaPyPreparedQuery, + queries::ScyllaPyQuery, +}; use scylla::{batch::BatchStatement, query::Query}; #[derive(Clone, FromPyObject)] @@ -49,3 +53,11 @@ impl From for Query { } } } + +#[derive(Clone, FromPyObject)] +pub enum BatchInput { + #[pyo3(transparent, annotation = "Batch")] + Batch(ScyllaPyBatch), + #[pyo3(transparent, annotation = "InlineBatch")] + InlineBatch(ScyllaPyInlineBatch), +} diff --git a/src/lib.rs b/src/lib.rs index 2d43aa2..7e733fb 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -4,12 +4,15 @@ pub mod extra_types; pub mod inputs; pub mod prepared_queries; pub mod queries; +pub mod query_builder; pub mod query_results; pub mod scylla_cls; pub mod utils; use pyo3::{pymodule, types::PyModule, PyResult, Python}; +use crate::utils::add_submodule; + #[pymodule] #[pyo3(name = "_internal")] fn _internal(py: Python<'_>, pymod: &PyModule) -> PyResult<()> { @@ -21,7 +24,9 @@ fn _internal(py: Python<'_>, pymod: &PyModule) -> PyResult<()> { pymod.add_class::()?; pymod.add_class::()?; pymod.add_class::()?; + pymod.add_class::()?; pymod.add_class::()?; - pymod.add_submodule(extra_types::add_module(py, "extra_types")?)?; + add_submodule(py, pymod, "extra_types", extra_types::module_constructor)?; + add_submodule(py, pymod, "query_builder", query_builder::add_module)?; Ok(()) } diff --git a/src/queries.rs b/src/queries.rs index b9a2902..96e1837 100644 --- a/src/queries.rs +++ b/src/queries.rs @@ -1,38 +1,105 @@ use std::time::Duration; use crate::consistencies::{ScyllaPyConsistency, ScyllaPySerialConsistency}; -use pyo3::{pyclass, pymethods, Python}; -use scylla::statement::query::Query; +use pyo3::{pyclass, pymethods, types::PyDict, FromPyObject, Python}; +use scylla::{batch::Batch, statement::query::Query}; -#[pyclass(name = "Query")] -#[derive(Clone, Debug)] -pub struct ScyllaPyQuery { - #[pyo3(get)] - pub query: String, - #[pyo3(get)] +#[derive(Clone, Debug, Default, FromPyObject)] +pub struct ScyllaPyRequestParams { pub consistency: Option, - #[pyo3(get)] pub serial_consistency: Option, - #[pyo3(get)] pub request_timeout: Option, - #[pyo3(get)] pub timestamp: Option, - #[pyo3(get)] pub is_idempotent: Option, - #[pyo3(get)] pub tracing: Option, } +impl ScyllaPyRequestParams { + /// Apply parameters to scylla's query. + pub fn apply_to_query(&self, query: &mut Query) { + if let Some(consistency) = self.consistency { + query.set_consistency(consistency.into()); + } + if let Some(is_idempotent) = self.is_idempotent { + query.set_is_idempotent(is_idempotent); + } + if let Some(tracing) = self.tracing { + query.set_tracing(tracing); + } + query.set_timestamp(self.timestamp); + query.set_request_timeout(self.request_timeout.map(Duration::from_secs)); + query.set_serial_consistency(self.serial_consistency.map(Into::into)); + } + + pub fn apply_to_batch(&self, batch: &mut Batch) { + if let Some(consistency) = self.consistency { + batch.set_consistency(consistency.into()); + } + if let Some(is_idempotent) = self.is_idempotent { + batch.set_is_idempotent(is_idempotent); + } + if let Some(tracing) = self.tracing { + batch.set_tracing(tracing); + } + batch.set_timestamp(self.timestamp); + batch.set_serial_consistency(self.serial_consistency.map(Into::into)); + } + + /// Parse dict to query parameters. + /// + /// This function takes dict and + /// tries to construct `ScyllaPyRequestParams`. + /// + /// # Errors + /// + /// May result in an error if + /// incorrect type passed. + pub fn from_dict(params: Option<&PyDict>) -> anyhow::Result { + let Some(params) = params else { + return Ok(Self::default()); + }; + Ok(Self { + consistency: params + .get_item("consistency") + .map(pyo3::FromPyObject::extract) + .transpose()?, + serial_consistency: params + .get_item("serial_consistency") + .map(pyo3::FromPyObject::extract) + .transpose()?, + request_timeout: params + .get_item("request_timeout") + .map(pyo3::FromPyObject::extract) + .transpose()?, + timestamp: params + .get_item("timestamp") + .map(pyo3::FromPyObject::extract) + .transpose()?, + is_idempotent: params + .get_item("is_idempotent") + .map(pyo3::FromPyObject::extract) + .transpose()?, + tracing: params + .get_item("tracing") + .map(pyo3::FromPyObject::extract) + .transpose()?, + }) + } +} + +#[pyclass(name = "Query")] +#[derive(Clone, Debug)] +pub struct ScyllaPyQuery { + #[pyo3(get)] + pub query: String, + pub params: ScyllaPyRequestParams, +} + impl From<&ScyllaPyQuery> for ScyllaPyQuery { fn from(value: &ScyllaPyQuery) -> Self { ScyllaPyQuery { query: value.query.clone(), - consistency: value.consistency, - serial_consistency: value.serial_consistency, - request_timeout: value.request_timeout, - timestamp: value.timestamp, - is_idempotent: value.is_idempotent, - tracing: value.tracing, + params: ScyllaPyRequestParams::default(), } } } @@ -63,12 +130,14 @@ impl ScyllaPyQuery { ) -> Self { Self { query, - consistency, - serial_consistency, - request_timeout, - timestamp, - is_idempotent, - tracing, + params: ScyllaPyRequestParams { + consistency, + serial_consistency, + request_timeout, + timestamp, + is_idempotent, + tracing, + }, } } @@ -80,7 +149,7 @@ impl ScyllaPyQuery { #[must_use] pub fn with_consistency(&self, consistency: Option) -> Self { let mut query = Self::from(self); - query.consistency = consistency; + query.params.consistency = consistency; query } @@ -90,35 +159,35 @@ impl ScyllaPyQuery { serial_consistency: Option, ) -> Self { let mut query = Self::from(self); - query.serial_consistency = serial_consistency; + query.params.serial_consistency = serial_consistency; query } #[must_use] pub fn with_request_timeout(&self, request_timeout: Option) -> Self { let mut query = Self::from(self); - query.request_timeout = request_timeout; + query.params.request_timeout = request_timeout; query } #[must_use] pub fn with_timestamp(&self, timestamp: Option) -> Self { let mut query = Self::from(self); - query.timestamp = timestamp; + query.params.timestamp = timestamp; query } #[must_use] pub fn with_is_idempotent(&self, is_idempotent: Option) -> Self { let mut query = Self::from(self); - query.is_idempotent = is_idempotent; + query.params.is_idempotent = is_idempotent; query } #[must_use] pub fn with_tracing(&self, tracing: Option) -> Self { let mut query = Self::from(self); - query.tracing = tracing; + query.params.tracing = tracing; query } } @@ -126,18 +195,7 @@ impl ScyllaPyQuery { impl From for Query { fn from(value: ScyllaPyQuery) -> Self { let mut query = Self::new(value.query); - if let Some(consistency) = value.consistency { - query.set_consistency(consistency.into()); - } - if let Some(is_idempotent) = value.is_idempotent { - query.set_is_idempotent(is_idempotent); - } - if let Some(tracing) = value.tracing { - query.set_tracing(tracing); - } - query.set_timestamp(value.timestamp); - query.set_request_timeout(value.request_timeout.map(Duration::from_secs)); - query.set_serial_consistency(value.serial_consistency.map(Into::into)); + value.params.apply_to_query(&mut query); query } } diff --git a/src/query_builder/delete.rs b/src/query_builder/delete.rs new file mode 100644 index 0000000..f420e3e --- /dev/null +++ b/src/query_builder/delete.rs @@ -0,0 +1,255 @@ +use pyo3::{pyclass, pymethods, types::PyDict, PyAny, PyRefMut, Python}; +use scylla::query::Query; + +use super::utils::{pretty_build, IfCluase, Timeout}; +use crate::{ + batches::ScyllaPyInlineBatch, + queries::ScyllaPyRequestParams, + scylla_cls::Scylla, + utils::{py_to_value, ScyllaPyCQLDTO}, +}; +use scylla::frame::value::SerializedValues; + +#[pyclass] +#[derive(Clone, Debug, Default)] +pub struct Delete { + table_: String, + columns: Option>, + timeout_: Option, + timestamp_: Option, + if_clause_: Option, + where_clauses_: Vec, + values_: Vec, + request_params_: ScyllaPyRequestParams, +} + +impl Delete { + fn build_query(&self) -> anyhow::Result { + if self.where_clauses_.is_empty() { + return Err(anyhow::anyhow!( + "At least one where clause should be specified." + )); + } + let columns = self + .columns + .as_ref() + .map_or(String::new(), |cols| cols.join(", ")); + let params = vec![ + self.timestamp_ + .map(|timestamp| format!("TIMESTAMP {timestamp}")), + self.timeout_.as_ref().map(|timeout| match timeout { + Timeout::Int(int) => format!("TIMEOUT {int}"), + Timeout::Str(string) => format!("TIMEOUT {string}"), + }), + ]; + let prepared_params = params + .iter() + .map(|item| item.as_ref().map_or("", String::as_str)) + .filter(|item| !item.is_empty()) + .collect::>(); + let usings = if prepared_params.is_empty() { + String::new() + } else { + format!("USING {}", prepared_params.join(" AND ")) + }; + let where_clause = format!("WHERE {}", self.where_clauses_.join(" AND ")); + let if_conditions = self + .if_clause_ + .as_ref() + .map_or(String::default(), |cond| match cond { + IfCluase::Exists => String::from("IF EXISTS"), + IfCluase::Condition { clauses, values: _ } => { + format!("IF {}", clauses.join(" AND ")) + } + }); + Ok(pretty_build([ + "DELETE", + columns.as_str(), + "FROM", + self.table_.as_str(), + usings.as_str(), + where_clause.as_str(), + if_conditions.as_str(), + ])) + } +} + +#[pymethods] +impl Delete { + #[new] + #[must_use] + pub fn py_new(table: String) -> Self { + Self { + table_: table, + ..Default::default() + } + } + + #[must_use] + #[pyo3(signature = (*cols))] + pub fn cols(mut slf: PyRefMut<'_, Self>, cols: Vec) -> PyRefMut<'_, Self> { + slf.columns = Some(cols); + slf + } + + /// Add where clause. + /// + /// This function adds where with values. + /// + /// # Errors + /// + /// Can return an error, if values + /// cannot be parsed. + #[pyo3(signature = (clause, values = None))] + pub fn r#where<'a>( + mut slf: PyRefMut<'a, Self>, + clause: String, + values: Option>, + ) -> anyhow::Result> { + slf.where_clauses_.push(clause); + if let Some(vals) = values { + for value in vals { + slf.values_.push(py_to_value(value)?); + } + } + Ok(slf) + } + + #[must_use] + pub fn timeout(mut slf: PyRefMut<'_, Self>, timeout: Timeout) -> PyRefMut<'_, Self> { + slf.timeout_ = Some(timeout); + slf + } + + #[must_use] + pub fn timestamp(mut slf: PyRefMut<'_, Self>, timestamp: u64) -> PyRefMut<'_, Self> { + slf.timestamp_ = Some(timestamp); + slf + } + + #[must_use] + pub fn if_exists(mut slf: PyRefMut<'_, Self>) -> PyRefMut<'_, Self> { + slf.if_clause_ = Some(IfCluase::Exists); + slf + } + + /// Add if clause. + /// + /// # Errors + /// + /// May return an error, if values + /// cannot be converted to rust types. + #[pyo3(signature = (clause, values = None))] + pub fn if_<'a>( + mut slf: PyRefMut<'a, Self>, + clause: String, + values: Option>, + ) -> anyhow::Result> { + let parsed_values = if let Some(vals) = values { + vals.iter() + .map(|item| py_to_value(item)) + .collect::, _>>()? + } else { + vec![] + }; + match slf.if_clause_.as_mut() { + Some(IfCluase::Condition { clauses, values }) => { + clauses.push(clause); + values.extend(parsed_values); + } + None | Some(IfCluase::Exists) => { + slf.if_clause_ = Some(IfCluase::Condition { + clauses: vec![clause], + values: parsed_values, + }); + } + } + Ok(slf) + } + + /// Add parameters to the request. + /// + /// These parameters are used by scylla. + /// + /// # Errors + /// + /// May return an error, if request parameters + /// cannot be built. + #[pyo3(signature = (**params))] + pub fn request_params<'a>( + mut slf: PyRefMut<'a, Self>, + params: Option<&'a PyDict>, + ) -> anyhow::Result> { + slf.request_params_ = ScyllaPyRequestParams::from_dict(params)?; + Ok(slf) + } + + /// Execute a query. + /// + /// # Errors + /// + /// May return an error, if something goes wrong + /// during query building + /// or during query execution. + pub fn execute<'a>(&'a self, py: Python<'a>, scylla: &'a Scylla) -> anyhow::Result<&'a PyAny> { + let mut query = Query::new(self.build_query()?); + self.request_params_.apply_to_query(&mut query); + + let values = if let Some(if_clause) = &self.if_clause_ { + if_clause.extend_values(self.values_.clone()) + } else { + self.values_.clone() + }; + scylla.native_execute(py, query, values) + } + + /// Add to batch + /// + /// Adds current query to batch. + /// + /// # Error + /// + /// May result into error if query cannot be build. + /// Or values cannot be passed to batch. + pub fn add_to_batch(&self, batch: &mut ScyllaPyInlineBatch) -> anyhow::Result<()> { + let mut query = Query::new(self.build_query()?); + self.request_params_.apply_to_query(&mut query); + + let values = if let Some(if_clause) = &self.if_clause_ { + if_clause.extend_values(self.values_.clone()) + } else { + self.values_.clone() + }; + let mut serialized = SerializedValues::new(); + for val in values { + serialized.add_value(&val)?; + } + batch.add_query_inner(query, serialized); + Ok(()) + } + + #[must_use] + pub fn __repr__(&self) -> String { + format!("{self:?}") + } + + /// Convert query to string. + /// + /// # Errors + /// + /// May return an error if something + /// goes wrong during query building. + pub fn __str__(&self) -> anyhow::Result { + self.build_query() + } + + #[must_use] + pub fn __copy__(&self) -> Self { + self.clone() + } + + #[must_use] + pub fn __deepcopy__(&self, _memo: &PyDict) -> Self { + self.clone() + } +} diff --git a/src/query_builder/insert.rs b/src/query_builder/insert.rs new file mode 100644 index 0000000..d382d99 --- /dev/null +++ b/src/query_builder/insert.rs @@ -0,0 +1,210 @@ +use pyo3::{pyclass, pymethods, types::PyDict, PyAny, PyRefMut, Python}; +use scylla::query::Query; + +use crate::{ + batches::ScyllaPyInlineBatch, + queries::ScyllaPyRequestParams, + scylla_cls::Scylla, + utils::{py_to_value, ScyllaPyCQLDTO}, +}; +use scylla::frame::value::SerializedValues; + +use super::utils::{pretty_build, Timeout}; + +#[pyclass] +#[derive(Clone, Debug, Default)] +pub struct Insert { + table_: String, + if_not_exists_: bool, + names_: Vec, + values_: Vec, + + timeout_: Option, + ttl_: Option, + timestamp_: Option, + + request_params_: ScyllaPyRequestParams, +} + +impl Insert { + /// Build a statement. + /// + /// # Errors + /// If no values was set. + pub fn build_query(&self) -> anyhow::Result { + if self.names_.is_empty() { + return Err(anyhow::anyhow!("Please use at least one set method.")); + } + let names = self.names_.join(","); + let values = self + .names_ + .iter() + .map(|_| "?") + .collect::>() + .join(","); + let names_values = format!("({names}) VALUES ({values})"); + let ifnexist = if self.if_not_exists_ { + "IF NOT EXISTS" + } else { + "" + }; + let params = vec![ + self.timestamp_ + .map(|timestamp| format!("TIMESTAMP {timestamp}")), + self.ttl_.map(|ttl| format!("TTL {ttl}")), + self.timeout_.as_ref().map(|timeout| match timeout { + Timeout::Int(int) => format!("TIMEOUT {int}"), + Timeout::Str(string) => format!("TIMEOUT {string}"), + }), + ]; + let prepared_params = params + .iter() + .map(|item| item.as_ref().map_or("", String::as_str)) + .filter(|item| !item.is_empty()) + .collect::>(); + let usings = if prepared_params.is_empty() { + String::new() + } else { + format!("USING {}", prepared_params.join(" AND ")) + }; + + Ok(pretty_build([ + "INSERT INTO", + self.table_.as_str(), + names_values.as_str(), + ifnexist, + usings.as_str(), + ])) + } +} + +#[pymethods] +impl Insert { + #[new] + #[must_use] + pub fn py_new(table: String) -> Self { + Self { + table_: table, + ..Default::default() + } + } + + #[must_use] + pub fn if_not_exists(mut slf: PyRefMut<'_, Self>) -> PyRefMut<'_, Self> { + slf.if_not_exists_ = true; + slf + } + + /// Set value to column. + /// + /// # Errors + /// + /// If value cannot be translated + /// into `Rust` type. + pub fn set<'a>( + mut slf: PyRefMut<'a, Self>, + name: String, + value: &'a PyAny, + ) -> anyhow::Result> { + slf.names_.push(name); + // Small optimization to speedup inserts. + if value.is_none() { + slf.values_.push(ScyllaPyCQLDTO::Unset); + } else { + slf.values_.push(py_to_value(value)?); + } + Ok(slf) + } + + #[must_use] + pub fn timeout(mut slf: PyRefMut<'_, Self>, timeout: Timeout) -> PyRefMut<'_, Self> { + slf.timeout_ = Some(timeout); + slf + } + + #[must_use] + pub fn timestamp(mut slf: PyRefMut<'_, Self>, timestamp: u64) -> PyRefMut<'_, Self> { + slf.timestamp_ = Some(timestamp); + slf + } + + #[must_use] + pub fn ttl(mut slf: PyRefMut<'_, Self>, ttl: i32) -> PyRefMut<'_, Self> { + slf.ttl_ = Some(ttl); + slf + } + + /// Add parameters to the request. + /// + /// These parameters are used by scylla. + /// + /// # Errors + /// + /// May return an error, if request parameters + /// cannot be built. + #[pyo3(signature = (**params))] + pub fn request_params<'a>( + mut slf: PyRefMut<'a, Self>, + params: Option<&'a PyDict>, + ) -> anyhow::Result> { + slf.request_params_ = ScyllaPyRequestParams::from_dict(params)?; + Ok(slf) + } + + /// Execute a query. + /// + /// This function is used to execute built query. + /// + /// # Errors + /// + /// If query cannot be built. + /// Also proxies errors from `native_execute`. + pub fn execute<'a>(&'a self, py: Python<'a>, scylla: &'a Scylla) -> anyhow::Result<&'a PyAny> { + let mut query = Query::new(self.build_query()?); + self.request_params_.apply_to_query(&mut query); + scylla.native_execute(py, query, self.values_.clone()) + } + + /// Add to batch + /// + /// Adds current query to batch. + /// + /// # Error + /// + /// May result into error if query cannot be build. + /// Or values cannot be passed to batch. + pub fn add_to_batch(&self, batch: &mut ScyllaPyInlineBatch) -> anyhow::Result<()> { + let mut query = Query::new(self.build_query()?); + self.request_params_.apply_to_query(&mut query); + + let mut serialized = SerializedValues::new(); + for val in self.values_.clone() { + serialized.add_value(&val)?; + } + batch.add_query_inner(query, serialized); + Ok(()) + } + + #[must_use] + pub fn __repr__(&self) -> String { + format!("{self:?}") + } + + /// Returns string part of a query. + /// + /// # Errors + /// If cannot construct query. + pub fn __str__(&self) -> anyhow::Result { + self.build_query() + } + + #[must_use] + pub fn __copy__(&self) -> Self { + self.clone() + } + + #[must_use] + pub fn __deepcopy__(&self, _memo: &PyDict) -> Self { + self.clone() + } +} diff --git a/src/query_builder/mod.rs b/src/query_builder/mod.rs new file mode 100644 index 0000000..0a6493b --- /dev/null +++ b/src/query_builder/mod.rs @@ -0,0 +1,26 @@ +use pyo3::{types::PyModule, PyResult, Python}; + +use self::{delete::Delete, insert::Insert, select::Select, update::Update}; + +pub mod delete; +pub mod insert; +pub mod select; +pub mod update; +mod utils; + +/// Create `QueryBuilder` module. +/// +/// This function creates a module with a +/// given name and adds classes to it. +/// +/// # Errors +/// +/// * Cannot create module by any reason. +/// * Cannot add class by some reason. +pub fn add_module(_py: Python<'_>, module: &PyModule) -> PyResult<()> { + module.add_class::