Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: generate m2m connects and disconnects in the compiler #5153

Merged
merged 9 commits into from
Feb 7, 2025
37 changes: 36 additions & 1 deletion quaint/src/ast/expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,8 @@ impl<'a> Comparable<'a> for Expression<'a> {
where
T: Into<Expression<'a>>,
{
Compare::In(Box::new(self), Box::new(selection.into()))
let expr = extract_single_var_row(selection.into());
Compare::In(Box::new(self), Box::new(expr))
}

fn not_in_selection<T>(self, selection: T) -> Compare<'a>
Expand Down Expand Up @@ -521,3 +522,37 @@ impl<'a> Comparable<'a> for Expression<'a> {
Compare::All(Box::new(self))
}
}

/// Converts a row consisting of a single var into the var itself.
/// Any other expression is returned as is.
fn extract_single_var_row(expr: Expression) -> Expression {
let Expression {
kind: ExpressionKind::Row(values),
..
} = &expr
else {
return expr;
};

let Some((
val @ Expression {
kind:
ExpressionKind::Parameterized(Value {
typed: ValueType::Var(_, _),
..
}),
..
},
[],
)) = values.values.split_first()
else {
return expr;
};

val.clone()
.decorate(
Some("prisma-comma-repeatable-start"),
Some("prisma-comma-repeatable-end"),
)
.into()
}
13 changes: 13 additions & 0 deletions quaint/src/ast/row.rs
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,19 @@ where
}
}

impl<'a, A> FromIterator<A> for Row<'a>
where
A: Into<Expression<'a>>,
{
fn from_iter<T>(iter: T) -> Self
where
T: IntoIterator<Item = A>,
{
let inner = iter.into_iter().map(Into::into).collect::<Vec<_>>();
Self { values: inner }
}
}

impl<'a> Comparable<'a> for Row<'a> {
fn equals<T>(self, comparison: T) -> Compare<'a>
where
Expand Down
14 changes: 4 additions & 10 deletions query-compiler/query-compiler/src/translate/query/read.rs
Original file line number Diff line number Diff line change
Expand Up @@ -207,23 +207,17 @@ fn build_read_one2m_query(
selected_fields: &FieldSelection,
builder: &dyn QueryBuilder,
) -> TranslateResult<(Expression, JoinFields)> {
let join_fields = field
.related_field()
.left_scalars()
.into_iter()
.map(|sf| sf.name().to_owned())
.collect();
let related_scalars = field.related_field().left_scalars();
let join_fields = related_scalars.iter().map(|sf| sf.name().to_owned()).collect();

// TODO: we ignore chunking for now
let linking_scalars = field.related_field().left_scalars();

if let Some(conditions) = conditions {
assert_eq!(
linking_scalars.len(),
related_scalars.len(),
conditions.len(),
"linking fields should match conditions"
);
for (condition, child_field) in conditions.into_iter().zip(linking_scalars) {
for (condition, child_field) in conditions.into_iter().zip(related_scalars) {
args.add_filter(Filter::Scalar(ScalarFilter {
condition,
projection: ScalarProjection::Single(child_field.clone()),
Expand Down
35 changes: 2 additions & 33 deletions query-engine/query-builders/sql-query-builder/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,8 @@ pub mod write;
use std::{collections::HashMap, marker::PhantomData};

use quaint::{
ast::{Column, Comparable, ConditionTree, Conjunctive, Decoratable, Delete, Query, Row, Values},
ast::{Column, Comparable, ConditionTree, Query, Row, Values},
visitor::Visitor,
Value, ValueType,
};
use query_builder::{DbQuery, QueryBuilder};
use query_structure::{
Expand Down Expand Up @@ -219,37 +218,7 @@ impl<'a, V: Visitor<'a>> QueryBuilder for SqlQueryBuilder<'a, V> {
parent_id: &SelectionResult,
child_ids: &[SelectionResult],
) -> Result<DbQuery, Box<dyn std::error::Error + Send + Sync>> {
let relation = field.relation();

let parent_column = field.related_field().m2m_column(&self.context);
let child_column = field.m2m_column(&self.context);

let parent_id_values = parent_id.db_values(&self.context);
let parent_id_criteria = parent_column.equals(parent_id_values);

let values = child_ids
.iter()
.map(|id| id.db_values(&self.context))
.collect::<Vec<_>>();

let child_id_criteria = match values.split_first().map(|(fst, rem)| (&fst[..], rem)) {
Some((
[val @ Value {
typed: ValueType::Var(_, _),
..
}],
[],
)) => child_column.in_selection(val.clone().decorate(
Some("prisma-comma-repeatable-start"),
Some("prisma-comma-repeatable-end"),
)),
_ => child_column.in_selection(values),
};

let query = Delete::from_table(relation.as_table(&self.context))
.so_that(parent_id_criteria.and(child_id_criteria))
.add_traceparent(self.context.traceparent);

let query = write::delete_relation_table_records(&field, parent_id, child_ids, &self.context);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice!

self.convert_query(query)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,8 @@ pub fn delete_relation_table_records(
let parent_id_values = parent_id.db_values(ctx);
let parent_id_criteria = parent_column.equals(parent_id_values);

let child_id_criteria = super::in_conditions(&[child_column], child_ids, ctx);
let child_ids = child_ids.iter().flat_map(|id| id.db_values(ctx)).collect::<Row>();
let child_id_criteria = child_column.in_selection(child_ids);

Delete::from_table(relation.as_table(ctx))
.so_that(parent_id_criteria.and(child_id_criteria))
Expand Down
Loading