Skip to content

Commit

Permalink
parse + compile constraint.to and constraint.to_columns on foreig…
Browse files Browse the repository at this point in the history
…n key constraints (dbt-labs#10414)
  • Loading branch information
MichelleArk authored Jul 25, 2024
1 parent e1621eb commit cab6dab
Show file tree
Hide file tree
Showing 12 changed files with 655 additions and 35 deletions.
6 changes: 6 additions & 0 deletions .changes/unreleased/Features-20240719-161841.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Features
body: Support ref and source in foreign key constraint expressions
time: 2024-07-19T16:18:41.434278-04:00
custom:
Author: michelleark
Issue: "8062"
42 changes: 40 additions & 2 deletions core/dbt/clients/jinja_static.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from typing import Any, Dict, Optional
from typing import Any, Dict, List, Optional, Union

import jinja2

from dbt.exceptions import MacroNamespaceNotStringError
from dbt.artifacts.resources import RefArgs
from dbt.exceptions import MacroNamespaceNotStringError, ParsingError
from dbt_common.clients.jinja import get_environment
from dbt_common.exceptions.macros import MacroNameNotStringError
from dbt_common.tests import test_caching_enabled
from dbt_extractor import ExtractionError, py_extract_from_source # type: ignore

_TESTING_MACRO_CACHE: Optional[Dict[str, Any]] = {}

Expand Down Expand Up @@ -153,3 +155,39 @@ def statically_parse_adapter_dispatch(func_call, ctx, db_wrapper):
possible_macro_calls.append(f"{package_name}.{func_name}")

return possible_macro_calls


def statically_parse_ref_or_source(expression: str) -> Union[RefArgs, List[str]]:
"""
Returns a RefArgs or List[str] object, corresponding to ref or source respectively, given an input jinja expression.
input: str representing how input node is referenced in tested model sql
* examples:
- "ref('my_model_a')"
- "ref('my_model_a', version=3)"
- "ref('package', 'my_model_a', version=3)"
- "source('my_source_schema', 'my_source_name')"
If input is not a well-formed jinja ref or source expression, a ParsingError is raised.
"""
ref_or_source: Union[RefArgs, List[str]]

try:
statically_parsed = py_extract_from_source(f"{{{{ {expression} }}}}")
except ExtractionError:
raise ParsingError(f"Invalid jinja expression: {expression}")

if statically_parsed.get("refs"):
raw_ref = list(statically_parsed["refs"])[0]
ref_or_source = RefArgs(
package=raw_ref.get("package"),
name=raw_ref.get("name"),
version=raw_ref.get("version"),
)
elif statically_parsed.get("sources"):
source_name, source_table_name = list(statically_parsed["sources"])[0]
ref_or_source = [source_name, source_table_name]
else:
raise ParsingError(f"Invalid ref or source expression: {expression}")

return ref_or_source
26 changes: 26 additions & 0 deletions core/dbt/compilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,15 @@
from dbt.exceptions import (
DbtInternalError,
DbtRuntimeError,
ForeignKeyConstraintToSyntaxError,
GraphDependencyNotFoundError,
ParsingError,
)
from dbt.flags import get_flags
from dbt.graph import Graph
from dbt.node_types import ModelLanguage, NodeType
from dbt_common.clients.system import make_directory
from dbt_common.contracts.constraints import ConstraintType
from dbt_common.events.contextvars import get_node_info
from dbt_common.events.format import pluralize
from dbt_common.events.functions import fire_event
Expand Down Expand Up @@ -437,8 +440,31 @@ def _compile_code(
relation_name = str(relation_cls.create_from(self.config, node))
node.relation_name = relation_name

# Compile 'ref' and 'source' expressions in foreign key constraints
if node.resource_type == NodeType.Model:
for constraint in node.all_constraints:
if constraint.type == ConstraintType.foreign_key and constraint.to:
constraint.to = self._compile_relation_for_foreign_key_constraint_to(
manifest, node, constraint.to
)

return node

def _compile_relation_for_foreign_key_constraint_to(
self, manifest: Manifest, node: ManifestSQLNode, to_expression: str
) -> str:
try:
foreign_key_node = manifest.find_node_from_ref_or_source(to_expression)
except ParsingError:
raise ForeignKeyConstraintToSyntaxError(node, to_expression)

if not foreign_key_node:
raise GraphDependencyNotFoundError(node, to_expression)

adapter = get_adapter(self.config)
relation_name = str(adapter.Relation.create_from(self.config, foreign_key_node))
return relation_name

# This method doesn't actually "compile" any of the nodes. That is done by the
# "compile_node" method. This creates a Linker and builds the networkx graph,
# writes out the graph.gpickle file, and prints the stats, returning a Graph object.
Expand Down
19 changes: 18 additions & 1 deletion core/dbt/contracts/graph/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,10 @@
from dbt.adapters.factory import get_adapter_package_names

# to preserve import paths
from dbt.artifacts.resources import BaseResource, DeferRelation, NodeVersion
from dbt.artifacts.resources import BaseResource, DeferRelation, NodeVersion, RefArgs
from dbt.artifacts.resources.v1.config import NodeConfig
from dbt.artifacts.schemas.manifest import ManifestMetadata, UniqueID, WritableManifest
from dbt.clients.jinja_static import statically_parse_ref_or_source
from dbt.contracts.files import (
AnySourceFile,
FileHash,
Expand Down Expand Up @@ -1635,6 +1636,22 @@ def add_saved_query(self, source_file: SchemaSourceFile, saved_query: SavedQuery

# end of methods formerly in ParseResult

def find_node_from_ref_or_source(
self, expression: str
) -> Optional[Union[ModelNode, SourceDefinition]]:
ref_or_source = statically_parse_ref_or_source(expression)

node = None
if isinstance(ref_or_source, RefArgs):
node = self.ref_lookup.find(
ref_or_source.name, ref_or_source.package, ref_or_source.version, self
)
else:
source_name, source_table_name = ref_or_source[0], ref_or_source[1]
node = self.source_lookup.find(f"{source_name}.{source_table_name}", None, self)

return node

# Provide support for copy.deepcopy() - we just need to avoid the lock!
# pickle and deepcopy use this. It returns a callable object used to
# create the initial version of the object and a tuple of arguments
Expand Down
18 changes: 17 additions & 1 deletion core/dbt/contracts/graph/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,11 @@
NodeType,
)
from dbt_common.clients.system import write_file
from dbt_common.contracts.constraints import ConstraintType
from dbt_common.contracts.constraints import (
ColumnLevelConstraint,
ConstraintType,
ModelLevelConstraint,
)
from dbt_common.events.contextvars import set_log_contextvars
from dbt_common.events.functions import warn_or_error

Expand Down Expand Up @@ -489,6 +493,18 @@ def search_name(self):
def materialization_enforces_constraints(self) -> bool:
return self.config.materialized in ["table", "incremental"]

@property
def all_constraints(self) -> List[Union[ModelLevelConstraint, ColumnLevelConstraint]]:
constraints: List[Union[ModelLevelConstraint, ColumnLevelConstraint]] = []
for model_level_constraint in self.constraints:
constraints.append(model_level_constraint)

for column in self.columns.values():
for column_level_constraint in column.constraints:
constraints.append(column_level_constraint)

return constraints

def infer_primary_key(self, data_tests: List["GenericTestNode"]) -> List[str]:
"""
Infers the columns that can be used as primary key of a model in the following order:
Expand Down
12 changes: 12 additions & 0 deletions core/dbt/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,18 @@ def get_message(self) -> str:
return msg


class ForeignKeyConstraintToSyntaxError(CompilationError):
def __init__(self, node, expression: str) -> None:
self.expression = expression
self.node = node
super().__init__(msg=self.get_message())

def get_message(self) -> str:
msg = f"'{self.node.unique_id}' defines a foreign key constraint 'to' expression which is not valid 'ref' or 'source' syntax: {self.expression}."

return msg


# client level exceptions


Expand Down
24 changes: 23 additions & 1 deletion core/dbt/parser/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from typing import Any, Callable, Dict, Generic, Iterable, List, Optional, Type, TypeVar

from dbt import deprecations
from dbt.artifacts.resources import RefArgs
from dbt.clients.jinja_static import statically_parse_ref_or_source
from dbt.clients.yaml_helper import load_yaml_text
from dbt.config import RuntimeConfig
from dbt.context.configured import SchemaYamlVars, generate_schema_yml_context
Expand Down Expand Up @@ -915,7 +917,7 @@ def patch_node_properties(self, node, patch: "ParsedNodePatch") -> None:
self.patch_constraints(node, patch.constraints)
node.build_contract_checksum()

def patch_constraints(self, node, constraints) -> None:
def patch_constraints(self, node, constraints: List[Dict[str, Any]]) -> None:
contract_config = node.config.get("contract")
if contract_config.enforced is True:
self._validate_constraint_prerequisites(node)
Expand All @@ -930,6 +932,26 @@ def patch_constraints(self, node, constraints) -> None:

self._validate_pk_constraints(node, constraints)
node.constraints = [ModelLevelConstraint.from_dict(c) for c in constraints]
self._process_constraints_refs_and_sources(node)

def _process_constraints_refs_and_sources(self, model_node: ModelNode) -> None:
"""
Populate model_node.refs and model_node.sources based on foreign-key constraint references,
whether defined at the model-level or column-level.
"""
for constraint in model_node.all_constraints:
if constraint.type == ConstraintType.foreign_key and constraint.to:
try:
ref_or_source = statically_parse_ref_or_source(constraint.to)
except ParsingError:
raise ParsingError(
f"Invalid 'ref' or 'source' syntax on foreign key constraint 'to' on model {model_node.name}: {constraint.to}."
)

if isinstance(ref_or_source, RefArgs):
model_node.refs.append(ref_or_source)
else:
model_node.sources.append(ref_or_source)

def _validate_pk_constraints(
self, model_node: ModelNode, constraints: List[Dict[str, Any]]
Expand Down
115 changes: 115 additions & 0 deletions tests/functional/constraints/fixtures.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
model_foreign_key_model_schema_yml = """
models:
- name: my_model
constraints:
- type: foreign_key
columns: [id]
to: ref('my_model_to')
to_columns: [id]
columns:
- name: id
data_type: integer
"""


model_foreign_key_source_schema_yml = """
sources:
- name: test_source
tables:
- name: test_table
models:
- name: my_model
constraints:
- type: foreign_key
columns: [id]
to: source('test_source', 'test_table')
to_columns: [id]
columns:
- name: id
data_type: integer
"""


model_foreign_key_model_node_not_found_schema_yml = """
models:
- name: my_model
constraints:
- type: foreign_key
columns: [id]
to: ref('doesnt_exist')
to_columns: [id]
columns:
- name: id
data_type: integer
"""


model_foreign_key_model_invalid_syntax_schema_yml = """
models:
- name: my_model
constraints:
- type: foreign_key
columns: [id]
to: invalid
to_columns: [id]
columns:
- name: id
data_type: integer
"""


model_foreign_key_model_column_schema_yml = """
models:
- name: my_model
columns:
- name: id
data_type: integer
constraints:
- type: foreign_key
to: ref('my_model_to')
to_columns: [id]
"""


model_foreign_key_column_invalid_syntax_schema_yml = """
models:
- name: my_model
columns:
- name: id
data_type: integer
constraints:
- type: foreign_key
to: invalid
to_columns: [id]
"""


model_foreign_key_column_node_not_found_schema_yml = """
models:
- name: my_model
columns:
- name: id
data_type: integer
constraints:
- type: foreign_key
to: ref('doesnt_exist')
to_columns: [id]
"""

model_column_level_foreign_key_source_schema_yml = """
sources:
- name: test_source
tables:
- name: test_table
models:
- name: my_model
columns:
- name: id
data_type: integer
constraints:
- type: foreign_key
to: source('test_source', 'test_table')
to_columns: [id]
"""
Loading

0 comments on commit cab6dab

Please sign in to comment.