Skip to content

Commit

Permalink
core: Add ruff rules PERF
Browse files Browse the repository at this point in the history
  • Loading branch information
cbornet committed Jan 23, 2025
1 parent f2ea62f commit f1ad740
Show file tree
Hide file tree
Showing 11 changed files with 93 additions and 110 deletions.
8 changes: 1 addition & 7 deletions libs/core/langchain_core/indexing/in_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,7 @@ def delete(self, ids: Optional[list[str]] = None, **kwargs: Any) -> DeleteRespon

def get(self, ids: Sequence[str], /, **kwargs: Any) -> list[Document]:
"""Get by ids."""
found_documents = []

for id_ in ids:
if id_ in self.store:
found_documents.append(self.store[id_])

return found_documents
return [self.store[id_] for id_ in ids if id_ in self.store]

def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
Expand Down
34 changes: 12 additions & 22 deletions libs/core/langchain_core/messages/ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,31 +222,21 @@ def _backwards_compat_tool_calls(cls, values: dict) -> Any:

# Ensure "type" is properly set on all tool call-like dicts.
if tool_calls := values.get("tool_calls"):
updated: list = []
for tc in tool_calls:
updated.append(
create_tool_call(**{k: v for k, v in tc.items() if k != "type"})
)
values["tool_calls"] = updated
values["tool_calls"] = [
create_tool_call(**{k: v for k, v in tc.items() if k != "type"})
for tc in tool_calls
]
if invalid_tool_calls := values.get("invalid_tool_calls"):
updated = []
for tc in invalid_tool_calls:
updated.append(
create_invalid_tool_call(
**{k: v for k, v in tc.items() if k != "type"}
)
)
values["invalid_tool_calls"] = updated
values["invalid_tool_calls"] = [
create_invalid_tool_call(**{k: v for k, v in tc.items() if k != "type"})
for tc in invalid_tool_calls
]

if tool_call_chunks := values.get("tool_call_chunks"):
updated = []
for tc in tool_call_chunks:
updated.append(
create_tool_call_chunk(
**{k: v for k, v in tc.items() if k != "type"}
)
)
values["tool_call_chunks"] = updated
values["tool_call_chunks"] = [
create_tool_call_chunk(**{k: v for k, v in tc.items() if k != "type"})
for tc in tool_call_chunks
]

return values

Expand Down
30 changes: 19 additions & 11 deletions libs/core/langchain_core/runnables/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,11 +555,11 @@ def get_prompts(
"""Return a list of prompts used by this Runnable."""
from langchain_core.prompts.base import BasePromptTemplate

prompts = []
for _, node in self.get_graph(config=config).nodes.items():
if isinstance(node.data, BasePromptTemplate):
prompts.append(node.data)
return prompts
return [
node.data
for node in self.get_graph(config=config).nodes.values()
if isinstance(node.data, BasePromptTemplate)
]

def __or__(
self,
Expand Down Expand Up @@ -3141,9 +3141,13 @@ def batch(
**(kwargs if stepidx == 0 else {}),
)
# If an input failed, add it to the map
for i, inp in zip(remaining_idxs, inputs):
if isinstance(inp, Exception):
failed_inputs_map[i] = inp
failed_inputs_map.update(
{
i: inp
for i, inp in zip(remaining_idxs, inputs)
if isinstance(inp, Exception)
}
)
inputs = [inp for inp in inputs if not isinstance(inp, Exception)]
# If all inputs have failed, stop processing
if len(failed_inputs_map) == len(configs):
Expand Down Expand Up @@ -3271,9 +3275,13 @@ async def abatch(
**(kwargs if stepidx == 0 else {}),
)
# If an input failed, add it to the map
for i, inp in zip(remaining_idxs, inputs):
if isinstance(inp, Exception):
failed_inputs_map[i] = inp
failed_inputs_map.update(
{
i: inp
for i, inp in zip(remaining_idxs, inputs)
if isinstance(inp, Exception)
}
)
inputs = [inp for inp in inputs if not isinstance(inp, Exception)]
# If all inputs have failed, stop processing
if len(failed_inputs_map) == len(configs):
Expand Down
18 changes: 10 additions & 8 deletions libs/core/langchain_core/runnables/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -643,10 +643,11 @@ def _first_node(graph: Graph, exclude: Sequence[str] = ()) -> Optional[Node]:
When drawing the graph, this node would be the origin.
"""
targets = {edge.target for edge in graph.edges if edge.source not in exclude}
found: list[Node] = []
for node in graph.nodes.values():
if node.id not in exclude and node.id not in targets:
found.append(node)
found: list[Node] = [
node
for node in graph.nodes.values()
if node.id not in exclude and node.id not in targets
]
return found[0] if len(found) == 1 else None


Expand All @@ -657,8 +658,9 @@ def _last_node(graph: Graph, exclude: Sequence[str] = ()) -> Optional[Node]:
When drawing the graph, this node would be the destination.
"""
sources = {edge.source for edge in graph.edges if edge.target not in exclude}
found: list[Node] = []
for node in graph.nodes.values():
if node.id not in exclude and node.id not in sources:
found.append(node)
found: list[Node] = [
node
for node in graph.nodes.values()
if node.id not in exclude and node.id not in sources
]
return found[0] if len(found) == 1 else None
10 changes: 6 additions & 4 deletions libs/core/langchain_core/tracers/log_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -649,10 +649,12 @@ async def consume_astream() -> None:
"value": copy.deepcopy(chunk),
}
)
for op in jsonpatch.JsonPatch.from_diff(
prev_final_output, final_output, dumps=dumps
):
patches.append({**op, "path": f"/final_output{op['path']}"})
patches.extend(
{**op, "path": f"/final_output{op['path']}"}
for op in jsonpatch.JsonPatch.from_diff(
prev_final_output, final_output, dumps=dumps
)
)
await stream.send_stream.send(RunLogPatch(*patches))
finally:
await stream.send_stream.aclose()
Expand Down
30 changes: 15 additions & 15 deletions libs/core/langchain_core/utils/function_calling.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,21 +576,21 @@ class Person(BaseModel):
)
"""
messages: list[BaseMessage] = [HumanMessage(content=input)]
openai_tool_calls = []
for tool_call in tool_calls:
openai_tool_calls.append(
{
"id": str(uuid.uuid4()),
"type": "function",
"function": {
# The name of the function right now corresponds to the name
# of the pydantic model. This is implicit in the API right now,
# and will be improved over time.
"name": tool_call.__class__.__name__,
"arguments": tool_call.model_dump_json(),
},
}
)
openai_tool_calls = [
{
"id": str(uuid.uuid4()),
"type": "function",
"function": {
# The name of the function right now corresponds to the name
# of the pydantic model. This is implicit in the API right now,
# and will be improved over time.
"name": tool_call.__class__.__name__,
"arguments": tool_call.model_dump_json(),
},
}
for tool_call in tool_calls
]

messages.append(
AIMessage(content="", additional_kwargs={"tool_calls": openai_tool_calls})
)
Expand Down
4 changes: 2 additions & 2 deletions libs/core/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ python = ">=3.12.4"
[tool.poetry.extras]

[tool.ruff.lint]
select = [ "ASYNC", "B", "C4", "COM", "DJ", "E", "EM", "EXE", "F", "FLY", "FURB", "I", "ICN", "INT", "LOG", "N", "NPY", "PD", "PIE", "Q", "RSE", "S", "SIM", "SLOT", "T10", "T201", "TID", "UP", "W", "YTT",]
ignore = [ "COM812", "UP007", "S110", "S112",]
select = [ "ASYNC", "B", "C4", "COM", "DJ", "E", "EM", "EXE", "F", "FLY", "FURB", "I", "ICN", "INT", "LOG", "N", "NPY", "PD", "PERF", "PIE", "Q", "RSE", "S", "SIM", "SLOT", "T10", "T201", "TID", "UP", "W", "YTT",]
ignore = [ "COM812", "PERF203", "UP007", "S110", "S112",]

[tool.coverage.run]
omit = [ "tests/*",]
Expand Down
13 changes: 6 additions & 7 deletions libs/core/tests/unit_tests/prompts/test_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -955,13 +955,12 @@ def test_chat_prompt_template_variable_names() -> None:
prompt.get_input_schema()

if record:
error_msg = []
for warning in record:
error_msg.append(
f"Warning type: {warning.category.__name__}, "
f"Warning message: {warning.message}, "
f"Warning location: {warning.filename}:{warning.lineno}"
)
error_msg = [
f"Warning type: {warning.category.__name__}, "
f"Warning message: {warning.message}, "
f"Warning location: {warning.filename}:{warning.lineno}"
for warning in record
]
msg = "\n".join(error_msg)
else:
msg = ""
Expand Down
24 changes: 6 additions & 18 deletions libs/core/tests/unit_tests/runnables/test_runnable.py
Original file line number Diff line number Diff line change
Expand Up @@ -3409,9 +3409,7 @@ def test_deep_stream() -> None:

stream = chain.stream({"question": "What up"})

chunks = []
for chunk in stream:
chunks.append(chunk)
chunks = list(stream)

assert len(chunks) == len("foo-lish")
assert "".join(chunks) == "foo-lish"
Expand All @@ -3435,9 +3433,7 @@ def test_deep_stream_assign() -> None:

stream = chain.stream({"question": "What up"})

chunks = []
for chunk in stream:
chunks.append(chunk)
chunks = list(stream)

assert len(chunks) == len("foo-lish")
assert add(chunks) == {"str": "foo-lish"}
Expand Down Expand Up @@ -3535,9 +3531,7 @@ async def test_deep_astream() -> None:

stream = chain.astream({"question": "What up"})

chunks = []
async for chunk in stream:
chunks.append(chunk)
chunks = [chunk async for chunk in stream]

assert len(chunks) == len("foo-lish")
assert "".join(chunks) == "foo-lish"
Expand All @@ -3561,9 +3555,7 @@ async def test_deep_astream_assign() -> None:

stream = chain.astream({"question": "What up"})

chunks = []
async for chunk in stream:
chunks.append(chunk)
chunks = [chunk async for chunk in stream]

assert len(chunks) == len("foo-lish")
assert add(chunks) == {"str": "foo-lish"}
Expand Down Expand Up @@ -3659,9 +3651,7 @@ def test_runnable_sequence_transform() -> None:

stream = chain.transform(llm.stream("Hi there!"))

chunks = []
for chunk in stream:
chunks.append(chunk)
chunks = list(stream)

assert len(chunks) == len("foo-lish")
assert "".join(chunks) == "foo-lish"
Expand All @@ -3674,9 +3664,7 @@ async def test_runnable_sequence_atransform() -> None:

stream = chain.atransform(llm.astream("Hi there!"))

chunks = []
async for chunk in stream:
chunks.append(chunk)
chunks = [chunk async for chunk in stream]

assert len(chunks) == len("foo-lish")
assert "".join(chunks) == "foo-lish"
Expand Down
16 changes: 8 additions & 8 deletions libs/core/tests/unit_tests/runnables/test_runnable_events_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -1492,12 +1492,12 @@ def bar(a: str) -> str:

events = []

for _ in range(10):
try:
try:
for _ in range(10):
next_chunk = await iterable.__anext__()
events.append(next_chunk)
except Exception:
break
except Exception:
pass

events = _with_nulled_run_id(events)
for event in events:
Expand Down Expand Up @@ -1609,12 +1609,12 @@ def fail(inputs: str) -> None:

events = []

for _ in range(10):
try:
try:
for _ in range(10):
next_chunk = await iterable.__anext__()
events.append(next_chunk)
except Exception:
break
except Exception:
pass

events = _with_nulled_run_id(events)
for event in events:
Expand Down
16 changes: 8 additions & 8 deletions libs/core/tests/unit_tests/runnables/test_runnable_events_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1448,12 +1448,12 @@ def bar(a: str) -> str:

events = []

for _ in range(10):
try:
try:
for _ in range(10):
next_chunk = await iterable.__anext__()
events.append(next_chunk)
except Exception:
break
except Exception:
pass

events = _with_nulled_run_id(events)
for event in events:
Expand Down Expand Up @@ -1565,12 +1565,12 @@ def fail(inputs: str) -> None:

events = []

for _ in range(10):
try:
try:
for _ in range(10):
next_chunk = await iterable.__anext__()
events.append(next_chunk)
except Exception:
break
except Exception:
pass

events = _with_nulled_run_id(events)
for event in events:
Expand Down

0 comments on commit f1ad740

Please sign in to comment.