Skip to content

Commit

Permalink
core: Add ruff rules TRY (tryceratops) (#29388)
Browse files Browse the repository at this point in the history
TRY004 ("use TypeError rather than ValueError") existing errors are
marked as ignore to preserve backward compatibility.
LMK if you prefer to fix some of them.

Co-authored-by: Erick Friis <[email protected]>
  • Loading branch information
cbornet and efriis authored Jan 24, 2025
1 parent 723b603 commit dbb6b7b
Show file tree
Hide file tree
Showing 26 changed files with 136 additions and 124 deletions.
8 changes: 4 additions & 4 deletions libs/core/langchain_core/callbacks/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def trace_as_chain_group(
except Exception as e:
if not group_cm.ended:
run_manager.on_chain_error(e)
raise e
raise
else:
if not group_cm.ended:
run_manager.on_chain_end({})
Expand Down Expand Up @@ -207,7 +207,7 @@ async def atrace_as_chain_group(
except Exception as e:
if not group_cm.ended:
await run_manager.on_chain_error(e)
raise e
raise
else:
if not group_cm.ended:
await run_manager.on_chain_end({})
Expand Down Expand Up @@ -289,7 +289,7 @@ def handle_event(
f" {repr(e)}"
)
if handler.raise_error:
raise e
raise
finally:
if coros:
try:
Expand Down Expand Up @@ -388,7 +388,7 @@ async def _ahandle_event_for_handler(
f"Error in {handler.__class__.__name__}.{event_name} callback: {repr(e)}"
)
if handler.raise_error:
raise e
raise


async def ahandle_event(
Expand Down
44 changes: 24 additions & 20 deletions libs/core/langchain_core/language_models/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ def _convert_input(self, input: LanguageModelInput) -> PromptValue:
f"Invalid input type {type(input)}. "
"Must be a PromptValue, str, or list of BaseMessages."
)
raise ValueError(msg)
raise ValueError(msg) # noqa: TRY004

def invoke(
self,
Expand Down Expand Up @@ -407,19 +407,21 @@ def stream(
generation = chunk
else:
generation += chunk
if generation is None:
msg = "No generation chunks were returned"
raise ValueError(msg)
except BaseException as e:
run_manager.on_llm_error(
e,
response=LLMResult(
generations=[[generation]] if generation else []
),
)
raise e
else:
run_manager.on_llm_end(LLMResult(generations=[[generation]]))
raise

if generation is None:
err = ValueError("No generation chunks were returned")
run_manager.on_llm_error(err, response=LLMResult(generations=[]))
raise err

run_manager.on_llm_end(LLMResult(generations=[[generation]]))

async def astream(
self,
Expand Down Expand Up @@ -485,19 +487,21 @@ async def astream(
generation = chunk
else:
generation += chunk
if generation is None:
msg = "No generation chunks were returned"
raise ValueError(msg)
except BaseException as e:
await run_manager.on_llm_error(
e,
response=LLMResult(generations=[[generation]] if generation else []),
)
raise e
else:
await run_manager.on_llm_end(
LLMResult(generations=[[generation]]),
)
raise

if generation is None:
err = ValueError("No generation chunks were returned")
await run_manager.on_llm_error(err, response=LLMResult(generations=[]))
raise err

await run_manager.on_llm_end(
LLMResult(generations=[[generation]]),
)

# --- Custom methods ---

Expand Down Expand Up @@ -641,7 +645,7 @@ def generate(
except BaseException as e:
if run_managers:
run_managers[i].on_llm_error(e, response=LLMResult(generations=[]))
raise e
raise
flattened_outputs = [
LLMResult(generations=[res.generations], llm_output=res.llm_output) # type: ignore[list-item]
for res in results
Expand Down Expand Up @@ -1022,7 +1026,7 @@ def __call__(
return generation.message
else:
msg = "Unexpected generation type"
raise ValueError(msg)
raise ValueError(msg) # noqa: TRY004

async def _call_async(
self,
Expand All @@ -1039,7 +1043,7 @@ async def _call_async(
return generation.message
else:
msg = "Unexpected generation type"
raise ValueError(msg)
raise ValueError(msg) # noqa: TRY004

@deprecated("0.1.7", alternative="invoke", removal="1.0")
def call_as_llm(
Expand All @@ -1057,7 +1061,7 @@ def predict(
return result.content
else:
msg = "Cannot use predict when output is not a string."
raise ValueError(msg)
raise ValueError(msg) # noqa: TRY004

@deprecated("0.1.7", alternative="invoke", removal="1.0")
def predict_messages(
Expand All @@ -1082,7 +1086,7 @@ async def apredict(
return result.content
else:
msg = "Cannot use predict when output is not a string."
raise ValueError(msg)
raise ValueError(msg) # noqa: TRY004

@deprecated("0.1.7", alternative="ainvoke", removal="1.0")
async def apredict_messages(
Expand Down
4 changes: 2 additions & 2 deletions libs/core/langchain_core/language_models/fake_chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ def _stream(
f"Expected generate to return a ChatResult, "
f"but got {type(chat_result)} instead."
)
raise ValueError(msg)
raise ValueError(msg) # noqa: TRY004

message = chat_result.generations[0].message

Expand All @@ -251,7 +251,7 @@ def _stream(
f"Expected invoke to return an AIMessage, "
f"but got {type(message)} instead."
)
raise ValueError(msg)
raise ValueError(msg) # noqa: TRY004

content = message.content

Expand Down
42 changes: 23 additions & 19 deletions libs/core/langchain_core/language_models/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ def _convert_input(self, input: LanguageModelInput) -> PromptValue:
f"Invalid input type {type(input)}. "
"Must be a PromptValue, str, or list of BaseMessages."
)
raise ValueError(msg)
raise ValueError(msg) # noqa: TRY004

def _get_ls_params(
self,
Expand Down Expand Up @@ -448,7 +448,7 @@ def batch(
if return_exceptions:
return cast(list[str], [e for _ in inputs])
else:
raise e
raise
else:
batches = [
inputs[i : i + max_concurrency]
Expand Down Expand Up @@ -494,7 +494,7 @@ async def abatch(
if return_exceptions:
return cast(list[str], [e for _ in inputs])
else:
raise e
raise
else:
batches = [
inputs[i : i + max_concurrency]
Expand Down Expand Up @@ -562,19 +562,21 @@ def stream(
generation = chunk
else:
generation += chunk
if generation is None:
msg = "No generation chunks were returned"
raise ValueError(msg)
except BaseException as e:
run_manager.on_llm_error(
e,
response=LLMResult(
generations=[[generation]] if generation else []
),
)
raise e
else:
run_manager.on_llm_end(LLMResult(generations=[[generation]]))
raise

if generation is None:
err = ValueError("No generation chunks were returned")
run_manager.on_llm_error(err, response=LLMResult(generations=[]))
raise err

run_manager.on_llm_end(LLMResult(generations=[[generation]]))

async def astream(
self,
Expand Down Expand Up @@ -632,17 +634,19 @@ async def astream(
generation = chunk
else:
generation += chunk
if generation is None:
msg = "No generation chunks were returned"
raise ValueError(msg)
except BaseException as e:
await run_manager.on_llm_error(
e,
response=LLMResult(generations=[[generation]] if generation else []),
)
raise e
else:
await run_manager.on_llm_end(LLMResult(generations=[[generation]]))
raise

if generation is None:
err = ValueError("No generation chunks were returned")
await run_manager.on_llm_error(err, response=LLMResult(generations=[]))
raise err

await run_manager.on_llm_end(LLMResult(generations=[[generation]]))

# --- Custom methods ---

Expand Down Expand Up @@ -790,7 +794,7 @@ def _generate_helper(
except BaseException as e:
for run_manager in run_managers:
run_manager.on_llm_error(e, response=LLMResult(generations=[]))
raise e
raise
flattened_outputs = output.flatten()
for manager, flattened_output in zip(run_managers, flattened_outputs):
manager.on_llm_end(flattened_output)
Expand Down Expand Up @@ -850,7 +854,7 @@ def generate(
"Argument 'prompts' is expected to be of type List[str], received"
f" argument of type {type(prompts)}."
)
raise ValueError(msg)
raise ValueError(msg) # noqa: TRY004
# Create callback managers
if isinstance(metadata, list):
metadata = [
Expand Down Expand Up @@ -1036,7 +1040,7 @@ async def _agenerate_helper(
for run_manager in run_managers
]
)
raise e
raise
flattened_outputs = output.flatten()
await asyncio.gather(
*[
Expand Down Expand Up @@ -1289,7 +1293,7 @@ def __call__(
f"{type(prompt)}. If you want to run the LLM on multiple prompts, use "
"`generate` instead."
)
raise ValueError(msg)
raise ValueError(msg) # noqa: TRY004
return (
self.generate(
[prompt],
Expand Down
23 changes: 13 additions & 10 deletions libs/core/langchain_core/messages/ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,17 @@ def init_tool_calls(self) -> Self:
return self
tool_calls = []
invalid_tool_calls = []

def add_chunk_to_invalid_tool_calls(chunk: ToolCallChunk) -> None:
invalid_tool_calls.append(
create_invalid_tool_call(
name=chunk["name"],
args=chunk["args"],
id=chunk["id"],
error=None,
)
)

for chunk in self.tool_call_chunks:
try:
args_ = parse_partial_json(chunk["args"]) if chunk["args"] != "" else {} # type: ignore[arg-type]
Expand All @@ -375,17 +386,9 @@ def init_tool_calls(self) -> Self:
)
)
else:
msg = "Malformed args."
raise ValueError(msg)
add_chunk_to_invalid_tool_calls(chunk)
except Exception:
invalid_tool_calls.append(
create_invalid_tool_call(
name=chunk["name"],
args=chunk["args"],
id=chunk["id"],
error=None,
)
)
add_chunk_to_invalid_tool_calls(chunk)
self.tool_calls = tool_calls
self.invalid_tool_calls = invalid_tool_calls
return self
Expand Down
4 changes: 2 additions & 2 deletions libs/core/langchain_core/messages/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def get_buffer_string(
role = m.role
else:
msg = f"Got unsupported message type: {m}"
raise ValueError(msg)
raise ValueError(msg) # noqa: TRY004
message = f"{role}: {m.content}"
if isinstance(m, AIMessage) and "function_call" in m.additional_kwargs:
message += f"{m.additional_kwargs['function_call']}"
Expand Down Expand Up @@ -1400,7 +1400,7 @@ def _get_message_openai_role(message: BaseMessage) -> str:
return message.role
else:
msg = f"Unknown BaseMessage type {message.__class__}."
raise ValueError(msg)
raise ValueError(msg) # noqa: TRY004


def _convert_to_openai_tool_calls(tool_calls: list[ToolCall]) -> list[dict]:
Expand Down
19 changes: 10 additions & 9 deletions libs/core/langchain_core/output_parsers/openai_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,19 +282,20 @@ def parse_result(self, result: list[Generation], *, partial: bool = False) -> An
name_dict = {tool.__name__: tool for tool in self.tools}
pydantic_objects = []
for res in json_results:
if not isinstance(res["args"], dict):
if partial:
continue
msg = (
f"Tool arguments must be specified as a dict, received: "
f"{res['args']}"
)
raise ValueError(msg)
try:
if not isinstance(res["args"], dict):
msg = (
f"Tool arguments must be specified as a dict, received: "
f"{res['args']}"
)
raise ValueError(msg)
pydantic_objects.append(name_dict[res["type"]](**res["args"]))
except (ValidationError, ValueError) as e:
except (ValidationError, ValueError):
if partial:
continue
else:
raise e
raise
if self.first_tool_only:
return pydantic_objects[0] if pydantic_objects else None
else:
Expand Down
4 changes: 2 additions & 2 deletions libs/core/langchain_core/output_parsers/pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,10 @@ def parse_result(
try:
json_object = super().parse_result(result)
return self._parse_obj(json_object)
except OutputParserException as e:
except OutputParserException:
if partial:
return None
raise e
raise

def parse(self, text: str) -> TBaseModel:
"""Parse the output of an LLM call to a pydantic object.
Expand Down
Loading

0 comments on commit dbb6b7b

Please sign in to comment.