Skip to content

Commit

Permalink
Make specialized 'calculate_node_config' methods
Browse files Browse the repository at this point in the history
  • Loading branch information
gshank committed Nov 30, 2024
1 parent df23c7d commit 42a1afb
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 34 deletions.
120 changes: 88 additions & 32 deletions core/dbt/context/context_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def __init__(self, active_project: RuntimeConfig):
def get_config_source(self, project: Project) -> ConfigSource:
return RenderedConfig(project)

def get_node_project(self, project_name: str):
def get_node_project_config(self, project_name: str):
if project_name == self._active_project.project_name:
return self._active_project
dependencies = self._active_project.load_dependencies()
Expand Down Expand Up @@ -131,6 +131,17 @@ def _active_project_configs(
) -> Iterator[Dict[str, Any]]:
return self._project_configs(self._active_project, fqn, resource_type)

@abstractmethod
def calculate_node_config(
self,
config_call_dict: Dict[str, Any],
fqn: List[str],
resource_type: NodeType,
project_name: str,
base: bool,
patch_config_dict: Optional[Dict[str, Any]] = None,
) -> T: ...

@abstractmethod
def _update_from_config(
self, result: T, partial: Dict[str, Any], validate: bool = False
Expand All @@ -140,6 +151,25 @@ def _update_from_config(
def initial_result(self, resource_type: NodeType, base: bool) -> T: ...

# BaseContextConfigGenerator
@abstractmethod
def calculate_node_config_dict(
self,
config_call_dict: Dict[str, Any],
fqn: List[str],
resource_type: NodeType,
project_name: str,
base: bool,
patch_config_dict: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]: ...


class ContextConfigGenerator(BaseContextConfigGenerator[C]):
def __init__(self, active_project: RuntimeConfig):
self._active_project = active_project

def get_config_source(self, project: Project) -> ConfigSource:
return RenderedConfig(project)

def calculate_node_config(
self,
config_call_dict: Dict[str, Any],
Expand All @@ -148,13 +178,18 @@ def calculate_node_config(
project_name: str,
base: bool,
patch_config_dict: Optional[Dict[str, Any]] = None,
) -> BaseConfig:
own_config = self.get_node_project(project_name)
) -> C:
# Note: This method returns a BaseConfig object. This is a duplicate of
# of UnrenderedConfigGenerator.calculate_node_config, but calls methods
# that deal with config objects instead of dictionaries.
# Additions to one method, should probably also go in the other.

project_config = self.get_node_project_config(project_name)

# creates "default" config object ("cls.from_dict({})")
config_obj = self.initial_result(resource_type=resource_type, base=base)

project_configs = self._project_configs(own_config, fqn, resource_type)
project_configs = self._project_configs(project_config, fqn, resource_type)
for fqn_config in project_configs:
config_obj = self._update_from_config(config_obj, fqn_config)

Expand All @@ -168,33 +203,14 @@ def calculate_node_config(
# the ParseConfigObject (via add_config_call)
config_obj = self._update_from_config(config_obj, config_call_dict)

if own_config.project_name != self._active_project.project_name:
if project_config.project_name != self._active_project.project_name:
for fqn_config in self._active_project_configs(fqn, resource_type):
config_obj = self._update_from_config(config_obj, fqn_config)

return config_obj # type: ignore[return-value]

@abstractmethod
def calculate_node_config_dict(
self,
config_call_dict: Dict[str, Any],
fqn: List[str],
resource_type: NodeType,
project_name: str,
base: bool,
patch_config_dict: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]: ...


class ContextConfigGenerator(BaseContextConfigGenerator[C]):
def __init__(self, active_project: RuntimeConfig):
self._active_project = active_project

def get_config_source(self, project: Project) -> ConfigSource:
return RenderedConfig(project)
return config_obj

def initial_result(self, resource_type: NodeType, base: bool) -> C:
# defaults, own_config, config calls, active_config (if != own_config)
# defaults, project_config, config calls, active_config (if != project_config)
config_cls = get_config_for(resource_type, base=base)
# Calculate the defaults. We don't want to validate the defaults,
# because it might be invalid in the case of required config members
Expand Down Expand Up @@ -234,7 +250,7 @@ def calculate_node_config_dict(
patch_config_dict: Optional[dict] = None,
) -> Dict[str, Any]:

# calls BaseContextConfigGenerator.calculate_node_config
# returns a config object
config_obj = self.calculate_node_config(
config_call_dict=config_call_dict,
fqn=fqn,
Expand All @@ -257,6 +273,45 @@ class UnrenderedConfigGenerator(BaseContextConfigGenerator[Dict[str, Any]]):
def get_config_source(self, project: Project) -> ConfigSource:
return UnrenderedConfig(project)

def calculate_node_config(
self,
config_call_dict: Dict[str, Any],
fqn: List[str],
resource_type: NodeType,
project_name: str,
base: bool,
patch_config_dict: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
# Note: This method returns a Dict[str, Any]. This is a duplicate of
# of ContextConfigGenerator.calculate_node_config, but calls methods
# that deal with dictionaries instead of config object.
# Additions to one method, should probably also go in the other.

project_config = self.get_node_project_config(project_name)

# creates "default" config object ({})
config_dict = self.initial_result(resource_type=resource_type, base=base)

project_configs = self._project_configs(project_config, fqn, resource_type)
for fqn_config in project_configs:
config_dict = self._update_from_config(config_dict, fqn_config)

# When schema files patch config, it has lower precedence than
# config in the models (config_call_dict), so we add the patch_config_dict
# before the config_call_dict
if patch_config_dict:
config_dict = self._update_from_config(config_dict, patch_config_dict)

# config_calls are created in the 'experimental' model parser and
# the ParseConfigObject (via add_config_call)
config_dict = self._update_from_config(config_dict, config_call_dict)

if project_config.project_name != self._active_project.project_name:
for fqn_config in self._active_project_configs(fqn, resource_type):
config_dict = self._update_from_config(config_dict, fqn_config)

return config_dict

# UnrenderedConfigGenerator
def calculate_node_config_dict(
self,
Expand All @@ -267,17 +322,18 @@ def calculate_node_config_dict(
base: bool,
patch_config_dict: Optional[dict] = None,
) -> Dict[str, Any]:

# calls BaseContextConfigGenerator.calculate_node_config
return self.calculate_node_config(
# Just call UnrenderedConfigGenerator.calculate_node_config, which
# will return a config dictionary
result = self.calculate_node_config(
config_call_dict=config_call_dict,
fqn=fqn,
resource_type=resource_type,
project_name=project_name,
base=base,
patch_config_dict=patch_config_dict,
) # type: ignore[return-value]
# Note: this returns a config_obj, NOT a dictionary
)
# Note: this returns a dictionary
return result

def initial_result(self, resource_type: NodeType, base: bool) -> Dict[str, Any]:
return {}
Expand Down
4 changes: 2 additions & 2 deletions tests/functional/graph_selection/test_version_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def seeds(self, test_data_dir):
def selectors(self):
return selectors_yml

@pytest.mark.skip('broken until mash 3.15')
@pytest.mark.skip("broken until mash 3.15")
def test_select_none_versions(self, project):
manifest = run_dbt(["parse"])
print(f"--- nodes.keys(): {manifest.nodes.keys()}")
Expand All @@ -77,7 +77,7 @@ def test_select_old_versions(self, project):
results = run_dbt(["ls", "--select", "version:old"])
assert sorted(results) == ["test.versioned.v1"]

@pytest.mark.skip('broken until mash 3.15')
@pytest.mark.skip("broken until mash 3.15")
def test_select_prerelease_versions(self, project):
results = run_dbt(["ls", "--select", "version:prerelease"])
assert sorted(results) == [
Expand Down

0 comments on commit 42a1afb

Please sign in to comment.