Skip to content

Commit

Permalink
Completes Ruff B letter flake8-bugbear rules
Browse files Browse the repository at this point in the history
Refactors to take this rule into account.
  • Loading branch information
skrawcz committed Aug 18, 2024
1 parent 3571554 commit d011d38
Show file tree
Hide file tree
Showing 69 changed files with 150 additions and 143 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ def construct_df(
negatives_per_positive: int = 1,
random_seed: int = 123,
) -> pd.DataFrame:
f"""Return dataframe of {base_df} paris with negatives added."""
"""Return dataframe of {base_df} paris with negatives added."""
return pd.concat(
[
base_df,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def table_ref(
table = client.open_table(table_name)
except FileNotFoundError:
if schema is None:
raise ValueError("`schema` must be provided to create table.")
raise ValueError("`schema` must be provided to create table.") from FileNotFoundError

table = _create_table(
client=client,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def best_model_per_series(cross_validation_evaluation: pd.DataFrame) -> pd.Serie
def inference_predictions(
forecaster: StatsForecast,
inference_forecast_steps: int = 12,
inference_confidence_percentile: list[float] = [90.0],
inference_confidence_percentile: list[float] = [90.0], # noqa
) -> pd.DataFrame:
"""Infer values using the training harness. Fitted models aren't stored
Expand All @@ -141,7 +141,7 @@ def plotting_config(
plot_uids: Optional[list[str]] = None,
plot_models: Optional[list[str]] = None,
plot_anomalies: bool = False,
plot_confidence_percentile: list[float] = [90.0],
plot_confidence_percentile: list[float] = [90.0], # noqa: B006
plot_engine: str = "matplotlib",
) -> dict:
"""Configuration for plotting functions"""
Expand Down
4 changes: 2 additions & 2 deletions contrib/hamilton/contrib/user/zilto/webscraper/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ def html_page(url: str) -> str:
def parsed_html(
url: str,
html_page: str,
tags_to_extract: List[str] = ["p", "li", "div"],
tags_to_remove: List[str] = ["script", "style"],
tags_to_extract: List[str] = ["p", "li", "div"], # noqa: B006
tags_to_remove: List[str] = ["script", "style"], # noqa: B006
) -> ParsingResult:
"""Parse an HTML string using BeautifulSoup
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def cross_validation_folds(

def study(
higher_is_better: bool,
pruner: Optional[optuna.pruners.BasePruner] = optuna.pruners.MedianPruner(),
pruner: Optional[optuna.pruners.BasePruner] = None,
sampler: Optional[optuna.samplers.BaseSampler] = None,
study_storage: Optional[str] = None,
study_name: Optional[str] = None,
Expand All @@ -142,6 +142,8 @@ def study(
"""Create an optuna study; use the XGBoost + Optuna integration for pruning
ref: https://github.com/optuna/optuna-examples/blob/main/xgboost/xgboost_integration.py
"""
if pruner is None:
pruner = optuna.pruners.MedianPruner()
return optuna.create_study(
direction="maximize" if higher_is_better else "minimize",
pruner=pruner,
Expand Down
2 changes: 1 addition & 1 deletion contrib/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
with open("README.md") as readme_file:
readme = readme_file.read()
except Exception:
warnings.warn("README.md not found")
warnings.warn("README.md not found") # noqa
readme = None

REQUIREMENTS_FILES = ["requirements.txt"]
Expand Down
2 changes: 1 addition & 1 deletion examples/LLM_Workflows/GraphRAG/ingest_fighters.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def raw_fighter_details() -> pd.DataFrame:

def fighter(raw_fighter_details: pd.DataFrame) -> Parallelizable[pd.Series]:
"""We then want to do something for each record. That's what this code sets up"""
for idx, row in raw_fighter_details.iterrows():
for _, row in raw_fighter_details.iterrows():
yield row


Expand Down
4 changes: 2 additions & 2 deletions examples/LLM_Workflows/image_telephone/streamlit.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,8 +403,8 @@ def explore_display():
image_urls_to_display = image_urls[0 : len(projection)]
if len(image_urls_to_display) != len(projection):
image_url_length = len(image_urls_to_display)
for i in range(len(projection) - len(image_urls_to_display)):
image_urls_to_display.append(image_urls[image_url_length - 1])
# for i in range(len(projection) - len(image_urls_to_display)):
image_urls_to_display.append(image_urls[image_url_length - 1])
embedding_path_plot(projection, image_urls_to_display, selected_entry, prompt_path)
# highlight_point(projection, selected_entry)

Expand Down
2 changes: 1 addition & 1 deletion examples/LLM_Workflows/knowledge_retrieval/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def call_arxiv_function(messages, full_message):
return response
except Exception as e:
logger.error(type(e))
raise Exception("Function chat request failed")
raise Exception("Function chat request failed") from e

elif full_message["message"]["function_call"]["name"] == "read_article_and_summarize":
parsed_output = json.loads(full_message["message"]["function_call"]["arguments"])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def pdf_text(pdf_path: pd.Series) -> pd.Series:
:return: Series of strings of the PDFs' contents
"""
_pdf_text = []
for i, file_path in pdf_path.items():
for _, file_path in pdf_path.items():
# creating a pdf reader object
reader = PdfReader(file_path)
text = ""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ class SummaryResponse(pydantic.BaseModel):


@app.post("/store_arxiv", tags=["Ingestion"])
async def store_arxiv(arxiv_ids: list[str] = fastapi.Form(...)) -> JSONResponse:
async def store_arxiv(arxiv_ids: list[str] = fastapi.Form(...)) -> JSONResponse: # noqa: B008
"""Retrieve PDF files of arxiv articles for arxiv_ids\n
Read the PDF as text, create chunks, and embed them using OpenAI API\n
Store chunks with embeddings in Weaviate.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ def article_text(url: str, article_regex: str) -> str:
"""
try:
html = requests.get(url)
except requests.exceptions.RequestException:
raise Exception(f"Failed to get URL: {url}")
except requests.exceptions.RequestException as e:
raise Exception(f"Failed to get URL: {url}") from e
article = re.findall(article_regex, html.text, re.DOTALL)
if not article:
raise ValueError(f"No article found in {url}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@ def sitemap_text(sitemap_url: str = "https://hamilton.dagworks.io/en/latest/site
try:
sitemap = requests.get(sitemap_url)
except Exception as e:
raise RuntimeError(f"Failed to fetch sitemap from {sitemap_url}. Original error: {str(e)}")
raise RuntimeError(
f"Failed to fetch sitemap from {sitemap_url}. Original error: {str(e)}"
) from e
return sitemap.text


Expand Down
2 changes: 1 addition & 1 deletion examples/dagster/dagster_code/tutorial/assets.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def most_frequent_words() -> MaterializeResult:
for raw_title in topstories["title"]:
title = raw_title.lower()
for word in title.split():
cleaned_word = word.strip(".,-!?:;()[]'\"-")
cleaned_word = word.strip(".,-!?:;()[]'\"-") # noqa
if cleaned_word not in stopwords and len(cleaned_word) > 0:
word_counts[cleaned_word] = word_counts.get(cleaned_word, 0) + 1

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def get_signups_for_date(self, date: datetime) -> Sequence[Signup]:
signups = []
num_signups = self.random.randint(25, 100)

for i in range(num_signups):
for _ in range(num_signups):
signup = self.generate_signup(date)
signups.append(signup.to_dict())

Expand Down
2 changes: 1 addition & 1 deletion examples/dagster/hamilton_code/dataflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def most_frequent_words(title: pd.Series) -> dict[str, int]:
word_counts = {}
for raw_title in title:
for word in raw_title.lower().split():
word = word.strip(".,-!?:;()[]'\"-")
word = word.strip(".,-!?:;()[]'\"-") # noqa
if len(word) == 0:
continue

Expand Down
2 changes: 1 addition & 1 deletion examples/dagster/hamilton_code/mock_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def get_signups_for_date(self, date: datetime) -> Sequence[Signup]:
signups = []
num_signups = self.random.randint(25, 100)

for i in range(num_signups):
for _ in range(num_signups):
signup = self.generate_signup(date)
signups.append(signup.to_dict())

Expand Down
4 changes: 2 additions & 2 deletions examples/decoupling_io/adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
import sklearn.inspection
import sklearn.metrics
import sklearn.model_selection
except ImportError:
raise NotImplementedError("scikit-learn is not installed.")
except ImportError as e:
raise NotImplementedError("scikit-learn is not installed.") from e


from hamilton import registry
Expand Down
28 changes: 16 additions & 12 deletions examples/dlt/slack/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,12 +168,7 @@ def get_thread_replies(messages: List[Dict[str, Any]]) -> Iterable[TDataItem]:
write_disposition=write_disposition,
)
def messages_resource(
created_at: dlt.sources.incremental[DateTime] = dlt.sources.incremental(
"ts",
initial_value=start_dt,
end_value=end_dt,
allow_external_schedulers=True,
),
created_at: dlt.sources.incremental[DateTime] = None,
) -> Iterable[TDataItem]:
"""
Yield all messages for a set of selected channels as a DLT resource. Keep blocks column without normalization.
Expand All @@ -184,19 +179,21 @@ def messages_resource(
Yields:
Iterable[TDataItem]: A list of messages.
"""
if created_at is None:
created_at = dlt.sources.incremental(
"ts",
initial_value=start_dt,
end_value=end_dt,
allow_external_schedulers=True,
)
start_date_ts = ensure_dt_type(created_at.last_value, to_ts=True)
end_date_ts = ensure_dt_type(created_at.end_value, to_ts=True)
for channel_data in fetched_selected_channels:
yield from get_messages(channel_data, start_date_ts, end_date_ts)

def per_table_messages_resource(
channel_data: Dict[str, Any],
created_at: dlt.sources.incremental[DateTime] = dlt.sources.incremental(
"ts",
initial_value=start_dt,
end_value=end_dt,
allow_external_schedulers=True,
),
created_at: dlt.sources.incremental[DateTime] = None,
) -> Iterable[TDataItem]:
"""Yield all messages for a given channel as a DLT resource. Keep blocks column without normalization.
Expand All @@ -207,6 +204,13 @@ def per_table_messages_resource(
Yields:
Iterable[TDataItem]: A list of messages.
"""
if created_at is None:
created_at = dlt.sources.incremental(
"ts",
initial_value=start_dt,
end_value=end_dt,
allow_external_schedulers=True,
)
start_date_ts = ensure_dt_type(created_at.last_value, to_ts=True)
end_date_ts = ensure_dt_type(created_at.end_value, to_ts=True)
yield from get_messages(channel_data, start_date_ts, end_date_ts)
Expand Down
4 changes: 2 additions & 2 deletions examples/due_date_probabilities/probability_estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,9 +125,9 @@ def raw_probabilities(raw_data: str) -> pd.DataFrame:

def resampled(raw_probabilities: pd.DataFrame) -> List[int]:
sample_data = []
for index, row in raw_probabilities.iterrows():
for _, row in raw_probabilities.iterrows():
count = row.probability * 1000
for i in range(int(count)):
for _i in range(int(count)):
sample_data.append(row.days)
return sample_data

Expand Down
2 changes: 1 addition & 1 deletion examples/people_data_labs/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def stock_growth_rate_since_last_funding_round(
df = pd.merge(left=stock_data, right=period_start, on="ticker", how="inner")

stock_growth = dict()
for idx, row in df.iterrows():
for _, row in df.iterrows():
history = pd.json_normalize(row["historical_price"]).astype({"date": "datetime64[ns]"})

# skip ticker if history is empty
Expand Down
4 changes: 2 additions & 2 deletions examples/prefect/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,15 +72,15 @@ def train_and_evaluate_model_task(
)
def absenteeism_prediction_flow(
raw_data_location: str = "./data/Absenteeism_at_work.csv",
feature_set: list[str] = [
feature_set: list[str] = [ # noqa: B006
"age_zero_mean_unit_variance",
"has_children",
"has_pet",
"is_summer",
"service_time",
],
label: str = "absenteeism_time_in_hours",
validation_user_ids: list[str] = [
validation_user_ids: list[str] = [ # noqa: B006
"1",
"2",
"4",
Expand Down
7 changes: 3 additions & 4 deletions examples/spark/world_of_warcraft/zone_features__spark_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,9 @@ def world_of_warcraft(spark_session: ps.SparkSession) -> ps.DataFrame:

def zone_flags(world_of_warcraft: ps.DataFrame) -> ps.DataFrame:
zone_flags = world_of_warcraft
for zone in ["durotar", "darkshore"]:
zone_flags = zone_flags.withColumn(
"darkshore_flag", sf.when(sf.col("zone") == " Darkshore", 1).otherwise(0)
).withColumn("durotar_flag", sf.when(sf.col("zone") == " Durotar", 1).otherwise(0))
zone_flags = zone_flags.withColumn(
"darkshore_flag", sf.when(sf.col("zone") == " Darkshore", 1).otherwise(0)
).withColumn("durotar_flag", sf.when(sf.col("zone") == " Durotar", 1).otherwise(0))
return zone_flags


Expand Down
6 changes: 3 additions & 3 deletions hamilton/cli/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def _try_command(cmd: Callable, **cmd_kwargs) -> Any:
command=cmd_name, success=False, message={"error": str(type(e)), "details": str(e)}
)
logger.error(dataclasses.asdict(response))
raise typer.Exit(code=1)
raise typer.Exit(code=1) from e

return result

Expand Down Expand Up @@ -297,12 +297,12 @@ def ui(
"""Runs the Hamilton UI on sqllite in port 8241"""
try:
from hamilton_ui import commands
except ImportError:
except ImportError as e:
logger.error(
"hamilton[ui] not installed -- you have to install this to run the UI. "
'Run `pip install "sf-hamilton[ui]"` to install and get started with the UI!'
)
raise typer.Exit(code=1)
raise typer.Exit(code=1) from e

ctx.invoke(
commands.run,
Expand Down
10 changes: 5 additions & 5 deletions hamilton/cli/logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ def get_git_base_directory() -> str:
else:
print("Error:", result.stderr.strip())
raise OSError(f"{result.stderr.strip()}")
except FileNotFoundError:
raise FileNotFoundError("Git command not found. Please make sure Git is installed.")
except FileNotFoundError as e:
raise FileNotFoundError("Git command not found. Please make sure Git is installed.") from e


def get_git_reference(git_relative_path: Union[str, Path], git_reference: str) -> str:
Expand All @@ -51,8 +51,8 @@ def get_git_reference(git_relative_path: Union[str, Path], git_reference: str) -
return
else:
return
except FileNotFoundError:
raise FileNotFoundError("Git command not found. Please make sure Git is installed.")
except FileNotFoundError as e:
raise FileNotFoundError("Git command not found. Please make sure Git is installed.") from e


def version_hamilton_functions(module: ModuleType) -> Dict[str, str]:
Expand Down Expand Up @@ -184,7 +184,7 @@ def diff_versions(current_map: Dict[str, str], reference_map: Dict[str, str]) ->
if v1 != v2:
edit.append(node_name)

for node_name, v2 in reference_map.items():
for node_name, _ in reference_map.items():
v1 = current_map.get(node_name)
if v1 is None:
reference_only.append(node_name)
Expand Down
4 changes: 2 additions & 2 deletions hamilton/dataflows/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,10 +498,10 @@ def are_py_dependencies_satisfied(dataflow, user=None, version="latest"):
else:
package_name = line
required_version = None
required_version # here for now...
required_version # noqa here for now...
try:
installed_version = pkg_version(package_name)
installed_version # here for now..
installed_version # noqa here for now..
except PackageNotFoundError:
logger.info(f"Package '{package_name}' is not installed.")
return False
Expand Down
2 changes: 1 addition & 1 deletion hamilton/execution/executors.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def base_execute_task(task: TaskImplementation) -> Dict[str, Any]:
for node_ in task.nodes:
if not getattr(node_, "callable_modified", False):
node_._callable = _modify_callable(node_.node_role, node_.callable)
setattr(node_, "callable_modified", True)
node_.callable_modified = True
if task.adapter.does_hook("pre_task_execute", is_async=False):
task.adapter.call_all_lifecycle_hooks_sync(
"pre_task_execute",
Expand Down
6 changes: 3 additions & 3 deletions hamilton/execution/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ def realize_parameterized_group(
for dependency in new_task.base_dependencies:
new_dependencies[dependency] = []
if dependency in task_names_in_group:
for group_name, name_map in name_maps.items():
for _, name_map in name_maps.items():
new_dependencies[dependency].append(name_map[dependency])
else:
new_dependencies[dependency].append(dependency)
Expand Down Expand Up @@ -403,10 +403,10 @@ def update_task_state(

tasks_to_enqueue = []
# not efficient, TODO -- use a reverse dependency map
for key, task in self.task_pool.items():
for _key, task in self.task_pool.items():
if self.task_states[task.task_id] == TaskState.INITIALIZED:
should_launch = True
for base_dep_name, realized_dep_list in task.realized_dependencies.items():
for _base_dep_name, realized_dep_list in task.realized_dependencies.items():
for dep in realized_dep_list:
if self.task_states[dep] != TaskState.SUCCESSFUL:
should_launch = False
Expand Down
Loading

0 comments on commit d011d38

Please sign in to comment.