diff --git a/CITATION.md b/CITATION.md index 75839ef8..4eff5440 100644 --- a/CITATION.md +++ b/CITATION.md @@ -14,7 +14,7 @@ If you use NVIDIA Ingest in a publication, please use citations in the following ## Sample Citations: -Using [RAPIDS](rapids.ai) citations for reference. +Using [RAPIDS](https://rapids.ai/) citations for reference. ### Bringing UMAP Closer to the Speed of Light
with GPU Acceleration ```tex diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index f7d8cc62..22a47ded 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -11,10 +11,10 @@ External contributions will be welcome soon, and they are greatly appreciated! E - [Seasoned Developers](#seasoned-developers) - [Workflow](#workflow) - [Common Processing Patterns](#common-processing-patterns) - - [traceable](#traceable) - - [nv_ingest_node_failure_context_manager](#nv_ingest_node_failure_context_manager) - - [filter_by_task](#filter_by_task) - - [cm_skip_processing_if_failed](#cm_skip_processing_if_failed) + - [traceable](#traceable---srcnv_ingestutiltracingtaggingpy) + - [nv_ingest_node_failure_context_manager](#nv_ingest_node_failure_context_manager---srcnv_ingestutilexception_handlersdecoratorspy) + - [filter_by_task](#filter_by_task---srcnv_ingestutilflow_controlfilter_by_taskpy) + - [cm_skip_processing_if_failed](#cm_skip_processing_if_failed---morpheusutilscontrol_message_utilspy) - [Adding a New Stage or Module](#adding-a-new-stage-or-module) - [Common Practices for Writing Unit Tests](#common-practices-for-writing-unit-tests) - [General Guidelines](#general-guidelines) diff --git a/Dockerfile b/Dockerfile index d4719933..4b54d27a 100644 --- a/Dockerfile +++ b/Dockerfile @@ -3,8 +3,8 @@ # SPDX-License-Identifier: Apache-2.0 # syntax=docker/dockerfile:1.3 -ARG BASE_IMG=nvcr.io/nvidia/morpheus/morpheus -ARG BASE_IMG_TAG=v24.06.01-runtime +ARG BASE_IMG=nvcr.io/nvidia/cuda +ARG BASE_IMG_TAG=12.2.2-base-ubuntu22.04 # Use NVIDIA Morpheus as the base image FROM $BASE_IMG:$BASE_IMG_TAG AS base @@ -13,27 +13,63 @@ ARG RELEASE_TYPE="dev" ARG VERSION="" ARG VERSION_REV="0" -# We require Python 3.10.15 but base image currently comes with 3.10.14, update here. -RUN source activate morpheus \ - && conda install python=3.10.15 +# Install necessary dependencies using apt-get +RUN apt-get update && apt-get install -y \ + wget \ + bzip2 \ + ca-certificates \ + curl \ + && apt-get clean + +# Install miniconda +RUN wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O /tmp/miniconda.sh \ + && bash /tmp/miniconda.sh -b -p /opt/conda \ + && rm /tmp/miniconda.sh + +# Add conda to the PATH +ENV PATH=/opt/conda/bin:$PATH + +# Install Mamba, a faster alternative to conda, within the base environment +RUN conda install -y mamba -n base -c conda-forge + +# Create nv_ingest base environment +RUN conda create -y --name nv_ingest python=3.10.15 + +# Activate the environment (make it default for subsequent commands) +RUN echo "source activate nv_ingest" >> ~/.bashrc + +# Set default shell to bash +SHELL ["/bin/bash", "-c"] + +# Install Tini via conda from the conda-forge channel +RUN source activate nv_ingest \ + && mamba install -y -c conda-forge tini + +# Install Morpheus dependencies +RUN source activate nv_ingest \ + && mamba install -y \ + nvidia/label/dev::morpheus-core \ + nvidia/label/dev::morpheus-llm \ + -c rapidsai -c pytorch -c nvidia -c conda-forge + +# Install additional dependencies using apt-get +RUN apt-get update && apt-get install -y \ + libgl1-mesa-glx \ + && apt-get clean # Set the working directory in the container WORKDIR /workspace -RUN apt-get update \ - && apt-get install --yes \ - libgl1-mesa-glx +# Copy custom entrypoint script +COPY ./docker/scripts/entrypoint.sh /workspace/docker/entrypoint.sh +FROM base AS nv_ingest_install # Copy the module code COPY setup.py setup.py -# Don't copy full source here, pipelines won't be installed via setup anyway, and this allows us to rebuild more quickly if we're just changing the pipeline - COPY ci ci COPY requirements.txt extra-requirements.txt test-requirements.txt util-requirements.txt ./ -SHELL ["/bin/bash", "-c"] - -# Prevent haystack from ending telemetry data +# Prevent haystack from sending telemetry data ENV HAYSTACK_TELEMETRY_ENABLED=False # Ensure the NV_INGEST_VERSION is PEP 440 compatible @@ -53,8 +89,10 @@ ENV NV_INGEST_RELEASE_TYPE=${RELEASE_TYPE} ENV NV_INGEST_VERSION_OVERRIDE=${NV_INGEST_VERSION_OVERRIDE} ENV NV_INGEST_CLIENT_VERSION_OVERRIDE=${NV_INGEST_VERSION_OVERRIDE} +SHELL ["/bin/bash", "-c"] + # Cache the requirements and install them before uploading source code changes -RUN source activate morpheus \ +RUN source activate nv_ingest \ && pip install -r requirements.txt COPY tests tests @@ -63,8 +101,8 @@ COPY client client COPY src/nv_ingest src/nv_ingest RUN rm -rf ./src/nv_ingest/dist ./client/dist -# Build the client and install it in the conda cache so that the later nv-ingest build can locate it -RUN source activate morpheus \ +# Build the client and install it in the conda cache +RUN source activate nv_ingest \ && pip install -e client \ && pip install -r extra-requirements.txt @@ -73,36 +111,48 @@ RUN chmod +x ./ci/scripts/build_pip_packages.sh \ && ./ci/scripts/build_pip_packages.sh --type ${RELEASE_TYPE} --lib client \ && ./ci/scripts/build_pip_packages.sh --type ${RELEASE_TYPE} --lib service -RUN source activate morpheus \ +RUN source activate nv_ingest \ && pip install ./dist/*.whl -RUN source activate morpheus \ +RUN source activate nv_ingest \ && rm -rf src requirements.txt test-requirements.txt util-requirements.txt -# Interim pyarrow backport until folded into upstream dependency tree -RUN source activate morpheus \ - && conda install https://anaconda.org/conda-forge/pyarrow/14.0.2/download/linux-64/pyarrow-14.0.2-py310h188ebfb_19_cuda.conda - # Upgrade setuptools to mitigate https://github.com/advisories/GHSA-cx63-2mw6-8hw5 RUN source activate base \ && conda install setuptools==70.0.0 -FROM base AS runtime - -RUN source activate morpheus \ +RUN source activate nv_ingest \ && pip install ./client/dist/*.whl \ + ## Installations below can be removed after the next Morpheus release + && pip install --no-input milvus==2.3.5 \ + && pip install --no-input pymilvus==2.3.6 \ + && pip install --no-input langchain==0.1.16 \ + && pip install --no-input langchain-nvidia-ai-endpoints==0.0.11 \ + && pip install --no-input faiss-gpu==1.7.* \ + && pip install --no-input google-search-results==2.4 \ + && pip install --no-input nemollm==0.3.5 \ && rm -rf client/dist +# Install patched MRC version to circumvent NUMA node issue -- remove after Morpheus 10.24 release +RUN source activate nv_ingest \ + && conda install -y -c nvidia/label/dev mrc=24.10.00a=cuda_12.5_py310_h5ae46af_10 + +FROM nv_ingest_install AS runtime + COPY src/pipeline.py ./ COPY pyproject.toml ./ -COPY ./docker/scripts/entrypoint_source_ext.sh /opt/docker/bin/entrypoint_source + +RUN chmod +x /workspace/docker/entrypoint.sh + +# Set entrypoint to tini with a custom entrypoint script +ENTRYPOINT ["/opt/conda/envs/nv_ingest/bin/tini", "--", "/workspace/docker/entrypoint.sh"] # Start both the core nv-ingest pipeline service and the FastAPI microservice in parallel CMD ["sh", "-c", "python /workspace/pipeline.py & uvicorn nv_ingest.main:app --workers 32 --host 0.0.0.0 --port 7670 & wait"] -FROM base AS development +FROM nv_ingest_install AS development -RUN source activate morpheus && \ +RUN source activate nv_ingest && \ pip install -e ./client CMD ["/bin/bash"] diff --git a/README.md b/README.md index b6b69936..35c587ff 100644 --- a/README.md +++ b/README.md @@ -86,6 +86,9 @@ Username: $oauthtoken Password: ``` +> [!NOTE] +> during the early access (EA) phase, your API key must be created as a member of `nemo-microservice / ea-participants` which you may join by applying for early access here: https://developer.nvidia.com/nemo-microservices-early-access/join. When approved, switch your profile to this org / team, then the key you generate will have access to the resources outlined below. + 4. Create a .env file containing your NGC API key, and the following paths: ``` # Container images must access resources from NGC. diff --git a/client/src/nv_ingest_client/cli/util/click.py b/client/src/nv_ingest_client/cli/util/click.py index 6412e851..f7fd9c72 100644 --- a/client/src/nv_ingest_client/cli/util/click.py +++ b/client/src/nv_ingest_client/cli/util/click.py @@ -3,7 +3,6 @@ # SPDX-License-Identifier: Apache-2.0 -import glob import json import logging import os @@ -13,6 +12,13 @@ import click from nv_ingest_client.cli.util.processing import check_schema +from nv_ingest_client.primitives.tasks.caption import CaptionTaskSchema +from nv_ingest_client.primitives.tasks.chart_extraction import ChartExtractionSchema +from nv_ingest_client.primitives.tasks.chart_extraction import ChartExtractionTask +from nv_ingest_client.primitives.tasks.dedup import DedupTaskSchema +from nv_ingest_client.primitives.tasks.embed import EmbedTaskSchema +from nv_ingest_client.primitives.tasks.extract import ExtractTaskSchema +from nv_ingest_client.primitives.tasks.filter import FilterTaskSchema from nv_ingest_client.primitives.tasks import CaptionTask from nv_ingest_client.primitives.tasks import DedupTask from nv_ingest_client.primitives.tasks import EmbedTask @@ -21,14 +27,12 @@ from nv_ingest_client.primitives.tasks import SplitTask from nv_ingest_client.primitives.tasks import StoreTask from nv_ingest_client.primitives.tasks import VdbUploadTask -from nv_ingest_client.primitives.tasks.caption import CaptionTaskSchema -from nv_ingest_client.primitives.tasks.dedup import DedupTaskSchema -from nv_ingest_client.primitives.tasks.embed import EmbedTaskSchema -from nv_ingest_client.primitives.tasks.extract import ExtractTaskSchema -from nv_ingest_client.primitives.tasks.filter import FilterTaskSchema from nv_ingest_client.primitives.tasks.split import SplitTaskSchema from nv_ingest_client.primitives.tasks.store import StoreTaskSchema +from nv_ingest_client.primitives.tasks.table_extraction import TableExtractionSchema +from nv_ingest_client.primitives.tasks.table_extraction import TableExtractionTask from nv_ingest_client.primitives.tasks.vdb_upload import VdbUploadTaskSchema +from nv_ingest_client.util.util import generate_matching_files logger = logging.getLogger(__name__) @@ -104,48 +108,59 @@ def click_validate_task(ctx, param, value): if task_id == "split": task_options = check_schema(SplitTaskSchema, options, task_id, json_options) new_task_id = f"{task_id}" - new_task = SplitTask(**task_options.dict()) + new_task = [(new_task_id, SplitTask(**task_options.dict()))] elif task_id == "extract": task_options = check_schema(ExtractTaskSchema, options, task_id, json_options) new_task_id = f"{task_id}_{task_options.document_type}" - new_task = ExtractTask(**task_options.dict()) + new_task = [(new_task_id, ExtractTask(**task_options.dict()))] + + if (task_options.extract_tables == True): + subtask_options = check_schema(TableExtractionSchema, {}, "table_data_extract", "{}") + new_task.append(("table_data_extract", TableExtractionTask(**subtask_options.dict()))) + + if (task_options.extract_charts == True): + subtask_options = check_schema(ChartExtractionSchema, {}, "chart_data_extract", "{}") + new_task.append(("chart_data_extract", ChartExtractionTask(**subtask_options.dict()))) + elif task_id == "store": task_options = check_schema(StoreTaskSchema, options, task_id, json_options) new_task_id = f"{task_id}" - new_task = StoreTask(**task_options.dict()) + new_task = [(new_task_id, StoreTask(**task_options.dict()))] elif task_id == "caption": task_options = check_schema(CaptionTaskSchema, options, task_id, json_options) new_task_id = f"{task_id}" - new_task = CaptionTask(**task_options.dict()) + new_task = [(new_task_id, CaptionTask(**task_options.dict()))] elif task_id == "dedup": task_options = check_schema(DedupTaskSchema, options, task_id, json_options) new_task_id = f"{task_id}" - new_task = DedupTask(**task_options.dict()) + new_task = [(new_task_id, DedupTask(**task_options.dict()))] elif task_id == "filter": task_options = check_schema(FilterTaskSchema, options, task_id, json_options) new_task_id = f"{task_id}" - new_task = FilterTask(**task_options.dict()) + new_task = [(new_task_id, FilterTask(**task_options.dict()))] elif task_id == "embed": task_options = check_schema(EmbedTaskSchema, options, task_id, json_options) new_task_id = f"{task_id}" - new_task = EmbedTask(**task_options.dict()) + new_task = [(new_task_id, EmbedTask(**task_options.dict()))] elif task_id == "vdb_upload": task_options = check_schema(VdbUploadTaskSchema, options, task_id, json_options) new_task_id = f"{task_id}" - new_task = VdbUploadTask(**task_options.dict()) - + new_task = [(new_task_id, VdbUploadTask(**task_options.dict()))] else: raise ValueError(f"Unsupported task type: {task_id}") + if new_task_id in validated_tasks: + raise ValueError(f"Duplicate task detected: {new_task_id}") + logger.debug("Adding task: %s", new_task_id) - validated_tasks[new_task_id] = new_task + for task_tuple in new_task: + validated_tasks[task_tuple[0]] = task_tuple[1] except ValueError as e: validation_errors.append(str(e)) if validation_errors: # Aggregate error messages with original values highlighted error_message = "\n".join(validation_errors) - # logger.error(error_message) raise click.BadParameter(error_message) return validated_tasks @@ -190,37 +205,6 @@ def pre_process_dataset(dataset_json: str, shuffle_dataset: bool): return file_source -def _generate_matching_files(file_sources): - """ - Generates a list of file paths that match the given patterns specified in file_sources. - - Parameters - ---------- - file_sources : list of str - A list containing the file source patterns to match against. - - Returns - ------- - generator - A generator yielding paths to files that match the specified patterns. - - Notes - ----- - This function utilizes glob pattern matching to find files that match the specified patterns. - It yields each matching file path, allowing for efficient processing of potentially large - sets of files. - """ - - files = [ - file_path - for pattern in file_sources - for file_path in glob.glob(pattern, recursive=True) - if os.path.isfile(file_path) - ] - for file_path in files: - yield file_path - - def click_match_and_validate_files(ctx, param, value): """ Matches and validates files based on the provided file source patterns. @@ -239,7 +223,7 @@ def click_match_and_validate_files(ctx, param, value): if not value: return [] - matching_files = list(_generate_matching_files(value)) + matching_files = list(generate_matching_files(value)) if not matching_files: logger.warning("No files found matching the specified patterns.") return [] diff --git a/client/src/nv_ingest_client/cli/util/processing.py b/client/src/nv_ingest_client/cli/util/processing.py index b58f50c7..f5f40992 100644 --- a/client/src/nv_ingest_client/cli/util/processing.py +++ b/client/src/nv_ingest_client/cli/util/processing.py @@ -1,7 +1,7 @@ # SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. # All rights reserved. # SPDX-License-Identifier: Apache-2.0 -import concurrent + import json import logging import os @@ -12,7 +12,8 @@ from concurrent.futures import as_completed from statistics import mean from statistics import median -from typing import Dict, Any +from typing import Any +from typing import Dict from typing import List from typing import Tuple from typing import Type @@ -23,9 +24,7 @@ from tqdm import tqdm from nv_ingest_client.client import NvIngestClient -from nv_ingest_client.primitives import JobSpec -from nv_ingest_client.util.file_processing.extract import extract_file_content -from nv_ingest_client.util.util import check_ingest_result +from nv_ingest_client.util.processing import handle_future_result from nv_ingest_client.util.util import estimate_page_count logger = logging.getLogger(__name__) @@ -131,7 +130,7 @@ def check_schema(schema: Type[BaseModel], options: dict, task_id: str, original_ def report_stage_statistics( - stage_elapsed_times: defaultdict(list), total_trace_elapsed: float, abs_elapsed: float + stage_elapsed_times: defaultdict, total_trace_elapsed: float, abs_elapsed: float ) -> None: """ Reports the statistics for each processing stage, including average, median, total time spent, @@ -206,10 +205,10 @@ def report_overall_speed(total_pages_processed: int, start_time_ns: int, total_f def report_statistics( - start_time_ns: int, - stage_elapsed_times: defaultdict, - total_pages_processed: int, - total_files: int, + start_time_ns: int, + stage_elapsed_times: defaultdict, + total_pages_processed: int, + total_files: int, ) -> None: """ Aggregates and reports statistics for the entire processing session. @@ -400,11 +399,13 @@ def save_response_data(response, output_directory): """ if ("data" not in response) or (not response["data"]): + logger.debug("Data is not in the response or response.data is empty") return response_data = response["data"] if not isinstance(response_data, list) or len(response_data) == 0: + logger.debug("Response data is not a list or the list is empty.") return doc_meta_base = response_data[0]["metadata"] @@ -423,136 +424,20 @@ def save_response_data(response, output_directory): f.write(json.dumps(documents, indent=2)) -def create_job_specs_for_batch(files_batch: List[str], tasks: Dict[str, Any], client: NvIngestClient) -> List[str]: - """ - Create and submit job specifications (JobSpecs) for a batch of files, returning the job IDs. - - This function takes a batch of files, processes each file to extract its content and type, - creates a job specification (JobSpec) for each file, and adds tasks from the provided task - list. It then submits the jobs to the client and collects their job IDs. - - Parameters - ---------- - files_batch : List[str] - A list of file paths to be processed. Each file is assumed to be in a format compatible - with the `extract_file_content` function, which extracts the file's content and type. - - tasks : Dict[str, Any] - A dictionary of tasks to be added to each job. The keys represent task names, and the - values represent task specifications or configurations. Standard tasks include "split", - "extract", "store", "caption", "dedup", "filter", "embed", and "vdb_upload". - - client : NvIngestClient - An instance of NvIngestClient, which handles the job submission. The `add_job` method of - the client is used to submit each job specification. - - Returns - ------- - Tuple[List[JobSpec], List[str]] - A Tuple containing the list of JobSpecs and list of job IDs corresponding to the submitted jobs. - Each job ID is returned by the client's `add_job` method. - - Raises - ------ - ValueError - If there is an error extracting the file content or type from any of the files, a - ValueError will be logged, and the corresponding file will be skipped. - - Notes - ----- - - The function assumes that a utility function `extract_file_content` is defined elsewhere, - which extracts the content and type from the provided file paths. - - For each file, a `JobSpec` is created with relevant metadata, including document type and - file content. Various tasks are conditionally added based on the provided `tasks` dictionary. - - The job specification includes tracing options with a timestamp (in nanoseconds) for - diagnostic purposes. - - Examples - -------- - Suppose you have a batch of files and tasks to process: - - >>> files_batch = ["file1.txt", "file2.pdf"] - >>> tasks = {"split": ..., "extract_txt": ..., "store": ...} - >>> client = NvIngestClient() - >>> job_ids = create_job_specs_for_batch(files_batch, tasks, client) - >>> print(job_ids) - ['job_12345', 'job_67890'] - - In this example, jobs are created and submitted for the files in `files_batch`, with the - tasks in `tasks` being added to each job specification. The returned job IDs are then - printed. - - See Also - -------- - extract_file_content : Function that extracts the content and type of a file. - JobSpec : The class representing a job specification. - NvIngestClient : Client class used to submit jobs to a job processing system. - """ - - job_ids = [] - for file_name in files_batch: - try: - file_content, file_type = extract_file_content(file_name) # Assume these are defined - file_type = file_type.value - except ValueError as ve: - logger.error(f"Error extracting content from {file_name}: {ve}") - continue - - job_spec = JobSpec( - document_type=file_type, - payload=file_content, - source_id=file_name, - source_name=file_name, - extended_options={"tracing_options": {"trace": True, "ts_send": time.time_ns()}}, - ) - - logger.debug(f"Tasks: {tasks.keys()}") - for task in tasks: - logger.debug(f"Task: {task}") - - # TODO(Devin): Formalize this later, don't have time right now. - if "split" in tasks: - job_spec.add_task(tasks["split"]) - - if f"extract_{file_type}" in tasks: - job_spec.add_task(tasks[f"extract_{file_type}"]) - - if "store" in tasks: - job_spec.add_task(tasks["store"]) - - if "caption" in tasks: - job_spec.add_task(tasks["caption"]) - - if "dedup" in tasks: - job_spec.add_task(tasks["dedup"]) - - if "filter" in tasks: - job_spec.add_task(tasks["filter"]) - - if "embed" in tasks: - job_spec.add_task(tasks["embed"]) - - if "vdb_upload" in tasks: - job_spec.add_task(tasks["vdb_upload"]) - - job_id = client.add_job(job_spec) - job_ids.append(job_id) - - return job_ids - - def generate_job_batch_for_iteration( - client: Any, - pbar: Any, - files: List[str], - tasks: Dict, - processed: int, - batch_size: int, - retry_job_ids: List[str], - fail_on_error: bool = False + client: Any, + pbar: Any, + files: List[str], + tasks: Dict, + processed: int, + batch_size: int, + retry_job_ids: List[str], + fail_on_error: bool = False, ) -> Tuple[List[str], Dict[str, str], int]: """ - Generates a batch of job specifications for the current iteration of file processing. This function handles retrying failed jobs and creating new jobs for unprocessed files. The job specifications are then submitted for processing. + Generates a batch of job specifications for the current iteration of file processing. + This function handles retrying failed jobs and creating new jobs for unprocessed files. + The job specifications are then submitted for processing. Parameters ---------- @@ -597,9 +482,9 @@ def generate_job_batch_for_iteration( if (cur_job_count < batch_size) and (processed < len(files)): new_job_count = min(batch_size - cur_job_count, len(files) - processed) - batch_files = files[processed: processed + new_job_count] # noqa: E203 + batch_files = files[processed : processed + new_job_count] # noqa: E203 - new_job_indices = create_job_specs_for_batch(batch_files, tasks, client) + new_job_indices = client.create_jobs_for_batch(batch_files, tasks) if len(new_job_indices) != new_job_count: missing_jobs = new_job_count - len(new_job_indices) error_msg = f"Missing {missing_jobs} job specs -- this is likely due to bad reads or file corruption" @@ -618,93 +503,14 @@ def generate_job_batch_for_iteration( return job_indices, job_index_map_updates, processed -def handle_future_result( - future: concurrent.futures.Future, - futures_dict: Dict[concurrent.futures.Future, str], -) -> Dict[str, Any]: - """ - Handle the result of a completed future job, process annotations, and save the result. - - This function processes the result of a future, extracts annotations (if any), logs them, - checks the validity of the ingest result, and optionally saves the result to the provided - output directory. If the result indicates a failure, a retry list of job IDs is prepared. - - Parameters - ---------- - future : concurrent.futures.Future - A future object representing an asynchronous job. The result of this job will be - processed once it completes. - - futures_dict : Dict[concurrent.futures.Future, str] - A dictionary mapping future objects to job IDs. The job ID associated with the - provided future is retrieved from this dictionary. - - Returns - ------- - Dict[str, Any] - - Raises - ------ - RuntimeError - If the job result is invalid, this exception is raised with a description of the failure. - - Notes - ----- - - The `future.result()` is assumed to return a tuple where the first element is the actual - result (as a dictionary), and the second element (if present) can be ignored. - - Annotations in the result (if any) are logged for debugging purposes. - - The `check_ingest_result` function (assumed to be defined elsewhere) is used to validate - the result. If the result is invalid, a `RuntimeError` is raised. - - The function handles saving the result data to the specified output directory using the - `save_response_data` function. - - Examples - -------- - Suppose we have a future object representing a job, a dictionary of futures to job IDs, - and a directory for saving results: - - >>> future = concurrent.futures.Future() - >>> futures_dict = {future: "job_12345"} - >>> job_id_map = {"job_12345": {...}} - >>> output_directory = "/path/to/save" - >>> result, retry_job_ids = handle_future_result(future, futures_dict, job_id_map, output_directory) - - In this example, the function processes the completed job and saves the result to the - specified directory. If the job fails, it raises a `RuntimeError` and returns a list of - retry job IDs. - - See Also - -------- - check_ingest_result : Function to validate the result of the job. - save_response_data : Function to save the result to a directory. - """ - - try: - result, _ = future.result()[0] - if ("annotations" in result) and result["annotations"]: - annotations = result["annotations"] - for key, value in annotations.items(): - logger.debug(f"Annotation: {key} -> {json.dumps(value, indent=2)}") - - failed, description = check_ingest_result(result) - - if failed: - raise RuntimeError(f"{description}") - except Exception as e: - logger.debug(f"Error processing future result: {e}") - raise e - - return result - - def create_and_process_jobs( - files: List[str], - client: NvIngestClient, - tasks: Dict[str, Any], - output_directory: str, - batch_size: int, - timeout: int = 10, - fail_on_error: bool = False, + files: List[str], + client: NvIngestClient, + tasks: Dict[str, Any], + output_directory: str, + batch_size: int, + timeout: int = 10, + fail_on_error: bool = False, ) -> Tuple[int, Dict[str, List[float]], int]: """ Process a list of files, creating and submitting jobs for each file, then fetch and handle the results. @@ -805,7 +611,6 @@ def create_and_process_jobs( futures_dict = client.fetch_job_result_async(job_ids, timeout=timeout, data_only=False) for future in as_completed(futures_dict.keys()): - retry = False job_id = futures_dict[future] source_name = job_id_map[job_id] diff --git a/client/src/nv_ingest_client/client/client.py b/client/src/nv_ingest_client/client/client.py index 868d147b..b23cb473 100644 --- a/client/src/nv_ingest_client/client/client.py +++ b/client/src/nv_ingest_client/client/client.py @@ -5,12 +5,15 @@ # pylint: disable=broad-except +import concurrent.futures import json import logging -import concurrent.futures +import math +import time from concurrent.futures import Future from concurrent.futures import ThreadPoolExecutor from concurrent.futures import as_completed +from typing import Any from typing import Callable from typing import Dict from typing import List @@ -19,12 +22,16 @@ from typing import Union from nv_ingest_client.message_clients.rest.rest_client import RestClient +from nv_ingest_client.primitives import BatchJobSpec from nv_ingest_client.primitives import JobSpec from nv_ingest_client.primitives.jobs import JobState from nv_ingest_client.primitives.jobs import JobStateEnum from nv_ingest_client.primitives.tasks import Task from nv_ingest_client.primitives.tasks import TaskType +from nv_ingest_client.primitives.tasks import is_valid_task_type from nv_ingest_client.primitives.tasks import task_factory +from nv_ingest_client.util.processing import handle_future_result +from nv_ingest_client.util.util import create_job_specs_for_batch logger = logging.getLogger(__name__) @@ -148,14 +155,14 @@ def _get_and_check_job_state( ) -> JobState: if required_state and not isinstance(required_state, list): required_state = [required_state] - job_state = self._job_states.get(job_index) if not job_state: raise ValueError(f"Job with ID {job_index} does not exist in JobStates: {self._job_states}") if required_state and (job_state.state not in required_state): raise ValueError( - f"Job with ID {job_state.job_spec.job_id} has invalid state {job_state.state}, expected {required_state}" + f"Job with ID {job_state.job_spec.job_id} has invalid state " + f"{job_state.state}, expected {required_state}" ) return job_state @@ -163,13 +170,27 @@ def _get_and_check_job_state( def job_count(self): return len(self._job_states) - def add_job(self, job_spec: JobSpec = None) -> str: + def _add_single_job(self, job_spec: JobSpec) -> str: job_index = self._generate_job_index() self._job_states[job_index] = JobState(job_spec=job_spec) return job_index + def add_job(self, job_spec: Union[BatchJobSpec, JobSpec]) -> str: + if isinstance(job_spec, JobSpec): + job_index = self._add_single_job(job_spec) + return job_index + elif isinstance(job_spec, BatchJobSpec): + job_indexes = [] + for _, job_specs in job_spec.job_specs.items(): + for job in job_specs: + job_index = self._add_single_job(job) + job_indexes.append(job_index) + return job_indexes + else: + raise ValueError(f"Unexpected type: {type(job_spec)}") + def create_job( self, payload: str, @@ -311,8 +332,7 @@ def fetch_job_result_cli(self, job_ids: Union[str, List[str]], timeout: float = job_ids = [job_ids] return [self._fetch_job_result(job_id, timeout, data_only) for job_id in job_ids] - - + # Nv-Ingest jobs are often "long running". Therefore after # submission we intermittently check if the job is completed. def _fetch_job_result_wait(self, job_id: str, timeout: float = 60, data_only: bool = True): @@ -321,21 +341,82 @@ def _fetch_job_result_wait(self, job_id: str, timeout: float = 60, data_only: bo return [self._fetch_job_result(job_id, timeout, data_only)] except TimeoutError: logger.debug("Job still processing ... aka HTTP 202 received") - + # This is the direct Python approach function for retrieving jobs which handles the timeouts directly # in the function itself instead of expecting the user to handle it themselves - # Note this method only supports fetching a single job result synchronously - def fetch_job_result(self, job_id: str, timeout: float = 100, data_only: bool = True): - # A thread pool executor is a simple approach to performing an action with a timeout + def fetch_job_result( + self, + job_ids: List[str], + timeout: float = 100, + max_retries: Optional[int] = None, + retry_delay: float = 1, + verbose: bool = False, + ) -> List[Tuple[Optional[Dict], str]]: + """ + Fetches job results for multiple job IDs concurrently with individual timeouts and retry logic. + + Args: + job_ids (List[str]): A list of job IDs to fetch results for. + timeout (float): Timeout for each fetch operation, in seconds. + max_retries (int): Maximum number of retries for jobs that are not ready yet. + retry_delay (float): Delay between retry attempts, in seconds. + + Returns: + List[Tuple[Optional[Dict], str]]: A list of tuples containing the job results and job IDs. + If a timeout or error occurs, the result will be None for that job. + + Raises: + ValueError: If there is an error in decoding the job result. + TimeoutError: If the fetch operation times out. + Exception: For all other unexpected issues. + """ + results = [] + + def fetch_with_retries(job_id: str): + retries = 0 + while (max_retries is None) or (retries < max_retries): + try: + # Attempt to fetch the job result + result = self._fetch_job_result(job_id, timeout, data_only=False) + return result, job_id + except Exception as e: + # Check if the error is a retryable error + if "Job is not ready yet. Retry later." in str(e): + if verbose: + logger.info( + f"Job {job_id} is not ready. " + f"Retrying {retries + 1}/{max_retries if max_retries else '∞'} " + f"after {retry_delay} seconds." + ) + retries += 1 + time.sleep(retry_delay) # Wait before retrying + else: + # For any other error, log and break out of the retry loop + logger.error(f"Error while fetching result for job ID {job_id}: {e}") + return None, job_id + logger.error(f"Max retries exceeded for job {job_id}.") + return None, job_id + + # Use ThreadPoolExecutor to fetch results concurrently with ThreadPoolExecutor() as executor: - future = executor.submit(self._fetch_job_result_wait, job_id, timeout, data_only) - try: - # Wait for the result within the specified timeout - return future.result(timeout=timeout) - except concurrent.futures.TimeoutError: - # Raise a standard Python TimeoutError which the client will be expecting - raise TimeoutError(f"Job processing did not complete within the specified {timeout} seconds") - + futures = {executor.submit(fetch_with_retries, job_id): job_id for job_id in job_ids} + + # Collect results as futures complete + for future in as_completed(futures): + job_id = futures[future] + try: + result = handle_future_result(future, futures, timeout) + results.append(result.get("data")) + except concurrent.futures.TimeoutError: + logger.error(f"Timeout while fetching result for job ID {job_id}") + except json.JSONDecodeError as e: + logger.error(f"Decoding while processing job ID {job_id}: {e}") + except RuntimeError as e: + logger.error(f"Error while processing job ID {job_id}: {e}") + except Exception as e: + logger.error(f"Error while fetching result for job ID {job_id}: {e}") + + return results def _ensure_submitted(self, job_ids: List[str]): if isinstance(job_ids, str): @@ -425,18 +506,43 @@ def _submit_job( # Free up memory -- payload should never be used again, and we don't want to keep it around. job_state.job_spec.payload = None - + return x_trace_id except Exception as err: logger.error(f"Failed to submit job {job_index} to queue {job_queue_id}: {err}") job_state.state = JobStateEnum.FAILED raise - def submit_job(self, job_indices: Union[str, List[str]], job_queue_id: str) -> List[Union[Dict, None]]: + def submit_job( + self, job_indices: Union[str, List[str]], job_queue_id: str, batch_size: int = 10 + ) -> List[Union[Dict, None]]: if isinstance(job_indices, str): job_indices = [job_indices] - return [self._submit_job(job_id, job_queue_id) for job_id in job_indices] + results = [] + total_batches = math.ceil(len(job_indices) / batch_size) + + submission_errors = [] + for batch_num in range(total_batches): + batch_start = batch_num * batch_size + batch_end = batch_start + batch_size + batch = job_indices[batch_start:batch_end] + + # Submit each batch of jobs + for job_id in batch: + try: + x_trace_id = self._submit_job(job_id, job_queue_id) + except Exception as e: # Even if one fails, we should continue with the rest of the batch. + submission_errors.append(e) + continue + results.append(x_trace_id) + + if submission_errors: + error_msg = str(submission_errors[0]) + if len(submission_errors) > 1: + error_msg += f"... [{len(submission_errors) - 1} more messages truncated]" + raise type(submission_errors[0])(error_msg) + return results def submit_job_async(self, job_indices: Union[str, List[str]], job_queue_id: str) -> Dict[Future, str]: """ @@ -475,3 +581,95 @@ def submit_job_async(self, job_indices: Union[str, List[str]], job_queue_id: str future_to_job_index[future] = job_index return future_to_job_index + + def create_jobs_for_batch(self, files_batch: List[str], tasks: Dict[str, Any]) -> List[JobSpec]: + """ + Create and submit job specifications (JobSpecs) for a batch of files, returning the job IDs. + This function takes a batch of files, processes each file to extract its content and type, + creates a job specification (JobSpec) for each file, and adds tasks from the provided task + list. It then submits the jobs to the client and collects their job IDs. + + Parameters + ---------- + files_batch : List[str] + A list of file paths to be processed. Each file is assumed to be in a format compatible + with the `extract_file_content` function, which extracts the file's content and type. + tasks : Dict[str, Any] + A dictionary of tasks to be added to each job. The keys represent task names, and the + values represent task specifications or configurations. Standard tasks include "split", + "extract", "store", "caption", "dedup", "filter", "embed", and "vdb_upload". + + Returns + ------- + Tuple[List[JobSpec], List[str]] + A Tuple containing the list of JobSpecs and list of job IDs corresponding to the submitted jobs. + Each job ID is returned by the client's `add_job` method. + + Raises + ------ + ValueError + If there is an error extracting the file content or type from any of the files, a + ValueError will be logged, and the corresponding file will be skipped. + + Notes + ----- + - The function assumes that a utility function `extract_file_content` is defined elsewhere, + which extracts the content and type from the provided file paths. + - For each file, a `JobSpec` is created with relevant metadata, including document type and + file content. Various tasks are conditionally added based on the provided `tasks` dictionary. + - The job specification includes tracing options with a timestamp (in nanoseconds) for + diagnostic purposes. + + Examples + -------- + Suppose you have a batch of files and tasks to process: + >>> files_batch = ["file1.txt", "file2.pdf"] + >>> tasks = {"split": ..., "extract_txt": ..., "store": ...} + >>> client = NvIngestClient() + >>> job_ids = client.create_job_specs_for_batch(files_batch, tasks) + >>> print(job_ids) + ['job_12345', 'job_67890'] + + In this example, jobs are created and submitted for the files in `files_batch`, with the + tasks in `tasks` being added to each job specification. The returned job IDs are then + printed. + + See Also + -------- + create_job_specs_for_batch: Function that creates job specifications for a batch of files. + JobSpec : The class representing a job specification. + """ + if not isinstance(tasks, dict): + raise ValueError("`tasks` must be a dictionary of task names -> task specifications.") + + job_specs = create_job_specs_for_batch(files_batch) + + job_ids = [] + for job_spec in job_specs: + logger.debug(f"Tasks: {tasks.keys()}") + for task in tasks: + logger.debug(f"Task: {task}") + + file_type = job_spec.document_type + + seen_tasks = set() # For tracking tasks and rejecting duplicate tasks. + + for task_name, task_config in tasks.items(): + if task_name.lower().startswith("extract_"): + task_file_type = task_name.split("_", 1)[1] + if file_type.lower() != task_file_type.lower(): + continue + elif not is_valid_task_type(task_name.upper()): + raise ValueError(f"Invalid task type: '{task_name}'") + + if str(task_config) in seen_tasks: + raise ValueError(f"Duplicate task detected: {task_name} with config {task_config}") + + job_spec.add_task(task_config) + + seen_tasks.add(str(task_config)) + + job_id = self.add_job(job_spec) + job_ids.append(job_id) + + return job_ids diff --git a/client/src/nv_ingest_client/nv_ingest_cli.py b/client/src/nv_ingest_client/nv_ingest_cli.py index 14679a23..11986260 100644 --- a/client/src/nv_ingest_client/nv_ingest_cli.py +++ b/client/src/nv_ingest_client/nv_ingest_cli.py @@ -11,20 +11,19 @@ import click import pkg_resources -from nv_ingest_client.cli.util.click import ClientType from nv_ingest_client.cli.util.click import LogLevel from nv_ingest_client.cli.util.click import click_match_and_validate_files from nv_ingest_client.cli.util.click import click_validate_batch_size from nv_ingest_client.cli.util.click import click_validate_file_exists from nv_ingest_client.cli.util.click import click_validate_task -from nv_ingest_client.cli.util.dataset import get_dataset_files -from nv_ingest_client.cli.util.dataset import get_dataset_statistics from nv_ingest_client.cli.util.processing import create_and_process_jobs from nv_ingest_client.cli.util.processing import report_statistics from nv_ingest_client.cli.util.system import configure_logging from nv_ingest_client.cli.util.system import ensure_directory_with_permissions from nv_ingest_client.client import NvIngestClient from nv_ingest_client.message_clients.rest.rest_client import RestClient +from nv_ingest_client.util.dataset import get_dataset_files +from nv_ingest_client.util.dataset import get_dataset_statistics from pkg_resources import DistributionNotFound from pkg_resources import VersionConflict @@ -131,6 +130,10 @@ - extract_images (bool): Enables image extraction. Default: False. - extract_tables (bool): Enables table extraction. Default: False. - extract_charts (bool): Enables chart extraction. Default: False. + - text_depth (str): Text extraction granularity ('document', 'page'). Default: 'document'. + Note: this will affect the granularity of text extraction, and the associated metadata. ie. 'page' will extract + text per page and you will get page-level metadata, 'document' will extract text for the entire document so + elements like page numbers will not be associated with individual text elements. \b - store: Stores any images extracted from documents. Options: diff --git a/client/src/nv_ingest_client/primitives/__init__.py b/client/src/nv_ingest_client/primitives/__init__.py index b4834304..e081645f 100644 --- a/client/src/nv_ingest_client/primitives/__init__.py +++ b/client/src/nv_ingest_client/primitives/__init__.py @@ -2,7 +2,8 @@ # All rights reserved. # SPDX-License-Identifier: Apache-2.0 +from .jobs import BatchJobSpec from .jobs import JobSpec from .tasks import Task -__all__ = ["JobSpec", "Task"] +__all__ = ["BatchJobSpec", "JobSpec", "Task"] diff --git a/client/src/nv_ingest_client/primitives/jobs/__init__.py b/client/src/nv_ingest_client/primitives/jobs/__init__.py index 7d8b481a..ecd714f9 100644 --- a/client/src/nv_ingest_client/primitives/jobs/__init__.py +++ b/client/src/nv_ingest_client/primitives/jobs/__init__.py @@ -2,8 +2,9 @@ # All rights reserved. # SPDX-License-Identifier: Apache-2.0 -from .job_spec import JobSpec -from .job_state import JobState -from .job_state import JobStateEnum +from nv_ingest_client.primitives.jobs.job_spec import BatchJobSpec +from nv_ingest_client.primitives.jobs.job_spec import JobSpec +from nv_ingest_client.primitives.jobs.job_state import JobState +from nv_ingest_client.primitives.jobs.job_state import JobStateEnum -__all__ = ["JobSpec", "JobState", "JobStateEnum"] +__all__ = ["BatchJobSpec", "JobSpec", "JobState", "JobStateEnum"] diff --git a/client/src/nv_ingest_client/primitives/jobs/job_spec.py b/client/src/nv_ingest_client/primitives/jobs/job_spec.py index 8693ecd1..aba0c759 100644 --- a/client/src/nv_ingest_client/primitives/jobs/job_spec.py +++ b/client/src/nv_ingest_client/primitives/jobs/job_spec.py @@ -4,12 +4,17 @@ import logging +from collections import defaultdict +from io import BytesIO from typing import Dict from typing import List from typing import Optional +from typing import Union from uuid import UUID from nv_ingest_client.primitives.tasks import Task +from nv_ingest_client.util.dataset import get_dataset_files +from nv_ingest_client.util.dataset import get_dataset_statistics logger = logging.getLogger(__name__) @@ -53,13 +58,13 @@ class JobSpec: """ def __init__( - self, - payload: str = None, - tasks: Optional[List] = None, - source_id: Optional[str] = None, - source_name: Optional[str] = None, - document_type: Optional[str] = None, - extended_options: Optional[Dict] = None, + self, + payload: str = None, + tasks: Optional[List] = None, + source_id: Optional[str] = None, + source_name: Optional[str] = None, + document_type: Optional[str] = None, + extended_options: Optional[Dict] = None, ) -> None: self._document_type = document_type or "txt" self._extended_options = extended_options or {} @@ -134,6 +139,10 @@ def source_name(self) -> str: def source_name(self, source_name: str) -> None: self._source_name = source_name + @property + def document_type(self) -> str: + return self._document_type + def add_task(self, task) -> None: """ Adds a task to the job specification. @@ -152,3 +161,210 @@ def add_task(self, task) -> None: raise ValueError("Task must derive from nv_ingest_client.primitives.Task class") self._tasks.append(task) + + +class BatchJobSpec: + """ + A class used to represent a batch of job specifications (JobSpecs). + + This class allows for batch processing of multiple jobs, either from a list of + `JobSpec` instances or from file paths. It provides methods for adding job + specifications, associating tasks with those specifications, and serializing the + batch to a dictionary format. + + Attributes + ---------- + _file_type_to_job_spec : defaultdict + A dictionary that maps document types to a list of `JobSpec` instances. + """ + + def __init__(self, job_specs_or_files: Optional[Union[List[JobSpec], List[str]]] = None) -> None: + """ + Initializes the BatchJobSpec instance. + + Parameters + ---------- + job_specs_or_files : Optional[Union[List[JobSpec], List[str]]], optional + A list of either `JobSpec` instances or file paths. If provided, the + instance will be initialized with the corresponding job specifications. + """ + self._file_type_to_job_spec = defaultdict(list) + + if job_specs_or_files: + if isinstance(job_specs_or_files[0], JobSpec): + self.from_job_specs(job_specs_or_files) + elif isinstance(job_specs_or_files[0], str): + self.from_files(job_specs_or_files) + else: + raise ValueError("Invalid input type for job_specs. Must be a list of JobSpec or file paths.") + + def from_job_specs(self, job_specs: Union[JobSpec, List[JobSpec]]) -> None: + """ + Initializes the batch with a list of `JobSpec` instances. + + Parameters + ---------- + job_specs : Union[JobSpec, List[JobSpec]] + A single `JobSpec` or a list of `JobSpec` instances to add to the batch. + """ + if isinstance(job_specs, JobSpec): + job_specs = [job_specs] + + for job_spec in job_specs: + self.add_job_spec(job_spec) + + def from_files(self, files: Union[str, List[str]]) -> None: + """ + Initializes the batch by generating job specifications from file paths. + + Parameters + ---------- + files : Union[str, List[str]] + A single file path or a list of file paths to create job specifications from. + """ + from nv_ingest_client.util.util import create_job_specs_for_batch + from nv_ingest_client.util.util import generate_matching_files + + if isinstance(files, str): + files = [files] + + matching_files = list(generate_matching_files(files)) + if not matching_files: + logger.warning(f"No files found matching {files}.") + return + + job_specs = create_job_specs_for_batch(matching_files) + for job_spec in job_specs: + self.add_job_spec(job_spec) + + def _from_dataset(self, dataset: str, shuffle_dataset: bool = True) -> None: + """ + Internal method to initialize the batch from a dataset. + + Parameters + ---------- + dataset : str + The path to the dataset file. + shuffle_dataset : bool, optional + Whether to shuffle the dataset files before adding them to the batch, by default True. + """ + with open(dataset, "rb") as file: + dataset_bytes = BytesIO(file.read()) + + if logger.isEnabledFor(logging.DEBUG): + logger.debug(get_dataset_statistics(dataset_bytes)) + + dataset_files = get_dataset_files(dataset_bytes, shuffle_dataset) + + self.from_files(dataset_files) + + @classmethod + def from_dataset(cls, dataset: str, shuffle_dataset: bool = True): + """ + Class method to create a `BatchJobSpec` instance from a dataset. + + Parameters + ---------- + dataset : str + The path to the dataset file. + shuffle_dataset : bool, optional + Whether to shuffle the dataset files before adding them to the batch, by default True. + + Returns + ------- + BatchJobSpec + A new instance of `BatchJobSpec` initialized with the dataset files. + """ + batch_job_spec = cls() + batch_job_spec._from_dataset(dataset, shuffle_dataset=shuffle_dataset) + return batch_job_spec + + def add_job_spec(self, job_spec: JobSpec) -> None: + """ + Adds a `JobSpec` to the batch. + + Parameters + ---------- + job_spec : JobSpec + The job specification to add. + """ + self._file_type_to_job_spec[job_spec.document_type].append(job_spec) + + def add_task(self, task, document_type=None): + """ + Adds a task to the relevant job specifications in the batch. + + If a `document_type` is provided, the task will be added to all job specifications + matching that document type. If no `document_type` is provided, the task will be added + to all job specifications in the batch. + + Parameters + ---------- + task : Task + The task to add. Must derive from the `nv_ingest_client.primitives.Task` class. + + document_type : str, optional + The document type used to filter job specifications. If not provided, the + `document_type` is inferred from the task, or the task is applied to all job specifications. + + Raises + ------ + ValueError + If the task does not derive from the `Task` class. + """ + if not isinstance(task, Task): + raise ValueError("Task must derive from nv_ingest_client.primitives.Task class") + + document_type = document_type or task.to_dict().get("document_type") + + if document_type: + target_job_specs = self._file_type_to_job_spec[document_type] + else: + target_job_specs = [] + for job_specs in self._file_type_to_job_spec.values(): + target_job_specs.extend(job_specs) + + for job_spec in target_job_specs: + job_spec.add_task(task) + + def to_dict(self) -> Dict[str, List[Dict]]: + """ + Serializes the batch of job specifications into a list of dictionaries. + + Returns + ------- + List[Dict] + A list of dictionaries representing the job specifications in the batch. + """ + return { + file_type: [j.to_dict() for j in job_specs] for file_type, job_specs in self._file_type_to_job_spec.items() + } + + def __str__(self) -> str: + """ + Returns a string representation of the batch. + + Returns + ------- + str + A string representation of the job specifications in the batch. + """ + result = "" + for file_type, job_specs in self._file_type_to_job_spec.items(): + result += f"{file_type}\n" + for job_spec in job_specs: + result += str(job_spec) + "\n" + + return result + + @property + def job_specs(self) -> Dict[str, List[str]]: + """ + A property that returns a dictionary of job specs categorized by document type. + + Returns + ------- + Dict[str, List[str]] + A dictionary mapping document types to job specifications. + """ + return self._file_type_to_job_spec diff --git a/client/src/nv_ingest_client/primitives/tasks/chart_extraction.py b/client/src/nv_ingest_client/primitives/tasks/chart_extraction.py new file mode 100644 index 00000000..37c7612a --- /dev/null +++ b/client/src/nv_ingest_client/primitives/tasks/chart_extraction.py @@ -0,0 +1,53 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + + +# pylint: disable=too-few-public-methods +# pylint: disable=too-many-arguments + +import logging +from typing import Dict + +from pydantic import BaseModel + +from .task_base import Task + +logger = logging.getLogger(__name__) + + +class ChartExtractionSchema(BaseModel): + class Config: + extra = "forbid" + + +class ChartExtractionTask(Task): + """ + Object for chart extraction task + """ + + def __init__( + self) -> None: + """ + Setup Dedup Task Config + """ + super().__init__() + + def __str__(self) -> str: + """ + Returns a string with the object's config and run time state + """ + info = "" + info += "chart extraction task\n" + return info + + def to_dict(self) -> Dict: + """ + Convert to a dict for submission to redis + """ + + task_properties = { + "params": {}, + } + + return {"type": "chart_data_extract", "task_properties": task_properties} diff --git a/client/src/nv_ingest_client/primitives/tasks/extract.py b/client/src/nv_ingest_client/primitives/tasks/extract.py index 6c2205f1..36f43d04 100644 --- a/client/src/nv_ingest_client/primitives/tasks/extract.py +++ b/client/src/nv_ingest_client/primitives/tasks/extract.py @@ -246,3 +246,7 @@ def to_dict(self) -> Dict: } task_properties["params"].update(adobe_properties) return {"type": "extract", "task_properties": task_properties} + + @property + def document_type(self): + return self._document_type diff --git a/client/src/nv_ingest_client/primitives/tasks/table_extraction.py b/client/src/nv_ingest_client/primitives/tasks/table_extraction.py new file mode 100644 index 00000000..5d1d7299 --- /dev/null +++ b/client/src/nv_ingest_client/primitives/tasks/table_extraction.py @@ -0,0 +1,53 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + + +# pylint: disable=too-few-public-methods +# pylint: disable=too-many-arguments + +import logging +from typing import Dict + +from pydantic import BaseModel + +from .task_base import Task + +logger = logging.getLogger(__name__) + + +class TableExtractionSchema(BaseModel): + class Config: + extra = "forbid" + + +class TableExtractionTask(Task): + """ + Object for table extraction tasks + """ + + def __init__( + self) -> None: + """ + Setup Dedup Task Config + """ + super().__init__() + + def __str__(self) -> str: + """ + Returns a string with the object's config and run time state + """ + info = "" + info += "table extraction task\n" + return info + + def to_dict(self) -> Dict: + """ + Convert to a dict for submission to redis + """ + + task_properties = { + "params": {}, + } + + return {"type": "table_data_extract", "task_properties": task_properties} diff --git a/client/src/nv_ingest_client/primitives/tasks/task_base.py b/client/src/nv_ingest_client/primitives/tasks/task_base.py index 47cee997..5d4a65cd 100644 --- a/client/src/nv_ingest_client/primitives/tasks/task_base.py +++ b/client/src/nv_ingest_client/primitives/tasks/task_base.py @@ -16,6 +16,7 @@ class TaskType(Enum): CAPTION = auto() + DEDUP = auto() EMBED = auto() EXTRACT = auto() FILTER = auto() @@ -23,6 +24,8 @@ class TaskType(Enum): TRANSFORM = auto() STORE = auto() VDB_UPLOAD = auto() + TABLE_DATA_EXTRACT = auto() + CHART_DATA_EXTRACT = auto() def is_valid_task_type(task_type_str: str) -> bool: @@ -68,7 +71,6 @@ def to_dict(self) -> Dict: return {} - # class ExtractUnstructuredTask(ExtractTask): # """ # Object for document unstructured extraction task diff --git a/client/src/nv_ingest_client/primitives/tasks/task_factory.py b/client/src/nv_ingest_client/primitives/tasks/task_factory.py index faa91e3c..8a4eb64e 100644 --- a/client/src/nv_ingest_client/primitives/tasks/task_factory.py +++ b/client/src/nv_ingest_client/primitives/tasks/task_factory.py @@ -9,6 +9,7 @@ from typing import Union from .caption import CaptionTask +from .dedup import DedupTask from .embed import EmbedTask from .extract import ExtractTask from .filter import FilterTask @@ -33,6 +34,7 @@ def __init__(self, **kwargs) -> None: # Mapping of TaskType to Task classes, arranged alphabetically by task type _TASK_MAP: Dict[TaskType, Callable] = { TaskType.CAPTION: CaptionTask, + TaskType.DEDUP: DedupTask, TaskType.EMBED: EmbedTask, TaskType.EXTRACT: ExtractTask, TaskType.FILTER: FilterTask, diff --git a/client/src/nv_ingest_client/cli/util/dataset.py b/client/src/nv_ingest_client/util/dataset.py similarity index 100% rename from client/src/nv_ingest_client/cli/util/dataset.py rename to client/src/nv_ingest_client/util/dataset.py diff --git a/client/src/nv_ingest_client/util/processing.py b/client/src/nv_ingest_client/util/processing.py new file mode 100644 index 00000000..58e945a8 --- /dev/null +++ b/client/src/nv_ingest_client/util/processing.py @@ -0,0 +1,94 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import concurrent +import json +import logging +from typing import Any +from typing import Dict +from typing import Optional + +from nv_ingest_client.util.util import check_ingest_result + +logger = logging.getLogger(__name__) + + +def handle_future_result( + future: concurrent.futures.Future, + futures_dict: Dict[concurrent.futures.Future, str], + timeout: Optional[int] = None, +) -> Dict[str, Any]: + """ + Handle the result of a completed future job, process annotations, and save the result. + + This function processes the result of a future, extracts annotations (if any), logs them, + checks the validity of the ingest result, and optionally saves the result to the provided + output directory. If the result indicates a failure, a retry list of job IDs is prepared. + + Parameters + ---------- + future : concurrent.futures.Future + A future object representing an asynchronous job. The result of this job will be + processed once it completes. + + futures_dict : Dict[concurrent.futures.Future, str] + A dictionary mapping future objects to job IDs. The job ID associated with the + provided future is retrieved from this dictionary. + + Returns + ------- + Dict[str, Any] + + Raises + ------ + RuntimeError + If the job result is invalid, this exception is raised with a description of the failure. + + Notes + ----- + - The `future.result()` is assumed to return a tuple where the first element is the actual + result (as a dictionary), and the second element (if present) can be ignored. + - Annotations in the result (if any) are logged for debugging purposes. + - The `check_ingest_result` function (assumed to be defined elsewhere) is used to validate + the result. If the result is invalid, a `RuntimeError` is raised. + - The function handles saving the result data to the specified output directory using the + `save_response_data` function. + + Examples + -------- + Suppose we have a future object representing a job, a dictionary of futures to job IDs, + and a directory for saving results: + + >>> future = concurrent.futures.Future() + >>> futures_dict = {future: "job_12345"} + >>> job_id_map = {"job_12345": {...}} + >>> output_directory = "/path/to/save" + >>> result, retry_job_ids = handle_future_result(future, futures_dict, job_id_map, output_directory) + + In this example, the function processes the completed job and saves the result to the + specified directory. If the job fails, it raises a `RuntimeError` and returns a list of + retry job IDs. + + See Also + -------- + check_ingest_result : Function to validate the result of the job. + save_response_data : Function to save the result to a directory. + """ + + try: + result, _ = future.result(timeout=timeout)[0] + if ("annotations" in result) and result["annotations"]: + annotations = result["annotations"] + for key, value in annotations.items(): + logger.debug(f"Annotation: {key} -> {json.dumps(value, indent=2)}") + + failed, description = check_ingest_result(result) + + if failed: + raise RuntimeError(f"{description}") + except Exception as e: + logger.debug(f"Error processing future result: {e}") + raise e + + return result diff --git a/client/src/nv_ingest_client/util/util.py b/client/src/nv_ingest_client/util/util.py index 61ebc5e3..e5832ef1 100644 --- a/client/src/nv_ingest_client/util/util.py +++ b/client/src/nv_ingest_client/util/util.py @@ -2,22 +2,30 @@ # All rights reserved. # SPDX-License-Identifier: Apache-2.0 - +import glob import logging import os +import time import traceback import typing from io import BytesIO from typing import Dict +from typing import List import pypdfium2 as pdfium from docx import Document as DocxDocument +from nv_ingest_client.primitives.jobs.job_spec import JobSpec from nv_ingest_client.util.file_processing.extract import DocumentTypeEnum from nv_ingest_client.util.file_processing.extract import detect_encoding_and_read_text_file from nv_ingest_client.util.file_processing.extract import extract_file_content from nv_ingest_client.util.file_processing.extract import get_or_infer_file_type from pptx import Presentation + +logger = logging.getLogger(__name__) + + + # pylint: disable=invalid-name # pylint: disable=missing-class-docstring # pylint: disable=logging-fstring-interpolation @@ -246,17 +254,116 @@ def check_ingest_result(json_payload: Dict) -> typing.Tuple[bool, str]: logger.debug( f"Checking ingest result:\n Status: {json_payload.get('status', None)}" - f"\n Description: {json_payload.get('description', None)}") + f"\n Description: {json_payload.get('description', None)}" + ) is_failed = json_payload.get("status", "") in "failed" description = json_payload.get("description", "") # Look to see if we have any failure annotations to augment the description - if (is_failed and 'annotations' in json_payload): - for annot_id, value in json_payload['annotations'].items(): - if ('task_result' in value and value['task_result'] == "FAILURE"): - message = value.get('message', "Unknown") + if is_failed and "annotations" in json_payload: + for annot_id, value in json_payload["annotations"].items(): + if "task_result" in value and value["task_result"] == "FAILURE": + message = value.get("message", "Unknown") description = f"\n↪ Event that caused this failure: {annot_id} -> {message}" break return is_failed, description + + +def generate_matching_files(file_sources): + """ + Generates a list of file paths that match the given patterns specified in file_sources. + + Parameters + ---------- + file_sources : list of str + A list containing the file source patterns to match against. + + Returns + ------- + generator + A generator yielding paths to files that match the specified patterns. + + Notes + ----- + This function utilizes glob pattern matching to find files that match the specified patterns. + It yields each matching file path, allowing for efficient processing of potentially large + sets of files. + """ + files = [ + file_path + for pattern in file_sources + for file_path in glob.glob(pattern, recursive=True) + if os.path.isfile(file_path) + ] + for file_path in files: + yield file_path + + +def create_job_specs_for_batch(files_batch: List[str]) -> List[JobSpec]: + """ + Create and job specifications (JobSpecs) for a batch of files. + This function takes a batch of files, processes each file to extract its content and type, + creates a job specification (JobSpec) for each file. + + Parameters + ---------- + files_batch : List[str] + A list of file paths to be processed. Each file is assumed to be in a format compatible + with the `extract_file_content` function, which extracts the file's content and type. + + Returns + ------- + List[JobSpec] + A list of JobSpecs. + + Raises + ------ + ValueError + If there is an error extracting the file content or type from any of the files, a + ValueError will be logged, and the corresponding file will be skipped. + + Notes + ----- + - The function assumes that a utility function `extract_file_content` is defined elsewhere, + which extracts the content and type from the provided file paths. + - For each file, a `JobSpec` is created with relevant metadata, including document type and + file content. + - The job specification includes tracing options with a timestamp (in nanoseconds) for + diagnostic purposes. + + Examples + -------- + Suppose you have a batch of files and tasks to process: + + >>> files_batch = ["file1.txt", "file2.pdf"] + >>> client = NvIngestClient() + >>> job_specs = create_job_specs_for_batch(files_batch) + >>> print(job_specs) + [nv_ingest_client.primitives.jobs.job_spec.JobSpec object at 0x743acb468bb0>, ] + + See Also + -------- + extract_file_content : Function that extracts the content and type of a file. + JobSpec : The class representing a job specification. + """ + job_specs = [] + for file_name in files_batch: + try: + file_content, file_type = extract_file_content(file_name) # Assume these are defined + file_type = file_type.value + except ValueError as ve: + logger.error(f"Error extracting content from {file_name}: {ve}") + continue + + job_spec = JobSpec( + document_type=file_type, + payload=file_content, + source_id=file_name, + source_name=file_name, + extended_options={"tracing_options": {"trace": True, "ts_send": time.time_ns()}}, + ) + job_specs.append(job_spec) + + return job_specs diff --git a/docker-compose.yaml b/docker-compose.yaml index 620b1f33..1b8943d3 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -68,7 +68,28 @@ services: reservations: devices: - driver: nvidia - count: all + device_ids: ["1"] + capabilities: [gpu] + runtime: nvidia + + deplot: + image: nvcr.io/ohlfw0olaadg/ea-participants/deplot:1.0.0 + ports: + - "8003:8000" + - "8004:8001" + - "8005:8002" + user: root + environment: + - NIM_HTTP_API_PORT=8000 + - NIM_TRITON_LOG_VERBOSE=1 + - NGC_API_KEY=${NIM_NGC_API_KEY:-${NGC_API_KEY:-ngcapikey}} + - CUDA_VISIBLE_DEVICES=0 + deploy: + resources: + reservations: + devices: + - driver: nvidia + device_ids: ["0"] capabilities: [gpu] runtime: nvidia @@ -79,14 +100,11 @@ services: - "8006:8000" - "8007:8001" - "8008:8002" - volumes: - - ${HOME}/.cache:/home/nvs/.cache user: root environment: - NIM_HTTP_API_PORT=8000 - NIM_TRITON_LOG_VERBOSE=1 - NGC_API_KEY=${NIM_NGC_API_KEY:-${NGC_API_KEY:-ngcapikey}} - - CUDA_VISIBLE_DEVICES=1 # NIM OpenTelemetry Settings - NIM_OTEL_SERVICE_NAME=cached - NIM_OTEL_TRACES_EXPORTER=otlp @@ -96,12 +114,13 @@ services: # Triton OpenTelemetry Settings - TRITON_OTEL_URL=http://otel-collector:4318/v1/traces - TRITON_OTEL_RATE=1 + - CUDA_VISIBLE_DEVICES=0 deploy: resources: reservations: devices: - driver: nvidia - count: all + device_ids: ["1"] capabilities: [gpu] runtime: nvidia @@ -112,14 +131,11 @@ services: - "8009:8000" - "8010:8001" - "8011:8002" - volumes: - - ${HOME}/.cache:/home/nvs/.cache user: root environment: - NIM_HTTP_API_PORT=8000 - NIM_TRITON_LOG_VERBOSE=1 - NGC_API_KEY=${NIM_NGC_API_KEY:-${NGC_API_KEY:-ngcapikey}} - - CUDA_VISIBLE_DEVICES=1 # NIM OpenTelemetry Settings - NIM_OTEL_SERVICE_NAME=paddle - NIM_OTEL_TRACES_EXPORTER=otlp @@ -134,7 +150,7 @@ services: reservations: devices: - driver: nvidia - count: all + device_ids: ["1"] capabilities: [gpu] runtime: nvidia @@ -150,7 +166,6 @@ services: - NIM_HTTP_API_PORT=8000 - NIM_TRITON_LOG_VERBOSE=1 - NGC_API_KEY=${NIM_NGC_API_KEY:-${NGC_API_KEY:-ngcapikey}} - - CUDA_VISIBLE_DEVICES=1 # NIM OpenTelemetry Settings - NIM_OTEL_SERVICE_NAME=embedding - NIM_OTEL_TRACES_EXPORTER=otlp @@ -165,51 +180,51 @@ services: reservations: devices: - driver: nvidia - count: all + device_ids: ["1"] capabilities: [gpu] runtime: nvidia nv-ingest-ms-runtime: image: nvcr.io/ohlfw0olaadg/ea-participants/nv-ingest:24.08 build: - context: ${NV_INGEST_ROOT} + context: ${NV_INGEST_ROOT:-.} dockerfile: "./Dockerfile" target: runtime volumes: - - ${DATASET_ROOT}:/workspace/data + - ${DATASET_ROOT:-./data}:/workspace/data ports: - "7670:7670" cap_add: - sys_nice environment: - CACHED_GRPC_ENDPOINT=cached:8001 - - CACHED_HTTP_ENDPOINT="" - - CACHED_HEALTH_ENDPOINT=cached:8000 + - CACHED_HTTP_ENDPOINT=http://cached:8000/v1/infer + - CACHED_INFER_PROTOCOL=grpc + - CUDA_VISIBLE_DEVICES=0 - DEPLOT_GRPC_ENDPOINT="" - # build.nvidia.com hosted deplot - #- DEPLOT_HTTP_ENDPOINT=https://ai.api.nvidia.com/v1/vlm/google/deplot # self hosted deplot - - DEPLOT_HTTP_ENDPOINT=http://deplot:8000/v1/chat/completions - DEPLOT_HEALTH_ENDPOINT=deplot:8000 + - DEPLOT_HTTP_ENDPOINT=http://deplot:8000/v1/chat/completions + # build.nvidia.com hosted deplot + - DEPLOT_INFER_PROTOCOL=http + #- DEPLOT_HTTP_ENDPOINT=https://ai.api.nvidia.com/v1/vlm/google/deplot - DOUGHNUT_GRPC_TRITON=triton-doughnut:8001 - INGEST_LOG_LEVEL=DEFAULT - MESSAGE_CLIENT_HOST=redis - MESSAGE_CLIENT_PORT=6379 - MINIO_BUCKET=${MINIO_BUCKET:-nv-ingest} + - MRC_IGNORE_NUMA_CHECK=1 - NGC_API_KEY=${NGC_API_KEY:-ngcapikey} - NVIDIA_BUILD_API_KEY=${NVIDIA_BUILD_API_KEY:-${NGC_API_KEY:-ngcapikey}} - OTEL_EXPORTER_OTLP_ENDPOINT=otel-collector:4317 - PADDLE_GRPC_ENDPOINT=paddle:8001 - - PADDLE_HTTP_ENDPOINT="" - - PADDLE_HEALTH_ENDPOINT=paddle:8000 + - PADDLE_HTTP_ENDPOINT=http://paddle:8000/v1/infer + - PADDLE_INFER_PROTOCOL=grpc + - READY_CHECK_ALL_COMPONENTS=True - REDIS_MORPHEUS_TASK_QUEUE=morpheus_task_queue - - TABLE_DETECTION_GRPC_TRITON=yolox:8001 - - TABLE_DETECTION_HTTP_TRITON="" - YOLOX_GRPC_ENDPOINT=yolox:8001 - - YOLOX_HTTP_ENDPOINT="" - - YOLOX_HEALTH_ENDPOINT=yolox:8000 - - CUDA_VISIBLE_DEVICES=1 - - READY_CHECK_ALL_COMPONENTS=True + - YOLOX_HTTP_ENDPOINT=http://yolox:8000/v1/infer + - YOLOX_INFER_PROTOCOL=grpc healthcheck: test: curl --fail http://nv-ingest-ms-runtime:7670/v1/health/ready || exit 1 interval: 10s @@ -220,10 +235,9 @@ services: reservations: devices: - driver: nvidia - count: all + device_ids: ["1"] capabilities: [gpu] - - + otel-collector: image: otel/opentelemetry-collector-contrib:0.102.1 hostname: otel-collector @@ -313,32 +327,39 @@ services: # timeout: 20s # retries: 3 -# milvus: -# # Turn on to leverage the `vdb_upload` task -# restart: always -# container_name: milvus-standalone -# image: milvusdb/milvus:v2.3.5 -# command: ["milvus", "run", "standalone"] -# hostname: milvus -# security_opt: -# - seccomp:unconfined -# environment: -# ETCD_ENDPOINTS: etcd:2379 -# MINIO_ADDRESS: minio:9000 -# volumes: -# - ./.volumes/milvus:/var/lib/milvus -# healthcheck: -# test: ["CMD", "curl", "-f", "http://localhost:9091/healthz"] -# interval: 30s -# start_period: 90s -# timeout: 20s -# retries: 3 -# ports: -# - "19530:19530" -# - "9091:9091" -# depends_on: -# - "etcd" -# - "minio" + # milvus: + # # Turn on to leverage the `vdb_upload` task + # restart: always + # container_name: milvus-standalone + # image: milvusdb/milvus:v2.4.9-gpu + # command: ["milvus", "run", "standalone"] + # hostname: milvus + # security_opt: + # - seccomp:unconfined + # environment: + # ETCD_ENDPOINTS: etcd:2379 + # MINIO_ADDRESS: minio:9000 + # volumes: + # - ./.volumes/milvus:/var/lib/milvus + # healthcheck: + # test: ["CMD", "curl", "-f", "http://localhost:9091/healthz"] + # interval: 30s + # start_period: 90s + # timeout: 20s + # retries: 3 + # ports: + # - "19530:19530" + # - "9091:9091" + # deploy: + # resources: + # reservations: + # devices: + # - driver: nvidia + # device_ids: ["1"] + # capabilities: [gpu] + # depends_on: + # - "etcd" + # - "minio" # attu: # # Turn on to leverage the `vdb_upload` task diff --git a/docker/scripts/entrypoint.sh b/docker/scripts/entrypoint.sh new file mode 100644 index 00000000..4aa72ba8 --- /dev/null +++ b/docker/scripts/entrypoint.sh @@ -0,0 +1,27 @@ +#!/bin/bash --login +# SPDX-FileCopyrightText: Copyright (c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# Activate the `morpheus` conda environment. +. /opt/conda/etc/profile.d/conda.sh +conda activate nv_ingest + +# Source "source" file if it exists +SRC_FILE="/opt/docker/bin/entrypoint_source" +[ -f "${SRC_FILE}" ] && source "${SRC_FILE}" + +# Run whatever the user wants. +exec "$@" \ No newline at end of file diff --git a/docs/content-metadata.md b/docs/content-metadata.md index 012ff82e..49667244 100644 --- a/docs/content-metadata.md +++ b/docs/content-metadata.md @@ -30,10 +30,12 @@ Metadata: Descriptive data which can be associated with Sources, Content(Image o | | Caption | Any caption or subheader associated with Image | Extracted | | | Text | Extracted text from a structured chart | Extracted | Pending Research | | | Image location | Location (x,y) of chart within an image | Extracted | | +| | Image location max dimensions | Max dimensions (x\_max,y\_max) of location (x,y) | Extracted | | | | uploaded\_image\_uri | Mirrors source\_metadata.source\_location | | | | Table Metadata (tables within documents) | Table format | Structured (dataframe / lists of rows and columns), or serialized as markdown, html, latex, simple (cells separated just as spaces) | Extracted | | | Table content | Extracted text content, formatted according to table\_metadata.table\_format. Important: Tables should not be chunked | Extracted | | | | Table location | Bounding box of the table | Extracted | | +| | Table location max dimensions | Max dimensions (x\_max,y\_max) of bounding box of the table | Extracted | | | | Caption | Detected captions for the table/chart | Extracted | | | | Title | TODO | Extracted | | | | Subtitle | TODO | Extracted | | diff --git a/docs/telemetry.md b/docs/telemetry.md index 398435d5..f1bc8b90 100644 --- a/docs/telemetry.md +++ b/docs/telemetry.md @@ -14,7 +14,7 @@ To run OpenTelemetry locally, run $ docker compose up otel-collector ``` -Once and OpenTelemetry and Zipkin are running, you can open your browser to explore traces: [http://localhost:9411/zipkin/](http://localhost:9411/zipkin/). +Once and OpenTelemetry and Zipkin are running, you can open your browser to explore traces: http://$YOUR_DOCKER_HOST:9411/zipkin/. ![](images/zipkin.png) @@ -24,10 +24,6 @@ To run Prometheus, run $ docker compose up prometheus ``` -Once Promethus is running, you can open your browser to explore metrics: [http://localhost:9090/](http://localhost:9090/) +Once Promethus is running, you can open your browser to explore metrics: [http://$YOUR_DOCKER_HOST:9090/] ![](images/prometheus.png) - -## Helm chart - -TODO diff --git a/helm/README.md b/helm/README.md index 57caf1b1..c8566306 100644 --- a/helm/README.md +++ b/helm/README.md @@ -92,7 +92,7 @@ minikube addons enable storage-provisioner-rancher Jobs are submitted via the `nv-ingest-cli` command. See installation [here](https://github.com/NVIDIA/nv-ingest/tree/main/client) -### Access To Redis +### Access To NV Ingest API It is recommended that the end user provide a mechanism for [`Ingress`](https://kubernetes.io/docs/concepts/services-networking/ingress/) for the Redis pod. You can test outside of your Kuberenetes cluster by [port-forwarding](https://kubernetes.io/docs/reference/kubectl/generated/kubectl_port-forward/) the Redis pod to your local environment. @@ -100,7 +100,7 @@ You can test outside of your Kuberenetes cluster by [port-forwarding](https://ku Example: ```bash -kubectl port-forward -n ${NAMESPACE} nv-ingest-redis-master-0 6379:6379 +kubectl port-forward -n ${NAMESPACE} service/nv-ingest 7670:7670 ``` ### Executing jobs diff --git a/helm/templates/deployment.yaml b/helm/templates/deployment.yaml index 4c617cd8..1e9560ea 100644 --- a/helm/templates/deployment.yaml +++ b/helm/templates/deployment.yaml @@ -132,7 +132,7 @@ spec: ports: - name: http - containerPort: {{ .Values.service.port }} + containerPort: 7670 protocol: TCP {{- if .Values.livenessProbe.enabled }} livenessProbe: diff --git a/helm/values.yaml b/helm/values.yaml index 3dd65739..2afb1d75 100644 --- a/helm/values.yaml +++ b/helm/values.yaml @@ -73,10 +73,12 @@ replicaCount: 1 ## @param resources.requests.memory [default: 16Gi] Specify request for memory resources: limits: - memory: 32Gi + memory: 90Gi nvidia.com/gpu: 1 + cpu: "36000m" requests: - memory: 16Gi + memory: 24Gi + cpu: "16000m" ## @param tmpDirSize [default: 8Gi] Specify the amount of space to reserve for temporary storage @@ -236,8 +238,17 @@ redis: enabled: false replica: replicaCount: 1 + resources: + requests: + memory: "6Gi" + limits: + memory: "12Gi" master: - resourcesPreset: xlarge + resources: + requests: + memory: "6Gi" + limits: + memory: "12Gi" ## @section Environment Variables ## @descriptionStart @@ -251,19 +262,21 @@ redis: ## @param envVars.MINIO_PUBLIC_ADDRESS [default: "http://localhost:9000"] Override this to publicly routable minio address, default assumes port-forwarding ## @param envVars.MINIO_BUCKET [default: "nv-ingest"] Override this for specific minio bucket to upload extracted images to ## @skip envVars.REDIS_MORPHEUS_TASK_QUEUE -## @skip envVars.TABLE_DETECTION_GRPC_TRITON -## @skip envVars.TABLE_DETECTION_HTTP_TRITON ## @skip envVars.CACHED_GRPC_ENDPOINT ## @skip envVars.CACHED_HTTP_ENDPOINT +## @skip envVars.CACHED_INFER_ENDPOINT ## @skip envVars.PADDLE_GRPC_ENDPOINT ## @skip envVars.PADDLE_HTTP_ENDPOINT +## @skip envVars.PADDLE_INFER_ENDPOINT ## @skip envVars.YOLOX_GRPC_ENDPOINT ## @skip envVars.YOLOX_HTTP_ENDPOINT +## @skip envVars.YOLOX_INFER_ENDPOINT ## @skip envVars.DEPLOT_GRPC_ENDPOINT ## @skip envVars.DEPLOT_HTTP_ENDPOINT +## @skip envVars.DEPLOT_INFER_ENDPOINT envVars: - MESSAGE_CLIENT_HOST: "nv-ingest-ms-runtime" - MESSAGE_CLIENT_PORT: "7670" + MESSAGE_CLIENT_HOST: "nv-ingest-redis-master" + MESSAGE_CLIENT_PORT: "6379" REDIS_MORPHEUS_TASK_QUEUE: "morpheus_task_queue" NV_INGEST_DEFAULT_TIMEOUT_MS: "1234" @@ -271,17 +284,18 @@ envVars: MINIO_PUBLIC_ADDRESS: http://localhost:9000 MINIO_BUCKET: nv-ingest - TABLE_DETECTION_GRPC_TRITON: nv-ingest-yolox:8001 - TABLE_DETECTION_HTTP_TRITON: "" - CACHED_GRPC_ENDPOINT: nv-ingest-cached:8001 - CACHED_HTTP_ENDPOINT: "" + CACHED_HTTP_ENDPOINT: http://nv-ingest-cached:8000/v1/infer + CACHED_INFER_ENDPOINT: grpc PADDLE_GRPC_ENDPOINT: nv-ingest-paddle:8001 - PADDLE_HTTP_ENDPOINT: "" + PADDLE_HTTP_ENDPOINT: http://nv-ingest-paddle:8000/v1/infer + PADDLE_INFER_PROTOCOL: grpc YOLOX_GRPC_ENDPOINT: nv-ingest-yolox:8001 - YOLOX_HTTP_ENDPOINT: "" + YOLOX_HTTP_ENDPOINT: http://nv-ingest-yolox:8000/v1/infer + YOLOX_INFER_PROTOCOL: grpc DEPLOT_GRPC_ENDPOINT: "" DEPLOT_HTTP_ENDPOINT: http://nv-ingest-deplot:8000/v1/chat/completions + DEPLOT_INFER_PROTOCOL: http EMBEDDING_NIM_ENDPOINT: "http://nv-ingest-embedding:8000/v1" MILVUS_ENDPOINT: "http://nv-ingest-milvus:19530" @@ -478,7 +492,7 @@ readinessProbe: ## @param service.labels [object] Specifies additional labels to be added to service. service: type: ClusterIP - port: 8000 + port: 7670 annotations: {} labels: {} name: "" # override the default service name diff --git a/requirements.txt b/requirements.txt index 91a15572..e0cd4296 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,33 +3,34 @@ charset-normalizer click dataclasses farm-haystack[ocr,inference,pdf,preprocessing,file-conversion] +fastapi==0.109.1 fastparquet==2024.2.0 fsspec +gunicorn==22.0.0 minio~=7.2.5 more_itertools nltk==3.9.1 -openai==1.40.6 +numpy olefile==0.47 -onnx==1.16.0 +openai==1.40.6 +onnx==1.17.0 opencv-python==4.10.0.84 opentelemetry-api opentelemetry-exporter-otlp opentelemetry-instrumentation -opentelemetry-instrumentation-fastapi opentelemetry-instrumentation-asgi +opentelemetry-instrumentation-fastapi opentelemetry-sdk pandas~=1.5.3 pydantic==1.10.14 pyinstrument pypdfium2 python-docx +python-multipart python-pptx==0.6.23 redis~=5.0.1 setuptools==70.0.0 tabulate torchvision==0.18.0 unstructured-client==0.23.3 -fastapi==0.109.1 uvicorn==0.24.0-post.1 -gunicorn==22.0.0 -python-multipart diff --git a/src/nv_ingest/api/main.py b/src/nv_ingest/api/main.py index 7800a61e..9beba4f6 100644 --- a/src/nv_ingest/api/main.py +++ b/src/nv_ingest/api/main.py @@ -9,6 +9,7 @@ # its affiliates is strictly prohibited. import logging +import os from opentelemetry import trace from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter @@ -24,7 +25,8 @@ trace.set_tracer_provider(TracerProvider()) tracer = trace.get_tracer(__name__) -exporter = OTLPSpanExporter(endpoint="otel-collector:4317", insecure=True) +otel_endpoint = os.getenv("OTEL_EXPORTER_OTLP_ENDPOINT", "otel-collector:4317") +exporter = OTLPSpanExporter(endpoint=otel_endpoint, insecure=True) span_processor = BatchSpanProcessor(exporter) trace.get_tracer_provider().add_span_processor(span_processor) diff --git a/src/nv_ingest/api/v1/health.py b/src/nv_ingest/api/v1/health.py index 6fe4a9cd..cdc76906 100644 --- a/src/nv_ingest/api/v1/health.py +++ b/src/nv_ingest/api/v1/health.py @@ -66,14 +66,14 @@ async def get_ready_state() -> dict: # for now to assume that if nv-ingest is running so is # the pipeline. morpheus_pipeline_ready = True - + # We give the users an option to disable checking all distributed services for "readiness" check_all_components = os.getenv("READY_CHECK_ALL_COMPONENTS", "True").lower() if check_all_components in ['1', 'true', 'yes']: - yolox_ready = is_ready(os.getenv("YOLOX_HEALTH_ENDPOINT", None), "/v1/health/ready") - deplot_ready = is_ready(os.getenv("DEPLOT_HEALTH_ENDPOINT", None), "/v1/health/ready") - cached_ready = is_ready(os.getenv("CACHED_HEALTH_ENDPOINT", None), "/v1/health/ready") - paddle_ready = is_ready(os.getenv("PADDLE_HEALTH_ENDPOINT", None), "/v1/health/ready") + yolox_ready = is_ready(os.getenv("YOLOX_HTTP_ENDPOINT", None), "/v1/health/ready") + deplot_ready = is_ready(os.getenv("DEPLOT_HTTP_ENDPOINT", None), "/v1/health/ready") + cached_ready = is_ready(os.getenv("CACHED_HTTP_ENDPOINT", None), "/v1/health/ready") + paddle_ready = is_ready(os.getenv("PADDLE_HTTP_ENDPOINT", None), "/v1/health/ready") if (ingest_ready and morpheus_pipeline_ready diff --git a/src/nv_ingest/api/v1/ingest.py b/src/nv_ingest/api/v1/ingest.py index 25a97818..0edc9511 100644 --- a/src/nv_ingest/api/v1/ingest.py +++ b/src/nv_ingest/api/v1/ingest.py @@ -10,26 +10,26 @@ # pylint: skip-file +from io import BytesIO +from typing import Annotated import base64 import json -from io import BytesIO import logging import time import traceback -from typing import Annotated -from opentelemetry import trace -from nv_ingest_client.primitives.jobs.job_spec import JobSpec -from fastapi import File, UploadFile from fastapi import APIRouter from fastapi import Depends +from fastapi import File, UploadFile from fastapi import HTTPException -from nv_ingest_client.primitives.tasks.extract import ExtractTask +from nv_ingest_client.primitives.jobs.job_spec import JobSpec +from opentelemetry import trace +from redis import RedisError +from nv_ingest_client.primitives.tasks.extract import ExtractTask from nv_ingest.schemas.message_wrapper_schema import MessageWrapper from nv_ingest.service.impl.ingest.redis_ingest_service import RedisIngestService from nv_ingest.service.meta.ingest.ingest_service_meta import IngestServiceMeta -from nv_ingest.schemas.ingest_job_schema import DocumentTypeEnum logger = logging.getLogger("uvicorn") tracer = trace.get_tracer(__name__) @@ -127,7 +127,6 @@ async def submit_job(job_spec: MessageWrapper, ingest_service: INGEST_SERVICE_T) # will be able to trace across uvicorn -> morpheus current_trace_id = trace.get_current_span().get_span_context().trace_id - # Recreate the JobSpec to test what is going on .... job_spec_dict = json.loads(job_spec.payload) job_spec_dict['tracing_options']['trace_id'] = current_trace_id job_spec_dict['tracing_options']['ts_http_done'] = time.time_ns() @@ -135,7 +134,7 @@ async def submit_job(job_spec: MessageWrapper, ingest_service: INGEST_SERVICE_T) updated_job_spec = MessageWrapper( payload=json.dumps(job_spec_dict) ) - + submitted_job_id = await ingest_service.submit_job(updated_job_spec) return submitted_job_id except Exception as ex: diff --git a/src/nv_ingest/extraction_workflows/docx/docxreader.py b/src/nv_ingest/extraction_workflows/docx/docxreader.py index 449aeb45..87f569cf 100644 --- a/src/nv_ingest/extraction_workflows/docx/docxreader.py +++ b/src/nv_ingest/extraction_workflows/docx/docxreader.py @@ -333,7 +333,7 @@ def _construct_image_metadata(self, image, para_idx, caption, base_unified_metad # For docx there is no bounding box. The paragraph that follows the image is typically # the caption. Add that para to the page nearby for now. fixme - bbox = (-1, -1, -1, -1) + bbox = (0, 0, 0, 0) page_nearby_blocks = { "text": {"content": [], "bbox": []}, "images": {"content": [], "bbox": []}, diff --git a/src/nv_ingest/extraction_workflows/pdf/doughnut_helper.py b/src/nv_ingest/extraction_workflows/pdf/doughnut_helper.py index 30027c45..b0239e93 100644 --- a/src/nv_ingest/extraction_workflows/pdf/doughnut_helper.py +++ b/src/nv_ingest/extraction_workflows/pdf/doughnut_helper.py @@ -123,9 +123,12 @@ def doughnut(pdf_stream, extract_text: bool, extract_images: bool, extract_table } pages = [] + page_sizes = [] for page_idx in range(pdf_metadata.page_count): page = doc.get_page(page_idx) pages.append(page) + page_width, page_height = doc.get_page_size(page_idx) + page_sizes.append((page_width, page_height)) # Split into batches. i = 0 @@ -147,6 +150,7 @@ def doughnut(pdf_stream, extract_text: bool, extract_images: bool, extract_table for page_idx, raw_text, bbox_offset in responses: page_image = None + page_width, page_height = page_sizes[page_idx] classes, bboxes, texts = doughnut_utils.extract_classes_bboxes(raw_text) @@ -173,7 +177,7 @@ def doughnut(pdf_stream, extract_text: bool, extract_images: bool, extract_table except UnicodeDecodeError: pass bbox = doughnut_utils.reverse_transform_bbox(bbox, bbox_offset) - table = LatexTable(latex=txt, bbox=bbox) + table = LatexTable(latex=txt, bbox=bbox, max_width=page_width, max_height=page_height) accumulated_tables.append(table) elif extract_images and (cls == "Picture"): @@ -190,7 +194,12 @@ def doughnut(pdf_stream, extract_text: bool, extract_images: bool, extract_table base64_img = numpy_to_base64(img_numpy) bbox = doughnut_utils.reverse_transform_bbox(bbox, bbox_offset) image = Base64Image( - image=base64_img, bbox=bbox, width=img_numpy.shape[1], height=img_numpy.shape[0] + image=base64_img, + bbox=bbox, + width=img_numpy.shape[1], + height=img_numpy.shape[0], + max_width=page_width, + max_height=page_height, ) accumulated_images.append(image) diff --git a/src/nv_ingest/extraction_workflows/pdf/pdfium_helper.py b/src/nv_ingest/extraction_workflows/pdf/pdfium_helper.py index ad22ce3b..7a1de0f1 100644 --- a/src/nv_ingest/extraction_workflows/pdf/pdfium_helper.py +++ b/src/nv_ingest/extraction_workflows/pdf/pdfium_helper.py @@ -17,6 +17,8 @@ # limitations under the License. import logging +import traceback + from math import log from typing import List from typing import Optional @@ -30,15 +32,12 @@ from nv_ingest.schemas.metadata_schema import AccessLevelEnum from nv_ingest.schemas.metadata_schema import TextTypeEnum from nv_ingest.schemas.pdf_extractor_schema import PDFiumConfigSchema -from nv_ingest.util.image_processing.table_and_chart import join_cached_and_deplot_output from nv_ingest.util.image_processing.transforms import crop_image from nv_ingest.util.image_processing.transforms import numpy_to_base64 -from nv_ingest.util.nim.helpers import call_image_inference_model from nv_ingest.util.nim.helpers import create_inference_client from nv_ingest.util.nim.helpers import perform_model_inference from nv_ingest.util.pdf.metadata_aggregators import Base64Image -from nv_ingest.util.pdf.metadata_aggregators import ImageChart -from nv_ingest.util.pdf.metadata_aggregators import ImageTable +from nv_ingest.util.pdf.metadata_aggregators import CroppedImageWithContent from nv_ingest.util.pdf.metadata_aggregators import construct_image_metadata from nv_ingest.util.pdf.metadata_aggregators import construct_table_and_chart_metadata from nv_ingest.util.pdf.metadata_aggregators import construct_text_metadata @@ -47,9 +46,6 @@ from nv_ingest.util.pdf.pdfium import pdfium_pages_to_numpy from nv_ingest.util.pdf.pdfium import pdfium_try_get_bitmap_as_numpy - -PADDLE_MIN_WIDTH = 32 -PADDLE_MIN_HEIGHT = 32 YOLOX_MAX_BATCH_SIZE = 8 YOLOX_MAX_WIDTH = 1536 YOLOX_MAX_HEIGHT = 1536 @@ -63,18 +59,16 @@ def extract_tables_and_charts_using_image_ensemble( - pages: List[libpdfium.PdfPage], - config: PDFiumConfigSchema, - max_batch_size: int = YOLOX_MAX_BATCH_SIZE, - num_classes: int = YOLOX_NUM_CLASSES, - conf_thresh: float = YOLOX_CONF_THRESHOLD, - iou_thresh: float = YOLOX_IOU_THRESHOLD, - min_score: float = YOLOX_MIN_SCORE, - final_thresh: float = YOLOX_FINAL_SCORE, - extract_tables: bool = True, - extract_charts: bool = True, - trace_info: Optional[List] = None, -) -> List[Tuple[int, ImageTable]]: + pages: List[libpdfium.PdfPage], + config: PDFiumConfigSchema, + max_batch_size: int = YOLOX_MAX_BATCH_SIZE, + num_classes: int = YOLOX_NUM_CLASSES, + conf_thresh: float = YOLOX_CONF_THRESHOLD, + iou_thresh: float = YOLOX_IOU_THRESHOLD, + min_score: float = YOLOX_MIN_SCORE, + final_thresh: float = YOLOX_FINAL_SCORE, + trace_info: Optional[List] = None, +) -> List[Tuple[int, CroppedImageWithContent]]: """ Extract tables and charts from a series of document pages using an ensemble of image-based models. @@ -134,66 +128,49 @@ def extract_tables_and_charts_using_image_ensemble( """ tables_and_charts = [] - if not extract_tables and not extract_charts: - logger.debug("Nothing to do since both extract_tables and extract_charts are set to false.") - return tables_and_charts - - yolox_client = paddle_client = deplot_client = cached_client = None + yolox_client = None try: yolox_client = create_inference_client(config.yolox_endpoints, config.auth_token) - if extract_tables: - paddle_client = create_inference_client(config.paddle_endpoints, config.auth_token) - if extract_charts: - cached_client = create_inference_client(config.cached_endpoints, config.auth_token) - deplot_client = create_inference_client(config.deplot_endpoints, config.auth_token) batches = [] i = 0 while i < len(pages): batch_size = min(2 ** int(log(len(pages) - i, 2)), max_batch_size) - batches.append(pages[i : i + batch_size]) # noqa: E203 + batches.append(pages[i: i + batch_size]) # noqa: E203 i += batch_size - page_idx = 0 + page_index = 0 for batch in batches: original_images, _ = pdfium_pages_to_numpy( batch, scale_tuple=(YOLOX_MAX_WIDTH, YOLOX_MAX_HEIGHT), trace_info=trace_info ) + # original images is an implicitly indexed list of pages original_image_shapes = [image.shape for image in original_images] input_array = prepare_images_for_inference(original_images) output_array = perform_model_inference(yolox_client, "yolox", input_array, trace_info=trace_info) - results = process_inference_results( + # Get back inference results + yolox_annotated_detections = process_inference_results( output_array, original_image_shapes, num_classes, conf_thresh, iou_thresh, min_score, final_thresh ) - for annotation_dict, original_image in zip(results, original_images): - handle_table_chart_extraction( + for annotation_dict, original_image in zip(yolox_annotated_detections, original_images): + extract_table_and_chart_images( annotation_dict, original_image, - page_idx, - paddle_client, - deplot_client, - cached_client, + page_index, tables_and_charts, - extract_tables=extract_tables, - extract_charts=extract_charts, - trace_info=trace_info, ) - page_idx += 1 + page_index += 1 + except Exception as e: - logger.error(f"Error during table/chart extraction: {str(e)}") - raise + logger.error(f"Unhandled error during table/chart extraction: {str(e)}") + traceback.print_exc() + raise e finally: - if isinstance(paddle_client, grpcclient.InferenceServerClient): - paddle_client.close() - if isinstance(cached_client, grpcclient.InferenceServerClient): - cached_client.close() - if isinstance(deplot_client, grpcclient.InferenceServerClient): - deplot_client.close() if isinstance(yolox_client, grpcclient.InferenceServerClient): yolox_client.close() @@ -234,13 +211,13 @@ def prepare_images_for_inference(images: List[np.ndarray]) -> np.ndarray: def process_inference_results( - output_array: np.ndarray, - original_image_shapes: List[Tuple[int, int]], - num_classes: int, - conf_thresh: float, - iou_thresh: float, - min_score: float, - final_thresh: float, + output_array: np.ndarray, + original_image_shapes: List[Tuple[int, int]], + num_classes: int, + conf_thresh: float, + iou_thresh: float, + min_score: float, + final_thresh: float, ): """ Process the model output to generate detection results and expand bounding boxes. @@ -289,6 +266,7 @@ def process_inference_results( annotation_dicts = [yolox_utils.expand_chart_bboxes(annotation_dict) for annotation_dict in results] inference_results = [] + # Filter out bounding boxes below the final threshold for annotation_dict in annotation_dicts: new_dict = {} if "table" in annotation_dict: @@ -303,17 +281,11 @@ def process_inference_results( # Handle individual table/chart extraction and model inference -def handle_table_chart_extraction( - annotation_dict, - original_image, - page_idx, - paddle_client, - deplot_client, - cached_client, - tables_and_charts, - extract_tables=True, - extract_charts=True, - trace_info=None, +def extract_table_and_chart_images( + annotation_dict, + original_image, + page_idx, + tables_and_charts, ): """ Handle the extraction of tables and charts from the inference results and run additional model inference. @@ -326,12 +298,6 @@ def handle_table_chart_extraction( The original image from which objects were detected. page_idx : int The index of the current page being processed. - paddle_client : grpcclient.InferenceServerClient - The gRPC client for the paddle model used to process tables. - deplot_client : grpcclient.InferenceServerClient - The gRPC client for the deplot model used to process charts. - cached_client : grpcclient.InferenceServerClient - The gRPC client for the cached model used to process charts. tables_and_charts : List[Tuple[int, ImageTable]] A list to which extracted tables and charts will be appended. @@ -346,8 +312,7 @@ def handle_table_chart_extraction( >>> annotation_dict = {"table": [], "chart": []} >>> original_image = np.random.rand(1536, 1536, 3) >>> tables_and_charts = [] - >>> handle_table_chart_extraction(annotation_dict, original_image, 0, paddle_client, deplot_client, cached_client, - tables_and_charts) + >>> extract_table_and_chart_images(annotation_dict, original_image, 0, tables_and_charts) """ width, height, *_ = original_image.shape @@ -360,42 +325,26 @@ def handle_table_chart_extraction( *bbox, _ = bboxes h1, w1, h2, w2 = bbox * np.array([height, width, height, width]) - if extract_tables and label == "table": - # PaddleOCR NIM enforces minimum dimensions for TRT engines. - cropped = crop_image( - original_image, - (h1, w1, h2, w2), - min_width=PADDLE_MIN_WIDTH, - min_height=PADDLE_MIN_HEIGHT, - ) - base64_img = numpy_to_base64(cropped) - - table_content = call_image_inference_model(paddle_client, "paddle", cropped, trace_info=trace_info) - table_data = ImageTable(table_content, base64_img, (w1, h1, w2, h2)) - tables_and_charts.append((page_idx, table_data)) - elif extract_charts and label == "chart": - cropped = crop_image(original_image, (h1, w1, h2, w2)) - base64_img = numpy_to_base64(cropped) + cropped = crop_image(original_image, (h1, w1, h2, w2)) + base64_img = numpy_to_base64(cropped) - deplot_result = call_image_inference_model( - deplot_client, "google/deplot", cropped, trace_info=trace_info - ) - cached_result = call_image_inference_model(cached_client, "cached", cropped, trace_info=trace_info) - chart_content = join_cached_and_deplot_output(cached_result, deplot_result) - chart_data = ImageChart(chart_content, base64_img, (w1, h1, w2, h2)) - tables_and_charts.append((page_idx, chart_data)) + table_data = CroppedImageWithContent( + content="", image=base64_img, bbox=(w1, h1, w2, h2), max_width=width, + max_height=height, type_string=label + ) + tables_and_charts.append((page_idx, table_data)) # Define a helper function to use unstructured-io to extract text from a base64 -# encoded bytestram PDF +# encoded bytestream PDF def pdfium( - pdf_stream, - extract_text: bool, - extract_images: bool, - extract_tables: bool, - extract_charts: bool, - trace_info=None, - **kwargs, + pdf_stream, + extract_text: bool, + extract_images: bool, + extract_tables: bool, + extract_charts: bool, + trace_info=None, + **kwargs, ): """ Helper function to use pdfium to extract text from a bytestream PDF. @@ -472,6 +421,7 @@ def pdfium( text_depth = text_depth if text_depth == TextTypeEnum.PAGE else TextTypeEnum.DOCUMENT for page_idx in range(pdf_metadata.page_count): page = doc.get_page(page_idx) + page_width, page_height = doc.get_page_size(page_idx) # https://pypdfium2.readthedocs.io/en/stable/python_api.html#module-pypdfium2._helpers.textpage if extract_text: @@ -503,11 +453,14 @@ def pdfium( if obj_type == "IMAGE": try: # Attempt to retrieve the image bitmap - image_numpy: np.ndarray = pdfium_try_get_bitmap_as_numpy(obj) + image_numpy: np.ndarray = pdfium_try_get_bitmap_as_numpy(obj) # noqa image_base64: str = numpy_to_base64(image_numpy) image_bbox = obj.get_pos() image_size = obj.get_size() - image_data = Base64Image(image_base64, image_bbox, image_size[0], image_size[1]) + image_data = Base64Image( + image=image_base64, bbox=image_bbox, width=image_size[0], height=image_size[1], + max_width=page_width, max_height=page_height + ) extracted_image_data = construct_image_metadata( image_data, @@ -519,7 +472,7 @@ def pdfium( extracted_data.append(extracted_image_data) except Exception as e: - logger.error(f"Error extracting image: {e}") + logger.error(f"Unhandled error extracting image: {e}") pass # Pdfium failed to extract the image associated with this object - corrupt or missing. # Table and chart collection @@ -544,11 +497,9 @@ def pdfium( if extract_tables or extract_charts: for page_idx, table_and_charts in extract_tables_and_charts_using_image_ensemble( - pages, - pdfium_config, - extract_tables=extract_tables, - extract_charts=extract_charts, - trace_info=trace_info, + pages, + pdfium_config, + trace_info=trace_info, ): extracted_data.append( construct_table_and_chart_metadata( diff --git a/src/nv_ingest/extraction_workflows/pdf/yolox_utils.py b/src/nv_ingest/extraction_workflows/pdf/yolox_utils.py index dfc0c632..e8458ea0 100644 --- a/src/nv_ingest/extraction_workflows/pdf/yolox_utils.py +++ b/src/nv_ingest/extraction_workflows/pdf/yolox_utils.py @@ -113,6 +113,7 @@ def postprocess_results(results, original_image_shapes, min_score=0.0): out.append(annotation_dict) + # {label: [[x1, y1, x2, y2, confidence], ...], ...} return out diff --git a/src/nv_ingest/modules/filters/image_dedup.py b/src/nv_ingest/modules/filters/image_dedup.py index de817236..4d82d499 100644 --- a/src/nv_ingest/modules/filters/image_dedup.py +++ b/src/nv_ingest/modules/filters/image_dedup.py @@ -137,7 +137,7 @@ def _apply_dedup_filter(ctrl_msg: ControlMessage, filter_flag): gdf.drop(labels=["info_message_metadata", "metadata"], inplace=True, axis=1) gdf["info_message_metadata"] = duplicate_images_gdf["info_message_metadata"] gdf.loc[duplicate_images_gdf["document_type"].index, "document_type"] = ContentTypeEnum.INFO_MSG.value - gdf["metadata"] = gdf[exploded_metadata_cols + ["info_message_metadata"]].to_struct() + gdf["metadata"] = gdf[exploded_metadata_cols].to_struct() gdf.drop(labels=gdf.columns.difference(base_cols), inplace=True, axis=1) message_meta = MessageMeta(df=gdf) diff --git a/src/nv_ingest/modules/filters/image_filter.py b/src/nv_ingest/modules/filters/image_filter.py index 74579132..cfb98276 100644 --- a/src/nv_ingest/modules/filters/image_filter.py +++ b/src/nv_ingest/modules/filters/image_filter.py @@ -153,7 +153,7 @@ def _apply_filter(ctrl_msg: ControlMessage, task_params: dict): mdf.loc[ filtered_images_gdf["document_type"].index, "document_type" ] = ContentTypeEnum.INFO_MSG.value # noqa - mdf["metadata"] = mdf[exploded_metadata_cols + ["info_message_metadata"]].to_struct() # noqa + mdf["metadata"] = mdf[exploded_metadata_cols].to_struct() # noqa mdf.drop(labels=mdf.columns.difference(base_cols), inplace=True, axis=1) # noqa diff --git a/src/nv_ingest/modules/sinks/redis_task_sink.py b/src/nv_ingest/modules/sinks/redis_task_sink.py index e4fd047b..86f6ac25 100644 --- a/src/nv_ingest/modules/sinks/redis_task_sink.py +++ b/src/nv_ingest/modules/sinks/redis_task_sink.py @@ -50,8 +50,9 @@ def extract_data_frame(message: ControlMessage) -> Tuple[Any, Dict[str, Any]]: with message.payload().mutable_dataframe() as mdf: logger.debug(f"Redis Sink Received DataFrame with {len(mdf)} rows.") keep_cols = ["document_type", "metadata"] - return mdf, mdf[keep_cols].to_dict(orient="records") - except Exception: + return mdf, mdf[keep_cols].to_pandas().to_dict(orient="records") + except Exception as err: + logger.warning(f"Failed to extract DataFrame from message payload: {err}") return None, None @@ -118,7 +119,7 @@ def create_json_payload(message: ControlMessage, df_json: Dict[str, Any]) -> Lis df_json_size = sys.getsizeof(df_json_str) # 256 MB size limit (in bytes) - size_limit = 256 * 1024 * 1024 + size_limit = 128 * 1024 * 1024 # If df_json is larger than the size limit, split it into chunks if df_json_size > size_limit: @@ -306,13 +307,13 @@ def process_and_forward(message: ControlMessage, redis_client: RedisClient) -> C annotate_cm(message, message="Pushed") push_to_redis(redis_client, response_channel, json_payloads) except RedisError as e: - mdf_size = len(mdf) if mdf else 0 + mdf_size = len(mdf) if not mdf.empty else 0 handle_failure(redis_client, response_channel, json_result_fragments, e, mdf_size) except Exception as e: traceback.print_exc() logger.error(f"Critical error processing message: {e}") - mdf_size = len(mdf) if mdf else 0 + mdf_size = len(mdf) if not mdf.empty else 0 handle_failure(redis_client, response_channel, json_result_fragments, e, mdf_size) return message diff --git a/src/nv_ingest/modules/sinks/vdb_task_sink.py b/src/nv_ingest/modules/sinks/vdb_task_sink.py index cddd3adc..d133a0e0 100644 --- a/src/nv_ingest/modules/sinks/vdb_task_sink.py +++ b/src/nv_ingest/modules/sinks/vdb_task_sink.py @@ -10,9 +10,9 @@ import mrc from morpheus.messages import ControlMessage -from morpheus.service.vdb.milvus_client import DATA_TYPE_MAP -from morpheus.service.vdb.utils import VectorDBServiceFactory -from morpheus.service.vdb.vector_db_service import VectorDBService +from morpheus_llm.service.vdb.milvus_client import DATA_TYPE_MAP +from morpheus_llm.service.vdb.utils import VectorDBServiceFactory +from morpheus_llm.service.vdb.vector_db_service import VectorDBService from morpheus.utils.control_message_utils import cm_skip_processing_if_failed from morpheus.utils.module_ids import WRITE_TO_VECTOR_DB from morpheus.utils.module_utils import ModuleLoaderFactory @@ -211,6 +211,7 @@ def extract_df(ctrl_msg: ControlMessage, filter_errors: bool): mdf["embedding"] = mdf["metadata"].struct.field("embedding") mdf["_source_metadata"] = mdf["metadata"].struct.field("source_metadata") + mdf["_content_metadata"] = mdf["metadata"].struct.field("content_metadata") df = mdf[mdf["_contains_embeddings"]].copy() df = df[ @@ -218,9 +219,10 @@ def extract_df(ctrl_msg: ControlMessage, filter_errors: bool): "embedding", "_content", "_source_metadata", + "_content_metadata", ] ] - df.columns = ["vector", "text", "source"] + df.columns = ["vector", "text", "source", "content_metadata"] return df, resource_name diff --git a/src/nv_ingest/modules/sources/redis_task_source.py b/src/nv_ingest/modules/sources/redis_task_source.py index 085d9045..142b48ae 100644 --- a/src/nv_ingest/modules/sources/redis_task_source.py +++ b/src/nv_ingest/modules/sources/redis_task_source.py @@ -4,10 +4,12 @@ import logging +import time import traceback from datetime import datetime from functools import partial from typing import Dict +import copy, json import cudf import mrc @@ -30,7 +32,6 @@ MODULE_NAMESPACE = "nv_ingest" RedisTaskSourceLoaderFactory = ModuleLoaderFactory(MODULE_NAME, MODULE_NAMESPACE) - def fetch_and_process_messages(redis_client: RedisClient, validated_config: RedisTaskSourceSchema): """Fetch messages from the Redis list and process them.""" @@ -53,10 +54,12 @@ def process_message(job: Dict, ts_fetched: datetime) -> ControlMessage: Fetch messages from the Redis list (task queue) and yield as ControlMessage. """ + if logger.isEnabledFor(logging.DEBUG): + no_payload = copy.deepcopy(job) + no_payload["job_payload"]["content"] = ["[...]"] # Redact the payload for logging + logger.debug("Job: %s", json.dumps(no_payload, indent=2)) + validate_ingest_job(job) - # no_payload = copy.deepcopy(job) - # no_payload["job_payload"]["content"] = ["[...]"] # Redact the payload for logging - # logger.debug("Job: %s", json.dumps(no_payload, indent=2)) control_message = ControlMessage() try: diff --git a/src/nv_ingest/schemas/chart_extractor_schema.py b/src/nv_ingest/schemas/chart_extractor_schema.py new file mode 100644 index 00000000..ce6dd0cc --- /dev/null +++ b/src/nv_ingest/schemas/chart_extractor_schema.py @@ -0,0 +1,140 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import logging +from typing import Optional, Tuple + +from pydantic import BaseModel, root_validator, validator + +logger = logging.getLogger(__name__) + + +class ChartExtractorConfigSchema(BaseModel): + """ + Configuration schema for chart extraction service endpoints and options. + + Parameters + ---------- + auth_token : Optional[str], default=None + Authentication token required for secure services. + + cached_endpoints : Tuple[Optional[str], Optional[str]], default=(None, None) + A tuple containing the gRPC and HTTP services for the cached endpoint. + Either the gRPC or HTTP service can be empty, but not both. + + deplot_endpoints : Tuple[Optional[str], Optional[str]], default=(None, None) + A tuple containing the gRPC and HTTP services for the deplot endpoint. + Either the gRPC or HTTP service can be empty, but not both. + + paddle_endpoints : Tuple[Optional[str], Optional[str]], default=(None, None) + A tuple containing the gRPC and HTTP services for the paddle endpoint. + Either the gRPC or HTTP service can be empty, but not both. + + Methods + ------- + validate_endpoints(values) + Validates that at least one of the gRPC or HTTP services is provided for each endpoint. + + Raises + ------ + ValueError + If both gRPC and HTTP services are empty for any endpoint. + + Config + ------ + extra : str + Pydantic config option to forbid extra fields. + """ + + auth_token: Optional[str] = None + + cached_endpoints: Tuple[Optional[str], Optional[str]] = (None, None) + cached_infer_protocol: str = "" + + deplot_endpoints: Tuple[Optional[str], Optional[str]] = (None, None) + deplot_infer_protocol: str = "" + + ## NOTE: Paddle isn't currently called independently of the cached NIM, but will be in the future. + paddle_endpoints: Tuple[Optional[str], Optional[str]] = (None, None) + paddle_infer_protocol: str = "" + + @root_validator(pre=True) + def validate_endpoints(cls, values): + """ + Validates the gRPC and HTTP services for all endpoints. + + Ensures that at least one service (either gRPC or HTTP) is provided + for each endpoint in the configuration. + + Parameters + ---------- + values : dict + Dictionary containing the values of the attributes for the class. + + Returns + ------- + dict + The validated dictionary of values. + + Raises + ------ + ValueError + If both gRPC and HTTP services are empty for any endpoint. + """ + + def clean_service(service): + """Set service to None if it's an empty string or contains only spaces or quotes.""" + if service is None or not service.strip() or service.strip(" \"'") == "": + return None + return service + + for endpoint_name in ["cached_endpoints", "deplot_endpoints", "paddle_endpoints"]: + grpc_service, http_service = values.get(endpoint_name, (None, None)) + grpc_service = clean_service(grpc_service) + http_service = clean_service(http_service) + + if not grpc_service and not http_service: + raise ValueError(f"Both gRPC and HTTP services cannot be empty for {endpoint_name}.") + + values[endpoint_name] = (grpc_service, http_service) + + return values + + class Config: + extra = "forbid" + + +class ChartExtractorSchema(BaseModel): + """ + Configuration schema for chart extraction processing settings. + + Parameters + ---------- + max_queue_size : int, default=1 + The maximum number of items allowed in the processing queue. + + n_workers : int, default=2 + The number of worker threads to use for processing. + + raise_on_failure : bool, default=False + A flag indicating whether to raise an exception if a failure occurs during chart extraction. + + stage_config : Optional[ChartExtractorConfigSchema], default=None + Configuration for the chart extraction stage, including cached, deplot, and paddle service endpoints. + """ + + max_queue_size: int = 1 + n_workers: int = 2 + raise_on_failure: bool = False + + stage_config: Optional[ChartExtractorConfigSchema] = None + + @validator('max_queue_size', 'n_workers', pre=True, always=True) + def check_positive(cls, v, field): + if v <= 0: + raise ValueError(f"{field.name} must be greater than 10.") + return v + + class Config: + extra = "forbid" diff --git a/src/nv_ingest/schemas/ingest_job_schema.py b/src/nv_ingest/schemas/ingest_job_schema.py index b657f75b..6864551b 100644 --- a/src/nv_ingest/schemas/ingest_job_schema.py +++ b/src/nv_ingest/schemas/ingest_job_schema.py @@ -44,6 +44,8 @@ class TaskTypeEnum(str, Enum): split = "split" store = "store" vdb_upload = "vdb_upload" + table_data_extract = "table_data_extract" + chart_data_extract = "chart_data_extract" class FilterTypeEnum(str, Enum): @@ -129,6 +131,12 @@ class IngestTaskVdbUploadSchema(BaseModelNoExt): filter_errors: bool = True +class IngestTaskTableExtraction(BaseModelNoExt): + params: Dict = {} + +class IngestChartTableExtraction(BaseModelNoExt): + params: Dict = {} + class IngestTaskSchema(BaseModelNoExt): type: TaskTypeEnum task_properties: Union[ @@ -140,6 +148,8 @@ class IngestTaskSchema(BaseModelNoExt): IngestTaskDedupSchema, IngestTaskFilterSchema, IngestTaskVdbUploadSchema, + IngestTaskTableExtraction, + IngestChartTableExtraction ] raise_on_failure: bool = False @@ -156,6 +166,8 @@ def check_task_properties_type(cls, values): TaskTypeEnum.split: IngestTaskSplitSchema, TaskTypeEnum.store: IngestTaskStoreSchema, TaskTypeEnum.vdb_upload: IngestTaskVdbUploadSchema, + TaskTypeEnum.table_data_extract: IngestTaskTableExtraction, + TaskTypeEnum.chart_data_extract: IngestChartTableExtraction, }.get(task_type.lower()) # logger.debug(f"Checking task_properties type for task type '{task_type}'") diff --git a/src/nv_ingest/schemas/metadata_schema.py b/src/nv_ingest/schemas/metadata_schema.py index 53305dde..6a51c83d 100644 --- a/src/nv_ingest/schemas/metadata_schema.py +++ b/src/nv_ingest/schemas/metadata_schema.py @@ -4,6 +4,7 @@ from datetime import datetime +import logging from enum import Enum from typing import Any from typing import Dict @@ -17,6 +18,8 @@ from nv_ingest.schemas.base_model_noext import BaseModelNoExt from nv_ingest.util.converters import datetools +logger = logging.getLogger(__name__) + # Do we want types and similar items to be enums or just strings? class SourceTypeEnum(str, Enum): @@ -246,22 +249,53 @@ class TextMetadataSchema(BaseModelNoExt): text_location: tuple = (0, 0, 0, 0) +import logging +from pydantic import validator + +# Set up logging +logger = logging.getLogger(__name__) + + class ImageMetadataSchema(BaseModelNoExt): image_type: Union[ImageTypeEnum, str] structured_image_type: ImageTypeEnum = ImageTypeEnum.image_type_1 caption: str = "" text: str = "" image_location: tuple = (0, 0, 0, 0) + image_location_max_dimensions: tuple = (0, 0) uploaded_image_url: str = "" width: int = 0 height: int = 0 + @validator("image_type", pre=True, always=True) + def validate_image_type(cls, v): + if not isinstance(v, (ImageTypeEnum, str)): + raise ValueError("image_type must be a string or ImageTypeEnum") + return v + + @validator("width", "height", pre=True, always=True) + def clamp_non_negative(cls, v, field): + if v < 0: + logger.warning(f"{field.name} is negative; clamping to 0. Original value: {v}") + return 0 + return v + class TableMetadataSchema(BaseModelNoExt): caption: str = "" table_format: TableFormatEnum table_content: str = "" table_location: tuple = (0, 0, 0, 0) + table_location_max_dimensions: tuple = (0, 0) + uploaded_image_uri: str = "" + + +class ChartMetadataSchema(BaseModelNoExt): + caption: str = "" + table_format: TableFormatEnum + table_content: str = "" + table_location: tuple = (0, 0, 0, 0) + table_location_max_dimensions: tuple = (0, 0) uploaded_image_uri: str = "" @@ -289,6 +323,7 @@ class MetadataSchema(BaseModelNoExt): text_metadata: Optional[TextMetadataSchema] = None image_metadata: Optional[ImageMetadataSchema] = None table_metadata: Optional[TableMetadataSchema] = None + chart_metadata: Optional[ChartMetadataSchema] = None error_metadata: Optional[ErrorMetadataSchema] = None info_message_metadata: Optional[InfoMessageMetadataSchema] = None debug_metadata: Optional[Dict[str, Any]] = None diff --git a/src/nv_ingest/schemas/pdf_extractor_schema.py b/src/nv_ingest/schemas/pdf_extractor_schema.py index 2826bee4..9d0f538a 100644 --- a/src/nv_ingest/schemas/pdf_extractor_schema.py +++ b/src/nv_ingest/schemas/pdf_extractor_schema.py @@ -22,25 +22,10 @@ class PDFiumConfigSchema(BaseModel): auth_token : Optional[str], default=None Authentication token required for secure services. - cached_endpoints : Tuple[str, str] - A tuple containing the gRPC and HTTP services for the cached endpoint. - Either the gRPC or HTTP service can be empty, but not both. - - deplot_endpoints : Tuple[str, str] - A tuple containing the gRPC and HTTP services for the deplot endpoint. - Either the gRPC or HTTP service can be empty, but not both. - - paddle_endpoints : Tuple[str, str] - A tuple containing the gRPC and HTTP services for the paddle endpoint. - Either the gRPC or HTTP service can be empty, but not both. - yolox_endpoints : Tuple[str, str] A tuple containing the gRPC and HTTP services for the yolox endpoint. Either the gRPC or HTTP service can be empty, but not both. - identify_nearby_objects : bool, default=False - A flag indicating whether to identify nearby objects during processing. - Methods ------- validate_endpoints(values) @@ -59,12 +44,8 @@ class PDFiumConfigSchema(BaseModel): auth_token: Optional[str] = None - cached_endpoints: Tuple[Optional[str], Optional[str]] = (None, None) - deplot_endpoints: Tuple[Optional[str], Optional[str]] = (None, None) - paddle_endpoints: Tuple[Optional[str], Optional[str]] = (None, None) yolox_endpoints: Tuple[Optional[str], Optional[str]] = (None, None) - - identify_nearby_objects: bool = False + yolox_infer_protocol: str = "" @root_validator(pre=True) def validate_endpoints(cls, values): @@ -93,7 +74,8 @@ def clean_service(service): return None return service - for endpoint_name in ["cached_endpoints", "deplot_endpoints", "paddle_endpoints", "yolox_endpoints"]: + for model_name in ["yolox"]: + endpoint_name = f"{model_name}_endpoints" grpc_service, http_service = values.get(endpoint_name) grpc_service = clean_service(grpc_service) http_service = clean_service(http_service) @@ -103,6 +85,13 @@ def clean_service(service): values[endpoint_name] = (grpc_service, http_service) + protocol_name = f"{model_name}_infer_protocol" + protocol_value = values.get(protocol_name) + if not protocol_value: + protocol_value = "http" if http_service else "grpc" if grpc_service else "" + protocol_value = protocol_value.lower() + values[protocol_name] = protocol_value + return values class Config: diff --git a/src/nv_ingest/schemas/table_extractor_schema.py b/src/nv_ingest/schemas/table_extractor_schema.py new file mode 100644 index 00000000..043a055e --- /dev/null +++ b/src/nv_ingest/schemas/table_extractor_schema.py @@ -0,0 +1,122 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + + +import logging + +from typing import Optional, Tuple +from pydantic import BaseModel, root_validator, validator + +logger = logging.getLogger(__name__) + + +class TableExtractorConfigSchema(BaseModel): + """ + Configuration schema for the table extraction stage settings. + + Parameters + ---------- + auth_token : Optional[str], default=None + Authentication token required for secure services. + + paddle_endpoints : Tuple[Optional[str], Optional[str]], default=(None, None) + A tuple containing the gRPC and HTTP services for the paddle endpoint. + Either the gRPC or HTTP service can be empty, but not both. + + Methods + ------- + validate_endpoints(values) + Validates that at least one of the gRPC or HTTP services is provided for the yolox endpoint. + + Raises + ------ + ValueError + If both gRPC and HTTP services are empty for the yolox endpoint. + + Config + ------ + extra : str + Pydantic config option to forbid extra fields. + """ + + auth_token: Optional[str] = None + + paddle_endpoints: Tuple[Optional[str], Optional[str]] = (None, None) + paddle_infer_protocol: str = "" + + @root_validator(pre=True) + def validate_endpoints(cls, values): + """ + Validates the gRPC and HTTP services for the yolox endpoint. + + Parameters + ---------- + values : dict + Dictionary containing the values of the attributes for the class. + + Returns + ------- + dict + The validated dictionary of values. + + Raises + ------ + ValueError + If both gRPC and HTTP services are empty for the yolox endpoint. + """ + + def clean_service(service): + """Set service to None if it's an empty string or contains only spaces or quotes.""" + if service is None or not service.strip() or service.strip(" \"'") == "": + return None + return service + + grpc_service, http_service = values.get("paddle_endpoints", (None, None)) + grpc_service = clean_service(grpc_service) + http_service = clean_service(http_service) + + if not grpc_service and not http_service: + raise ValueError("Both gRPC and HTTP services cannot be empty for paddle_endpoints.") + + values["paddle_endpoints"] = (grpc_service, http_service) + + return values + + class Config: + extra = "forbid" + + +class TableExtractorSchema(BaseModel): + """ + Configuration schema for the table extraction processing settings. + + Parameters + ---------- + max_queue_size : int, default=1 + The maximum number of items allowed in the processing queue. + + n_workers : int, default=2 + The number of worker threads to use for processing. + + raise_on_failure : bool, default=False + A flag indicating whether to raise an exception if a failure occurs during table extraction. + + stage_config : Optional[TableExtractorConfigSchema], default=None + Configuration for the table extraction stage, including yolox service endpoints. + """ + + max_queue_size: int = 1 + n_workers: int = 2 + raise_on_failure: bool = False + + @validator('max_queue_size', 'n_workers', pre=True, always=True) + def check_positive(cls, v, field): + if v <= 0: + raise ValueError(f"{field.name} must be greater than 10.") + return v + + stage_config: Optional[TableExtractorConfigSchema] = None + + class Config: + extra = "forbid" diff --git a/src/nv_ingest/schemas/vdb_task_sink_schema.py b/src/nv_ingest/schemas/vdb_task_sink_schema.py index c5faadee..231c62bb 100644 --- a/src/nv_ingest/schemas/vdb_task_sink_schema.py +++ b/src/nv_ingest/schemas/vdb_task_sink_schema.py @@ -37,10 +37,11 @@ def build_default_milvus_config(embedding_size: int = 1024) -> typing.Dict[str, "index_conf": { "field_name": "vector", "metric_type": "L2", - "index_type": "HNSW", + "index_type": "GPU_CAGRA", "params": { - "M": 8, - "efConstruction": 64, + 'intermediate_graph_degree':128, + 'graph_degree': 64, + "build_algo": "NN_DESCENT", }, }, "schema_conf": { @@ -67,6 +68,11 @@ def build_default_milvus_config(embedding_size: int = 1024) -> typing.Dict[str, dtype=pymilvus.DataType.JSON, description="Source document and raw data extracted content", ).to_dict(), + pymilvus.FieldSchema( + name="content_metadata", + dtype=pymilvus.DataType.JSON, + description="Content metadata", + ).to_dict(), ], "description": "NV-INGEST collection schema", }, diff --git a/src/nv_ingest/stages/nim/__init__.py b/src/nv_ingest/stages/nim/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/nv_ingest/stages/nim/chart_extraction.py b/src/nv_ingest/stages/nim/chart_extraction.py new file mode 100644 index 00000000..46339228 --- /dev/null +++ b/src/nv_ingest/stages/nim/chart_extraction.py @@ -0,0 +1,191 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import logging +import functools +import pandas as pd +from typing import Any +from typing import Dict +from typing import Optional +from typing import Tuple + +import tritonclient.grpc as grpcclient +from morpheus.config import Config + +from nv_ingest.schemas.chart_extractor_schema import ChartExtractorSchema +from nv_ingest.stages.multiprocessing_stage import MultiProcessingBaseStage +from nv_ingest.util.image_processing.table_and_chart import join_cached_and_deplot_output +from nv_ingest.util.image_processing.transforms import base64_to_numpy +from nv_ingest.util.nim.helpers import call_image_inference_model, create_inference_client + +logger = logging.getLogger(f"morpheus.{__name__}") + + +def _update_metadata(row: pd.Series, cached_client: Any, deplot_client: Any, trace_info: Dict) -> Dict: + """ + Modifies the metadata of a row if the conditions for chart extraction are met. + + Parameters + ---------- + row : pd.Series + A row from the DataFrame containing metadata for the chart extraction. + + cached_client : Any + The client used to call the cached inference model. + + deplot_client : Any + The client used to call the deplot inference model. + + trace_info : Dict + Trace information used for logging or debugging. + + Returns + ------- + Dict + The modified metadata if conditions are met, otherwise the original metadata. + + Raises + ------ + ValueError + If critical information (such as metadata) is missing from the row. + """ + metadata = row.get("metadata") + if metadata is None: + logger.error("Row does not contain 'metadata'.") + raise ValueError("Row does not contain 'metadata'.") + + base64_image = metadata.get("content") + content_metadata = metadata.get("content_metadata", {}) + chart_metadata = metadata.get("table_metadata") + + # Only modify if content type is structured and subtype is 'chart' and chart_metadata exists + if ((content_metadata.get("type") != "structured") or + (content_metadata.get("subtype") != "chart") or + (chart_metadata is None)): + return metadata + + # Modify chart metadata with the result from the inference model + try: + image_array = base64_to_numpy(base64_image) + + deplot_result = call_image_inference_model(deplot_client, "deplot", image_array, trace_info=trace_info) + cached_result = call_image_inference_model(cached_client, "cached", image_array, trace_info=trace_info) + chart_content = join_cached_and_deplot_output(cached_result, deplot_result) + + chart_metadata["table_content"] = chart_content + except Exception as e: + logger.error(f"Unhandled error calling image inference model: {e}", exc_info=True) + raise + + return metadata + + +def _extract_chart_data(df: pd.DataFrame, task_props: Dict[str, Any], + validated_config: Any, trace_info: Optional[Dict] = None) -> Tuple[pd.DataFrame, Dict]: + """ + Extracts chart data from a DataFrame. + + Parameters + ---------- + df : pd.DataFrame + DataFrame containing the content from which chart data is to be extracted. + + task_props : Dict[str, Any] + Dictionary containing task properties and configurations. + + validated_config : Any + The validated configuration object for chart extraction. + + trace_info : Optional[Dict], optional + Optional trace information for debugging or logging. Defaults to None. + + Returns + ------- + Tuple[pd.DataFrame, Dict] + A tuple containing the updated DataFrame and the trace information. + + Raises + ------ + Exception + If any error occurs during the chart data extraction process. + """ + + _ = task_props # unused + + deplot_client = create_inference_client( + validated_config.stage_config.deplot_endpoints, + validated_config.stage_config.auth_token, + validated_config.stage_config.deplot_infer_protocol + ) + + cached_client = create_inference_client( + validated_config.stage_config.cached_endpoints, + validated_config.stage_config.auth_token, + validated_config.stage_config.cached_infer_protocol + ) + + if trace_info is None: + trace_info = {} + logger.debug("No trace_info provided. Initialized empty trace_info dictionary.") + + try: + # Apply the _update_metadata function to each row in the DataFrame + df["metadata"] = df.apply(_update_metadata, axis=1, args=(cached_client, deplot_client, trace_info)) + + return df, trace_info + + except Exception as e: + logger.error("Error occurred while extracting chart data.", exc_info=True) + raise + finally: + if (isinstance(cached_client, grpcclient.InferenceServerClient)): + cached_client.close() + if (isinstance(deplot_client, grpcclient.InferenceServerClient)): + deplot_client.close() + + +def generate_chart_extractor_stage( + c: Config, + stage_config: Dict[str, Any], + task: str = "chart_data_extract", + task_desc: str = "chart_data_extraction", + pe_count: int = 1, +): + """ + Generates a multiprocessing stage to perform chart data extraction from PDF content. + + Parameters + ---------- + c : Config + Morpheus global configuration object. + + stage_config : Dict[str, Any] + Configuration parameters for the chart content extractor, passed as a dictionary + validated against the `ChartExtractorSchema`. + + task : str, optional + The task name for the stage worker function, defining the specific chart extraction process. + Default is "chart_data_extract". + + task_desc : str, optional + A descriptor used for latency tracing and logging during chart extraction. + Default is "chart_data_extraction". + + pe_count : int, optional + The number of process engines to use for chart data extraction. This value controls + how many worker processes will run concurrently. Default is 1. + + Returns + ------- + MultiProcessingBaseStage + A configured Morpheus stage with an applied worker function that handles chart data extraction + from PDF content. + """ + + validated_config = ChartExtractorSchema(**stage_config) + _wrapped_process_fn = functools.partial(_extract_chart_data, validated_config=validated_config) + + return MultiProcessingBaseStage( + c=c, pe_count=pe_count, task=task, task_desc=task_desc, process_fn=_wrapped_process_fn + ) diff --git a/src/nv_ingest/stages/nim/table_extraction.py b/src/nv_ingest/stages/nim/table_extraction.py new file mode 100644 index 00000000..9154e11d --- /dev/null +++ b/src/nv_ingest/stages/nim/table_extraction.py @@ -0,0 +1,182 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import logging +import functools +import pandas as pd +from typing import Any +from typing import Dict +from typing import Optional +from typing import Tuple + +import tritonclient.grpc as grpcclient +from morpheus.config import Config +from nv_ingest.schemas.table_extractor_schema import TableExtractorSchema +from nv_ingest.stages.multiprocessing_stage import MultiProcessingBaseStage +from nv_ingest.util.nim.helpers import call_image_inference_model, create_inference_client +from nv_ingest.util.image_processing.transforms import base64_to_numpy +from nv_ingest.util.image_processing.transforms import check_numpy_image_size + +logger = logging.getLogger(f"morpheus.{__name__}") + +PADDLE_MIN_WIDTH = 32 +PADDLE_MIN_HEIGHT = 32 + + +def _update_metadata(row: pd.Series, paddle_client: Any, trace_info: Dict) -> Dict: + """ + Modifies the metadata of a row if the conditions for table extraction are met. + + Parameters + ---------- + row : pd.Series + A row from the DataFrame containing metadata for the table extraction. + + paddle_client : Any + The client used to call the image inference model. + + trace_info : Dict + Trace information used for logging or debugging. + + Returns + ------- + Dict + The modified metadata if conditions are met, otherwise the original metadata. + + Raises + ------ + ValueError + If critical information (such as metadata) is missing from the row. + """ + + metadata = row.get("metadata") + if metadata is None: + logger.error("Row does not contain 'metadata'.") + raise ValueError("Row does not contain 'metadata'.") + + base64_image = metadata.get("content") + content_metadata = metadata.get("content_metadata", {}) + table_metadata = metadata.get("table_metadata") + + # Only modify if content type is structured and subtype is 'table' and table_metadata exists + if ((content_metadata.get("type") != "structured") or + (content_metadata.get("subtype") != "table") or + (table_metadata is None)): + return metadata + + # Modify table metadata with the result from the inference model + try: + image_array = base64_to_numpy(base64_image) + paddle_result = "" + if check_numpy_image_size(image_array, PADDLE_MIN_WIDTH, PADDLE_MIN_HEIGHT): + paddle_result = call_image_inference_model(paddle_client, "paddle", image_array, trace_info=trace_info) + + table_metadata["table_content"] = paddle_result + except Exception as e: + logger.error(f"Unhandled error calling image inference model: {e}", exc_info=True) + raise + + return metadata + + +def _extract_table_data(df: pd.DataFrame, task_props: Dict[str, Any], + validated_config: Any, trace_info: Optional[Dict] = None) -> Tuple[pd.DataFrame, Dict]: + """ + Extracts table data from a DataFrame. + + Parameters + ---------- + df : pd.DataFrame + DataFrame containing the content from which table data is to be extracted. + + task_props : Dict[str, Any] + Dictionary containing task properties and configurations. + + validated_config : Any + The validated configuration object for table extraction. + + trace_info : Optional[Dict], optional + Optional trace information for debugging or logging. Defaults to None. + + Returns + ------- + Tuple[pd.DataFrame, Dict] + A tuple containing the updated DataFrame and the trace information. + + Raises + ------ + Exception + If any error occurs during the table data extraction process. + """ + + _ = task_props # unused + + paddle_client = create_inference_client( + validated_config.stage_config.paddle_endpoints, + validated_config.stage_config.auth_token, + validated_config.stage_config.paddle_infer_protocol + ) + + if trace_info is None: + trace_info = {} + logger.debug("No trace_info provided. Initialized empty trace_info dictionary.") + + try: + # Apply the _update_metadata function to each row in the DataFrame + df["metadata"] = df.apply(_update_metadata, axis=1, args=(paddle_client, trace_info)) + + return df, trace_info + + except Exception as e: + logger.error("Error occurred while extracting table data.", exc_info=True) + raise + finally: + if (isinstance(paddle_client, grpcclient.InferenceServerClient)): + paddle_client.close() + + +def generate_table_extractor_stage( + c: Config, + stage_config: Dict[str, Any], + task: str = "table_data_extract", + task_desc: str = "table_data_extraction", + pe_count: int = 1, +): + """ + Generates a multiprocessing stage to perform table data extraction from PDF content. + + Parameters + ---------- + c : Config + Morpheus global configuration object. + + stage_config : Dict[str, Any] + Configuration parameters for the table content extractor, passed as a dictionary + validated against the `TableExtractorSchema`. + + task : str, optional + The task name for the stage worker function, defining the specific table extraction process. + Default is "table_data_extract". + + task_desc : str, optional + A descriptor used for latency tracing and logging during table extraction. + Default is "table_data_extraction". + + pe_count : int, optional + The number of process engines to use for table data extraction. This value controls + how many worker processes will run concurrently. Default is 1. + + Returns + ------- + MultiProcessingBaseStage + A configured Morpheus stage with an applied worker function that handles table data extraction + from PDF content. + """ + + validated_config = TableExtractorSchema(**stage_config) + _wrapped_process_fn = functools.partial(_extract_table_data, validated_config=validated_config) + + return MultiProcessingBaseStage( + c=c, pe_count=pe_count, task=task, task_desc=task_desc, process_fn=_wrapped_process_fn + ) diff --git a/src/nv_ingest/stages/pdf_extractor_stage.py b/src/nv_ingest/stages/pdf_extractor_stage.py index 072fce49..c5aaef85 100644 --- a/src/nv_ingest/stages/pdf_extractor_stage.py +++ b/src/nv_ingest/stages/pdf_extractor_stage.py @@ -61,7 +61,7 @@ def decode_and_extract( try: base64_content = base64_row["content"] except KeyError: - log_error_message = f"NO CONTENT FOUND IN ROW:\n{base64_row}" + log_error_message = f"Unhandled error processing row, no content was found:\n{base64_row}" logger.error(log_error_message) raise diff --git a/src/nv_ingest/util/image_processing/table_and_chart.py b/src/nv_ingest/util/image_processing/table_and_chart.py index 9e79c9d7..72e73934 100644 --- a/src/nv_ingest/util/image_processing/table_and_chart.py +++ b/src/nv_ingest/util/image_processing/table_and_chart.py @@ -43,7 +43,13 @@ def join_cached_and_deplot_output(cached_text, deplot_text): if (cached_text is not None): try: - cached_text_dict = json.loads(cached_text) + if isinstance(cached_text, str): + cached_text_dict = json.loads(cached_text) + elif isinstance(cached_text, dict): + cached_text_dict = cached_text + else: + cached_text_dict = {} + chart_content += cached_text_dict.get("chart_title", "") if (deplot_text is not None): diff --git a/src/nv_ingest/util/image_processing/transforms.py b/src/nv_ingest/util/image_processing/transforms.py index 6b0acb03..d441db72 100644 --- a/src/nv_ingest/util/image_processing/transforms.py +++ b/src/nv_ingest/util/image_processing/transforms.py @@ -2,6 +2,7 @@ # All rights reserved. # SPDX-License-Identifier: Apache-2.0 +import base64 from io import BytesIO from math import ceil from math import floor @@ -10,6 +11,7 @@ import numpy as np from PIL import Image +from PIL import UnidentifiedImageError from nv_ingest.util.converters import bytetools @@ -18,7 +20,11 @@ def pad_image( - array: np.ndarray, target_width: int = DEFAULT_MAX_WIDTH, target_height: int = DEFAULT_MAX_HEIGHT + array: np.ndarray, + target_width: int = DEFAULT_MAX_WIDTH, + target_height: int = DEFAULT_MAX_HEIGHT, + background_color: int = 255, + dtype=np.uint8, ) -> Tuple[np.ndarray, Tuple[int, int]]: """ Pads a NumPy array representing an image to the specified target dimensions. @@ -68,11 +74,29 @@ def pad_image( final_width = max(width, target_width) # Create the canvas and place the original image on it - canvas = 255 * np.ones((final_height, final_width, array.shape[2]), dtype=np.uint8) + canvas = background_color * np.ones((final_height, final_width, array.shape[2]), dtype=dtype) canvas[pad_height : pad_height + height, pad_width : pad_width + width] = array # noqa: E203 return canvas, (pad_width, pad_height) +def check_numpy_image_size(image: np.ndarray, min_height: int, min_width: int) -> bool: + """ + Checks if the height and width of the image are larger than the specified minimum values. + + Parameters: + image (np.ndarray): The image array (assumed to be in shape (H, W, C) or (H, W)). + min_height (int): The minimum height required. + min_width (int): The minimum width required. + + Returns: + bool: True if the image dimensions are larger than or equal to the minimum size, False otherwise. + """ + # Check if the image has at least 2 dimensions + if image.ndim < 2: + raise ValueError("The input array does not have sufficient dimensions for an image.") + + height, width = image.shape[:2] + return height >= min_height and width >= min_width def crop_image( array: np.array, bbox: Tuple[int, int, int, int], min_width: int = 1, min_height: int = 1 @@ -113,6 +137,66 @@ def crop_image( return cropped +def normalize_image( + array: np.ndarray, + r_mean: float = 0.485, + g_mean: float = 0.456, + b_mean: float = 0.406, + r_std: float = 0.229, + g_std: float = 0.224, + b_std: float = 0.225, +) -> np.ndarray: + """ + Normalizes an RGB image by applying a mean and standard deviation to each channel. + + Parameters: + ---------- + array : np.ndarray + The input image array, which can be either grayscale or RGB. The image should have a shape of + (height, width, 3) for RGB images, or (height, width) or (height, width, 1) for grayscale images. + If a grayscale image is provided, it will be converted to RGB format by repeating the grayscale values + across all three channels (R, G, B). + r_mean : float, optional + The mean to be subtracted from the red channel (default is 0.485). + g_mean : float, optional + The mean to be subtracted from the green channel (default is 0.456). + b_mean : float, optional + The mean to be subtracted from the blue channel (default is 0.406). + r_std : float, optional + The standard deviation to divide the red channel by (default is 0.229). + g_std : float, optional + The standard deviation to divide the green channel by (default is 0.224). + b_std : float, optional + The standard deviation to divide the blue channel by (default is 0.225). + + Returns: + ------- + np.ndarray + A normalized image array with the same shape as the input, where the RGB channels have been normalized + by the given means and standard deviations. + + Notes: + ----- + The input pixel values should be in the range [0, 255], and the function scales these values to [0, 1] + before applying normalization. + + If the input image is grayscale, it is converted to an RGB image by duplicating the grayscale values + across the three color channels. + """ + # If the input is a grayscale image with shape (height, width) or (height, width, 1), + # convert it to RGB with shape (height, width, 3). + if array.ndim == 2 or array.shape[2] == 1: + array = np.dstack((array, 255 * np.ones_like(array), 255 * np.ones_like(array))) + + height, width = array.shape[:2] + + mean = np.array([r_mean, g_mean, b_mean]).reshape((1, 1, 3)).astype(np.float32) + std = np.array([r_std, g_std, b_std]).reshape((1, 1, 3)).astype(np.float32) + output_array = (array.astype("float32") / 255.0 - mean) / std + + return output_array + + def numpy_to_base64(array: np.ndarray) -> str: """ Converts a NumPy array representing an image to a base64-encoded string. @@ -168,3 +252,51 @@ def numpy_to_base64(array: np.ndarray) -> str: raise RuntimeError(f"Failed to encode image to base64: {e}") return base64_img + + +def base64_to_numpy(base64_string: str) -> np.ndarray: + """ + Convert a base64-encoded image string to a NumPy array. + + Parameters + ---------- + base64_string : str + Base64-encoded string representing an image. + + Returns + ------- + numpy.ndarray + NumPy array representation of the decoded image. + + Raises + ------ + ValueError + If the base64 string is invalid or cannot be decoded into an image. + ImportError + If required libraries are not installed. + + Examples + -------- + >>> base64_str = '/9j/4AAQSkZJRgABAQAAAQABAAD/2wBD...' + >>> img_array = base64_to_numpy(base64_str) + """ + try: + # Decode the base64 string + image_data = base64.b64decode(base64_string) + except (base64.binascii.Error, ValueError) as e: + raise ValueError("Invalid base64 string") from e + + try: + # Convert the bytes into a BytesIO object + image_bytes = BytesIO(image_data) + + # Open the image using PIL + image = Image.open(image_bytes) + image.load() + except UnidentifiedImageError as e: + raise ValueError("Unable to decode image from base64 string") from e + + # Convert the image to a NumPy array + image_array = np.array(image) + + return image_array diff --git a/src/nv_ingest/util/nim/decorators.py b/src/nv_ingest/util/nim/decorators.py new file mode 100644 index 00000000..869c9e54 --- /dev/null +++ b/src/nv_ingest/util/nim/decorators.py @@ -0,0 +1,52 @@ +import logging +from functools import wraps +from multiprocessing import Lock +from multiprocessing import Manager + +logger = logging.getLogger(__name__) + +# Create a shared manager and lock for thread-safe access +manager = Manager() +global_cache = manager.dict() +lock = Lock() + + +def multiprocessing_cache(max_calls): + """ + A decorator that creates a global cache shared between multiple processes. + The cache is invalidated after `max_calls` number of accesses. + + Args: + max_calls (int): The number of calls after which the cache is cleared. + + Returns: + function: The decorated function with global cache and invalidation logic. + """ + + def decorator(func): + call_count = manager.Value("i", 0) # Shared integer for call counting + + @wraps(func) + def wrapper(*args, **kwargs): + key = (func.__name__, args, frozenset(kwargs.items())) + + with lock: + call_count.value += 1 + + if call_count.value > max_calls: + global_cache.clear() + call_count.value = 0 + + if key in global_cache: + return global_cache[key] + + result = func(*args, **kwargs) + + with lock: + global_cache[key] = result + + return result + + return wrapper + + return decorator diff --git a/src/nv_ingest/util/nim/helpers.py b/src/nv_ingest/util/nim/helpers.py index 80af504c..99c4a506 100644 --- a/src/nv_ingest/util/nim/helpers.py +++ b/src/nv_ingest/util/nim/helpers.py @@ -3,21 +3,37 @@ # SPDX-License-Identifier: Apache-2.0 import logging +import re +from typing import Any +from typing import Dict from typing import Optional from typing import Tuple +import backoff +import cv2 import numpy as np -import re +import packaging import requests import tritonclient.grpc as grpcclient +from nv_ingest.util.image_processing.transforms import normalize_image from nv_ingest.util.image_processing.transforms import numpy_to_base64 +from nv_ingest.util.image_processing.transforms import pad_image +from nv_ingest.util.nim.decorators import multiprocessing_cache from nv_ingest.util.tracing.tagging import traceable_func logger = logging.getLogger(__name__) +DEPLOT_MAX_TOKENS = 128 +DEPLOT_TEMPERATURE = 1.0 +DEPLOT_TOP_P = 1.0 + -def create_inference_client(endpoints: Tuple[str, str], auth_token: Optional[str]): +def create_inference_client( + endpoints: Tuple[str, str], + auth_token: Optional[str] = None, + infer_protocol: Optional[str] = None, +): """ Creates an inference client based on the provided endpoints. @@ -35,18 +51,26 @@ def create_inference_client(endpoints: Tuple[str, str], auth_token: Optional[str ------- grpcclient.InferenceServerClient or dict A gRPC client if the gRPC endpoint is provided, otherwise a dictionary containing the HTTP client details. + :param infer_protocol: """ - if endpoints[0] and endpoints[0].strip(): - logger.debug(f"Creating gRPC client with {endpoints}") - return grpcclient.InferenceServerClient(url=endpoints[0]) - else: - logger.debug(f"Creating HTTP client with {endpoints}") + grpc_endpoint, http_endpoint = endpoints + + if (infer_protocol is None) and (grpc_endpoint and grpc_endpoint.strip()): + infer_protocol = "grpc" + + if infer_protocol == "grpc": + logger.debug(f"Creating gRPC client with {grpc_endpoint}") + return grpcclient.InferenceServerClient(url=grpc_endpoint) + elif infer_protocol == "http": + url = generate_url(http_endpoint) + + logger.debug(f"Creating HTTP client with {http_endpoint}") headers = {"accept": "application/json", "content-type": "application/json"} if auth_token: headers["Authorization"] = f"Bearer {auth_token}" - return {"endpoint_url": endpoints[1], "headers": headers} + return {"endpoint_url": url, "headers": headers} @traceable_func(trace_name="pdf_content_extractor::{model_name}") @@ -76,62 +100,113 @@ def call_image_inference_model(client, model_name: str, image_data): If the HTTP request fails or if the response format is not as expected. """ if isinstance(client, grpcclient.InferenceServerClient): - if image_data.ndim == 3: - image_data = np.expand_dims(image_data, axis=0) - inputs = [grpcclient.InferInput("input", image_data.shape, "FP32")] - inputs[0].set_data_from_numpy(image_data.astype(np.float32)) + response = _call_image_inference_grpc_client(client, model_name, image_data) + else: + response = _call_image_inference_http_client(client, model_name, image_data) + + return response + + +def _call_image_inference_grpc_client(client, model_name: str, image_data): + if image_data.ndim == 3: + image_data = np.expand_dims(image_data, axis=0) + inputs = [grpcclient.InferInput("input", image_data.shape, "FP32")] + inputs[0].set_data_from_numpy(image_data.astype(np.float32)) + + outputs = [grpcclient.InferRequestedOutput("output")] + + try: + result = client.infer(model_name=model_name, inputs=inputs, outputs=outputs) + return " ".join([output[0].decode("utf-8") for output in result.as_numpy("output")]) + except Exception as e: + err_msg = f"Inference failed for model {model_name}: {str(e)}" + logger.error(err_msg) + raise RuntimeError(err_msg) + + +def _call_image_inference_http_client(client, model_name: str, image_data): + base64_img = numpy_to_base64(image_data) + + if model_name == "deplot": + payload = _prepare_deplot_payload(base64_img) + elif model_name in {"paddle", "cached", "yolox"}: + payload = _prepare_nim_payload(base64_img) + else: + raise ValueError(f"Model {model_name} is not supported.") + + try: + url = client["endpoint_url"] + headers = client["headers"] + + response = requests.post(url, json=payload, headers=headers) + response.raise_for_status() # Raise an exception for HTTP errors - outputs = [grpcclient.InferRequestedOutput("output")] + # Parse the JSON response + json_response = response.json() - try: - result = client.infer(model_name=model_name, inputs=inputs, outputs=outputs) - return " ".join([output[0].decode("utf-8") for output in result.as_numpy("output")]) - except Exception as e: - err_msg = f"Inference failed for model {model_name}: {str(e)}" - logger.error(err_msg) - raise RuntimeError(err_msg) + except requests.exceptions.RequestException as e: + raise RuntimeError(f"HTTP request failed: {e}") + except KeyError as e: + raise RuntimeError(f"Missing expected key in response: {e}") + except Exception as e: + raise RuntimeError(f"An error occurred during inference: {e}") + if model_name == "deplot": + result = _extract_content_from_deplot_response(json_response) else: - base64_img = numpy_to_base64(image_data) - - try: - url = client["endpoint_url"] - headers = client["headers"] - - messages = [ - { - "role": "user", - "content": f"Generate the underlying data table of the figure below: " - f'', - } - ] - payload = { - "model": model_name, - "messages": messages, - "max_tokens": 128, - "stream": False, - "temperature": 1.0, - "top_p": 1.0, - } - - response = requests.post(url, json=payload, headers=headers) - response.raise_for_status() # Raise an exception for HTTP errors - - # Parse the JSON response - json_response = response.json() - - # Validate the response structure - if "choices" not in json_response or not json_response["choices"]: - raise RuntimeError("Unexpected response format: 'choices' key is missing or empty.") - - return json_response["choices"][0]["message"]["content"] - - except requests.exceptions.RequestException as e: - raise RuntimeError(f"HTTP request failed: {e}") - except KeyError as e: - raise RuntimeError(f"Missing expected key in response: {e}") - except Exception as e: - raise RuntimeError(f"An error occurred during inference: {e}") + result = _extract_content_from_nim_response(json_response) + + return result + + +def _prepare_deplot_payload( + base64_img: str, + max_tokens: int = DEPLOT_MAX_TOKENS, + temperature: float = DEPLOT_TEMPERATURE, + top_p: float = DEPLOT_TOP_P, +) -> Dict[str, Any]: + messages = [ + { + "role": "user", + "content": f"Generate the underlying data table of the figure below: " + f'', + } + ] + payload = { + "model": "google/deplot", + "messages": messages, + "max_tokens": max_tokens, + "stream": False, + "temperature": temperature, + "top_p": top_p, + } + + return payload + + +def _prepare_nim_payload(base64_img: str) -> Dict[str, Any]: + image_url = f"data:image/png;base64,{base64_img}" + image = {"type": "image_url", "image_url": {"url": image_url}} + + message = {"content": [image]} + payload = {"messages": [message]} + + return payload + + +def _extract_content_from_deplot_response(json_response): + # Validate the response structure + if "choices" not in json_response or not json_response["choices"]: + raise RuntimeError("Unexpected response format: 'choices' key is missing or empty.") + + return json_response["choices"][0]["message"]["content"] + + +def _extract_content_from_nim_response(json_response): + if "data" not in json_response or not json_response["data"]: + raise RuntimeError("Unexpected response format: 'data' key is missing or empty.") + + return json_response["data"][0]["content"] # Perform inference and return predictions @@ -172,6 +247,63 @@ def perform_model_inference(client, model_name: str, input_array: np.ndarray): return query_response.as_numpy("output") +def preprocess_image_for_paddle(array: np.ndarray, paddle_version: Optional[str] = None) -> np.ndarray: + """ + Preprocesses an input image to be suitable for use with PaddleOCR by resizing, normalizing, padding, + and transposing it into the required format. + + This function is intended for preprocessing images to be passed as input to PaddleOCR using GRPC. + It is not necessary when using the HTTP endpoint. + + Steps: + ----- + 1. Resizes the image while maintaining aspect ratio such that its largest dimension is scaled to 960 pixels. + 2. Normalizes the image using the `normalize_image` function. + 3. Pads the image to ensure both its height and width are multiples of 32, as required by PaddleOCR. + 4. Transposes the image from (height, width, channel) to (channel, height, width), the format expected by PaddleOCR. + + Parameters: + ---------- + array : np.ndarray + The input image array of shape (height, width, channels). It should have pixel values in the range [0, 255]. + + Returns: + ------- + np.ndarray + A preprocessed image with the shape (channels, height, width) and normalized pixel values. + The image will be padded to have dimensions that are multiples of 32, with the padding color set to 0. + + Notes: + ----- + - The image is resized so that its largest dimension becomes 960 pixels, maintaining the aspect ratio. + - After normalization, the image is padded to the nearest multiple of 32 in both dimensions, which is + a requirement for PaddleOCR. + - The normalized pixel values are scaled between 0 and 1 before padding and transposing the image. + """ + if (not paddle_version) or (packaging.version.parse(paddle_version) < packaging.version.parse("0.2.0-rc1")): + return array + + height, width = array.shape[:2] + scale_factor = 960 / max(height, width) + new_height = int(height * scale_factor) + new_width = int(width * scale_factor) + resized = cv2.resize(array, (new_width, new_height)) + + normalized = normalize_image(resized) + + # PaddleOCR NIM (GRPC) requires input shapes to be multiples of 32. + new_height = (normalized.shape[0] + 31) // 32 * 32 + new_width = (normalized.shape[1] + 31) // 32 * 32 + padded, _ = pad_image( + normalized, target_height=new_height, target_width=new_width, background_color=0, dtype=np.float32 + ) + + # PaddleOCR NIM (GRPC) requires input to be (channel, height, width). + transposed = padded.transpose((2, 0, 1)) + + return transposed + + def remove_url_endpoints(url) -> str: """Some configurations provide the full endpoint in the URL. Ex: http://deplot:8000/v1/chat/completions. For hitting the @@ -185,8 +317,8 @@ def remove_url_endpoints(url) -> str: Returns: str: URL with just the hostname:port portion remaining """ - if '/v1' in url: - url = url.split('/v1')[0] + if "/v1" in url: + url = url.split("/v1")[0] return url @@ -204,26 +336,28 @@ def generate_url(url) -> str: Returns: str: Fully validated URL """ - if not re.match(r'^https?://', url): + if not re.match(r"^https?://", url): # Add the default `http://` if its not already present in the URL url = f"http://{url}" - url = remove_url_endpoints(url) - return url def is_ready(http_endpoint, ready_endpoint) -> bool: - # IF the url is empty or None that means the service was not configured # and is therefore automatically marked as "ready" - if http_endpoint is None or http_endpoint == '': + if http_endpoint is None or http_endpoint == "": + return True + + # If the url is for build.nvidia.com, it is automatically assumed "ready" + if "ai.api.nvidia.com" in http_endpoint: return True url = generate_url(http_endpoint) + url = remove_url_endpoints(url) - if not ready_endpoint.startswith('/') and not url.endswith('/'): - ready_endpoint = '/' + ready_endpoint + if not ready_endpoint.startswith("/") and not url.endswith("/"): + ready_endpoint = "/" + ready_endpoint url = url + ready_endpoint @@ -258,3 +392,46 @@ def is_ready(http_endpoint, ready_endpoint) -> bool: # Don't let anything squeeze by logger.warning(f"Exception: {ex}") return False + + +@backoff.on_predicate(backoff.expo, max_value=5) +@multiprocessing_cache(max_calls=100) +def get_version(http_endpoint, metadata_endpoint="/v1/metadata", version_field="version") -> str: + if http_endpoint is None or http_endpoint == "": + return "" + + url = generate_url(http_endpoint) + url = remove_url_endpoints(url) + + if not metadata_endpoint.startswith("/") and not url.endswith("/"): + metadata_endpoint = "/" + metadata_endpoint + + url = url + metadata_endpoint + + # Call the metadata endpoint of the NIM + try: + # Use a short timeout to prevent long hanging calls. 5 seconds seems resonable + resp = requests.get(url, timeout=5) + if resp.status_code == 200: + return resp.json().get(version_field, "") + else: + # Any other code is confusing. We should log it with a warning + # as it could be something that might hold up ready state + logger.warning(f"'{url}' HTTP Status: {resp.status_code} - Response Payload: {resp.json()}") + return "" + except requests.HTTPError as http_err: + logger.warning(f"'{url}' produced a HTTP error: {http_err}") + return "" + except requests.Timeout: + logger.warning(f"'{url}' request timed out") + return "" + except ConnectionError: + logger.warning(f"A connection error for '{url}' occurred") + return "" + except requests.RequestException as err: + logger.warning(f"An error occurred: {err} for '{url}'") + return "" + except Exception as ex: + # Don't let anything squeeze by + logger.warning(f"Exception: {ex}") + return "" diff --git a/src/nv_ingest/util/pdf/metadata_aggregators.py b/src/nv_ingest/util/pdf/metadata_aggregators.py index 68fc39ba..851a931b 100644 --- a/src/nv_ingest/util/pdf/metadata_aggregators.py +++ b/src/nv_ingest/util/pdf/metadata_aggregators.py @@ -10,7 +10,6 @@ from typing import Dict from typing import List from typing import Tuple -from typing import Union import pandas as pd import pypdfium2 as pdfium @@ -26,30 +25,23 @@ from nv_ingest.util.exception_handlers.pdf import pdfium_exception_handler +# TODO(Devin): Shift to this, since there is no difference between ImageTable and ImageChart @dataclass -class DataFrameTable: - df: pd.DataFrame - bbox: Tuple[int, int, int, int] - - -@dataclass -class ImageTable: - content: str - image: str - bbox: Tuple[int, int, int, int] - - -@dataclass -class ImageChart: +class CroppedImageWithContent: content: str image: str bbox: Tuple[int, int, int, int] + max_width: int + max_height: int + type_string: str @dataclass class LatexTable: latex: pd.DataFrame bbox: Tuple[int, int, int, int] + max_width: int + max_height: int @dataclass @@ -58,6 +50,8 @@ class Base64Image: bbox: Tuple[int, int, int, int] width: int height: int + max_width: int + max_height: int @dataclass @@ -137,16 +131,16 @@ def extract_pdf_metadata(doc: pdfium.PdfDocument, source_id: str) -> PDFMetadata def construct_text_metadata( - accumulated_text, - keywords, - page_idx, - block_idx, - line_idx, - span_idx, - page_count, - text_depth, - source_metadata, - base_unified_metadata, + accumulated_text, + keywords, + page_idx, + block_idx, + line_idx, + span_idx, + page_count, + text_depth, + source_metadata, + base_unified_metadata, ): extracted_text = " ".join(accumulated_text) @@ -193,11 +187,11 @@ def construct_text_metadata( def construct_image_metadata( - image_base64: Base64Image, - page_idx: int, - page_count: int, - source_metadata: Dict[str, Any], - base_unified_metadata: Dict[str, Any], + image_base64: Base64Image, + page_idx: int, + page_count: int, + source_metadata: Dict[str, Any], + base_unified_metadata: Dict[str, Any], ) -> List[Any]: """ Extracts image data from a PdfImage object, converts it to a base64-encoded string, @@ -252,7 +246,7 @@ def construct_image_metadata( "caption": "", "text": "", "image_location": image_base64.bbox, - "width": image_base64.width, + "image_location_max_dimensions": (max(image_base64.max_width,0), max(image_base64.max_height,0)), "height": image_base64.height, } @@ -275,11 +269,11 @@ def construct_image_metadata( # TODO(Devin): Disambiguate tables and charts, create two distinct processing methods @pdfium_exception_handler(descriptor="pdfium") def construct_table_and_chart_metadata( - table: Union[DataFrameTable, ImageTable, ImageChart], - page_idx: int, - page_count: int, - source_metadata: Dict, - base_unified_metadata: Dict, + structured_image: CroppedImageWithContent, + page_idx: int, + page_count: int, + source_metadata: Dict, + base_unified_metadata: Dict, ): """ +--------------------------------+--------------------------+------------+---+ @@ -309,29 +303,25 @@ def construct_table_and_chart_metadata( +--------------------------------+--------------------------+------------+---+ """ - if isinstance(table, DataFrameTable): - content = table.df.to_markdown(index=False) - structured_content_text = content - table_format = TableFormatEnum.MARKDOWN - subtype = ContentSubtypeEnum.TABLE - description = StdContentDescEnum.PDF_TABLE - - elif isinstance(table, ImageTable): - content = table.image - structured_content_text = table.content + if (structured_image.type_string in ("table",)): + content = structured_image.image + structured_content_text = structured_image.content table_format = TableFormatEnum.IMAGE subtype = ContentSubtypeEnum.TABLE description = StdContentDescEnum.PDF_TABLE + meta_name = "table_metadata" - elif isinstance(table, ImageChart): - content = table.image - structured_content_text = table.content + elif (structured_image.type_string in ("chart",)): + content = structured_image.image + structured_content_text = structured_image.content table_format = TableFormatEnum.IMAGE subtype = ContentSubtypeEnum.CHART description = StdContentDescEnum.PDF_CHART + # TODO(Devin) swap this to chart_metadata after we confirm metadata schema changes. + meta_name = "table_metadata" else: - raise ValueError("Unknown table/chart type.") + raise ValueError(f"Unknown table/chart type: {structured_image.type_string}") content_metadata = { "type": ContentTypeEnum.STRUCTURED, @@ -346,11 +336,12 @@ def construct_table_and_chart_metadata( "subtype": subtype, } - table_metadata = { + structured_metadata = { "caption": "", "table_format": table_format, "table_content": structured_content_text, - "table_location": table.bbox, + "table_location": structured_image.bbox, + "table_location_max_dimensions": (structured_image.max_width, structured_image.max_height), } ext_unified_metadata = base_unified_metadata.copy() @@ -360,7 +351,7 @@ def construct_table_and_chart_metadata( "content": content, "source_metadata": source_metadata, "content_metadata": content_metadata, - "table_metadata": table_metadata, + meta_name: structured_metadata, } ) diff --git a/src/nv_ingest/util/pipeline/__init__.py b/src/nv_ingest/util/pipeline/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/nv_ingest/util/pipeline/stage_builders.py b/src/nv_ingest/util/pipeline/stage_builders.py new file mode 100644 index 00000000..ee8ae3fe --- /dev/null +++ b/src/nv_ingest/util/pipeline/stage_builders.py @@ -0,0 +1,501 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + + +import math +import os +import logging +import typing + +import click +from morpheus.messages import ControlMessage +from morpheus.stages.general.linear_modules_source import LinearModuleSourceStage +from morpheus.stages.general.linear_modules_stage import LinearModulesStage + +from nv_ingest.modules.injectors.metadata_injector import MetadataInjectorLoaderFactory +from nv_ingest.modules.sinks.redis_task_sink import RedisTaskSinkLoaderFactory +from nv_ingest.modules.sinks.vdb_task_sink import VDBTaskSinkLoaderFactory +from nv_ingest.modules.sources.redis_task_source import RedisTaskSourceLoaderFactory +from nv_ingest.modules.telemetry.job_counter import JobCounterLoaderFactory +from nv_ingest.modules.telemetry.otel_meter import OpenTelemetryMeterLoaderFactory +from nv_ingest.modules.telemetry.otel_tracer import OpenTelemetryTracerLoaderFactory +from nv_ingest.modules.transforms.embed_extractions import EmbedExtractionsLoaderFactory +from nv_ingest.modules.transforms.nemo_doc_splitter import NemoDocSplitterLoaderFactory +from nv_ingest.stages.docx_extractor_stage import generate_docx_extractor_stage +from nv_ingest.stages.filters import generate_dedup_stage +from nv_ingest.stages.filters import generate_image_filter_stage +from nv_ingest.stages.nim.chart_extraction import generate_chart_extractor_stage +from nv_ingest.stages.nim.table_extraction import generate_table_extractor_stage +from nv_ingest.stages.pdf_extractor_stage import generate_pdf_extractor_stage +from nv_ingest.stages.pptx_extractor_stage import generate_pptx_extractor_stage +from nv_ingest.stages.storages.image_storage_stage import ImageStorageStage +from nv_ingest.stages.transforms.image_caption_extraction import generate_caption_extraction_stage + +logger = logging.getLogger(__name__) + + +def validate_positive(ctx, param, value): + if value <= 0: + raise click.BadParameter("must be a positive integer") + return value + + +def get_message_provider_config(): + message_provider_host = os.environ.get("MESSAGE_CLIENT_HOST", "localhost") + message_provider_port = os.environ.get("MESSAGE_CLIENT_PORT", "6379") + + logger.info(f"MESSAGE_CLIENT_HOST: {message_provider_host}") + logger.info(f"MESSAGE_CLIENT_PORT: {message_provider_port}") + + return message_provider_host, message_provider_port + + +def get_caption_classifier_service(): + triton_service_caption_classifier = os.environ.get( + "CAPTION_CLASSIFIER_GRPC_TRITON", + "", + ) + triton_service_caption_classifier_name = os.environ.get( + "CAPTION_CLASSIFIER_MODEL_NAME", + "", + ) + + logger.info(f"CAPTION_CLASSIFIER_GRPC_TRITON: {triton_service_caption_classifier}") + + return triton_service_caption_classifier, triton_service_caption_classifier_name + + +def get_table_detection_service(env_var_prefix): + prefix = env_var_prefix.upper() + grpc_endpoint = os.environ.get( + f"{prefix}_GRPC_ENDPOINT", + "", + ) + http_endpoint = os.environ.get( + f"{prefix}_HTTP_ENDPOINT", + "", + ) + auth_token = os.environ.get( + "NVIDIA_BUILD_API_KEY", + "", + ) or os.environ.get( + "NGC_API_KEY", + "", + ) + infer_protocol = os.environ.get( + f"{prefix}_INFER_PROTOCOL", + "http" if http_endpoint else "grpc" if grpc_endpoint else "", + ) + + logger.info(f"{prefix}_GRPC_TRITON: {grpc_endpoint}") + logger.info(f"{prefix}_HTTP_TRITON: {http_endpoint}") + logger.info(f"{prefix}_INFER_PROTOCOL: {infer_protocol}") + + return grpc_endpoint, http_endpoint, auth_token, infer_protocol + + +def get_default_cpu_count(): + default_cpu_count = os.environ.get("NV_INGEST_MAX_UTIL", int(max(1, math.floor(len(os.sched_getaffinity(0)))))) + + return default_cpu_count + + +def add_source_stage(pipe, morpheus_pipeline_config, ingest_config, message_provider_host, message_provider_port): + source_module_loader = RedisTaskSourceLoaderFactory.get_instance( + module_name="redis_listener", + module_config=ingest_config.get( + "redis_task_source", + { + "redis_client": { + "host": message_provider_host, + "port": message_provider_port, + } + }, + ), + ) + source_stage = pipe.add_stage( + LinearModuleSourceStage( + morpheus_pipeline_config, + source_module_loader, + output_type=ControlMessage, + output_port_name="output", + ) + ) + + return source_stage + + +def add_submitted_job_counter_stage(pipe, morpheus_pipeline_config, ingest_config): + submitted_job_counter_loader = JobCounterLoaderFactory.get_instance( + module_name="submitted_job_counter", + module_config=ingest_config.get( + "submitted_job_counter_module", + { + "name": "submitted_jobs", + }, + ), + ) + submitted_job_counter_stage = pipe.add_stage( + LinearModulesStage( + morpheus_pipeline_config, + submitted_job_counter_loader, + input_type=ControlMessage, + output_type=ControlMessage, + input_port_name="input", + output_port_name="output", + ) + ) + + return submitted_job_counter_stage + + +def add_metadata_injector_stage(pipe, morpheus_pipeline_config): + metadata_injector_loader = MetadataInjectorLoaderFactory.get_instance( + module_name="metadata_injection", module_config={} + ) + metadata_injector_stage = pipe.add_stage( + LinearModulesStage( + morpheus_pipeline_config, + metadata_injector_loader, + input_type=ControlMessage, + output_type=ControlMessage, + input_port_name="input", + output_port_name="output", + ) + ) + + return metadata_injector_stage + + +def add_pdf_extractor_stage(pipe, morpheus_pipeline_config, ingest_config, default_cpu_count): + yolox_grpc, yolox_http, yolox_auth, yolox_protocol = get_table_detection_service("yolox") + pdf_content_extractor_config = ingest_config.get( + "pdf_content_extraction_module", + { + "pdfium_config": { + "yolox_endpoints": (yolox_grpc, yolox_http), + "yolox_infer_protocol": yolox_protocol, + "auth_token": yolox_auth, # All auth tokens are the same for the moment + } + }, + ) + pdf_extractor_stage = pipe.add_stage( + generate_pdf_extractor_stage( + morpheus_pipeline_config, + pdf_content_extractor_config, + pe_count=8, + task="extract", + task_desc="pdf_content_extractor", + ) + ) + + return pdf_extractor_stage + + +def add_table_extractor_stage(pipe, morpheus_pipeline_config, ingest_config, default_cpu_count): + _, _, yolox_auth, _ = get_table_detection_service("yolox") + paddle_grpc, paddle_http, paddle_auth, paddle_protocol = get_table_detection_service("paddle") + table_content_extractor_config = ingest_config.get("table_content_extraction_module", + { + "stage_config": { + "paddle_endpoints": (paddle_grpc, paddle_http), + "paddle_infer_protocol": paddle_protocol, + "auth_token": yolox_auth, + } + }) + + table_extractor_stage = pipe.add_stage( + generate_table_extractor_stage( + morpheus_pipeline_config, + table_content_extractor_config, + pe_count=5 + ) + ) + + return table_extractor_stage + + +def add_chart_extractor_stage(pipe, morpheus_pipeline_config, ingest_config, default_cpu_count): + _, _, yolox_auth, _ = get_table_detection_service("yolox") + + deplot_grpc, deplot_http, deplot_auth, deplot_protocol = get_table_detection_service("deplot") + cached_grpc, cached_http, cached_auth, cached_protocol = get_table_detection_service("cached") + # NOTE: Paddle isn't currently used directly by the chart extraction stage, but will be in the future. + paddle_grpc, paddle_http, paddle_auth, paddle_protocol = get_table_detection_service("paddle") + table_content_extractor_config = ingest_config.get("table_content_extraction_module", + { + "stage_config": { + "cached_endpoints": (cached_grpc, cached_http), + "cached_infer_protocol": cached_protocol, + "deplot_endpoints": (deplot_grpc, deplot_http), + "deplot_infer_protocol": deplot_protocol, + "paddle_endpoints": (paddle_grpc, paddle_http), + "paddle_infer_protocol": paddle_protocol, + "auth_token": yolox_auth, + } + }) + + table_extractor_stage = pipe.add_stage( + generate_chart_extractor_stage( + morpheus_pipeline_config, + table_content_extractor_config, + pe_count=5 + ) + ) + + return table_extractor_stage + + +def add_docx_extractor_stage(pipe, morpheus_pipeline_config, default_cpu_count): + docx_extractor_stage = pipe.add_stage( + generate_docx_extractor_stage( + morpheus_pipeline_config, + pe_count=1, + task="extract", + task_desc="docx_content_extractor", + ) + ) + return docx_extractor_stage + + +def add_pptx_extractor_stage(pipe, morpheus_pipeline_config, default_cpu_count): + pptx_extractor_stage = pipe.add_stage( + generate_pptx_extractor_stage( + morpheus_pipeline_config, + pe_count=1, + task="extract", + task_desc="pptx_content_extractor", + ) + ) + return pptx_extractor_stage + + +def add_image_dedup_stage(pipe, morpheus_pipeline_config, ingest_config, default_cpu_count): + image_dedup_config = ingest_config.get("dedup_module", {}) + image_dedup_stage = pipe.add_stage( + generate_dedup_stage( + morpheus_pipeline_config, + image_dedup_config, + pe_count=2, + task="dedup", + task_desc="dedup_images", + ) + ) + return image_dedup_stage + + +def add_image_filter_stage(pipe, morpheus_pipeline_config, ingest_config, default_cpu_count): + image_filter_config = ingest_config.get("image_filter", {}) + image_filter_stage = pipe.add_stage( + generate_image_filter_stage( + morpheus_pipeline_config, + image_filter_config, + pe_count=2, + task="filter", + task_desc="filter_images", + ) + ) + return image_filter_stage + + +def add_nemo_splitter_stage(pipe, morpheus_pipeline_config, ingest_config): + nemo_splitter_loader = NemoDocSplitterLoaderFactory.get_instance( + module_name="nemo_doc_splitter", + module_config=ingest_config.get("text_splitting_module", {}), + ) + nemo_splitter_stage = pipe.add_stage( + LinearModulesStage( + morpheus_pipeline_config, + nemo_splitter_loader, + input_type=ControlMessage, + output_type=ControlMessage, + input_port_name="input", + output_port_name="output", + ) + ) + + return nemo_splitter_stage + + +def add_image_caption_stage(pipe, morpheus_pipeline_config, ingest_config, default_cpu_count): + endpoint_url, model_name = get_caption_classifier_service() + image_caption_config = ingest_config.get( + "image_caption_extraction_module", + { + "caption_classifier_model_name": model_name, + "endpoint_url": endpoint_url, + }, + ) + image_caption_stage = pipe.add_stage( + generate_caption_extraction_stage( + morpheus_pipeline_config, + image_caption_config, + pe_count=2, + task="caption", + task_desc="caption_ext", + ) + ) + + return image_caption_stage + + +def add_embed_extractions_stage(pipe, morpheus_pipeline_config, ingest_config): + api_key = os.getenv("NGC_API_KEY", "ngc_api_key") + embedding_nim_endpoint = os.getenv("EMBEDDING_NIM_ENDPOINT", "http://embedding:8000/v1") + + embed_extractions_loader = EmbedExtractionsLoaderFactory.get_instance( + module_name="embed_extractions", + module_config=ingest_config.get( + "embed_extractions_module", {"api_key": api_key, "embedding_nim_endpoint": embedding_nim_endpoint} + ), + ) + embed_extractions_stage = pipe.add_stage( + LinearModulesStage( + morpheus_pipeline_config, + embed_extractions_loader, + input_type=ControlMessage, + output_type=ControlMessage, + input_port_name="input", + output_port_name="output", + ) + ) + return embed_extractions_stage + + +def add_image_storage_stage(pipe, morpheus_pipeline_config): + image_storage_stage = pipe.add_stage(ImageStorageStage(morpheus_pipeline_config)) + + return image_storage_stage + + +def add_sink_stage(pipe, morpheus_pipeline_config, ingest_config, message_provider_host, message_provider_port): + sink_module_loader = RedisTaskSinkLoaderFactory.get_instance( + module_name="redis_task_sink", + module_config=ingest_config.get( + "redis_task_sink", + { + "redis_client": { + "host": message_provider_host, + "port": message_provider_port, + } + }, + ), + ) + sink_stage = pipe.add_stage( + LinearModulesStage( + morpheus_pipeline_config, + sink_module_loader, + input_type=typing.Any, + output_type=ControlMessage, + input_port_name="input", + output_port_name="output", + ) + ) + + return sink_stage + + +def add_otel_tracer_stage(pipe, morpheus_pipeline_config, ingest_config): + endpoint = os.getenv("OTEL_EXPORTER_OTLP_ENDPOINT", "http://localhost:4317") + + otel_tracer_loader = OpenTelemetryTracerLoaderFactory.get_instance( + module_name="otel_tracer", + module_config=ingest_config.get( + "otel_tracer_module", + { + "otel_endpoint": endpoint, + }, + ), + ) + otel_tracer_stage = pipe.add_stage( + LinearModulesStage( + morpheus_pipeline_config, + otel_tracer_loader, + input_type=ControlMessage, + output_type=ControlMessage, + input_port_name="input", + output_port_name="output", + ) + ) + return otel_tracer_stage + + +def add_otel_meter_stage(pipe, morpheus_pipeline_config, ingest_config, message_provider_host, message_provider_port): + endpoint = os.getenv("OTEL_EXPORTER_OTLP_ENDPOINT", "http://localhost:4317") + + otel_meter_loader = OpenTelemetryMeterLoaderFactory.get_instance( + module_name="otel_meter", + module_config=ingest_config.get( + "otel_meter_module", + { + "redis_client": { + "host": message_provider_host, + "port": message_provider_port, + }, + "otel_endpoint": endpoint, + }, + ), + ) + otel_meter_stage = pipe.add_stage( + LinearModulesStage( + morpheus_pipeline_config, + otel_meter_loader, + input_type=ControlMessage, + output_type=ControlMessage, + input_port_name="input", + output_port_name="output", + ) + ) + + return otel_meter_stage + + +def add_completed_job_counter_stage(pipe, morpheus_pipeline_config, ingest_config): + completed_job_counter_loader = JobCounterLoaderFactory.get_instance( + module_name="completed_job_counter", + module_config=ingest_config.get( + "completed_job_counter_module", + { + "name": "completed_jobs", + }, + ), + ) + completed_job_counter_stage = pipe.add_stage( + LinearModulesStage( + morpheus_pipeline_config, + completed_job_counter_loader, + input_type=ControlMessage, + output_type=ControlMessage, + input_port_name="input", + output_port_name="output", + ) + ) + return completed_job_counter_stage + + +def add_vdb_task_sink_stage(pipe, morpheus_pipeline_config, ingest_config): + milvus_endpoint = os.getenv("MILVUS_ENDPOINT", "http://milvus:19530") + + vdb_task_sink_loader = VDBTaskSinkLoaderFactory.get_instance( + module_name="vdb_task_sink", + module_config=ingest_config.get( + "vdb_task_sink_module", + { + "service_kwargs": { + "uri": milvus_endpoint, + } + }, + ), + ) + vdb_task_sink_stage = pipe.add_stage( + LinearModulesStage( + morpheus_pipeline_config, + vdb_task_sink_loader, + input_type=ControlMessage, + output_type=ControlMessage, + input_port_name="input", + output_port_name="output", + ) + ) + return vdb_task_sink_stage diff --git a/src/pipeline.py b/src/pipeline.py index 03e7d857..4d8289d7 100644 --- a/src/pipeline.py +++ b/src/pipeline.py @@ -4,572 +4,90 @@ import json -import logging -import math -import os -import typing from datetime import datetime import click from morpheus.config import Config from morpheus.config import CppConfig from morpheus.config import PipelineModes -from morpheus.messages import ControlMessage from morpheus.pipeline.pipeline import Pipeline -from morpheus.stages.general.linear_modules_source import LinearModuleSourceStage -from morpheus.stages.general.linear_modules_stage import LinearModulesStage from morpheus.utils.logger import configure_logging from pydantic import ValidationError -from nv_ingest.modules.injectors.metadata_injector import MetadataInjectorLoaderFactory -from nv_ingest.modules.sinks.redis_task_sink import RedisTaskSinkLoaderFactory -from nv_ingest.modules.sinks.vdb_task_sink import VDBTaskSinkLoaderFactory -from nv_ingest.modules.sources.redis_task_source import RedisTaskSourceLoaderFactory -from nv_ingest.modules.telemetry.job_counter import JobCounterLoaderFactory -from nv_ingest.modules.telemetry.otel_meter import OpenTelemetryMeterLoaderFactory -from nv_ingest.modules.telemetry.otel_tracer import OpenTelemetryTracerLoaderFactory -from nv_ingest.modules.transforms.embed_extractions import EmbedExtractionsLoaderFactory -from nv_ingest.modules.transforms.nemo_doc_splitter import NemoDocSplitterLoaderFactory from nv_ingest.schemas.ingest_pipeline_config_schema import IngestPipelineConfigSchema -from nv_ingest.stages.docx_extractor_stage import generate_docx_extractor_stage -from nv_ingest.stages.filters import generate_dedup_stage -from nv_ingest.stages.filters import generate_image_filter_stage -from nv_ingest.stages.pdf_extractor_stage import generate_pdf_extractor_stage -from nv_ingest.stages.pptx_extractor_stage import generate_pptx_extractor_stage -from nv_ingest.stages.storages.image_storage_stage import ImageStorageStage -from nv_ingest.stages.transforms.image_caption_extraction import generate_caption_extraction_stage from nv_ingest.util.converters.containers import merge_dict from nv_ingest.util.logging.configuration import LogLevel from nv_ingest.util.logging.configuration import configure_logging as configure_local_logging from nv_ingest.util.schema.schema_validator import validate_schema +from nv_ingest.util.pipeline.stage_builders import * logger = logging.getLogger(__name__) local_log_level = os.getenv("INGEST_LOG_LEVEL", "INFO") -if (local_log_level in ("DEFAULT")): +if (local_log_level in ("DEFAULT",)): local_log_level = "INFO" configure_local_logging(logger, local_log_level) -def validate_positive(ctx, param, value): - if value <= 0: - raise click.BadParameter("must be a positive integer") - return value - - -def get_message_provider_config(): - message_provider_host = os.environ.get("MESSAGE_CLIENT_HOST", "localhost") - message_provider_port = os.environ.get("MESSAGE_CLIENT_PORT", "6379") - - logger.info(f"MESSAGE_CLIENT_HOST: {message_provider_host}") - logger.info(f"MESSAGE_CLIENT_PORT: {message_provider_port}") - - return message_provider_host, message_provider_port - - -def get_caption_classifier_service(): - triton_service_caption_classifier = os.environ.get( - "CAPTION_CLASSIFIER_GRPC_TRITON", - "", - ) - triton_service_caption_classifier_name = os.environ.get( - "CAPTION_CLASSIFIER_MODEL_NAME", - "", - ) - - logger.info(f"CAPTION_CLASSIFIER_GRPC_TRITON: {triton_service_caption_classifier}") - - return triton_service_caption_classifier, triton_service_caption_classifier_name - - -def get_yolox_service_table_detection(): - grpc_endpoint = os.environ.get( - "TABLE_DETECTION_GRPC_TRITON", - "", - ) - http_endpoint = os.environ.get( - "TABLE_DETECTION_HTTP_TRITON", - "", - ) - auth_token = os.environ.get( - "NVIDIA_BUILD_API_KEY", - "", - ) or os.environ.get( - "NGC_API_KEY", - "", - ) - - logger.info(f"TABLE_DETECTION_GRPC_TRITON: {grpc_endpoint}") - logger.info(f"TABLE_DETECTION_HTTP_TRITON: {http_endpoint}") - - return grpc_endpoint, http_endpoint, auth_token - - -def get_paddle_service_table_detection(): - grpc_endpoint = os.environ.get( - "PADDLE_GRPC_ENDPOINT", - "", - ) - http_endpoint = os.environ.get( - "PADDLE_HTTP_ENDPOINT", - "", - ) - auth_token = os.environ.get( - "NVIDIA_BUILD_API_KEY", - "", - ) or os.environ.get( - "NGC_API_KEY", - "", - ) - - logger.info(f"PADDLE_GRPC_ENDPOINT: {grpc_endpoint}") - logger.info(f"PADDLE_HTTP_ENDPOINT: {http_endpoint}") - - return grpc_endpoint, http_endpoint, auth_token - - -def get_deplot_service_table_detection(): - grpc_endpoint = os.environ.get( - "DEPLOT_GRPC_ENDPOINT", - "", - ) - http_endpoint = os.environ.get( - "DEPLOT_HTTP_ENDPOINT", - "", - ) - auth_token = os.environ.get( - "NVIDIA_BUILD_API_KEY", - "", - ) or os.environ.get( - "NGC_API_KEY", - "", - ) - - logger.info(f"DEPLOT_GRPC_ENDPOINT: {grpc_endpoint}") - logger.info(f"DEPLOT_HTTP_ENDPOINT: {http_endpoint}") - - return grpc_endpoint, http_endpoint, auth_token - - -def get_cached_service_table_detection(): - grpc_endpoint = os.environ.get( - "CACHED_GRPC_ENDPOINT", - "", - ) - http_endpoint = os.environ.get( - "CACHED_HTTP_ENDPOINT", - "", - ) - auth_token = os.environ.get( - "NVIDIA_BUILD_API_KEY", - "", - ) or os.environ.get( - "NGC_API_KEY", - "", - ) - - logger.info(f"CACHED_GRPC_ENDPOINT: {grpc_endpoint}") - logger.info(f"CACHED_HTTP_ENDPOINT: {http_endpoint}") - - return grpc_endpoint, http_endpoint, auth_token - - -def get_default_cpu_count(): - default_cpu_count = os.environ.get("NV_INGEST_MAX_UTIL", int(max(1, math.floor(len(os.sched_getaffinity(0)))))) - - return default_cpu_count - - -def add_source_stage(pipe, morpheus_pipeline_config, ingest_config, message_provider_host, message_provider_port): - source_module_loader = RedisTaskSourceLoaderFactory.get_instance( - module_name="redis_listener", - module_config=ingest_config.get( - "redis_task_source", - { - "redis_client": { - "host": message_provider_host, - "port": message_provider_port, - } - }, - ), - ) - source_stage = pipe.add_stage( - LinearModuleSourceStage( - morpheus_pipeline_config, - source_module_loader, - output_type=ControlMessage, - output_port_name="output", - ) - ) - - return source_stage - - -def add_submitted_job_counter_stage(pipe, morpheus_pipeline_config, ingest_config): - submitted_job_counter_loader = JobCounterLoaderFactory.get_instance( - module_name="submitted_job_counter", - module_config=ingest_config.get( - "submitted_job_counter_module", - { - "name": "submitted_jobs", - }, - ), - ) - submitted_job_counter_stage = pipe.add_stage( - LinearModulesStage( - morpheus_pipeline_config, - submitted_job_counter_loader, - input_type=ControlMessage, - output_type=ControlMessage, - input_port_name="input", - output_port_name="output", - ) - ) - - return submitted_job_counter_stage - - -def add_metadata_injector_stage(pipe, morpheus_pipeline_config): - metadata_injector_loader = MetadataInjectorLoaderFactory.get_instance( - module_name="metadata_injection", module_config={} - ) - metadata_injector_stage = pipe.add_stage( - LinearModulesStage( - morpheus_pipeline_config, - metadata_injector_loader, - input_type=ControlMessage, - output_type=ControlMessage, - input_port_name="input", - output_port_name="output", - ) - ) - - return metadata_injector_stage - - -def add_pdf_extractor_stage(pipe, morpheus_pipeline_config, ingest_config, default_cpu_count): - yolox_grpc, yolox_http, yolox_auth = get_yolox_service_table_detection() - paddle_grpc, paddle_http, paddle_auth = get_paddle_service_table_detection() - deplot_grpc, deplot_http, deplot_auth = get_deplot_service_table_detection() - cached_grpc, cached_http, cached_auth = get_cached_service_table_detection() - pdf_content_extractor_config = ingest_config.get( - "pdf_content_extraction_module", - { - "pdfium_config": { - "cached_endpoints": (cached_grpc, cached_http), - "deplot_endpoints": (deplot_grpc, deplot_http), - "paddle_endpoints": (paddle_grpc, paddle_http), - "yolox_endpoints": (yolox_grpc, yolox_http), - "auth_token": yolox_auth, # All auth tokens are the same for the moment - } - }, - ) - pdf_extractor_stage = pipe.add_stage( - generate_pdf_extractor_stage( - morpheus_pipeline_config, - pdf_content_extractor_config, - pe_count=8, - task="extract", - task_desc="pdf_content_extractor", - ) - ) - - return pdf_extractor_stage - - -def add_docx_extractor_stage(pipe, morpheus_pipeline_config, default_cpu_count): - docx_extractor_stage = pipe.add_stage( - generate_docx_extractor_stage( - morpheus_pipeline_config, - pe_count=1, - task="extract", - task_desc="docx_content_extractor", - ) - ) - return docx_extractor_stage - - -def add_pptx_extractor_stage(pipe, morpheus_pipeline_config, default_cpu_count): - pptx_extractor_stage = pipe.add_stage( - generate_pptx_extractor_stage( - morpheus_pipeline_config, - pe_count=1, - task="extract", - task_desc="pptx_content_extractor", - ) - ) - return pptx_extractor_stage - - -def add_image_dedup_stage(pipe, morpheus_pipeline_config, ingest_config, default_cpu_count): - image_dedup_config = ingest_config.get("dedup_module", {}) - image_dedup_stage = pipe.add_stage( - generate_dedup_stage( - morpheus_pipeline_config, - image_dedup_config, - pe_count=2, - task="dedup", - task_desc="dedup_images", - ) - ) - return image_dedup_stage - - -def add_image_filter_stage(pipe, morpheus_pipeline_config, ingest_config, default_cpu_count): - image_filter_config = ingest_config.get("image_filter", {}) - image_filter_stage = pipe.add_stage( - generate_image_filter_stage( - morpheus_pipeline_config, - image_filter_config, - pe_count=2, - task="filter", - task_desc="filter_images", - ) - ) - return image_filter_stage - - -def add_nemo_splitter_stage(pipe, morpheus_pipeline_config, ingest_config): - nemo_splitter_loader = NemoDocSplitterLoaderFactory.get_instance( - module_name="nemo_doc_splitter", - module_config=ingest_config.get("text_splitting_module", {}), - ) - nemo_splitter_stage = pipe.add_stage( - LinearModulesStage( - morpheus_pipeline_config, - nemo_splitter_loader, - input_type=ControlMessage, - output_type=ControlMessage, - input_port_name="input", - output_port_name="output", - ) - ) - - return nemo_splitter_stage - - -def add_image_caption_stage(pipe, morpheus_pipeline_config, ingest_config, default_cpu_count): - endpoint_url, model_name = get_caption_classifier_service() - image_caption_config = ingest_config.get( - "image_caption_extraction_module", - { - "caption_classifier_model_name": model_name, - "endpoint_url": endpoint_url, - }, - ) - image_caption_stage = pipe.add_stage( - generate_caption_extraction_stage( - morpheus_pipeline_config, - image_caption_config, - pe_count=2, - task="caption", - task_desc="caption_ext", - ) - ) - - return image_caption_stage - - -def add_embed_extractions_stage(pipe, morpheus_pipeline_config, ingest_config): - api_key = os.getenv("NGC_API_KEY", "ngc_api_key") - embedding_nim_endpoint = os.getenv("EMBEDDING_NIM_ENDPOINT", "http://embedding:8000/v1") - - embed_extractions_loader = EmbedExtractionsLoaderFactory.get_instance( - module_name="embed_extractions", - module_config=ingest_config.get( - "embed_extractions_module", {"api_key": api_key, "embedding_nim_endpoint": embedding_nim_endpoint} - ), - ) - embed_extractions_stage = pipe.add_stage( - LinearModulesStage( - morpheus_pipeline_config, - embed_extractions_loader, - input_type=ControlMessage, - output_type=ControlMessage, - input_port_name="input", - output_port_name="output", - ) - ) - return embed_extractions_stage - - -def add_image_storage_stage(pipe, morpheus_pipeline_config): - image_storage_stage = pipe.add_stage(ImageStorageStage(morpheus_pipeline_config)) - - return image_storage_stage - - -def add_sink_stage(pipe, morpheus_pipeline_config, ingest_config, message_provider_host, message_provider_port): - sink_module_loader = RedisTaskSinkLoaderFactory.get_instance( - module_name="redis_task_sink", - module_config=ingest_config.get( - "redis_task_sink", - { - "redis_client": { - "host": message_provider_host, - "port": message_provider_port, - } - }, - ), - ) - sink_stage = pipe.add_stage( - LinearModulesStage( - morpheus_pipeline_config, - sink_module_loader, - input_type=typing.Any, - output_type=ControlMessage, - input_port_name="input", - output_port_name="output", - ) - ) - - return sink_stage - - -def add_otel_tracer_stage(pipe, morpheus_pipeline_config, ingest_config): - endpoint = os.getenv("OTEL_EXPORTER_OTLP_ENDPOINT", "http://localhost:4317") - - otel_tracer_loader = OpenTelemetryTracerLoaderFactory.get_instance( - module_name="otel_tracer", - module_config=ingest_config.get( - "otel_tracer_module", - { - "otel_endpoint": endpoint, - }, - ), - ) - otel_tracer_stage = pipe.add_stage( - LinearModulesStage( - morpheus_pipeline_config, - otel_tracer_loader, - input_type=ControlMessage, - output_type=ControlMessage, - input_port_name="input", - output_port_name="output", - ) - ) - return otel_tracer_stage - - -def add_otel_meter_stage(pipe, morpheus_pipeline_config, ingest_config, message_provider_host, message_provider_port): - endpoint = os.getenv("OTEL_EXPORTER_OTLP_ENDPOINT", "http://localhost:4317") - - otel_meter_loader = OpenTelemetryMeterLoaderFactory.get_instance( - module_name="otel_meter", - module_config=ingest_config.get( - "otel_meter_module", - { - "redis_client": { - "host": message_provider_host, - "port": message_provider_port, - }, - "otel_endpoint": endpoint, - }, - ), - ) - otel_meter_stage = pipe.add_stage( - LinearModulesStage( - morpheus_pipeline_config, - otel_meter_loader, - input_type=ControlMessage, - output_type=ControlMessage, - input_port_name="input", - output_port_name="output", - ) - ) - - return otel_meter_stage - - -def add_completed_job_counter_stage(pipe, morpheus_pipeline_config, ingest_config): - completed_job_counter_loader = JobCounterLoaderFactory.get_instance( - module_name="completed_job_counter", - module_config=ingest_config.get( - "completed_job_counter_module", - { - "name": "completed_jobs", - }, - ), - ) - completed_job_counter_stage = pipe.add_stage( - LinearModulesStage( - morpheus_pipeline_config, - completed_job_counter_loader, - input_type=ControlMessage, - output_type=ControlMessage, - input_port_name="input", - output_port_name="output", - ) - ) - return completed_job_counter_stage - - -def add_vdb_task_sink_stage(pipe, morpheus_pipeline_config, ingest_config): - milvus_endpoint = os.getenv("MILVUS_ENDPOINT", "http://milvus:19530") - - vdb_task_sink_loader = VDBTaskSinkLoaderFactory.get_instance( - module_name="vdb_task_sink", - module_config=ingest_config.get( - "vdb_task_sink_module", - { - "service_kwargs": { - "uri": milvus_endpoint, - } - }, - ), - ) - vdb_task_sink_stage = pipe.add_stage( - LinearModulesStage( - morpheus_pipeline_config, - vdb_task_sink_loader, - input_type=ControlMessage, - output_type=ControlMessage, - input_port_name="input", - output_port_name="output", - ) - ) - return vdb_task_sink_stage - - def setup_ingestion_pipeline( - pipe: Pipeline, morpheus_pipeline_config: Config, ingest_config: typing.Dict[str, typing.Any] + pipe: Pipeline, morpheus_pipeline_config: Config, ingest_config: typing.Dict[str, typing.Any] ): message_provider_host, message_provider_port = get_message_provider_config() default_cpu_count = get_default_cpu_count() - # Pre-processing stages + ######################################################################################################## + ## Insertion and Pre-processing stages + ######################################################################################################## source_stage = add_source_stage( pipe, morpheus_pipeline_config, ingest_config, message_provider_host, message_provider_port ) submitted_job_counter_stage = add_submitted_job_counter_stage(pipe, morpheus_pipeline_config, ingest_config) metadata_injector_stage = add_metadata_injector_stage(pipe, morpheus_pipeline_config) + ######################################################################################################## - # Primitive extraction + ######################################################################################################## + ## Primitive extraction + ######################################################################################################## pdf_extractor_stage = add_pdf_extractor_stage(pipe, morpheus_pipeline_config, ingest_config, default_cpu_count) docx_extractor_stage = add_docx_extractor_stage(pipe, morpheus_pipeline_config, default_cpu_count) pptx_extractor_stage = add_pptx_extractor_stage(pipe, morpheus_pipeline_config, default_cpu_count) + ######################################################################################################## - # Post-processing + ######################################################################################################## + ## Post-processing + ######################################################################################################## image_dedup_stage = add_image_dedup_stage(pipe, morpheus_pipeline_config, ingest_config, default_cpu_count) image_filter_stage = add_image_filter_stage(pipe, morpheus_pipeline_config, ingest_config, default_cpu_count) + table_extraction_stage = add_table_extractor_stage(pipe, morpheus_pipeline_config, ingest_config, default_cpu_count) + chart_extraction_stage = add_chart_extractor_stage(pipe, morpheus_pipeline_config, ingest_config, default_cpu_count) + ######################################################################################################## - # Transforms and data synthesis + ######################################################################################################## + ## Transforms and data synthesis + ######################################################################################################## nemo_splitter_stage = add_nemo_splitter_stage(pipe, morpheus_pipeline_config, ingest_config) embed_extractions_stage = add_embed_extractions_stage(pipe, morpheus_pipeline_config, ingest_config) + ######################################################################################################## - # Storage and output + ######################################################################################################## + ## Storage and output + ######################################################################################################## image_storage_stage = add_image_storage_stage(pipe, morpheus_pipeline_config) sink_stage = add_sink_stage( pipe, morpheus_pipeline_config, ingest_config, message_provider_host, message_provider_port ) vdb_task_sink_stage = add_vdb_task_sink_stage(pipe, morpheus_pipeline_config, ingest_config) + ######################################################################################################## - # Telemetry (Note: everything after the sync stage is out of the hot path, please keep it that way) + ####################################################################################################### + ## Telemetry (Note: everything after the sync stage is out of the hot path, please keep it that way) ## + ####################################################################################################### otel_tracer_stage = add_otel_tracer_stage(pipe, morpheus_pipeline_config, ingest_config) otel_meter_stage = add_otel_meter_stage( pipe, morpheus_pipeline_config, ingest_config, message_provider_host, message_provider_port ) completed_job_counter_stage = add_completed_job_counter_stage(pipe, morpheus_pipeline_config, ingest_config) + ######################################################################################################## # Add edges pipe.add_edge(source_stage, submitted_job_counter_stage) @@ -579,7 +97,9 @@ def setup_ingestion_pipeline( pipe.add_edge(docx_extractor_stage, pptx_extractor_stage) pipe.add_edge(pptx_extractor_stage, image_dedup_stage) pipe.add_edge(image_dedup_stage, image_filter_stage) - pipe.add_edge(image_filter_stage, nemo_splitter_stage) + pipe.add_edge(image_filter_stage, table_extraction_stage) + pipe.add_edge(table_extraction_stage, chart_extraction_stage) + pipe.add_edge(chart_extraction_stage, nemo_splitter_stage) pipe.add_edge(nemo_splitter_stage, embed_extractions_stage) pipe.add_edge(embed_extractions_stage, image_storage_stage) pipe.add_edge(image_storage_stage, vdb_task_sink_stage) @@ -616,7 +136,8 @@ def pipeline(morpheus_pipeline_config, ingest_config) -> float: @click.command() @click.option( - "--ingest_config_path", type=str, envvar="NV_INGEST_CONFIG_PATH", help="Path to the JSON configuration file." + "--ingest_config_path", type=str, envvar="NV_INGEST_CONFIG_PATH", help="Path to the JSON configuration file.", + hidden=True ) @click.option("--use_cpp", is_flag=True, help="Use C++ backend.") @click.option("--pipeline_batch_size", default=256, type=int, help="Batch size for the pipeline.") @@ -645,16 +166,16 @@ def pipeline(morpheus_pipeline_config, ingest_config) -> float: help="Log level.", ) def cli( - ingest_config_path, - caption_batch_size, - use_cpp, - pipeline_batch_size, - enable_monitor, - feature_length, - num_threads, - model_max_batch_size, - mode, - log_level, + ingest_config_path, + caption_batch_size, + use_cpp, + pipeline_batch_size, + enable_monitor, + feature_length, + num_threads, + model_max_batch_size, + mode, + log_level, ): """ Command line interface for configuring and running the pipeline with specified options. @@ -672,11 +193,11 @@ def cli( env_log_level = os.getenv("INGEST_LOG_LEVEL") if env_log_level: log_level = env_log_level - if (log_level in ("DEFAULT")): + if (log_level in ("DEFAULT",)): log_level = "INFO" log_level = log_level_mapping.get(log_level.upper(), logging.INFO) - logging.basicConfig(level=log_level) + logging.basicConfig(level=log_level, format="%(asctime)s - %(levelname)s - %(message)s") configure_logging(log_level=log_level) CppConfig.set_should_use_cpp(use_cpp) diff --git a/test-requirements.txt b/test-requirements.txt index 8c82bf4b..fe60055f 100644 --- a/test-requirements.txt +++ b/test-requirements.txt @@ -1,3 +1,4 @@ +autoflake==2.3.1 black==23.11.0 flake8==7.0.0 isort==5.13.2 @@ -5,6 +6,5 @@ pre-commit==3.5.0 pytest==7.4.3 pytest-cov==4.1.0 pytest-mock -yapf==0.40.2 pytest-mock==3.14.0 -autoflake==2.3.1 +yapf==0.40.2 diff --git a/tests/functional/test_ingest_pipeline.py b/tests/functional/test_ingest_pipeline.py index da486206..82c6fb2a 100644 --- a/tests/functional/test_ingest_pipeline.py +++ b/tests/functional/test_ingest_pipeline.py @@ -8,7 +8,6 @@ import pytest from nv_ingest_client.client import NvIngestClient -from nv_ingest_client.message_clients.redis.redis_client import RedisClient # type: ignore from nv_ingest_client.primitives import JobSpec from nv_ingest_client.primitives.tasks import EmbedTask from nv_ingest_client.primitives.tasks import ExtractTask @@ -47,85 +46,3 @@ def remove_keys(data, keys_to_remove): return [remove_keys(item, keys_to_remove) for item in data] else: return data - - -@pytest.mark.skip(reason="Test environment is not running nv-ingest and redis services.") -def test_ingest_pipeline(): - client = NvIngestClient( - message_client_allocator=RedisClient, - message_client_hostname=_DEFAULT_REDIS_HOST, - message_client_port=_DEFAULT_REDIS_PORT, - message_client_kwargs=None, - msg_counter_id="nv-ingest-message-id", - worker_pool_size=1, - ) - - file_content, file_type = extract_file_content(_VALIDATION_PDF) - - job_spec = JobSpec( - document_type=file_type, - payload=file_content, - source_id=_VALIDATION_PDF, - source_name=_VALIDATION_PDF, - extended_options={ - "tracing_options": { - "trace": True, - "ts_send": time.time_ns(), - } - }, - ) - - extract_task = ExtractTask( - document_type=file_type, - extract_text=True, - extract_images=True, - extract_tables=True, - text_depth=_DEFAULT_EXTRACT_PAGE_DEPTH, - extract_tables_method=_DEFAULT_EXTRACT_TABLES_METHOD, - ) - - split_task = SplitTask( - split_by=_DEFAULT_SPLIT_BY, - split_length=_DEFAULT_SPLIT_LENGTH, - split_overlap=_DEFAULT_SPLIT_OVERLAP, - max_character_length=_DEFAULT_SPLIT_MAX_CHARACTER_LENGTH, - sentence_window_size=_DEFAULT_SPLIT_SENTENCE_WINDOW_SIZE, - ) - - embed_task = EmbedTask( - text=True, - tables=True, - ) - - job_spec.add_task(extract_task) - job_spec.add_task(split_task) - job_spec.add_task(embed_task) - job_id = client.add_job(job_spec) - - client.submit_job(job_id, _DEFAULT_TASK_QUEUE) - generated_metadata = client.fetch_job_result(job_id, timeout=_DEFAULT_JOB_TIMEOUT)[0][0] - - with open(_VALIDATION_JSON, "r") as f: - expected_metadata = json.load(f)[0][0] - - keys_to_remove = ["date_created", "last_modified", "table_content"] - generated_metadata_cleaned = remove_keys(generated_metadata, keys_to_remove) - expected_metadata_cleaned = remove_keys(expected_metadata, keys_to_remove) - - for extraction_idx in range(len(generated_metadata_cleaned)): - content_type = generated_metadata_cleaned[extraction_idx]["metadata"]["content_metadata"]["type"] - - if content_type == "text": - assert generated_metadata_cleaned[extraction_idx] == expected_metadata_cleaned[extraction_idx] - - elif content_type == "image": - assert generated_metadata_cleaned[extraction_idx] == expected_metadata_cleaned[extraction_idx] - - elif content_type == "structured": - generated_embedding = generated_metadata_cleaned[extraction_idx]["metadata"]["embedding"] - expected_embedding = expected_metadata_cleaned[extraction_idx]["metadata"]["embedding"] - assert cosine_similarity([generated_embedding], [expected_embedding])[0] > 0.98 - - cleaned_generated_table_metadata = remove_keys(generated_metadata_cleaned, ["embedding", "table_content"]) - cleaned_expected_table_metadata = remove_keys(expected_metadata_cleaned, ["embedding", "table_content"]) - assert cleaned_generated_table_metadata == cleaned_expected_table_metadata diff --git a/tests/nv_ingest/modules/filters/test_image_dedup.py b/tests/nv_ingest/modules/filters/test_image_dedup.py index 125f4ff0..52a245cc 100644 --- a/tests/nv_ingest/modules/filters/test_image_dedup.py +++ b/tests/nv_ingest/modules/filters/test_image_dedup.py @@ -90,7 +90,7 @@ def test_apply_dedup(should_filter, expected0, expected1, expected2): payload_list = [] for _ in range(3): - payload_list.append(valid_image_dedup_payload("test", 1, 1)) + payload_list.append(valid_image_dedup_payload(f"test", 1, 1)) extracted_df = pd.DataFrame(payload_list, columns=["document_type", "metadata"]) extracted_gdf = cudf.from_pandas(extracted_df) diff --git a/tests/nv_ingest/modules/sources/test_redis_task_source.py b/tests/nv_ingest/modules/sources/test_redis_task_source.py index 2ccb84dd..5c67bd86 100644 --- a/tests/nv_ingest/modules/sources/test_redis_task_source.py +++ b/tests/nv_ingest/modules/sources/test_redis_task_source.py @@ -80,18 +80,19 @@ def test_process_message(job_payload, add_trace_tagging, trace_id, ts_send, ts_f payload = json.loads(job_payload) # Update tracing options based on parameters + job_id = "abc12345678910213123" + payload["job_id"] = job_id payload["tracing_options"] = {"trace": add_trace_tagging, "ts_send": int(ts_send.timestamp() * 1e9)} if trace_id is not None: payload["tracing_options"]["trace_id"] = trace_id - modified_payload = json.dumps(payload) - result = process_message(modified_payload, ts_fetched) + + result = process_message(payload, ts_fetched) # Basic type check for the returned object assert isinstance(result, ControlMessage) # Check for correct handling of tracing options - assert result.get_metadata("response_channel") == f"response_{payload['job_id']}" - assert result.get_metadata("job_id") == payload["job_id"] + assert result.get_metadata("response_channel") == f"response_{job_id}" if add_trace_tagging: assert result.get_metadata("config::add_trace_tagging") is True assert result.get_timestamp(f"trace::entry::{MODULE_NAME}") is not None diff --git a/tests/nv_ingest/schemas/test_chart_extractor_schema.py b/tests/nv_ingest/schemas/test_chart_extractor_schema.py new file mode 100644 index 00000000..5deb16b5 --- /dev/null +++ b/tests/nv_ingest/schemas/test_chart_extractor_schema.py @@ -0,0 +1,112 @@ +import pytest +from pydantic import ValidationError +from nv_ingest.schemas.chart_extractor_schema import ChartExtractorConfigSchema, ChartExtractorSchema # Adjust the import as per your file structure + +# Test cases for ChartExtractorConfigSchema +def test_valid_config_with_grpc_only(): + config = ChartExtractorConfigSchema( + auth_token="valid_token", + cached_endpoints=("grpc://cached_service", None), + deplot_endpoints=("grpc://deplot_service", None), + paddle_endpoints=("grpc://paddle_service", None) + ) + assert config.auth_token == "valid_token" + assert config.cached_endpoints == ("grpc://cached_service", None) + assert config.deplot_endpoints == ("grpc://deplot_service", None) + assert config.paddle_endpoints == ("grpc://paddle_service", None) + +def test_valid_config_with_http_only(): + config = ChartExtractorConfigSchema( + auth_token="valid_token", + cached_endpoints=(None, "http://cached_service"), + deplot_endpoints=(None, "http://deplot_service"), + paddle_endpoints=(None, "http://paddle_service") + ) + assert config.auth_token == "valid_token" + assert config.cached_endpoints == (None, "http://cached_service") + assert config.deplot_endpoints == (None, "http://deplot_service") + assert config.paddle_endpoints == (None, "http://paddle_service") + +def test_invalid_config_with_empty_services(): + with pytest.raises(ValidationError) as excinfo: + ChartExtractorConfigSchema( + cached_endpoints=(None, None), + deplot_endpoints=(None, None), + paddle_endpoints=(None, None) + ) + assert "Both gRPC and HTTP services cannot be empty" in str(excinfo.value) + +def test_valid_config_with_both_grpc_and_http(): + config = ChartExtractorConfigSchema( + auth_token="another_token", + cached_endpoints=("grpc://cached_service", "http://cached_service"), + deplot_endpoints=("grpc://deplot_service", "http://deplot_service"), + paddle_endpoints=("grpc://paddle_service", "http://paddle_service") + ) + assert config.auth_token == "another_token" + assert config.cached_endpoints == ("grpc://cached_service", "http://cached_service") + assert config.deplot_endpoints == ("grpc://deplot_service", "http://deplot_service") + assert config.paddle_endpoints == ("grpc://paddle_service", "http://paddle_service") + +def test_invalid_auth_token_none(): + config = ChartExtractorConfigSchema( + cached_endpoints=("grpc://cached_service", None), + deplot_endpoints=("grpc://deplot_service", None), + paddle_endpoints=("grpc://paddle_service", None) + ) + assert config.auth_token is None + +def test_invalid_endpoint_format(): + with pytest.raises(ValidationError): + ChartExtractorConfigSchema( + cached_endpoints=("invalid_endpoint", None), + deplot_endpoints=(None, "invalid_endpoint") + ) + +# Test cases for ChartExtractorSchema +def test_chart_extractor_schema_defaults(): + config = ChartExtractorSchema() + assert config.max_queue_size == 1 + assert config.n_workers == 2 + assert config.raise_on_failure is False + assert config.stage_config is None + +def test_chart_extractor_schema_with_custom_values(): + stage_config = ChartExtractorConfigSchema( + cached_endpoints=("grpc://cached_service", "http://cached_service"), + deplot_endpoints=("grpc://deplot_service", None), + paddle_endpoints=(None, "http://paddle_service") + ) + config = ChartExtractorSchema( + max_queue_size=10, + n_workers=5, + raise_on_failure=True, + stage_config=stage_config + ) + assert config.max_queue_size == 10 + assert config.n_workers == 5 + assert config.raise_on_failure is True + assert config.stage_config == stage_config + +def test_chart_extractor_schema_without_stage_config(): + config = ChartExtractorSchema( + max_queue_size=3, + n_workers=1, + raise_on_failure=False + ) + assert config.max_queue_size == 3 + assert config.n_workers == 1 + assert config.raise_on_failure is False + assert config.stage_config is None + +def test_invalid_chart_extractor_schema_negative_queue_size(): + with pytest.raises(ValidationError): + ChartExtractorSchema( + max_queue_size=-1 + ) + +def test_invalid_chart_extractor_schema_zero_workers(): + with pytest.raises(ValidationError): + ChartExtractorSchema( + n_workers=0 + ) \ No newline at end of file diff --git a/tests/nv_ingest/schemas/test_ingest_job_schema.py b/tests/nv_ingest/schemas/test_ingest_job_schema.py index bb570393..4f42867d 100644 --- a/tests/nv_ingest/schemas/test_ingest_job_schema.py +++ b/tests/nv_ingest/schemas/test_ingest_job_schema.py @@ -202,6 +202,18 @@ def test_multiple_task_types(): "params": {"filter": True}, }, }, + { + "type": "table_data_extract", + "task_properties":{ + "params": {}, + } + }, + { + "type": "chart_data_extract", + "task_properties":{ + "params": {}, + } + } ], } diff --git a/tests/nv_ingest/schemas/test_metadata_schema.py b/tests/nv_ingest/schemas/test_metadata_schema.py new file mode 100644 index 00000000..2f1e76f5 --- /dev/null +++ b/tests/nv_ingest/schemas/test_metadata_schema.py @@ -0,0 +1,188 @@ +import pytest +from pydantic import ValidationError +from datetime import datetime +from nv_ingest.schemas.metadata_schema import ( # Adjust the import as per your file structure + SourceMetadataSchema, + NearbyObjectsSchema, + ContentHierarchySchema, + ContentMetadataSchema, + TextMetadataSchema, + ImageMetadataSchema, + TableMetadataSchema, + ChartMetadataSchema, + ErrorMetadataSchema, + InfoMessageMetadataSchema, + TableFormatEnum, +) + + +# Test cases for SourceMetadataSchema +def test_source_metadata_schema_defaults(): + config = SourceMetadataSchema( + source_name="Test Source", + source_id="1234", + source_type="TestType" + ) + assert config.source_location == "" + assert config.collection_id == "" + assert config.partition_id == -1 + assert config.access_level == -1 + + +def test_source_metadata_schema_invalid_date(): + with pytest.raises(ValidationError): + SourceMetadataSchema( + source_name="Test Source", + source_id="1234", + source_type="TestType", + date_created="invalid_date" + ) + + +# Test cases for NearbyObjectsSchema +def test_nearby_objects_schema_defaults(): + config = NearbyObjectsSchema() + assert config.text.content == [] + assert config.images.content == [] + assert config.structured.content == [] + + +# Test cases for ContentHierarchySchema +def test_content_hierarchy_schema_defaults(): + config = ContentHierarchySchema() + assert config.page_count == -1 + assert config.page == -1 + assert config.block == -1 + assert config.line == -1 + assert config.span == -1 + + +def test_content_hierarchy_schema_with_nearby_objects(): + config = ContentHierarchySchema( + nearby_objects=NearbyObjectsSchema( + text={"content": ["sample text"]}, + images={"content": ["sample image"]} + ) + ) + assert config.nearby_objects.text.content == ["sample text"] + assert config.nearby_objects.images.content == ["sample image"] + + +# Test cases for ContentMetadataSchema +def test_content_metadata_schema_defaults(): + config = ContentMetadataSchema(type="text") + print(config) + assert config.description == "" + assert config.page_number == -1 + + +def test_content_metadata_schema_invalid_type(): + with pytest.raises(ValidationError): + ContentMetadataSchema(type="InvalidType") + + +# Test cases for TextMetadataSchema +def test_text_metadata_schema_defaults(): + config = TextMetadataSchema(text_type="document") + assert config.summary == "" + assert config.keywords == "" + assert config.language == "en" + assert config.text_location == (0, 0, 0, 0) + + +def test_text_metadata_schema_with_keywords(): + config = TextMetadataSchema(text_type="body", keywords=["keyword1", "keyword2"]) + assert config.keywords == ["keyword1", "keyword2"] + + +# Test cases for ImageMetadataSchema +def test_image_metadata_schema_defaults(): + config = ImageMetadataSchema(image_type="image") + assert config.caption == "" + assert config.width == 0 + assert config.height == 0 + + +def test_image_metadata_schema_invalid_type(): + with pytest.raises(ValidationError): + ImageMetadataSchema(image_type=3.14) # Using a float value + +def test_image_metadata_schema_invalid_type(): + with pytest.raises(ValidationError): + ImageMetadataSchema(image_type=3.14) + + +# Test cases for TableMetadataSchema +@pytest.mark.parametrize("table_format", ["html", "markdown", "latex", "image"]) +def test_table_metadata_schema_defaults(table_format): + config = TableMetadataSchema(table_format=table_format) + assert config.caption == "" + assert config.table_content == "" + + +def test_table_metadata_schema_with_location(): + config = TableMetadataSchema( + table_format="latex", + table_location=(1, 2, 3, 4) + ) + assert config.table_location == (1, 2, 3, 4) + + +@pytest.mark.parametrize("schema_class", [TableMetadataSchema, ChartMetadataSchema]) +@pytest.mark.parametrize("table_format", + [TableFormatEnum.HTML, TableFormatEnum.MARKDOWN, TableFormatEnum.LATEX, TableFormatEnum.IMAGE]) +def test_schema_valid_table_format(schema_class, table_format): + config = schema_class(table_format=table_format) + assert config.caption == "" + assert config.table_content == "" + + +def test_table_metadata_schema_invalid_table_format(): + with pytest.raises(ValidationError): + TableMetadataSchema(table_format="invalid_format") + + +# Test cases for ChartMetadataSchema +def test_chart_metadata_schema_defaults(): + config = ChartMetadataSchema(table_format="html") + assert config.caption == "" + assert config.table_content == "" + + +# Test cases for ErrorMetadataSchema +def test_error_metadata_schema_defaults(): + config = ErrorMetadataSchema( + task="embed", + status="error", + error_msg="An error occurred." + ) + assert config.source_id == "" + + +def test_error_metadata_schema_invalid_status(): + with pytest.raises(ValidationError): + ErrorMetadataSchema( + task="TaskType1", + status="InvalidStatus", + error_msg="An error occurred." + ) + + +# Test cases for InfoMessageMetadataSchema +def test_info_message_metadata_schema_defaults(): + config = InfoMessageMetadataSchema( + task="transform", + status="success", + message="This is an info message.", + filter=False + ) + assert config.filter is False + + +def test_info_message_metadata_schema_invalid_task(): + with pytest.raises(ValidationError): + InfoMessageMetadataSchema( + task="InvalidTaskType", + status="InfoStatus", + message="This is an info message." + ) diff --git a/tests/nv_ingest/schemas/test_table_extractor_schema.py b/tests/nv_ingest/schemas/test_table_extractor_schema.py new file mode 100644 index 00000000..72d7c057 --- /dev/null +++ b/tests/nv_ingest/schemas/test_table_extractor_schema.py @@ -0,0 +1,129 @@ +import pytest +from pydantic import ValidationError +from nv_ingest.schemas.table_extractor_schema import TableExtractorConfigSchema, \ + TableExtractorSchema + + +# Test cases for TableExtractorConfigSchema +def test_valid_config_with_grpc_only(): + config = TableExtractorConfigSchema( + auth_token="valid_token", + paddle_endpoints=("grpc://paddle_service", None) + ) + assert config.auth_token == "valid_token" + assert config.paddle_endpoints == ("grpc://paddle_service", None) + + +def test_valid_config_with_http_only(): + config = TableExtractorConfigSchema( + auth_token="valid_token", + paddle_endpoints=(None, "http://paddle_service") + ) + assert config.auth_token == "valid_token" + assert config.paddle_endpoints == (None, "http://paddle_service") + + +def test_valid_config_with_both_services(): + config = TableExtractorConfigSchema( + auth_token="valid_token", + paddle_endpoints=("grpc://paddle_service", "http://paddle_service") + ) + assert config.auth_token == "valid_token" + assert config.paddle_endpoints == ("grpc://paddle_service", "http://paddle_service") + + +def test_invalid_config_empty_endpoints(): + with pytest.raises(ValidationError) as exc_info: + TableExtractorConfigSchema( + paddle_endpoints=(None, None) + ) + assert "Both gRPC and HTTP services cannot be empty for paddle_endpoints" in str(exc_info.value) + + +def test_invalid_extra_fields(): + with pytest.raises(ValidationError) as exc_info: + TableExtractorConfigSchema( + auth_token="valid_token", + paddle_endpoints=("grpc://paddle_service", None), + extra_field="invalid" + ) + assert "extra fields not permitted" in str(exc_info.value) + + +def test_cleaning_empty_strings_in_endpoints(): + config = TableExtractorConfigSchema( + paddle_endpoints=(" ", "http://paddle_service") + ) + assert config.paddle_endpoints == (None, "http://paddle_service") + + config = TableExtractorConfigSchema( + paddle_endpoints=("grpc://paddle_service", "") + ) + assert config.paddle_endpoints == ("grpc://paddle_service", None) + + +def test_auth_token_is_none_by_default(): + config = TableExtractorConfigSchema( + paddle_endpoints=("grpc://paddle_service", "http://paddle_service") + ) + assert config.auth_token is None + + +# Test cases for TableExtractorSchema +def test_table_extractor_schema_defaults(): + config = TableExtractorSchema() + assert config.max_queue_size == 1 + assert config.n_workers == 2 + assert config.raise_on_failure is False + assert config.stage_config is None + + +def test_table_extractor_schema_with_custom_values(): + stage_config = TableExtractorConfigSchema( + paddle_endpoints=("grpc://paddle_service", "http://paddle_service") + ) + config = TableExtractorSchema( + max_queue_size=15, + n_workers=12, + raise_on_failure=True, + stage_config=stage_config + ) + assert config.max_queue_size == 15 + assert config.n_workers == 12 + assert config.raise_on_failure is True + assert config.stage_config == stage_config + + +def test_table_extractor_schema_without_stage_config(): + config = TableExtractorSchema( + max_queue_size=20, + n_workers=5, + raise_on_failure=True + ) + assert config.max_queue_size == 20 + assert config.n_workers == 5 + assert config.raise_on_failure is True + assert config.stage_config is None + + +def test_invalid_table_extractor_schema_negative_queue_size(): + with pytest.raises(ValidationError): + TableExtractorSchema( + max_queue_size=-5 + ) + + +def test_invalid_table_extractor_schema_zero_workers(): + with pytest.raises(ValidationError): + TableExtractorSchema( + n_workers=0 + ) + + +def test_invalid_extra_fields_in_table_extractor_schema(): + with pytest.raises(ValidationError): + TableExtractorSchema( + max_queue_size=10, + n_workers=5, + extra_field="invalid" + ) diff --git a/tests/nv_ingest/util/image_processing/test_transforms.py b/tests/nv_ingest/util/image_processing/test_transforms.py index b9b7e276..ad86c191 100644 --- a/tests/nv_ingest/util/image_processing/test_transforms.py +++ b/tests/nv_ingest/util/image_processing/test_transforms.py @@ -1,6 +1,41 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import pytest import numpy as np +from PIL import Image +import base64 +from io import BytesIO +from unittest import mock + +from nv_ingest.util.image_processing.transforms import numpy_to_base64, base64_to_numpy, check_numpy_image_size + + +# Helper function to create a base64-encoded string from an image +def create_base64_image(width, height, color="white"): + img = Image.new('RGB', (width, height), color=color) + buffered = BytesIO() + img.save(buffered, format="PNG") + return base64.b64encode(buffered.getvalue()).decode('utf-8') + + +# Fixture for a valid base64-encoded image string +@pytest.fixture +def valid_base64_image(): + return create_base64_image(64, 64) + -from nv_ingest.util.image_processing.transforms import numpy_to_base64 +# Fixture for a corrupted base64 string +@pytest.fixture +def corrupted_base64_image(): + return "not_a_valid_base64_string" + + +# Fixture for a base64 string that decodes but is not a valid image +@pytest.fixture +def non_image_base64(): + return base64.b64encode(b"This is not an image").decode('utf-8') def test_numpy_to_base64_valid_rgba_image(): @@ -25,3 +60,50 @@ def test_numpy_to_base64_grayscale_redundant_axis(): assert isinstance(result, str) assert len(result) > 0 + + +# Tests for base64_to_numpy +def test_base64_to_numpy_valid(valid_base64_image): + img_array = base64_to_numpy(valid_base64_image) + assert isinstance(img_array, np.ndarray) + assert img_array.shape[0] == 64 # Height + assert img_array.shape[1] == 64 # Width + + +def test_base64_to_numpy_invalid_string(corrupted_base64_image): + with pytest.raises(ValueError, match="Invalid base64 string"): + base64_to_numpy(corrupted_base64_image) + + +def test_base64_to_numpy_non_image(non_image_base64): + with pytest.raises(ValueError, match="Unable to decode image from base64 string"): + base64_to_numpy(non_image_base64) + + +def test_base64_to_numpy_import_error(monkeypatch, valid_base64_image): + # Simulate ImportError for PIL by patching import_module + with mock.patch("PIL.Image.open", side_effect=ImportError("PIL library not available")): + with pytest.raises(ImportError): + base64_to_numpy(valid_base64_image) + + +# Tests for check_numpy_image_size +def test_check_numpy_image_size_valid(): + img = np.zeros((100, 100, 3), dtype=np.uint8) + assert check_numpy_image_size(img, 50, 50) is True + + +def test_check_numpy_image_size_too_small_height(): + img = np.zeros((40, 100, 3), dtype=np.uint8) # Height less than min + assert check_numpy_image_size(img, 50, 50) is False + + +def test_check_numpy_image_size_too_small_width(): + img = np.zeros((100, 40, 3), dtype=np.uint8) # Width less than min + assert check_numpy_image_size(img, 50, 50) is False + + +def test_check_numpy_image_size_invalid_dimensions(): + img = np.zeros((100,), dtype=np.uint8) # 1D array + with pytest.raises(ValueError, match="The input array does not have sufficient dimensions for an image."): + check_numpy_image_size(img, 50, 50) diff --git a/tests/nv_ingest/util/nim/test_decorators.py b/tests/nv_ingest/util/nim/test_decorators.py new file mode 100644 index 00000000..58a697d8 --- /dev/null +++ b/tests/nv_ingest/util/nim/test_decorators.py @@ -0,0 +1,78 @@ +from multiprocessing import Manager +from multiprocessing import Process +from multiprocessing import Queue + +import pytest + +from nv_ingest.util.nim.decorators import multiprocessing_cache + + +@pytest.fixture +def shared_manager(): + """Fixture to create a shared multiprocessing manager.""" + return Manager() + + +def test_global_cache_with_same_arguments(shared_manager): + queue = Queue() + + @multiprocessing_cache(3) + def add(x, y): + queue.put(1) # Track each function call + return x + y + + def worker(val1, val2): + add(val1, val2) + + processes = [ + Process(target=worker, args=(1, 2)), # called 1st time + Process(target=worker, args=(1, 2)), + Process(target=worker, args=(1, 2)), + Process(target=worker, args=(1, 2)), # called 2nd time + ] + + for p in processes: + p.start() + + for p in processes: + p.join() + + total_calls = 0 + while not queue.empty(): + total_calls += queue.get() + + assert total_calls == 2 + + +def test_global_cache_with_different_arguments(shared_manager): + queue = Queue() + + @multiprocessing_cache(3) + def add(x, y): + queue.put(1) # Track each function call + return x + y + + def worker(val1, val2): + add(val1, val2) + + processes = [ + Process(target=worker, args=(1, 2)), # called 1st time + Process(target=worker, args=(3, 4)), # called 2nd time + Process(target=worker, args=(1, 2)), + Process(target=worker, args=(3, 4)), + Process(target=worker, args=(1, 2)), + Process(target=worker, args=(3, 4)), + Process(target=worker, args=(3, 4)), # called 3rd time + ] + + for p in processes: + p.start() + + for p in processes: + p.join() + + total_calls = 0 + while not queue.empty(): + total_calls += queue.get() + + assert total_calls == 3 diff --git a/tests/nv_ingest_client/cli/util/test_click.py b/tests/nv_ingest_client/cli/util/test_click.py index 2c9938fd..a3d9d531 100644 --- a/tests/nv_ingest_client/cli/util/test_click.py +++ b/tests/nv_ingest_client/cli/util/test_click.py @@ -8,7 +8,6 @@ import click import pytest -from nv_ingest_client.cli.util.click import _generate_matching_files from nv_ingest_client.cli.util.click import click_match_and_validate_files from nv_ingest_client.cli.util.click import click_validate_batch_size from nv_ingest_client.cli.util.click import click_validate_file_exists @@ -230,29 +229,13 @@ def test_empty_file_list(tmp_path): assert files == [], "Expected an empty list of files" -@pytest.mark.parametrize( - "patterns, mock_files, expected", - [ - (["*.txt"], ["test1.txt", "test2.txt"], ["test1.txt", "test2.txt"]), - (["*.txt"], [], []), - (["*.md"], ["README.md"], ["README.md"]), - (["docs/*.md"], ["docs/README.md", "docs/CHANGES.md"], ["docs/README.md", "docs/CHANGES.md"]), - ], -) -def test_generate_matching_files(patterns, mock_files, expected): - with patch( - "glob.glob", side_effect=lambda pattern, recursive: [f for f in mock_files if f.startswith(pattern[:-5])] - ), patch("os.path.isfile", return_value=True): - assert list(_generate_matching_files(patterns)) == expected - - def test_click_match_and_validate_files_found(): - with patch(f"{_MODULE_UNDER_TEST}._generate_matching_files", return_value=iter(["file1.txt", "file2.txt"])): + with patch(f"{_MODULE_UNDER_TEST}.generate_matching_files", return_value=iter(["file1.txt", "file2.txt"])): result = click_match_and_validate_files(None, None, ["*.txt"]) assert result == ["file1.txt", "file2.txt"] def test_click_match_and_validate_files_not_found(): - with patch(f"{_MODULE_UNDER_TEST}._generate_matching_files", return_value=iter([])): + with patch(f"{_MODULE_UNDER_TEST}.generate_matching_files", return_value=iter([])): result = click_match_and_validate_files(None, None, ["*.txt"]) assert result == [] diff --git a/tests/nv_ingest_client/client/test_client.py b/tests/nv_ingest_client/client/test_client.py index 1432f1e6..a4d447e2 100644 --- a/tests/nv_ingest_client/client/test_client.py +++ b/tests/nv_ingest_client/client/test_client.py @@ -9,15 +9,21 @@ from concurrent.futures import Future from concurrent.futures import as_completed from unittest.mock import MagicMock +from unittest.mock import Mock +from unittest.mock import mock_open +from unittest.mock import patch import pytest from nv_ingest_client.client import NvIngestClient from nv_ingest_client.primitives.jobs import JobSpec from nv_ingest_client.primitives.jobs import JobState from nv_ingest_client.primitives.jobs import JobStateEnum +from nv_ingest_client.primitives.tasks import ExtractTask from nv_ingest_client.primitives.tasks import SplitTask from nv_ingest_client.primitives.tasks import TaskType +MODULE_UNDER_TEST = "nv_ingest_client.client.client" + class MockClient: def __init__(self, host, port): @@ -525,3 +531,79 @@ def test_futures_reflect_submission_outcome(nv_ingest_client_with_jobs, job_id): # for future in as_completed(futures.keys()): # result = future.result()[0] # assert result[0] == {"result": "success"}, f"The fetched job result for {job_id} should be successful" + + +@pytest.fixture +def mock_create_job_specs_for_batch(): + with patch(f"{MODULE_UNDER_TEST}.create_job_specs_for_batch") as mock_create: + yield mock_create + + +@pytest.fixture +def tasks(): + """Fixture for common tasks.""" + return { + "split": SplitTask(), + "extract_pdf": ExtractTask(document_type="pdf"), + } + + +def test_create_jobs_for_batch_success(nv_ingest_client, tasks, mock_create_job_specs_for_batch): + mock_job_spec = Mock(spec=JobSpec) + mock_create_job_specs_for_batch.return_value = [mock_job_spec, mock_job_spec] + + files = ["file1.pdf", "file2.pdf"] + + job_ids = nv_ingest_client.create_jobs_for_batch(files, tasks) + + assert job_ids == ["0", "1"] + + +def test_create_jobs_for_batch_invalid_task(nv_ingest_client, mock_create_job_specs_for_batch): + mock_job_spec = Mock(spec=JobSpec) + mock_create_job_specs_for_batch.return_value = [mock_job_spec] + + files = ["file1.pdf"] + invalid_tasks = { + "invalid_task": None, + } + + with pytest.raises(ValueError, match="Invalid task type: 'invalid_task'"): + nv_ingest_client.create_jobs_for_batch(files, invalid_tasks) + + +def test_create_jobs_for_batch_duplicate_task(nv_ingest_client, mock_create_job_specs_for_batch): + mock_job_spec = Mock(spec=JobSpec) + mock_create_job_specs_for_batch.return_value = [mock_job_spec] + + files = ["file1.pdf"] + duplicate_tasks = { + "split": SplitTask(split_by="sentence"), + "store": SplitTask(split_by="sentence"), # Duplicate task + } + + with pytest.raises(ValueError, match="Duplicate task detected"): + nv_ingest_client.create_jobs_for_batch(files, duplicate_tasks) + + +def test_create_jobs_for_batch_extract_mismatch(nv_ingest_client, mock_create_job_specs_for_batch): + mock_job_spec = Mock(spec=JobSpec) + mock_job_spec.document_type = "pptx" + mock_create_job_specs_for_batch.return_value = [mock_job_spec] + + files = ["file1.pptx"] + tasks = { + "split": SplitTask(), + "extract_pdf": ExtractTask(document_type="pdf"), # Mismatch with pptx file + "extract_pptx": ExtractTask(document_type="pptx"), + } + + job_ids = nv_ingest_client.create_jobs_for_batch(files, tasks) + + assert job_ids == ["0"] + + # Check that extract_pdf was NOT called by inspecting the mock call list + calls = [call[0][0] for call in mock_job_spec.add_task.call_args_list] + assert tasks["split"] in calls + assert tasks["extract_pptx"] in calls + assert tasks["extract_pdf"] not in calls, "extract_pdf should not have been added" diff --git a/tests/nv_ingest_client/primitives/jobs/test_job_spec.py b/tests/nv_ingest_client/primitives/jobs/test_job_spec.py index bc9fd5ba..668fde3c 100644 --- a/tests/nv_ingest_client/primitives/jobs/test_job_spec.py +++ b/tests/nv_ingest_client/primitives/jobs/test_job_spec.py @@ -2,24 +2,33 @@ # All rights reserved. # SPDX-License-Identifier: Apache-2.0 +import json +import logging import uuid from typing import Dict +from unittest.mock import MagicMock +from unittest.mock import Mock +from unittest.mock import patch import pytest +from nv_ingest_client.primitives.jobs.job_spec import BatchJobSpec from nv_ingest_client.primitives.jobs.job_spec import JobSpec from nv_ingest_client.primitives.tasks import Task +MODULE_UNDER_TEST = "nv_ingest_client.primitives.jobs.job_spec" + # Assuming the Task class has a to_dict method class MockTask(Task): def to_dict(self) -> Dict: - return {"task": "mocktask"} + return {"document_type": "pdf", "task": "mocktask"} # Fixture to create a JobSpec instance @pytest.fixture def job_spec_fixture() -> JobSpec: return JobSpec( + document_type="pdf", payload={"key": "value"}, tasks=[MockTask()], source_id="source123", @@ -28,6 +37,22 @@ def job_spec_fixture() -> JobSpec: ) +def create_json_file(tmp_path, content): + file_path = tmp_path / "dataset.json" + with open(file_path, "w") as f: + json.dump(content, f) + return str(file_path) + + +@pytest.fixture +def dataset(tmp_path): + content = {"sampled_files": ["file1.txt", "file2.txt", "file3.txt"]} + file_path = tmp_path / "dataset.json" + with open(file_path, "w") as f: + json.dump(content, f) + return str(file_path) + + # Test initialization def test_job_spec_initialization(): job_spec = JobSpec( @@ -92,3 +117,120 @@ def test_set_properties(): job_spec.source_name = "source456.pdf" assert job_spec.source_name == "source456.pdf" + + +@pytest.fixture +def batch_job_spec_fixture(job_spec_fixture) -> BatchJobSpec: + batch_job_spec = BatchJobSpec() + batch_job_spec.add_job_spec(job_spec_fixture) + return batch_job_spec + + +def test_init_with_job_specs(job_spec_fixture): + batch_job_spec = BatchJobSpec([job_spec_fixture]) + + assert "pdf" in batch_job_spec._file_type_to_job_spec + assert job_spec_fixture in batch_job_spec._file_type_to_job_spec["pdf"] + + +def test_init_with_files(mocker, job_spec_fixture): + mocker.patch("nv_ingest_client.util.util.generate_matching_files", return_value=["file1.pdf"]) + mocker.patch("nv_ingest_client.util.util.create_job_specs_for_batch", return_value=[job_spec_fixture]) + + batch_job_spec = BatchJobSpec(["file1.pdf"]) + + # Verify that the files were processed and job specs were created + assert "pdf" in batch_job_spec._file_type_to_job_spec + assert len(batch_job_spec._file_type_to_job_spec["pdf"]) > 0 + + +def test_add_task_to_specific_document_type(batch_job_spec_fixture): + task = MockTask() + + # Add task to jobs with document_type 'pdf' + batch_job_spec_fixture.add_task(task, document_type="pdf") + + # Assert that the task was added to the JobSpec with document_type 'pdf' + for job_spec in batch_job_spec_fixture._file_type_to_job_spec["pdf"]: + assert task in job_spec._tasks + + +def test_add_task_to_inferred_document_type(batch_job_spec_fixture): + task = MockTask() + + # Add task without specifying document_type, should infer from task's to_dict + batch_job_spec_fixture.add_task(task) + + # Assert that the task was added to the JobSpec with the inferred document_type 'pdf' + for job_spec in batch_job_spec_fixture._file_type_to_job_spec["pdf"]: + assert task in job_spec._tasks + + +def test_add_task_to_all_job_specs(batch_job_spec_fixture): + # Mock a task without a document_type + task = MockTask() + task.to_dict = Mock(return_value={"task": "mocktask"}) # No document_type returned + + # Add task without document_type, it should add to all job specs + batch_job_spec_fixture.add_task(task) + + # Assert that the task was added to all job specs in the batch + for job_specs in batch_job_spec_fixture._file_type_to_job_spec.values(): + for job_spec in job_specs: + assert task in job_spec._tasks + + +def test_add_task_raises_value_error_for_invalid_task(batch_job_spec_fixture): + # Create an invalid task that doesn't derive from Task + invalid_task = object() + + # Expect a ValueError when adding an invalid task + with pytest.raises(ValueError, match="Task must derive from nv_ingest_client.primitives.Task class"): + batch_job_spec_fixture.add_task(invalid_task) + + +def test_batch_job_spec_to_dict(batch_job_spec_fixture): + result = batch_job_spec_fixture.to_dict() + + assert isinstance(result, dict) + assert "pdf" in result + assert len(result["pdf"]) > 0 + + +def test_batch_job_spec_str_method(batch_job_spec_fixture): + result = str(batch_job_spec_fixture) + + assert "pdf" in result + assert "source123" in result + + +@patch(f"{MODULE_UNDER_TEST}.get_dataset_files") +@patch(f"{MODULE_UNDER_TEST}.get_dataset_statistics") +@patch(f"{MODULE_UNDER_TEST}.logger") +def test__from_dataset(mock_logger, mock_get_dataset_statistics, mock_get_dataset_files, dataset): + mock_get_dataset_files.return_value = ["file1.txt", "file2.txt", "file3.txt"] + mock_get_dataset_statistics.return_value = "Statistics info" + + batch_job_spec = BatchJobSpec() + + batch_job_spec.from_files = MagicMock() + + batch_job_spec._from_dataset(dataset) + + mock_get_dataset_files.assert_called_once() + + mock_get_dataset_statistics.assert_called_once() + + batch_job_spec.from_files.assert_called_once_with(["file1.txt", "file2.txt", "file3.txt"]) + + if mock_logger.isEnabledFor(logging.DEBUG): + mock_logger.debug.assert_called_once_with("Statistics info") + + +@patch(f"{MODULE_UNDER_TEST}.BatchJobSpec._from_dataset") +def test_from_dataset(mock__from_dataset, dataset): + batch_job_spec = BatchJobSpec.from_dataset(dataset, shuffle_dataset=False) + + assert isinstance(batch_job_spec, BatchJobSpec) + + mock__from_dataset.assert_called_once_with(dataset, shuffle_dataset=False) diff --git a/tests/nv_ingest_client/cli/util/test_dataset.py b/tests/nv_ingest_client/util/test_dataset.py similarity index 94% rename from tests/nv_ingest_client/cli/util/test_dataset.py rename to tests/nv_ingest_client/util/test_dataset.py index 65b6e3ec..6edee02a 100644 --- a/tests/nv_ingest_client/cli/util/test_dataset.py +++ b/tests/nv_ingest_client/util/test_dataset.py @@ -6,8 +6,8 @@ from io import BytesIO import pytest -from nv_ingest_client.cli.util.dataset import get_dataset_files -from nv_ingest_client.cli.util.dataset import get_dataset_statistics +from nv_ingest_client.util.dataset import get_dataset_files +from nv_ingest_client.util.dataset import get_dataset_statistics @pytest.fixture diff --git a/tests/nv_ingest_client/util/test_util.py b/tests/nv_ingest_client/util/test_util.py index 1e1dab22..5f421fc4 100644 --- a/tests/nv_ingest_client/util/test_util.py +++ b/tests/nv_ingest_client/util/test_util.py @@ -2,6 +2,10 @@ # All rights reserved. # SPDX-License-Identifier: Apache-2.0 +from unittest.mock import patch + +import pytest +from nv_ingest_client.cli.util.click import generate_matching_files _MODULE_UNDER_TEST = "nv_ingest_client.util.util" @@ -66,3 +70,19 @@ # with patch(f"{_MODULE_UNDER_TEST}.os.path.splitext", return_value=("", ".pdf")): # with patch(f"{_MODULE_UNDER_TEST}.fitz.open", side_effect=Exception("Some error")): # assert estimate_page_count(file_path) == 0 + + +@pytest.mark.parametrize( + "patterns, mock_files, expected", + [ + (["*.txt"], ["test1.txt", "test2.txt"], ["test1.txt", "test2.txt"]), + (["*.txt"], [], []), + (["*.md"], ["README.md"], ["README.md"]), + (["docs/*.md"], ["docs/README.md", "docs/CHANGES.md"], ["docs/README.md", "docs/CHANGES.md"]), + ], +) +def test_generate_matching_files(patterns, mock_files, expected): + with patch( + "glob.glob", side_effect=lambda pattern, recursive: [f for f in mock_files if f.startswith(pattern[:-5])] + ), patch("os.path.isfile", return_value=True): + assert list(generate_matching_files(patterns)) == expected diff --git a/tests/stages/__init__.py b/tests/stages/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/stages/nims/__init__.py b/tests/stages/nims/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/stages/nims/test_chart_extraction.py b/tests/stages/nims/test_chart_extraction.py new file mode 100644 index 00000000..8b1d3c5b --- /dev/null +++ b/tests/stages/nims/test_chart_extraction.py @@ -0,0 +1,279 @@ +import pytest +import pandas as pd +from unittest.mock import Mock, patch +from nv_ingest.stages.nim.chart_extraction import _update_metadata, \ + _extract_chart_data # Adjust the import as per your module +import requests + +MODULE_UNDER_TEST = "nv_ingest.stages.nim.chart_extraction" # Replace with your actual module name + + +# Sample data for testing +@pytest.fixture +def base64_encoded_image(): + # Create a simple image and encode it to base64 + from PIL import Image + from io import BytesIO + import base64 + + img = Image.new('RGB', (64, 64), color='white') + buffered = BytesIO() + img.save(buffered, format="PNG") + img_bytes = buffered.getvalue() + base64_str = base64.b64encode(img_bytes).decode('utf-8') + return base64_str + + +@pytest.fixture +def sample_dataframe(base64_encoded_image): + data = { + "metadata": [{ + "content": base64_encoded_image, + "content_metadata": { + "type": "structured", + "subtype": "chart" + }, + "table_metadata": { + "table_content": "original_content" + } + }] + } + df = pd.DataFrame(data) + return df + + +@pytest.fixture +def dataframe_missing_metadata(): + data = { + "other_data": ["no metadata here"] + } + df = pd.DataFrame(data) + return df + + +@pytest.fixture +def dataframe_non_chart(base64_encoded_image): + data = { + "metadata": [{ + "content": base64_encoded_image, + "content_metadata": { + "type": "text", # Not "structured" + "subtype": "paragraph" # Not "chart" + }, + "table_metadata": { + "table_content": "original_content" + } + }] + } + df = pd.DataFrame(data) + return df + + +# Common mock fixtures +@pytest.fixture +def mock_clients_and_requests(): + # Dummy clients as dictionaries with 'endpoint_url' and 'headers' + deplot_client = { + 'endpoint_url': 'http://deplot_endpoint_url', + 'headers': {'Authorization': 'Bearer mock_token'} + } + cached_client = { + 'endpoint_url': 'http://cached_endpoint_url', + 'headers': {'Authorization': 'Bearer mock_token'} + } + + # Mock response for requests.post (successful inference) + mock_response_deplot = Mock() + mock_response_deplot.raise_for_status = Mock() # Does nothing + mock_response_deplot.json.return_value = { + 'object': 'list', + 'data': [{ + 'index': 0, + 'content': 'deplot_result_content', + 'object': 'string' + }], + 'model': 'deplot', + 'usage': None + } + + mock_response_cached = Mock() + mock_response_cached.raise_for_status = Mock() # Does nothing + mock_response_cached.json.return_value = { + 'object': 'list', + 'data': [{ + 'index': 0, + 'content': 'cached_result_content', + 'object': 'string' + }], + 'model': 'cached', + 'usage': None + } + + # Patching create_inference_client and requests.post + with patch(f'{MODULE_UNDER_TEST}.create_inference_client') as mock_create_client, \ + patch('requests.post') as mock_requests_post: + # Mock create_inference_client to return dummy clients + def side_effect_create_inference_client(endpoints, auth_token, protocol): + if 'deplot' in endpoints[0]: + return deplot_client + elif 'cached' in endpoints[0]: + return cached_client + else: + return None + + mock_create_client.side_effect = side_effect_create_inference_client + + # Mock requests.post to return different responses based on URL + def side_effect_requests_post(url, *args, **kwargs): + if 'deplot' in url: + return mock_response_deplot + elif 'cached' in url: + return mock_response_cached + else: + return Mock() + + mock_requests_post.side_effect = side_effect_requests_post + + yield deplot_client, cached_client, mock_create_client, mock_requests_post + + +@pytest.fixture +def mock_clients_and_requests_failure(): + # Dummy clients as dictionaries with 'endpoint_url' and 'headers' + deplot_client = { + 'endpoint_url': 'http://deplot_endpoint_url', + 'headers': {'Authorization': 'Bearer mock_token'} + } + cached_client = { + 'endpoint_url': 'http://cached_endpoint_url', + 'headers': {'Authorization': 'Bearer mock_token'} + } + + # Mock response for requests.post to raise an HTTPError + mock_response_failure = Mock() + mock_response_failure.raise_for_status.side_effect = requests.exceptions.HTTPError("Inference error") + mock_response_failure.json.return_value = {} + + # Patching create_inference_client and requests.post + with patch(f'{MODULE_UNDER_TEST}.create_inference_client') as mock_create_client, \ + patch('requests.post', return_value=mock_response_failure) as mock_requests_post: + # Mock create_inference_client to return dummy clients + def side_effect_create_inference_client(endpoints, auth_token, protocol): + if 'deplot' in endpoints[0]: + return deplot_client + elif 'cached' in endpoints[0]: + return cached_client + else: + return None + + mock_create_client.side_effect = side_effect_create_inference_client + + yield deplot_client, cached_client, mock_create_client, mock_requests_post + + +# Tests for _update_metadata +def test_update_metadata_missing_metadata(dataframe_missing_metadata, mock_clients_and_requests): + deplot_client, cached_client, _, _ = mock_clients_and_requests + + row = dataframe_missing_metadata.iloc[0] + trace_info = {} + with pytest.raises(ValueError, match="Row does not contain 'metadata'."): + _update_metadata(row, cached_client, deplot_client, trace_info) + + +def test_update_metadata_non_chart_content(dataframe_non_chart, mock_clients_and_requests): + deplot_client, cached_client, _, _ = mock_clients_and_requests + + row = dataframe_non_chart.iloc[0] + trace_info = {} + result = _update_metadata(row, cached_client, deplot_client, trace_info) + # The metadata should remain unchanged + assert result == row["metadata"] + + +@pytest.mark.xfail +def test_update_metadata_successful_update(sample_dataframe, mock_clients_and_requests): + deplot_client, cached_client, _, _ = mock_clients_and_requests + + row = sample_dataframe.iloc[0] + trace_info = {} + result = _update_metadata(row, cached_client, deplot_client, trace_info) + # The table_content should be updated with combined result + expected_content = 'Combined content: cached_result_content + deplot_result_content' + assert result["table_metadata"]["table_content"] == expected_content + + +@pytest.mark.xfail +def test_update_metadata_inference_failure(sample_dataframe, mock_clients_and_requests_failure): + deplot_client, cached_client, _, mock_requests_post = mock_clients_and_requests_failure + + row = sample_dataframe.iloc[0] + trace_info = {} + + with pytest.raises(RuntimeError, match="An error occurred during inference: Inference error"): + _update_metadata(row, cached_client, deplot_client, trace_info) + + # Verify that requests.post was called and raised an exception + assert mock_requests_post.call_count >= 1 # At least one call failed + + +@pytest.mark.xfail +def test_extract_chart_data_successful(sample_dataframe, mock_clients_and_requests): + deplot_client, cached_client, mock_create_client, mock_requests_post = mock_clients_and_requests + + validated_config = Mock() + validated_config.stage_config.deplot_endpoints = ("http://deplot_endpoint", None) + validated_config.stage_config.cached_endpoints = ("http://cached_endpoint", None) + validated_config.stage_config.auth_token = "mock_token" + validated_config.stage_config.deplot_infer_protocol = "mock_protocol" + validated_config.stage_config.cached_infer_protocol = "mock_protocol" + + trace_info = {} + + updated_df, trace_info_out = _extract_chart_data(sample_dataframe, {}, validated_config, trace_info) + + # Expected content from the combined results + expected_content = 'Combined content: cached_result_content + deplot_result_content' + assert updated_df.loc[0, 'metadata']['table_metadata']['table_content'] == expected_content + assert trace_info_out == trace_info + + # Verify that the mocked methods were called + assert mock_create_client.call_count == 2 # deplot and cached clients created + assert mock_requests_post.call_count == 2 # deplot and cached inference called + + +def test_extract_chart_data_missing_metadata(dataframe_missing_metadata, mock_clients_and_requests): + deplot_client, cached_client, _, _ = mock_clients_and_requests + + validated_config = Mock() + validated_config.stage_config.deplot_endpoints = ("http://deplot_endpoint", None) + validated_config.stage_config.cached_endpoints = ("http://cached_endpoint", None) + validated_config.stage_config.auth_token = "mock_token" + validated_config.stage_config.deplot_infer_protocol = "mock_protocol" + validated_config.stage_config.cached_infer_protocol = "mock_protocol" + + trace_info = {} + + with pytest.raises(ValueError, match="Row does not contain 'metadata'."): + _extract_chart_data(dataframe_missing_metadata, {}, validated_config, trace_info) + + +@pytest.mark.xfail +def test_extract_chart_data_inference_failure(sample_dataframe, mock_clients_and_requests_failure): + deplot_client, cached_client, mock_create_client, mock_requests_post = mock_clients_and_requests_failure + + validated_config = Mock() + validated_config.stage_config.deplot_endpoints = ("http://deplot_endpoint", None) + validated_config.stage_config.cached_endpoints = ("http://cached_endpoint", None) + validated_config.stage_config.auth_token = "mock_token" + validated_config.stage_config.deplot_infer_protocol = "mock_protocol" + validated_config.stage_config.cached_infer_protocol = "mock_protocol" + + trace_info = {} + + with pytest.raises(RuntimeError, match="An error occurred during inference: Inference error"): + _extract_chart_data(sample_dataframe, {}, validated_config, trace_info) + + # Verify that the mocked methods were called + assert mock_create_client.call_count == 2 + assert mock_requests_post.call_count >= 1 # At least one call failed diff --git a/tests/stages/nims/test_table_extraction.py b/tests/stages/nims/test_table_extraction.py new file mode 100644 index 00000000..6af24097 --- /dev/null +++ b/tests/stages/nims/test_table_extraction.py @@ -0,0 +1,339 @@ +import pytest +import pandas as pd +import base64 +import requests +from unittest.mock import Mock, patch +from io import BytesIO +from PIL import Image +from nv_ingest.stages.nim.table_extraction import _update_metadata, _extract_table_data + +# Constants for minimum image size +PADDLE_MIN_WIDTH = 32 +PADDLE_MIN_HEIGHT = 32 + +MODULE_UNDER_TEST = "nv_ingest.stages.nim.table_extraction" + + +# Fixture for common mock setup +@pytest.fixture +def mock_paddle_client_and_requests(): + # Dummy client as a dictionary with 'endpoint_url' and 'headers' + paddle_client = { + 'endpoint_url': 'http://mock_endpoint_url', + 'headers': {'Authorization': 'Bearer mock_token'} + } + + # Mock response for requests.post + mock_response = Mock() + mock_response.raise_for_status = Mock() + mock_response.json.return_value = { + 'object': 'list', + 'data': [{ + 'index': 0, + 'content': ('Chart 1 This chart shows some gadgets, and some very fictitious costs ' + 'Gadgets and their cost $160.00 $140.00 $120.00 $100.00 $80.00 $60.00 ' + '$40.00 $20.00 $- Hammer Powerdrill Bluetooth speaker Minifridge Premium ' + 'desk fan Cost'), + 'object': 'string' + }], + 'model': 'paddleocr', + 'usage': None + } + + # Patching create_inference_client and requests.post + with patch(f'{MODULE_UNDER_TEST}.create_inference_client', return_value=paddle_client) as mock_create_client, \ + patch('requests.post', return_value=mock_response) as mock_requests_post: + yield paddle_client, mock_create_client, mock_requests_post + + +# Fixture for common mock setup (inference failure) +@pytest.fixture +def mock_paddle_client_and_requests_failure(): + # Dummy client as a dictionary with 'endpoint_url' and 'headers' + paddle_client = { + 'endpoint_url': 'http://mock_endpoint_url', + 'headers': {'Authorization': 'Bearer mock_token'} + } + + # Mock response for requests.post to raise an HTTPError + mock_response = Mock() + mock_response.raise_for_status.side_effect = requests.exceptions.HTTPError("Inference error") + mock_response.json.return_value = {} + + # Patching create_inference_client and requests.post + with patch(f'{MODULE_UNDER_TEST}.create_inference_client', return_value=paddle_client) as mock_create_client, \ + patch('requests.post', return_value=mock_response) as mock_requests_post: + yield paddle_client, mock_create_client, mock_requests_post + + +# Fixture to create a sample image and encode it in base64 +@pytest.fixture +def base64_encoded_image(): + # Create a simple image using PIL + img = Image.new('RGB', (64, 64), color='white') + buffered = BytesIO() + img.save(buffered, format="PNG") + img_bytes = buffered.getvalue() + # Encode the image to base64 + base64_str = base64.b64encode(img_bytes).decode('utf-8') + return base64_str + + +# Fixture for a small image (below minimum size) +@pytest.fixture +def base64_encoded_small_image(): + img = Image.new('RGB', (16, 16), color='white') # Smaller than minimum size + buffered = BytesIO() + img.save(buffered, format="PNG") + img_bytes = buffered.getvalue() + base64_str = base64.b64encode(img_bytes).decode('utf-8') + return base64_str + + +# Fixture for a sample DataFrame +@pytest.fixture +def sample_dataframe(base64_encoded_image): + data = { + "metadata": [{ + "content": base64_encoded_image, + "content_metadata": { + "type": "structured", + "subtype": "table" + }, + "table_metadata": { + "table_content": "" + } + }] + } + df = pd.DataFrame(data) + return df + + +# Fixture for DataFrame with missing metadata +@pytest.fixture +def dataframe_missing_metadata(): + data = { + "other_data": ["no metadata here"] + } + df = pd.DataFrame(data) + return df + + +# Fixture for DataFrame where content_metadata doesn't meet conditions +@pytest.fixture +def dataframe_non_table(base64_encoded_image): + data = { + "metadata": [{ + "content": base64_encoded_image, + "content_metadata": { + "type": "text", # Not "structured" + "subtype": "paragraph" # Not "table" + }, + "table_metadata": { + "table_content": "" + } + }] + } + df = pd.DataFrame(data) + return df + + +# Dummy paddle client that simulates the external service +class DummyPaddleClient: + def infer(self, *args, **kwargs): + return "{'object': 'list', 'data': [{'index': 0, 'content': 'Chart 1 This chart shows some gadgets, and some very fictitious costs Gadgets and their cost $160.00 $140.00 $120.00 $100.00 $80.00 $60.00 $40.00 $20.00 $- Hammer Powerdrill Bluetooth speaker Minifridge Premium desk fan Cost', 'object': 'string'}], 'model': 'paddleocr', 'usage': None}" + + def close(self): + pass + + +# Tests for _update_metadata +def test_update_metadata_missing_metadata(): + row = pd.Series({ + "other_data": "not metadata" + }) + paddle_client = DummyPaddleClient() + trace_info = {} + with pytest.raises(ValueError, match="Row does not contain 'metadata'."): + _update_metadata(row, paddle_client, trace_info) + + +def test_update_metadata_non_table_content(dataframe_non_table): + row = dataframe_non_table.iloc[0] + paddle_client = DummyPaddleClient() + trace_info = {} + result = _update_metadata(row, paddle_client, trace_info) + # The metadata should remain unchanged + assert result == row["metadata"] + + +def test_update_metadata_image_too_small(base64_encoded_small_image): + row = pd.Series({ + "metadata": { + "content": base64_encoded_small_image, + "content_metadata": { + "type": "structured", + "subtype": "table" + }, + "table_metadata": { + "table_content": "" + } + } + }) + paddle_client = DummyPaddleClient() + trace_info = {} + result = _update_metadata(row, paddle_client, trace_info) + # Since the image is too small, table_content should remain unchanged + assert result["table_metadata"]["table_content"] == "" + + +def test_update_metadata_successful_update(sample_dataframe, mock_paddle_client_and_requests): + paddle_client, mock_create_client, mock_requests_post = mock_paddle_client_and_requests + + row = sample_dataframe.iloc[0] + trace_info = {} + result = _update_metadata(row, paddle_client, trace_info) + + # Expected content from the mocked response + expected_content = ('Chart 1 This chart shows some gadgets, and some very fictitious costs ' + 'Gadgets and their cost $160.00 $140.00 $120.00 $100.00 $80.00 $60.00 ' + '$40.00 $20.00 $- Hammer Powerdrill Bluetooth speaker Minifridge Premium ' + 'desk fan Cost') + + # The table_content should be updated with expected_content + assert result["table_metadata"]["table_content"] == expected_content + + # Verify that requests.post was called + mock_requests_post.assert_called_once() + + +def test_update_metadata_inference_failure(sample_dataframe, mock_paddle_client_and_requests_failure): + paddle_client, mock_create_client, mock_requests_post = mock_paddle_client_and_requests_failure + + row = sample_dataframe.iloc[0] + trace_info = {} + + with pytest.raises(RuntimeError, match="HTTP request failed: Inference error"): + _update_metadata(row, paddle_client, trace_info) + + # Verify that requests.post was called and raised an exception + mock_requests_post.assert_called_once() + + +# Tests for _extract_table_data +def test_extract_table_data_successful(sample_dataframe, mock_paddle_client_and_requests): + paddle_client, mock_create_client, mock_requests_post = mock_paddle_client_and_requests + + validated_config = Mock() + validated_config.stage_config.paddle_endpoints = "mock_endpoint" + validated_config.stage_config.auth_token = "mock_token" + validated_config.stage_config.paddle_infer_protocol = "mock_protocol" + + trace_info = {} + + updated_df, trace_info_out = _extract_table_data(sample_dataframe, {}, validated_config, trace_info) + + # Expected content from the mocked response + expected_content = ('Chart 1 This chart shows some gadgets, and some very fictitious costs ' + 'Gadgets and their cost $160.00 $140.00 $120.00 $100.00 $80.00 $60.00 ' + '$40.00 $20.00 $- Hammer Powerdrill Bluetooth speaker Minifridge Premium ' + 'desk fan Cost') + assert updated_df.loc[0, 'metadata']['table_metadata']['table_content'] == expected_content + assert trace_info_out == trace_info + + # Verify that the mocked methods were called + mock_create_client.assert_called_once() + mock_requests_post.assert_called_once() + + +def test_extract_table_data_missing_metadata(dataframe_missing_metadata, mock_paddle_client_and_requests): + paddle_client, mock_create_client, mock_requests_post = mock_paddle_client_and_requests + + validated_config = Mock() + validated_config.stage_config.paddle_endpoints = "mock_endpoint" + validated_config.stage_config.auth_token = "mock_token" + validated_config.stage_config.paddle_infer_protocol = "mock_protocol" + + trace_info = {} + + with pytest.raises(ValueError, match="Row does not contain 'metadata'."): + _extract_table_data(dataframe_missing_metadata, {}, validated_config, trace_info) + + # Verify that the mocked methods were called + mock_create_client.assert_called_once() + # Since metadata is missing, requests.post should not be called + mock_requests_post.assert_not_called() + + +def test_extract_table_data_inference_failure(sample_dataframe, mock_paddle_client_and_requests_failure): + paddle_client, mock_create_client, mock_requests_post = mock_paddle_client_and_requests_failure + + validated_config = Mock() + validated_config.stage_config.paddle_endpoints = "mock_endpoint" + validated_config.stage_config.auth_token = "mock_token" + validated_config.stage_config.paddle_infer_protocol = "mock_protocol" + + trace_info = {} + + with pytest.raises(RuntimeError, match="HTTP request failed: Inference error"): + _extract_table_data(sample_dataframe, {}, validated_config, trace_info) + + # Verify that create_inference_client was called + mock_create_client.assert_called_once() + # Verify that requests.post was called and raised an exception + mock_requests_post.assert_called_once() + + +def test_extract_table_data_image_too_small(base64_encoded_small_image): + data = { + "metadata": [{ + "content": base64_encoded_small_image, + "content_metadata": { + "type": "structured", + "subtype": "table" + }, + "table_metadata": { + "table_content": "" + } + }] + } + df = pd.DataFrame(data) + + validated_config = Mock() + validated_config.stage_config.paddle_endpoints = "mock_endpoint" + validated_config.stage_config.auth_token = "mock_token" + validated_config.stage_config.paddle_infer_protocol = "mock_protocol" + + # Dummy client as a dictionary with 'endpoint_url' and 'headers' + paddle_client = { + 'endpoint_url': 'http://mock_endpoint_url', + 'headers': {'Authorization': 'Bearer mock_token'} + } + trace_info = {} + + def mock_create_inference_client(endpoints, auth_token, protocol): + return paddle_client + + # Mock response to simulate requests.post behavior + mock_response = Mock() + mock_response.raise_for_status = Mock() # Does nothing + mock_response.json.return_value = { + 'object': 'list', + 'data': [{ + 'index': 0, + 'content': ('Chart 1 This chart shows some gadgets, and some very fictitious costs ' + 'Gadgets and their cost $160.00 $140.00 $120.00 $100.00 $80.00 $60.00 ' + '$40.00 $20.00 $- Hammer Powerdrill Bluetooth speaker Minifridge Premium ' + 'desk fan Cost'), + 'object': 'string' + }], + 'model': 'paddleocr', + 'usage': None + } + + with patch(f'{MODULE_UNDER_TEST}.create_inference_client', side_effect=mock_create_inference_client), \ + patch('requests.post', return_value=mock_response): + updated_df, _ = _extract_table_data(df, {}, validated_config, trace_info) + + # The table_content should remain unchanged because the image is too small + assert updated_df.loc[0, 'metadata']['table_metadata']['table_content'] == ""