Skip to content

Commit

Permalink
feat: generate m2m connects and disconnects in the compiler (#5153)
Browse files Browse the repository at this point in the history
* fix: correct projected dependency generation

* feat: generate m2m connects in the compiler

* feat: generate m2m disconnects in the compiler

* fix: fix disconnect and add a test case

* test: connect/disconnect wrong way around

* test: remove connect test since it needs parameterised multi-value insert

* chore: extract var in in_selection

* chore: avoid calling left_scalars twice

* doc: clarify comment
  • Loading branch information
jacek-prisma authored Feb 7, 2025
1 parent 744af5d commit 7eac4c9
Show file tree
Hide file tree
Showing 17 changed files with 398 additions and 137 deletions.
37 changes: 36 additions & 1 deletion quaint/src/ast/expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,8 @@ impl<'a> Comparable<'a> for Expression<'a> {
where
T: Into<Expression<'a>>,
{
Compare::In(Box::new(self), Box::new(selection.into()))
let expr = extract_single_var_row(selection.into());
Compare::In(Box::new(self), Box::new(expr))
}

fn not_in_selection<T>(self, selection: T) -> Compare<'a>
Expand Down Expand Up @@ -521,3 +522,37 @@ impl<'a> Comparable<'a> for Expression<'a> {
Compare::All(Box::new(self))
}
}

/// Converts a row consisting of a single var into the var itself.
/// Any other expression is returned as is.
fn extract_single_var_row(expr: Expression) -> Expression {
let Expression {
kind: ExpressionKind::Row(values),
..
} = &expr
else {
return expr;
};

let Some((
val @ Expression {
kind:
ExpressionKind::Parameterized(Value {
typed: ValueType::Var(_, _),
..
}),
..
},
[],
)) = values.values.split_first()
else {
return expr;
};

val.clone()
.decorate(
Some("prisma-comma-repeatable-start"),
Some("prisma-comma-repeatable-end"),
)
.into()
}
13 changes: 13 additions & 0 deletions quaint/src/ast/row.rs
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,19 @@ where
}
}

impl<'a, A> FromIterator<A> for Row<'a>
where
A: Into<Expression<'a>>,
{
fn from_iter<T>(iter: T) -> Self
where
T: IntoIterator<Item = A>,
{
let inner = iter.into_iter().map(Into::into).collect::<Vec<_>>();
Self { values: inner }
}
}

impl<'a> Comparable<'a> for Row<'a> {
fn equals<T>(self, comparison: T) -> Compare<'a>
where
Expand Down
68 changes: 41 additions & 27 deletions query-compiler/query-compiler/src/translate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use itertools::Itertools;
use query::translate_query;
use query_builder::QueryBuilder;
use query_core::{EdgeRef, Node, NodeRef, Query, QueryGraph, QueryGraphBuilderError, QueryGraphDependency};
use query_structure::{PlaceholderType, PrismaValue, SelectionResult};
use query_structure::{PlaceholderType, PrismaValue, SelectedField, SelectionResult};
use thiserror::Error;

use super::expression::{Binding, Expression};
Expand Down Expand Up @@ -91,7 +91,7 @@ impl<'a, 'b> NodeTranslator<'a, 'b> {
(
field.clone(),
PrismaValue::Placeholder {
name: self.graph.edge_source(edge).id(),
name: generate_projected_dependency_name(self.graph.edge_source(edge), field),
r#type: PlaceholderType::Any,
},
)
Expand Down Expand Up @@ -152,7 +152,7 @@ impl<'a, 'b> NodeTranslator<'a, 'b> {
// doesn't belong into results, and is executed before all result scopes.
let mut expressions: Vec<Expression> = child_pairs
.into_iter()
.map(|(edge, node)| self.process_child_with_dependency(edge, node))
.map(|(_, node)| self.process_child_with_dependencies(node))
.collect::<Result<Vec<_>, _>>()?;

// Fold result scopes into one expression.
Expand All @@ -169,9 +169,9 @@ impl<'a, 'b> NodeTranslator<'a, 'b> {
// if not, we can separate them with a getfirstnonempty
let bindings = result_subgraphs
.into_iter()
.map(|(edge, node)| {
.map(|(_, node)| {
let name = node.id();
let expr = self.process_child_with_dependency(edge, node)?;
let expr = self.process_child_with_dependencies(node)?;
Ok(Binding { name, expr })
})
.collect::<TranslateResult<Vec<_>>>()?;
Expand Down Expand Up @@ -199,39 +199,53 @@ impl<'a, 'b> NodeTranslator<'a, 'b> {
}
}

fn process_child_with_dependency(&mut self, edge: EdgeRef, node: NodeRef) -> TranslateResult<Expression> {
let edge_content = self.graph.edge_content(&edge);
let field = if let Some(QueryGraphDependency::ProjectedDataDependency(selection, _)) = edge_content {
let mut fields = selection.selections();
if let Some(first) = fields.next().filter(|_| fields.len() == 0) {
Some(first.db_name().to_string())
} else {
// we need to handle MapField with multiple fields?
todo!()
}
} else {
None
};
fn process_child_with_dependencies(&mut self, node: NodeRef) -> TranslateResult<Expression> {
let bindings = self
.graph
.incoming_edges(&node)
.into_iter()
.filter_map(|edge| {
let field = if let Some(QueryGraphDependency::ProjectedDataDependency(selection, _)) =
self.graph.edge_content(&edge)
{
let mut fields = selection.selections();
if let Some(first) = fields.next().filter(|_| fields.len() == 0) {
first
} else {
// we need to handle MapField with multiple fields?
todo!()
}
} else {
return None;
};

let source = self.graph.edge_source(&edge);
Some(Binding::new(
generate_projected_dependency_name(source, field),
Expression::MapField {
field: field.prisma_name().into_owned(),
records: Box::new(Expression::Get { name: source.id() }),
},
))
})
.collect::<Vec<_>>();

// translate plucks the edges coming into node, we need to avoid accessing it afterwards
let edges = self.graph.incoming_edges(&node);
let source = self.graph.edge_source(&edge);
let expr = NodeTranslator::new(self.graph, node, &edges, self.query_builder).translate()?;

// we insert a MapField expression if the edge was a projected data dependency
if let Some(field) = field {
if !bindings.is_empty() {
Ok(Expression::Let {
bindings: vec![Binding::new(
source.id(),
Expression::MapField {
field,
records: Box::new(Expression::Get { name: source.id() }),
},
)],
bindings,
expr: Box::new(expr),
})
} else {
Ok(expr)
}
}
}

fn generate_projected_dependency_name(source: NodeRef, field: &SelectedField) -> String {
format!("{}${}", source.id(), field.prisma_name())
}
140 changes: 80 additions & 60 deletions query-compiler/query-compiler/src/translate/query/read.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use crate::{
};
use itertools::Itertools;
use query_builder::{QueryArgumentsExt, QueryBuilder, RelationLink};
use query_core::{FilteredQuery, ReadQuery};
use query_core::{FilteredQuery, ReadQuery, RelatedRecordsQuery};
use query_structure::{
ConditionValue, FieldSelection, Filter, PrismaValue, QueryArguments, QueryMode, RelationField, ScalarCondition,
ScalarFilter, ScalarProjection,
Expand Down Expand Up @@ -61,7 +61,10 @@ pub(crate) fn translate_read_query(query: ReadQuery, builder: &dyn QueryBuilder)
}
}

ReadQuery::RelatedRecordsQuery(_) => unreachable!("related records query should not be at the top-level"),
ReadQuery::RelatedRecordsQuery(rrq) => {
let (expr, _) = build_read_related_records(rrq, None, builder)?;
expr
}

_ => todo!(),
})
Expand Down Expand Up @@ -99,7 +102,7 @@ fn add_inmemory_join(
})
.map(|rrq| -> TranslateResult<JoinExpression> {
let parent_field_name = rrq.parent_field.name().to_owned();

let left_scalars = rrq.parent_field.left_scalars();
let conditions = rrq
.parent_field
.left_scalars()
Expand All @@ -116,27 +119,15 @@ fn add_inmemory_join(
}
})
.collect();

let selected_fields = rrq.selected_fields.without_relations().into_virtuals_last();
let needs_reversed_order = rrq.args.needs_reversed_order();

let (mut child_query, join_on) = if rrq.parent_field.relation().is_many_to_many() {
build_read_m2m_query(rrq.parent_field, conditions, rrq.args, &selected_fields, builder)?
} else {
build_read_one2m_query(rrq.parent_field, conditions, rrq.args, &selected_fields, builder)?
};

if needs_reversed_order {
child_query = Expression::Reverse(Box::new(child_query));
}

if !rrq.nested.is_empty() {
child_query = add_inmemory_join(child_query, rrq.nested, builder)?;
};
let (child, join_fields) = build_read_related_records(rrq, Some(conditions), builder)?;

Ok(JoinExpression {
child: child_query,
on: join_on,
child,
on: left_scalars
.into_iter()
.map(|sf| sf.name().to_owned())
.zip(join_fields)
.collect(),
parent_field: parent_field_name,
})
})
Expand All @@ -157,64 +148,82 @@ fn add_inmemory_join(
})
}

fn build_read_related_records(
rrq: RelatedRecordsQuery,
conditions: Option<Vec<ScalarCondition>>,
builder: &dyn QueryBuilder,
) -> TranslateResult<(Expression, JoinFields)> {
let selected_fields = rrq.selected_fields.without_relations().into_virtuals_last();
let needs_reversed_order = rrq.args.needs_reversed_order();

let (mut child_query, join_on) = if rrq.parent_field.relation().is_many_to_many() {
build_read_m2m_query(rrq.parent_field, conditions, rrq.args, &selected_fields, builder)?
} else {
build_read_one2m_query(rrq.parent_field, conditions, rrq.args, &selected_fields, builder)?
};

if needs_reversed_order {
child_query = Expression::Reverse(Box::new(child_query));
}

if !rrq.nested.is_empty() {
child_query = add_inmemory_join(child_query, rrq.nested, builder)?;
};
Ok((child_query, join_on))
}

fn build_read_m2m_query(
field: RelationField,
mut conditions: Vec<ScalarCondition>,
conditions: Option<Vec<ScalarCondition>>,
args: QueryArguments,
selected_fields: &FieldSelection,
builder: &dyn QueryBuilder,
) -> TranslateResult<(Expression, Vec<(String, String)>)> {
let condition = conditions
.pop()
.expect("should have at least one condition in m2m relation");
assert!(
conditions.is_empty(),
"should have at most one condition in m2m relation"
);
) -> TranslateResult<(Expression, JoinFields)> {
let condition = conditions.map(|mut conditions| {
let condition = conditions
.pop()
.expect("should have at least one condition in m2m relation");
assert!(
conditions.is_empty(),
"should have at most one condition in m2m relation"
);
condition
});

let link = RelationLink::new(field, condition);
let join_expr = link
.field()
.linking_fields()
.scalars()
.map(|left| (left.name().to_owned(), link.to_string()))
.collect_vec();
let link_name = link.to_string();

let query = builder
.build_get_related_records(link, args, selected_fields)
.map_err(TranslateError::QueryBuildFailure)?;

Ok((Expression::Query(query), join_expr))
Ok((Expression::Query(query), JoinFields(vec![link_name])))
}

fn build_read_one2m_query(
field: RelationField,
conditions: Vec<ScalarCondition>,
conditions: Option<Vec<ScalarCondition>>,
mut args: QueryArguments,
selected_fields: &FieldSelection,
builder: &dyn QueryBuilder,
) -> TranslateResult<(Expression, Vec<(String, String)>)> {
let join_expr = field
.linking_fields()
.scalars()
.zip(field.related_field().left_scalars())
.map(|(left, right)| (left.name().to_owned(), right.name().to_owned()))
.collect_vec();
) -> TranslateResult<(Expression, JoinFields)> {
let related_scalars = field.related_field().left_scalars();
let join_fields = related_scalars.iter().map(|sf| sf.name().to_owned()).collect();

// TODO: we ignore chunking for now
let linking_scalars = field.related_field().left_scalars();

assert_eq!(
linking_scalars.len(),
conditions.len(),
"linking fields should match conditions"
);
for (condition, child_field) in conditions.into_iter().zip(linking_scalars) {
args.add_filter(Filter::Scalar(ScalarFilter {
condition,
projection: ScalarProjection::Single(child_field.clone()),
mode: QueryMode::Default,
}));
if let Some(conditions) = conditions {
assert_eq!(
related_scalars.len(),
conditions.len(),
"linking fields should match conditions"
);
for (condition, child_field) in conditions.into_iter().zip(related_scalars) {
args.add_filter(Filter::Scalar(ScalarFilter {
condition,
projection: ScalarProjection::Single(child_field.clone()),
mode: QueryMode::Default,
}));
}
}

let to_one_relation = !field.arity().is_list();
Expand All @@ -227,5 +236,16 @@ fn build_read_one2m_query(
if to_one_relation {
expr = Expression::Unique(Box::new(expr));
}
Ok((expr, join_expr))
Ok((expr, JoinFields(join_fields)))
}

struct JoinFields(Vec<String>);

impl IntoIterator for JoinFields {
type Item = String;
type IntoIter = std::vec::IntoIter<Self::Item>;

fn into_iter(self) -> Self::IntoIter {
self.0.into_iter()
}
}
Loading

0 comments on commit 7eac4c9

Please sign in to comment.