diff --git a/client/src/nv_ingest_client/cli/util/click.py b/client/src/nv_ingest_client/cli/util/click.py index 6412e851..2da3f03f 100644 --- a/client/src/nv_ingest_client/cli/util/click.py +++ b/client/src/nv_ingest_client/cli/util/click.py @@ -3,7 +3,6 @@ # SPDX-License-Identifier: Apache-2.0 -import glob import json import logging import os @@ -29,6 +28,7 @@ from nv_ingest_client.primitives.tasks.split import SplitTaskSchema from nv_ingest_client.primitives.tasks.store import StoreTaskSchema from nv_ingest_client.primitives.tasks.vdb_upload import VdbUploadTaskSchema +from nv_ingest_client.util.util import generate_matching_files logger = logging.getLogger(__name__) @@ -137,6 +137,9 @@ def click_validate_task(ctx, param, value): else: raise ValueError(f"Unsupported task type: {task_id}") + if new_task_id in validated_tasks: + raise ValueError(f"Duplicate task detected: {new_task_id}") + logger.debug("Adding task: %s", new_task_id) validated_tasks[new_task_id] = new_task except ValueError as e: @@ -190,37 +193,6 @@ def pre_process_dataset(dataset_json: str, shuffle_dataset: bool): return file_source -def _generate_matching_files(file_sources): - """ - Generates a list of file paths that match the given patterns specified in file_sources. - - Parameters - ---------- - file_sources : list of str - A list containing the file source patterns to match against. - - Returns - ------- - generator - A generator yielding paths to files that match the specified patterns. - - Notes - ----- - This function utilizes glob pattern matching to find files that match the specified patterns. - It yields each matching file path, allowing for efficient processing of potentially large - sets of files. - """ - - files = [ - file_path - for pattern in file_sources - for file_path in glob.glob(pattern, recursive=True) - if os.path.isfile(file_path) - ] - for file_path in files: - yield file_path - - def click_match_and_validate_files(ctx, param, value): """ Matches and validates files based on the provided file source patterns. @@ -239,7 +211,7 @@ def click_match_and_validate_files(ctx, param, value): if not value: return [] - matching_files = list(_generate_matching_files(value)) + matching_files = list(generate_matching_files(value)) if not matching_files: logger.warning("No files found matching the specified patterns.") return [] diff --git a/client/src/nv_ingest_client/cli/util/processing.py b/client/src/nv_ingest_client/cli/util/processing.py index 746f95f9..132a505c 100644 --- a/client/src/nv_ingest_client/cli/util/processing.py +++ b/client/src/nv_ingest_client/cli/util/processing.py @@ -1,7 +1,7 @@ # SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. # All rights reserved. # SPDX-License-Identifier: Apache-2.0 -import concurrent + import json import logging import os @@ -12,7 +12,8 @@ from concurrent.futures import as_completed from statistics import mean from statistics import median -from typing import Dict, Any +from typing import Any +from typing import Dict from typing import List from typing import Tuple from typing import Type @@ -23,9 +24,7 @@ from tqdm import tqdm from nv_ingest_client.client import NvIngestClient -from nv_ingest_client.primitives import JobSpec -from nv_ingest_client.util.file_processing.extract import extract_file_content -from nv_ingest_client.util.util import check_ingest_result +from nv_ingest_client.util.processing import handle_future_result from nv_ingest_client.util.util import estimate_page_count logger = logging.getLogger(__name__) @@ -131,7 +130,7 @@ def check_schema(schema: Type[BaseModel], options: dict, task_id: str, original_ def report_stage_statistics( - stage_elapsed_times: defaultdict(list), total_trace_elapsed: float, abs_elapsed: float + stage_elapsed_times: defaultdict(list), total_trace_elapsed: float, abs_elapsed: float ) -> None: """ Reports the statistics for each processing stage, including average, median, total time spent, @@ -206,10 +205,10 @@ def report_overall_speed(total_pages_processed: int, start_time_ns: int, total_f def report_statistics( - start_time_ns: int, - stage_elapsed_times: defaultdict, - total_pages_processed: int, - total_files: int, + start_time_ns: int, + stage_elapsed_times: defaultdict, + total_pages_processed: int, + total_files: int, ) -> None: """ Aggregates and reports statistics for the entire processing session. @@ -425,136 +424,20 @@ def save_response_data(response, output_directory): f.write(json.dumps(documents, indent=2)) -def create_job_specs_for_batch(files_batch: List[str], tasks: Dict[str, Any], client: NvIngestClient) -> List[str]: - """ - Create and submit job specifications (JobSpecs) for a batch of files, returning the job IDs. - - This function takes a batch of files, processes each file to extract its content and type, - creates a job specification (JobSpec) for each file, and adds tasks from the provided task - list. It then submits the jobs to the client and collects their job IDs. - - Parameters - ---------- - files_batch : List[str] - A list of file paths to be processed. Each file is assumed to be in a format compatible - with the `extract_file_content` function, which extracts the file's content and type. - - tasks : Dict[str, Any] - A dictionary of tasks to be added to each job. The keys represent task names, and the - values represent task specifications or configurations. Standard tasks include "split", - "extract", "store", "caption", "dedup", "filter", "embed", and "vdb_upload". - - client : NvIngestClient - An instance of NvIngestClient, which handles the job submission. The `add_job` method of - the client is used to submit each job specification. - - Returns - ------- - Tuple[List[JobSpec], List[str]] - A Tuple containing the list of JobSpecs and list of job IDs corresponding to the submitted jobs. - Each job ID is returned by the client's `add_job` method. - - Raises - ------ - ValueError - If there is an error extracting the file content or type from any of the files, a - ValueError will be logged, and the corresponding file will be skipped. - - Notes - ----- - - The function assumes that a utility function `extract_file_content` is defined elsewhere, - which extracts the content and type from the provided file paths. - - For each file, a `JobSpec` is created with relevant metadata, including document type and - file content. Various tasks are conditionally added based on the provided `tasks` dictionary. - - The job specification includes tracing options with a timestamp (in nanoseconds) for - diagnostic purposes. - - Examples - -------- - Suppose you have a batch of files and tasks to process: - - >>> files_batch = ["file1.txt", "file2.pdf"] - >>> tasks = {"split": ..., "extract_txt": ..., "store": ...} - >>> client = NvIngestClient() - >>> job_ids = create_job_specs_for_batch(files_batch, tasks, client) - >>> print(job_ids) - ['job_12345', 'job_67890'] - - In this example, jobs are created and submitted for the files in `files_batch`, with the - tasks in `tasks` being added to each job specification. The returned job IDs are then - printed. - - See Also - -------- - extract_file_content : Function that extracts the content and type of a file. - JobSpec : The class representing a job specification. - NvIngestClient : Client class used to submit jobs to a job processing system. - """ - - job_ids = [] - for file_name in files_batch: - try: - file_content, file_type = extract_file_content(file_name) # Assume these are defined - file_type = file_type.value - except ValueError as ve: - logger.error(f"Error extracting content from {file_name}: {ve}") - continue - - job_spec = JobSpec( - document_type=file_type, - payload=file_content, - source_id=file_name, - source_name=file_name, - extended_options={"tracing_options": {"trace": True, "ts_send": time.time_ns()}}, - ) - - logger.debug(f"Tasks: {tasks.keys()}") - for task in tasks: - logger.debug(f"Task: {task}") - - # TODO(Devin): Formalize this later, don't have time right now. - if "split" in tasks: - job_spec.add_task(tasks["split"]) - - if f"extract_{file_type}" in tasks: - job_spec.add_task(tasks[f"extract_{file_type}"]) - - if "store" in tasks: - job_spec.add_task(tasks["store"]) - - if "caption" in tasks: - job_spec.add_task(tasks["caption"]) - - if "dedup" in tasks: - job_spec.add_task(tasks["dedup"]) - - if "filter" in tasks: - job_spec.add_task(tasks["filter"]) - - if "embed" in tasks: - job_spec.add_task(tasks["embed"]) - - if "vdb_upload" in tasks: - job_spec.add_task(tasks["vdb_upload"]) - - job_id = client.add_job(job_spec) - job_ids.append(job_id) - - return job_ids - - def generate_job_batch_for_iteration( - client: Any, - pbar: Any, - files: List[str], - tasks: Dict, - processed: int, - batch_size: int, - retry_job_ids: List[str], - fail_on_error: bool = False + client: Any, + pbar: Any, + files: List[str], + tasks: Dict, + processed: int, + batch_size: int, + retry_job_ids: List[str], + fail_on_error: bool = False, ) -> Tuple[List[str], Dict[str, str], int]: """ - Generates a batch of job specifications for the current iteration of file processing. This function handles retrying failed jobs and creating new jobs for unprocessed files. The job specifications are then submitted for processing. + Generates a batch of job specifications for the current iteration of file processing. + This function handles retrying failed jobs and creating new jobs for unprocessed files. + The job specifications are then submitted for processing. Parameters ---------- @@ -599,9 +482,9 @@ def generate_job_batch_for_iteration( if (cur_job_count < batch_size) and (processed < len(files)): new_job_count = min(batch_size - cur_job_count, len(files) - processed) - batch_files = files[processed: processed + new_job_count] # noqa: E203 + batch_files = files[processed : processed + new_job_count] # noqa: E203 - new_job_indices = create_job_specs_for_batch(batch_files, tasks, client) + new_job_indices = client.create_jobs_for_batch(batch_files, tasks) if len(new_job_indices) != new_job_count: missing_jobs = new_job_count - len(new_job_indices) error_msg = f"Missing {missing_jobs} job specs -- this is likely due to bad reads or file corruption" @@ -620,93 +503,14 @@ def generate_job_batch_for_iteration( return job_indices, job_index_map_updates, processed -def handle_future_result( - future: concurrent.futures.Future, - futures_dict: Dict[concurrent.futures.Future, str], -) -> Dict[str, Any]: - """ - Handle the result of a completed future job, process annotations, and save the result. - - This function processes the result of a future, extracts annotations (if any), logs them, - checks the validity of the ingest result, and optionally saves the result to the provided - output directory. If the result indicates a failure, a retry list of job IDs is prepared. - - Parameters - ---------- - future : concurrent.futures.Future - A future object representing an asynchronous job. The result of this job will be - processed once it completes. - - futures_dict : Dict[concurrent.futures.Future, str] - A dictionary mapping future objects to job IDs. The job ID associated with the - provided future is retrieved from this dictionary. - - Returns - ------- - Dict[str, Any] - - Raises - ------ - RuntimeError - If the job result is invalid, this exception is raised with a description of the failure. - - Notes - ----- - - The `future.result()` is assumed to return a tuple where the first element is the actual - result (as a dictionary), and the second element (if present) can be ignored. - - Annotations in the result (if any) are logged for debugging purposes. - - The `check_ingest_result` function (assumed to be defined elsewhere) is used to validate - the result. If the result is invalid, a `RuntimeError` is raised. - - The function handles saving the result data to the specified output directory using the - `save_response_data` function. - - Examples - -------- - Suppose we have a future object representing a job, a dictionary of futures to job IDs, - and a directory for saving results: - - >>> future = concurrent.futures.Future() - >>> futures_dict = {future: "job_12345"} - >>> job_id_map = {"job_12345": {...}} - >>> output_directory = "/path/to/save" - >>> result, retry_job_ids = handle_future_result(future, futures_dict, job_id_map, output_directory) - - In this example, the function processes the completed job and saves the result to the - specified directory. If the job fails, it raises a `RuntimeError` and returns a list of - retry job IDs. - - See Also - -------- - check_ingest_result : Function to validate the result of the job. - save_response_data : Function to save the result to a directory. - """ - - try: - result, _ = future.result()[0] - if ("annotations" in result) and result["annotations"]: - annotations = result["annotations"] - for key, value in annotations.items(): - logger.debug(f"Annotation: {key} -> {json.dumps(value, indent=2)}") - - failed, description = check_ingest_result(result) - - if failed: - raise RuntimeError(f"{description}") - except Exception as e: - logger.debug(f"Error processing future result: {e}") - raise e - - return result - - def create_and_process_jobs( - files: List[str], - client: NvIngestClient, - tasks: Dict[str, Any], - output_directory: str, - batch_size: int, - timeout: int = 10, - fail_on_error: bool = False, + files: List[str], + client: NvIngestClient, + tasks: Dict[str, Any], + output_directory: str, + batch_size: int, + timeout: int = 10, + fail_on_error: bool = False, ) -> Tuple[int, Dict[str, List[float]], int]: """ Process a list of files, creating and submitting jobs for each file, then fetch and handle the results. @@ -807,7 +611,6 @@ def create_and_process_jobs( futures_dict = client.fetch_job_result_async(job_ids, timeout=timeout, data_only=False) for future in as_completed(futures_dict.keys()): - retry = False job_id = futures_dict[future] source_name = job_id_map[job_id] diff --git a/client/src/nv_ingest_client/client/client.py b/client/src/nv_ingest_client/client/client.py index 868d147b..45c9d0e9 100644 --- a/client/src/nv_ingest_client/client/client.py +++ b/client/src/nv_ingest_client/client/client.py @@ -5,12 +5,15 @@ # pylint: disable=broad-except +import concurrent.futures import json import logging -import concurrent.futures +import math +import time from concurrent.futures import Future from concurrent.futures import ThreadPoolExecutor from concurrent.futures import as_completed +from typing import Any from typing import Callable from typing import Dict from typing import List @@ -19,12 +22,16 @@ from typing import Union from nv_ingest_client.message_clients.rest.rest_client import RestClient +from nv_ingest_client.primitives import BatchJobSpec from nv_ingest_client.primitives import JobSpec from nv_ingest_client.primitives.jobs import JobState from nv_ingest_client.primitives.jobs import JobStateEnum from nv_ingest_client.primitives.tasks import Task from nv_ingest_client.primitives.tasks import TaskType +from nv_ingest_client.primitives.tasks import is_valid_task_type from nv_ingest_client.primitives.tasks import task_factory +from nv_ingest_client.util.processing import handle_future_result +from nv_ingest_client.util.util import create_job_specs_for_batch logger = logging.getLogger(__name__) @@ -53,13 +60,13 @@ class NvIngestClient: """ def __init__( - self, - message_client_allocator: Callable[..., RestClient] = RestClient, - message_client_hostname: Optional[str] = "localhost", - message_client_port: Optional[int] = 7670, - message_client_kwargs: Optional[Dict] = None, - msg_counter_id: Optional[str] = "nv-ingest-message-id", - worker_pool_size: int = 1, + self, + message_client_allocator: Callable[..., RestClient] = RestClient, + message_client_hostname: Optional[str] = "localhost", + message_client_port: Optional[int] = 7670, + message_client_kwargs: Optional[Dict] = None, + msg_counter_id: Optional[str] = "nv-ingest-message-id", + worker_pool_size: int = 1, ) -> None: """ Initializes the NvIngestClient with a client allocator, REST configuration, a message counter ID, @@ -142,20 +149,20 @@ def _pop_job_state(self, job_index: str) -> JobState: return job_state def _get_and_check_job_state( - self, - job_index: str, - required_state: Union[JobStateEnum, List[JobStateEnum]] = None, + self, + job_index: str, + required_state: Union[JobStateEnum, List[JobStateEnum]] = None, ) -> JobState: if required_state and not isinstance(required_state, list): required_state = [required_state] - job_state = self._job_states.get(job_index) if not job_state: raise ValueError(f"Job with ID {job_index} does not exist in JobStates: {self._job_states}") if required_state and (job_state.state not in required_state): raise ValueError( - f"Job with ID {job_state.job_spec.job_id} has invalid state {job_state.state}, expected {required_state}" + f"Job with ID {job_state.job_spec.job_id} has invalid state " + f"{job_state.state}, expected {required_state}" ) return job_state @@ -163,21 +170,35 @@ def _get_and_check_job_state( def job_count(self): return len(self._job_states) - def add_job(self, job_spec: JobSpec = None) -> str: + def _add_single_job(self, job_spec: JobSpec) -> str: job_index = self._generate_job_index() self._job_states[job_index] = JobState(job_spec=job_spec) return job_index + def add_job(self, job_spec: Union[BatchJobSpec, JobSpec]) -> str: + if isinstance(job_spec, JobSpec): + job_index = self._add_single_job(job_spec) + return job_index + elif isinstance(job_spec, BatchJobSpec): + job_indexes = [] + for _, job_specs in job_spec.job_specs.items(): + for job in job_specs: + job_index = self._add_single_job(job) + job_indexes.append(job_index) + return job_indexes + else: + raise ValueError(f"Unexpected type: {type(job_spec)}") + def create_job( - self, - payload: str, - source_id: str, - source_name: str, - document_type: str = None, - tasks: Optional[list] = None, - extended_options: Optional[dict] = None, + self, + payload: str, + source_id: str, + source_name: str, + document_type: str = None, + tasks: Optional[list] = None, + extended_options: Optional[dict] = None, ) -> str: """ Creates a new job with the specified parameters and adds it to the job tracking dictionary. @@ -228,10 +249,10 @@ def add_task(self, job_index: str, task: Task) -> None: job_state.job_spec.add_task(task) def create_task( - self, - job_index: Union[str, int], - task_type: TaskType, - task_params: dict = None, + self, + job_index: Union[str, int], + task_type: TaskType, + task_params: dict = None, ) -> None: """ Creates a task of the specified type with given parameters and associates it with the existing job. @@ -311,8 +332,7 @@ def fetch_job_result_cli(self, job_ids: Union[str, List[str]], timeout: float = job_ids = [job_ids] return [self._fetch_job_result(job_id, timeout, data_only) for job_id in job_ids] - - + # Nv-Ingest jobs are often "long running". Therefore after # submission we intermittently check if the job is completed. def _fetch_job_result_wait(self, job_id: str, timeout: float = 60, data_only: bool = True): @@ -321,21 +341,82 @@ def _fetch_job_result_wait(self, job_id: str, timeout: float = 60, data_only: bo return [self._fetch_job_result(job_id, timeout, data_only)] except TimeoutError: logger.debug("Job still processing ... aka HTTP 202 received") - + # This is the direct Python approach function for retrieving jobs which handles the timeouts directly # in the function itself instead of expecting the user to handle it themselves - # Note this method only supports fetching a single job result synchronously - def fetch_job_result(self, job_id: str, timeout: float = 100, data_only: bool = True): - # A thread pool executor is a simple approach to performing an action with a timeout + def fetch_job_result( + self, + job_ids: List[str], + timeout: float = 100, + max_retries: Optional[int] = None, + retry_delay: float = 1, + verbose: bool = False, + ) -> List[Tuple[Optional[Dict], str]]: + """ + Fetches job results for multiple job IDs concurrently with individual timeouts and retry logic. + + Args: + job_ids (List[str]): A list of job IDs to fetch results for. + timeout (float): Timeout for each fetch operation, in seconds. + max_retries (int): Maximum number of retries for jobs that are not ready yet. + retry_delay (float): Delay between retry attempts, in seconds. + + Returns: + List[Tuple[Optional[Dict], str]]: A list of tuples containing the job results and job IDs. + If a timeout or error occurs, the result will be None for that job. + + Raises: + ValueError: If there is an error in decoding the job result. + TimeoutError: If the fetch operation times out. + Exception: For all other unexpected issues. + """ + results = [] + + def fetch_with_retries(job_id: str): + retries = 0 + while (max_retries is None) or (retries < max_retries): + try: + # Attempt to fetch the job result + result = self._fetch_job_result(job_id, timeout, data_only=False) + return result, job_id + except Exception as e: + # Check if the error is a retryable error + if "Job is not ready yet. Retry later." in str(e): + if verbose: + logger.info( + f"Job {job_id} is not ready. " + f"Retrying {retries + 1}/{max_retries if max_retries else '∞'} " + f"after {retry_delay} seconds." + ) + retries += 1 + time.sleep(retry_delay) # Wait before retrying + else: + # For any other error, log and break out of the retry loop + logger.error(f"Error while fetching result for job ID {job_id}: {e}") + return None, job_id + logger.error(f"Max retries exceeded for job {job_id}.") + return None, job_id + + # Use ThreadPoolExecutor to fetch results concurrently with ThreadPoolExecutor() as executor: - future = executor.submit(self._fetch_job_result_wait, job_id, timeout, data_only) - try: - # Wait for the result within the specified timeout - return future.result(timeout=timeout) - except concurrent.futures.TimeoutError: - # Raise a standard Python TimeoutError which the client will be expecting - raise TimeoutError(f"Job processing did not complete within the specified {timeout} seconds") - + futures = {executor.submit(fetch_with_retries, job_id): job_id for job_id in job_ids} + + # Collect results as futures complete + for future in as_completed(futures): + job_id = futures[future] + try: + result = handle_future_result(future, futures, timeout) + results.append(result.get("data")) + except concurrent.futures.TimeoutError: + logger.error(f"Timeout while fetching result for job ID {job_id}") + except json.JSONDecodeError as e: + logger.error(f"Decoding while processing job ID {job_id}: {e}") + except RuntimeError as e: + logger.error(f"Error while processing job ID {job_id}: {e}") + except Exception as e: + logger.error(f"Error while fetching result for job ID {job_id}: {e}") + + return results def _ensure_submitted(self, job_ids: List[str]): if isinstance(job_ids, str): @@ -356,7 +437,7 @@ def _ensure_submitted(self, job_ids: List[str]): job_state.future = None def fetch_job_result_async( - self, job_ids: Union[str, List[str]], timeout: float = 10, data_only: bool = True + self, job_ids: Union[str, List[str]], timeout: float = 10, data_only: bool = True ) -> Dict[Future, str]: """ Fetches job results for a list or a single job ID asynchronously and returns a mapping of futures to job IDs. @@ -386,9 +467,9 @@ def fetch_job_result_async( return future_to_job_id def _submit_job( - self, - job_index: str, - job_queue_id: str, + self, + job_index: str, + job_queue_id: str, ) -> Optional[Dict]: """ Submits a job to a specified job queue and optionally waits for a response if blocking is True. @@ -425,18 +506,43 @@ def _submit_job( # Free up memory -- payload should never be used again, and we don't want to keep it around. job_state.job_spec.payload = None - + return x_trace_id except Exception as err: logger.error(f"Failed to submit job {job_index} to queue {job_queue_id}: {err}") job_state.state = JobStateEnum.FAILED raise - def submit_job(self, job_indices: Union[str, List[str]], job_queue_id: str) -> List[Union[Dict, None]]: + def submit_job( + self, job_indices: Union[str, List[str]], job_queue_id: str, batch_size: int = 10 + ) -> List[Union[Dict, None]]: if isinstance(job_indices, str): job_indices = [job_indices] - return [self._submit_job(job_id, job_queue_id) for job_id in job_indices] + results = [] + total_batches = math.ceil(len(job_indices) / batch_size) + + submission_errors = [] + for batch_num in range(total_batches): + batch_start = batch_num * batch_size + batch_end = batch_start + batch_size + batch = job_indices[batch_start:batch_end] + + # Submit each batch of jobs + for job_id in batch: + try: + x_trace_id = self._submit_job(job_id, job_queue_id) + except Exception as e: # Even if one fails, we should continue with the rest of the batch. + submission_errors.append(e) + continue + results.append(x_trace_id) + + if submission_errors: + error_msg = str(submission_errors[0]) + if len(submission_errors) > 1: + error_msg += f"... [{len(submission_errors) - 1} more messages truncated]" + raise type(submission_errors[0])(error_msg) + return results def submit_job_async(self, job_indices: Union[str, List[str]], job_queue_id: str) -> Dict[Future, str]: """ @@ -475,3 +581,95 @@ def submit_job_async(self, job_indices: Union[str, List[str]], job_queue_id: str future_to_job_index[future] = job_index return future_to_job_index + + def create_jobs_for_batch(self, files_batch: List[str], tasks: Dict[str, Any]) -> List[JobSpec]: + """ + Create and submit job specifications (JobSpecs) for a batch of files, returning the job IDs. + This function takes a batch of files, processes each file to extract its content and type, + creates a job specification (JobSpec) for each file, and adds tasks from the provided task + list. It then submits the jobs to the client and collects their job IDs. + + Parameters + ---------- + files_batch : List[str] + A list of file paths to be processed. Each file is assumed to be in a format compatible + with the `extract_file_content` function, which extracts the file's content and type. + tasks : Dict[str, Any] + A dictionary of tasks to be added to each job. The keys represent task names, and the + values represent task specifications or configurations. Standard tasks include "split", + "extract", "store", "caption", "dedup", "filter", "embed", and "vdb_upload". + + Returns + ------- + Tuple[List[JobSpec], List[str]] + A Tuple containing the list of JobSpecs and list of job IDs corresponding to the submitted jobs. + Each job ID is returned by the client's `add_job` method. + + Raises + ------ + ValueError + If there is an error extracting the file content or type from any of the files, a + ValueError will be logged, and the corresponding file will be skipped. + + Notes + ----- + - The function assumes that a utility function `extract_file_content` is defined elsewhere, + which extracts the content and type from the provided file paths. + - For each file, a `JobSpec` is created with relevant metadata, including document type and + file content. Various tasks are conditionally added based on the provided `tasks` dictionary. + - The job specification includes tracing options with a timestamp (in nanoseconds) for + diagnostic purposes. + + Examples + -------- + Suppose you have a batch of files and tasks to process: + >>> files_batch = ["file1.txt", "file2.pdf"] + >>> tasks = {"split": ..., "extract_txt": ..., "store": ...} + >>> client = NvIngestClient() + >>> job_ids = client.create_job_specs_for_batch(files_batch, tasks) + >>> print(job_ids) + ['job_12345', 'job_67890'] + + In this example, jobs are created and submitted for the files in `files_batch`, with the + tasks in `tasks` being added to each job specification. The returned job IDs are then + printed. + + See Also + -------- + create_job_specs_for_batch: Function that creates job specifications for a batch of files. + JobSpec : The class representing a job specification. + """ + if not isinstance(tasks, dict): + raise ValueError("`tasks` must be a dictionary of task names -> task specifications.") + + job_specs = create_job_specs_for_batch(files_batch) + + job_ids = [] + for job_spec in job_specs: + logger.debug(f"Tasks: {tasks.keys()}") + for task in tasks: + logger.debug(f"Task: {task}") + + file_type = job_spec.document_type + + seen_tasks = set() # For tracking tasks and rejecting duplicate tasks. + + for task_name, task_config in tasks.items(): + if task_name.lower().startswith("extract_"): + task_file_type = task_name.split("_", 1)[1] + if file_type.lower() != task_file_type.lower(): + continue + elif not is_valid_task_type(task_name.upper()): + raise ValueError(f"Invalid task type: '{task_name}'") + + if str(task_config) in seen_tasks: + raise ValueError(f"Duplicate task detected: {task_name} with config {task_config}") + + job_spec.add_task(task_config) + + seen_tasks.add(str(task_config)) + + job_id = self.add_job(job_spec) + job_ids.append(job_id) + + return job_ids diff --git a/client/src/nv_ingest_client/nv_ingest_cli.py b/client/src/nv_ingest_client/nv_ingest_cli.py index c6afe067..1c382227 100644 --- a/client/src/nv_ingest_client/nv_ingest_cli.py +++ b/client/src/nv_ingest_client/nv_ingest_cli.py @@ -17,14 +17,14 @@ from nv_ingest_client.cli.util.click import click_validate_batch_size from nv_ingest_client.cli.util.click import click_validate_file_exists from nv_ingest_client.cli.util.click import click_validate_task -from nv_ingest_client.cli.util.dataset import get_dataset_files -from nv_ingest_client.cli.util.dataset import get_dataset_statistics from nv_ingest_client.cli.util.processing import create_and_process_jobs from nv_ingest_client.cli.util.processing import report_statistics from nv_ingest_client.cli.util.system import configure_logging from nv_ingest_client.cli.util.system import ensure_directory_with_permissions from nv_ingest_client.client import NvIngestClient from nv_ingest_client.message_clients.rest.rest_client import RestClient +from nv_ingest_client.util.dataset import get_dataset_files +from nv_ingest_client.util.dataset import get_dataset_statistics from pkg_resources import DistributionNotFound from pkg_resources import VersionConflict diff --git a/client/src/nv_ingest_client/primitives/__init__.py b/client/src/nv_ingest_client/primitives/__init__.py index b4834304..e081645f 100644 --- a/client/src/nv_ingest_client/primitives/__init__.py +++ b/client/src/nv_ingest_client/primitives/__init__.py @@ -2,7 +2,8 @@ # All rights reserved. # SPDX-License-Identifier: Apache-2.0 +from .jobs import BatchJobSpec from .jobs import JobSpec from .tasks import Task -__all__ = ["JobSpec", "Task"] +__all__ = ["BatchJobSpec", "JobSpec", "Task"] diff --git a/client/src/nv_ingest_client/primitives/jobs/__init__.py b/client/src/nv_ingest_client/primitives/jobs/__init__.py index 7d8b481a..ecd714f9 100644 --- a/client/src/nv_ingest_client/primitives/jobs/__init__.py +++ b/client/src/nv_ingest_client/primitives/jobs/__init__.py @@ -2,8 +2,9 @@ # All rights reserved. # SPDX-License-Identifier: Apache-2.0 -from .job_spec import JobSpec -from .job_state import JobState -from .job_state import JobStateEnum +from nv_ingest_client.primitives.jobs.job_spec import BatchJobSpec +from nv_ingest_client.primitives.jobs.job_spec import JobSpec +from nv_ingest_client.primitives.jobs.job_state import JobState +from nv_ingest_client.primitives.jobs.job_state import JobStateEnum -__all__ = ["JobSpec", "JobState", "JobStateEnum"] +__all__ = ["BatchJobSpec", "JobSpec", "JobState", "JobStateEnum"] diff --git a/client/src/nv_ingest_client/primitives/jobs/job_spec.py b/client/src/nv_ingest_client/primitives/jobs/job_spec.py index 8693ecd1..aba0c759 100644 --- a/client/src/nv_ingest_client/primitives/jobs/job_spec.py +++ b/client/src/nv_ingest_client/primitives/jobs/job_spec.py @@ -4,12 +4,17 @@ import logging +from collections import defaultdict +from io import BytesIO from typing import Dict from typing import List from typing import Optional +from typing import Union from uuid import UUID from nv_ingest_client.primitives.tasks import Task +from nv_ingest_client.util.dataset import get_dataset_files +from nv_ingest_client.util.dataset import get_dataset_statistics logger = logging.getLogger(__name__) @@ -53,13 +58,13 @@ class JobSpec: """ def __init__( - self, - payload: str = None, - tasks: Optional[List] = None, - source_id: Optional[str] = None, - source_name: Optional[str] = None, - document_type: Optional[str] = None, - extended_options: Optional[Dict] = None, + self, + payload: str = None, + tasks: Optional[List] = None, + source_id: Optional[str] = None, + source_name: Optional[str] = None, + document_type: Optional[str] = None, + extended_options: Optional[Dict] = None, ) -> None: self._document_type = document_type or "txt" self._extended_options = extended_options or {} @@ -134,6 +139,10 @@ def source_name(self) -> str: def source_name(self, source_name: str) -> None: self._source_name = source_name + @property + def document_type(self) -> str: + return self._document_type + def add_task(self, task) -> None: """ Adds a task to the job specification. @@ -152,3 +161,210 @@ def add_task(self, task) -> None: raise ValueError("Task must derive from nv_ingest_client.primitives.Task class") self._tasks.append(task) + + +class BatchJobSpec: + """ + A class used to represent a batch of job specifications (JobSpecs). + + This class allows for batch processing of multiple jobs, either from a list of + `JobSpec` instances or from file paths. It provides methods for adding job + specifications, associating tasks with those specifications, and serializing the + batch to a dictionary format. + + Attributes + ---------- + _file_type_to_job_spec : defaultdict + A dictionary that maps document types to a list of `JobSpec` instances. + """ + + def __init__(self, job_specs_or_files: Optional[Union[List[JobSpec], List[str]]] = None) -> None: + """ + Initializes the BatchJobSpec instance. + + Parameters + ---------- + job_specs_or_files : Optional[Union[List[JobSpec], List[str]]], optional + A list of either `JobSpec` instances or file paths. If provided, the + instance will be initialized with the corresponding job specifications. + """ + self._file_type_to_job_spec = defaultdict(list) + + if job_specs_or_files: + if isinstance(job_specs_or_files[0], JobSpec): + self.from_job_specs(job_specs_or_files) + elif isinstance(job_specs_or_files[0], str): + self.from_files(job_specs_or_files) + else: + raise ValueError("Invalid input type for job_specs. Must be a list of JobSpec or file paths.") + + def from_job_specs(self, job_specs: Union[JobSpec, List[JobSpec]]) -> None: + """ + Initializes the batch with a list of `JobSpec` instances. + + Parameters + ---------- + job_specs : Union[JobSpec, List[JobSpec]] + A single `JobSpec` or a list of `JobSpec` instances to add to the batch. + """ + if isinstance(job_specs, JobSpec): + job_specs = [job_specs] + + for job_spec in job_specs: + self.add_job_spec(job_spec) + + def from_files(self, files: Union[str, List[str]]) -> None: + """ + Initializes the batch by generating job specifications from file paths. + + Parameters + ---------- + files : Union[str, List[str]] + A single file path or a list of file paths to create job specifications from. + """ + from nv_ingest_client.util.util import create_job_specs_for_batch + from nv_ingest_client.util.util import generate_matching_files + + if isinstance(files, str): + files = [files] + + matching_files = list(generate_matching_files(files)) + if not matching_files: + logger.warning(f"No files found matching {files}.") + return + + job_specs = create_job_specs_for_batch(matching_files) + for job_spec in job_specs: + self.add_job_spec(job_spec) + + def _from_dataset(self, dataset: str, shuffle_dataset: bool = True) -> None: + """ + Internal method to initialize the batch from a dataset. + + Parameters + ---------- + dataset : str + The path to the dataset file. + shuffle_dataset : bool, optional + Whether to shuffle the dataset files before adding them to the batch, by default True. + """ + with open(dataset, "rb") as file: + dataset_bytes = BytesIO(file.read()) + + if logger.isEnabledFor(logging.DEBUG): + logger.debug(get_dataset_statistics(dataset_bytes)) + + dataset_files = get_dataset_files(dataset_bytes, shuffle_dataset) + + self.from_files(dataset_files) + + @classmethod + def from_dataset(cls, dataset: str, shuffle_dataset: bool = True): + """ + Class method to create a `BatchJobSpec` instance from a dataset. + + Parameters + ---------- + dataset : str + The path to the dataset file. + shuffle_dataset : bool, optional + Whether to shuffle the dataset files before adding them to the batch, by default True. + + Returns + ------- + BatchJobSpec + A new instance of `BatchJobSpec` initialized with the dataset files. + """ + batch_job_spec = cls() + batch_job_spec._from_dataset(dataset, shuffle_dataset=shuffle_dataset) + return batch_job_spec + + def add_job_spec(self, job_spec: JobSpec) -> None: + """ + Adds a `JobSpec` to the batch. + + Parameters + ---------- + job_spec : JobSpec + The job specification to add. + """ + self._file_type_to_job_spec[job_spec.document_type].append(job_spec) + + def add_task(self, task, document_type=None): + """ + Adds a task to the relevant job specifications in the batch. + + If a `document_type` is provided, the task will be added to all job specifications + matching that document type. If no `document_type` is provided, the task will be added + to all job specifications in the batch. + + Parameters + ---------- + task : Task + The task to add. Must derive from the `nv_ingest_client.primitives.Task` class. + + document_type : str, optional + The document type used to filter job specifications. If not provided, the + `document_type` is inferred from the task, or the task is applied to all job specifications. + + Raises + ------ + ValueError + If the task does not derive from the `Task` class. + """ + if not isinstance(task, Task): + raise ValueError("Task must derive from nv_ingest_client.primitives.Task class") + + document_type = document_type or task.to_dict().get("document_type") + + if document_type: + target_job_specs = self._file_type_to_job_spec[document_type] + else: + target_job_specs = [] + for job_specs in self._file_type_to_job_spec.values(): + target_job_specs.extend(job_specs) + + for job_spec in target_job_specs: + job_spec.add_task(task) + + def to_dict(self) -> Dict[str, List[Dict]]: + """ + Serializes the batch of job specifications into a list of dictionaries. + + Returns + ------- + List[Dict] + A list of dictionaries representing the job specifications in the batch. + """ + return { + file_type: [j.to_dict() for j in job_specs] for file_type, job_specs in self._file_type_to_job_spec.items() + } + + def __str__(self) -> str: + """ + Returns a string representation of the batch. + + Returns + ------- + str + A string representation of the job specifications in the batch. + """ + result = "" + for file_type, job_specs in self._file_type_to_job_spec.items(): + result += f"{file_type}\n" + for job_spec in job_specs: + result += str(job_spec) + "\n" + + return result + + @property + def job_specs(self) -> Dict[str, List[str]]: + """ + A property that returns a dictionary of job specs categorized by document type. + + Returns + ------- + Dict[str, List[str]] + A dictionary mapping document types to job specifications. + """ + return self._file_type_to_job_spec diff --git a/client/src/nv_ingest_client/primitives/tasks/extract.py b/client/src/nv_ingest_client/primitives/tasks/extract.py index 6c2205f1..36f43d04 100644 --- a/client/src/nv_ingest_client/primitives/tasks/extract.py +++ b/client/src/nv_ingest_client/primitives/tasks/extract.py @@ -246,3 +246,7 @@ def to_dict(self) -> Dict: } task_properties["params"].update(adobe_properties) return {"type": "extract", "task_properties": task_properties} + + @property + def document_type(self): + return self._document_type diff --git a/client/src/nv_ingest_client/primitives/tasks/task_base.py b/client/src/nv_ingest_client/primitives/tasks/task_base.py index 47cee997..aa33d452 100644 --- a/client/src/nv_ingest_client/primitives/tasks/task_base.py +++ b/client/src/nv_ingest_client/primitives/tasks/task_base.py @@ -16,6 +16,7 @@ class TaskType(Enum): CAPTION = auto() + DEDUP = auto() EMBED = auto() EXTRACT = auto() FILTER = auto() @@ -68,7 +69,6 @@ def to_dict(self) -> Dict: return {} - # class ExtractUnstructuredTask(ExtractTask): # """ # Object for document unstructured extraction task diff --git a/client/src/nv_ingest_client/primitives/tasks/task_factory.py b/client/src/nv_ingest_client/primitives/tasks/task_factory.py index faa91e3c..8a4eb64e 100644 --- a/client/src/nv_ingest_client/primitives/tasks/task_factory.py +++ b/client/src/nv_ingest_client/primitives/tasks/task_factory.py @@ -9,6 +9,7 @@ from typing import Union from .caption import CaptionTask +from .dedup import DedupTask from .embed import EmbedTask from .extract import ExtractTask from .filter import FilterTask @@ -33,6 +34,7 @@ def __init__(self, **kwargs) -> None: # Mapping of TaskType to Task classes, arranged alphabetically by task type _TASK_MAP: Dict[TaskType, Callable] = { TaskType.CAPTION: CaptionTask, + TaskType.DEDUP: DedupTask, TaskType.EMBED: EmbedTask, TaskType.EXTRACT: ExtractTask, TaskType.FILTER: FilterTask, diff --git a/client/src/nv_ingest_client/cli/util/dataset.py b/client/src/nv_ingest_client/util/dataset.py similarity index 100% rename from client/src/nv_ingest_client/cli/util/dataset.py rename to client/src/nv_ingest_client/util/dataset.py diff --git a/client/src/nv_ingest_client/util/processing.py b/client/src/nv_ingest_client/util/processing.py new file mode 100644 index 00000000..58e945a8 --- /dev/null +++ b/client/src/nv_ingest_client/util/processing.py @@ -0,0 +1,94 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import concurrent +import json +import logging +from typing import Any +from typing import Dict +from typing import Optional + +from nv_ingest_client.util.util import check_ingest_result + +logger = logging.getLogger(__name__) + + +def handle_future_result( + future: concurrent.futures.Future, + futures_dict: Dict[concurrent.futures.Future, str], + timeout: Optional[int] = None, +) -> Dict[str, Any]: + """ + Handle the result of a completed future job, process annotations, and save the result. + + This function processes the result of a future, extracts annotations (if any), logs them, + checks the validity of the ingest result, and optionally saves the result to the provided + output directory. If the result indicates a failure, a retry list of job IDs is prepared. + + Parameters + ---------- + future : concurrent.futures.Future + A future object representing an asynchronous job. The result of this job will be + processed once it completes. + + futures_dict : Dict[concurrent.futures.Future, str] + A dictionary mapping future objects to job IDs. The job ID associated with the + provided future is retrieved from this dictionary. + + Returns + ------- + Dict[str, Any] + + Raises + ------ + RuntimeError + If the job result is invalid, this exception is raised with a description of the failure. + + Notes + ----- + - The `future.result()` is assumed to return a tuple where the first element is the actual + result (as a dictionary), and the second element (if present) can be ignored. + - Annotations in the result (if any) are logged for debugging purposes. + - The `check_ingest_result` function (assumed to be defined elsewhere) is used to validate + the result. If the result is invalid, a `RuntimeError` is raised. + - The function handles saving the result data to the specified output directory using the + `save_response_data` function. + + Examples + -------- + Suppose we have a future object representing a job, a dictionary of futures to job IDs, + and a directory for saving results: + + >>> future = concurrent.futures.Future() + >>> futures_dict = {future: "job_12345"} + >>> job_id_map = {"job_12345": {...}} + >>> output_directory = "/path/to/save" + >>> result, retry_job_ids = handle_future_result(future, futures_dict, job_id_map, output_directory) + + In this example, the function processes the completed job and saves the result to the + specified directory. If the job fails, it raises a `RuntimeError` and returns a list of + retry job IDs. + + See Also + -------- + check_ingest_result : Function to validate the result of the job. + save_response_data : Function to save the result to a directory. + """ + + try: + result, _ = future.result(timeout=timeout)[0] + if ("annotations" in result) and result["annotations"]: + annotations = result["annotations"] + for key, value in annotations.items(): + logger.debug(f"Annotation: {key} -> {json.dumps(value, indent=2)}") + + failed, description = check_ingest_result(result) + + if failed: + raise RuntimeError(f"{description}") + except Exception as e: + logger.debug(f"Error processing future result: {e}") + raise e + + return result diff --git a/client/src/nv_ingest_client/util/util.py b/client/src/nv_ingest_client/util/util.py index 61ebc5e3..e5832ef1 100644 --- a/client/src/nv_ingest_client/util/util.py +++ b/client/src/nv_ingest_client/util/util.py @@ -2,22 +2,30 @@ # All rights reserved. # SPDX-License-Identifier: Apache-2.0 - +import glob import logging import os +import time import traceback import typing from io import BytesIO from typing import Dict +from typing import List import pypdfium2 as pdfium from docx import Document as DocxDocument +from nv_ingest_client.primitives.jobs.job_spec import JobSpec from nv_ingest_client.util.file_processing.extract import DocumentTypeEnum from nv_ingest_client.util.file_processing.extract import detect_encoding_and_read_text_file from nv_ingest_client.util.file_processing.extract import extract_file_content from nv_ingest_client.util.file_processing.extract import get_or_infer_file_type from pptx import Presentation + +logger = logging.getLogger(__name__) + + + # pylint: disable=invalid-name # pylint: disable=missing-class-docstring # pylint: disable=logging-fstring-interpolation @@ -246,17 +254,116 @@ def check_ingest_result(json_payload: Dict) -> typing.Tuple[bool, str]: logger.debug( f"Checking ingest result:\n Status: {json_payload.get('status', None)}" - f"\n Description: {json_payload.get('description', None)}") + f"\n Description: {json_payload.get('description', None)}" + ) is_failed = json_payload.get("status", "") in "failed" description = json_payload.get("description", "") # Look to see if we have any failure annotations to augment the description - if (is_failed and 'annotations' in json_payload): - for annot_id, value in json_payload['annotations'].items(): - if ('task_result' in value and value['task_result'] == "FAILURE"): - message = value.get('message', "Unknown") + if is_failed and "annotations" in json_payload: + for annot_id, value in json_payload["annotations"].items(): + if "task_result" in value and value["task_result"] == "FAILURE": + message = value.get("message", "Unknown") description = f"\n↪ Event that caused this failure: {annot_id} -> {message}" break return is_failed, description + + +def generate_matching_files(file_sources): + """ + Generates a list of file paths that match the given patterns specified in file_sources. + + Parameters + ---------- + file_sources : list of str + A list containing the file source patterns to match against. + + Returns + ------- + generator + A generator yielding paths to files that match the specified patterns. + + Notes + ----- + This function utilizes glob pattern matching to find files that match the specified patterns. + It yields each matching file path, allowing for efficient processing of potentially large + sets of files. + """ + files = [ + file_path + for pattern in file_sources + for file_path in glob.glob(pattern, recursive=True) + if os.path.isfile(file_path) + ] + for file_path in files: + yield file_path + + +def create_job_specs_for_batch(files_batch: List[str]) -> List[JobSpec]: + """ + Create and job specifications (JobSpecs) for a batch of files. + This function takes a batch of files, processes each file to extract its content and type, + creates a job specification (JobSpec) for each file. + + Parameters + ---------- + files_batch : List[str] + A list of file paths to be processed. Each file is assumed to be in a format compatible + with the `extract_file_content` function, which extracts the file's content and type. + + Returns + ------- + List[JobSpec] + A list of JobSpecs. + + Raises + ------ + ValueError + If there is an error extracting the file content or type from any of the files, a + ValueError will be logged, and the corresponding file will be skipped. + + Notes + ----- + - The function assumes that a utility function `extract_file_content` is defined elsewhere, + which extracts the content and type from the provided file paths. + - For each file, a `JobSpec` is created with relevant metadata, including document type and + file content. + - The job specification includes tracing options with a timestamp (in nanoseconds) for + diagnostic purposes. + + Examples + -------- + Suppose you have a batch of files and tasks to process: + + >>> files_batch = ["file1.txt", "file2.pdf"] + >>> client = NvIngestClient() + >>> job_specs = create_job_specs_for_batch(files_batch) + >>> print(job_specs) + [nv_ingest_client.primitives.jobs.job_spec.JobSpec object at 0x743acb468bb0>, ] + + See Also + -------- + extract_file_content : Function that extracts the content and type of a file. + JobSpec : The class representing a job specification. + """ + job_specs = [] + for file_name in files_batch: + try: + file_content, file_type = extract_file_content(file_name) # Assume these are defined + file_type = file_type.value + except ValueError as ve: + logger.error(f"Error extracting content from {file_name}: {ve}") + continue + + job_spec = JobSpec( + document_type=file_type, + payload=file_content, + source_id=file_name, + source_name=file_name, + extended_options={"tracing_options": {"trace": True, "ts_send": time.time_ns()}}, + ) + job_specs.append(job_spec) + + return job_specs diff --git a/tests/nv_ingest_client/cli/util/test_click.py b/tests/nv_ingest_client/cli/util/test_click.py index 2c9938fd..a3d9d531 100644 --- a/tests/nv_ingest_client/cli/util/test_click.py +++ b/tests/nv_ingest_client/cli/util/test_click.py @@ -8,7 +8,6 @@ import click import pytest -from nv_ingest_client.cli.util.click import _generate_matching_files from nv_ingest_client.cli.util.click import click_match_and_validate_files from nv_ingest_client.cli.util.click import click_validate_batch_size from nv_ingest_client.cli.util.click import click_validate_file_exists @@ -230,29 +229,13 @@ def test_empty_file_list(tmp_path): assert files == [], "Expected an empty list of files" -@pytest.mark.parametrize( - "patterns, mock_files, expected", - [ - (["*.txt"], ["test1.txt", "test2.txt"], ["test1.txt", "test2.txt"]), - (["*.txt"], [], []), - (["*.md"], ["README.md"], ["README.md"]), - (["docs/*.md"], ["docs/README.md", "docs/CHANGES.md"], ["docs/README.md", "docs/CHANGES.md"]), - ], -) -def test_generate_matching_files(patterns, mock_files, expected): - with patch( - "glob.glob", side_effect=lambda pattern, recursive: [f for f in mock_files if f.startswith(pattern[:-5])] - ), patch("os.path.isfile", return_value=True): - assert list(_generate_matching_files(patterns)) == expected - - def test_click_match_and_validate_files_found(): - with patch(f"{_MODULE_UNDER_TEST}._generate_matching_files", return_value=iter(["file1.txt", "file2.txt"])): + with patch(f"{_MODULE_UNDER_TEST}.generate_matching_files", return_value=iter(["file1.txt", "file2.txt"])): result = click_match_and_validate_files(None, None, ["*.txt"]) assert result == ["file1.txt", "file2.txt"] def test_click_match_and_validate_files_not_found(): - with patch(f"{_MODULE_UNDER_TEST}._generate_matching_files", return_value=iter([])): + with patch(f"{_MODULE_UNDER_TEST}.generate_matching_files", return_value=iter([])): result = click_match_and_validate_files(None, None, ["*.txt"]) assert result == [] diff --git a/tests/nv_ingest_client/client/test_client.py b/tests/nv_ingest_client/client/test_client.py index 1432f1e6..a4d447e2 100644 --- a/tests/nv_ingest_client/client/test_client.py +++ b/tests/nv_ingest_client/client/test_client.py @@ -9,15 +9,21 @@ from concurrent.futures import Future from concurrent.futures import as_completed from unittest.mock import MagicMock +from unittest.mock import Mock +from unittest.mock import mock_open +from unittest.mock import patch import pytest from nv_ingest_client.client import NvIngestClient from nv_ingest_client.primitives.jobs import JobSpec from nv_ingest_client.primitives.jobs import JobState from nv_ingest_client.primitives.jobs import JobStateEnum +from nv_ingest_client.primitives.tasks import ExtractTask from nv_ingest_client.primitives.tasks import SplitTask from nv_ingest_client.primitives.tasks import TaskType +MODULE_UNDER_TEST = "nv_ingest_client.client.client" + class MockClient: def __init__(self, host, port): @@ -525,3 +531,79 @@ def test_futures_reflect_submission_outcome(nv_ingest_client_with_jobs, job_id): # for future in as_completed(futures.keys()): # result = future.result()[0] # assert result[0] == {"result": "success"}, f"The fetched job result for {job_id} should be successful" + + +@pytest.fixture +def mock_create_job_specs_for_batch(): + with patch(f"{MODULE_UNDER_TEST}.create_job_specs_for_batch") as mock_create: + yield mock_create + + +@pytest.fixture +def tasks(): + """Fixture for common tasks.""" + return { + "split": SplitTask(), + "extract_pdf": ExtractTask(document_type="pdf"), + } + + +def test_create_jobs_for_batch_success(nv_ingest_client, tasks, mock_create_job_specs_for_batch): + mock_job_spec = Mock(spec=JobSpec) + mock_create_job_specs_for_batch.return_value = [mock_job_spec, mock_job_spec] + + files = ["file1.pdf", "file2.pdf"] + + job_ids = nv_ingest_client.create_jobs_for_batch(files, tasks) + + assert job_ids == ["0", "1"] + + +def test_create_jobs_for_batch_invalid_task(nv_ingest_client, mock_create_job_specs_for_batch): + mock_job_spec = Mock(spec=JobSpec) + mock_create_job_specs_for_batch.return_value = [mock_job_spec] + + files = ["file1.pdf"] + invalid_tasks = { + "invalid_task": None, + } + + with pytest.raises(ValueError, match="Invalid task type: 'invalid_task'"): + nv_ingest_client.create_jobs_for_batch(files, invalid_tasks) + + +def test_create_jobs_for_batch_duplicate_task(nv_ingest_client, mock_create_job_specs_for_batch): + mock_job_spec = Mock(spec=JobSpec) + mock_create_job_specs_for_batch.return_value = [mock_job_spec] + + files = ["file1.pdf"] + duplicate_tasks = { + "split": SplitTask(split_by="sentence"), + "store": SplitTask(split_by="sentence"), # Duplicate task + } + + with pytest.raises(ValueError, match="Duplicate task detected"): + nv_ingest_client.create_jobs_for_batch(files, duplicate_tasks) + + +def test_create_jobs_for_batch_extract_mismatch(nv_ingest_client, mock_create_job_specs_for_batch): + mock_job_spec = Mock(spec=JobSpec) + mock_job_spec.document_type = "pptx" + mock_create_job_specs_for_batch.return_value = [mock_job_spec] + + files = ["file1.pptx"] + tasks = { + "split": SplitTask(), + "extract_pdf": ExtractTask(document_type="pdf"), # Mismatch with pptx file + "extract_pptx": ExtractTask(document_type="pptx"), + } + + job_ids = nv_ingest_client.create_jobs_for_batch(files, tasks) + + assert job_ids == ["0"] + + # Check that extract_pdf was NOT called by inspecting the mock call list + calls = [call[0][0] for call in mock_job_spec.add_task.call_args_list] + assert tasks["split"] in calls + assert tasks["extract_pptx"] in calls + assert tasks["extract_pdf"] not in calls, "extract_pdf should not have been added" diff --git a/tests/nv_ingest_client/primitives/jobs/test_job_spec.py b/tests/nv_ingest_client/primitives/jobs/test_job_spec.py index bc9fd5ba..668fde3c 100644 --- a/tests/nv_ingest_client/primitives/jobs/test_job_spec.py +++ b/tests/nv_ingest_client/primitives/jobs/test_job_spec.py @@ -2,24 +2,33 @@ # All rights reserved. # SPDX-License-Identifier: Apache-2.0 +import json +import logging import uuid from typing import Dict +from unittest.mock import MagicMock +from unittest.mock import Mock +from unittest.mock import patch import pytest +from nv_ingest_client.primitives.jobs.job_spec import BatchJobSpec from nv_ingest_client.primitives.jobs.job_spec import JobSpec from nv_ingest_client.primitives.tasks import Task +MODULE_UNDER_TEST = "nv_ingest_client.primitives.jobs.job_spec" + # Assuming the Task class has a to_dict method class MockTask(Task): def to_dict(self) -> Dict: - return {"task": "mocktask"} + return {"document_type": "pdf", "task": "mocktask"} # Fixture to create a JobSpec instance @pytest.fixture def job_spec_fixture() -> JobSpec: return JobSpec( + document_type="pdf", payload={"key": "value"}, tasks=[MockTask()], source_id="source123", @@ -28,6 +37,22 @@ def job_spec_fixture() -> JobSpec: ) +def create_json_file(tmp_path, content): + file_path = tmp_path / "dataset.json" + with open(file_path, "w") as f: + json.dump(content, f) + return str(file_path) + + +@pytest.fixture +def dataset(tmp_path): + content = {"sampled_files": ["file1.txt", "file2.txt", "file3.txt"]} + file_path = tmp_path / "dataset.json" + with open(file_path, "w") as f: + json.dump(content, f) + return str(file_path) + + # Test initialization def test_job_spec_initialization(): job_spec = JobSpec( @@ -92,3 +117,120 @@ def test_set_properties(): job_spec.source_name = "source456.pdf" assert job_spec.source_name == "source456.pdf" + + +@pytest.fixture +def batch_job_spec_fixture(job_spec_fixture) -> BatchJobSpec: + batch_job_spec = BatchJobSpec() + batch_job_spec.add_job_spec(job_spec_fixture) + return batch_job_spec + + +def test_init_with_job_specs(job_spec_fixture): + batch_job_spec = BatchJobSpec([job_spec_fixture]) + + assert "pdf" in batch_job_spec._file_type_to_job_spec + assert job_spec_fixture in batch_job_spec._file_type_to_job_spec["pdf"] + + +def test_init_with_files(mocker, job_spec_fixture): + mocker.patch("nv_ingest_client.util.util.generate_matching_files", return_value=["file1.pdf"]) + mocker.patch("nv_ingest_client.util.util.create_job_specs_for_batch", return_value=[job_spec_fixture]) + + batch_job_spec = BatchJobSpec(["file1.pdf"]) + + # Verify that the files were processed and job specs were created + assert "pdf" in batch_job_spec._file_type_to_job_spec + assert len(batch_job_spec._file_type_to_job_spec["pdf"]) > 0 + + +def test_add_task_to_specific_document_type(batch_job_spec_fixture): + task = MockTask() + + # Add task to jobs with document_type 'pdf' + batch_job_spec_fixture.add_task(task, document_type="pdf") + + # Assert that the task was added to the JobSpec with document_type 'pdf' + for job_spec in batch_job_spec_fixture._file_type_to_job_spec["pdf"]: + assert task in job_spec._tasks + + +def test_add_task_to_inferred_document_type(batch_job_spec_fixture): + task = MockTask() + + # Add task without specifying document_type, should infer from task's to_dict + batch_job_spec_fixture.add_task(task) + + # Assert that the task was added to the JobSpec with the inferred document_type 'pdf' + for job_spec in batch_job_spec_fixture._file_type_to_job_spec["pdf"]: + assert task in job_spec._tasks + + +def test_add_task_to_all_job_specs(batch_job_spec_fixture): + # Mock a task without a document_type + task = MockTask() + task.to_dict = Mock(return_value={"task": "mocktask"}) # No document_type returned + + # Add task without document_type, it should add to all job specs + batch_job_spec_fixture.add_task(task) + + # Assert that the task was added to all job specs in the batch + for job_specs in batch_job_spec_fixture._file_type_to_job_spec.values(): + for job_spec in job_specs: + assert task in job_spec._tasks + + +def test_add_task_raises_value_error_for_invalid_task(batch_job_spec_fixture): + # Create an invalid task that doesn't derive from Task + invalid_task = object() + + # Expect a ValueError when adding an invalid task + with pytest.raises(ValueError, match="Task must derive from nv_ingest_client.primitives.Task class"): + batch_job_spec_fixture.add_task(invalid_task) + + +def test_batch_job_spec_to_dict(batch_job_spec_fixture): + result = batch_job_spec_fixture.to_dict() + + assert isinstance(result, dict) + assert "pdf" in result + assert len(result["pdf"]) > 0 + + +def test_batch_job_spec_str_method(batch_job_spec_fixture): + result = str(batch_job_spec_fixture) + + assert "pdf" in result + assert "source123" in result + + +@patch(f"{MODULE_UNDER_TEST}.get_dataset_files") +@patch(f"{MODULE_UNDER_TEST}.get_dataset_statistics") +@patch(f"{MODULE_UNDER_TEST}.logger") +def test__from_dataset(mock_logger, mock_get_dataset_statistics, mock_get_dataset_files, dataset): + mock_get_dataset_files.return_value = ["file1.txt", "file2.txt", "file3.txt"] + mock_get_dataset_statistics.return_value = "Statistics info" + + batch_job_spec = BatchJobSpec() + + batch_job_spec.from_files = MagicMock() + + batch_job_spec._from_dataset(dataset) + + mock_get_dataset_files.assert_called_once() + + mock_get_dataset_statistics.assert_called_once() + + batch_job_spec.from_files.assert_called_once_with(["file1.txt", "file2.txt", "file3.txt"]) + + if mock_logger.isEnabledFor(logging.DEBUG): + mock_logger.debug.assert_called_once_with("Statistics info") + + +@patch(f"{MODULE_UNDER_TEST}.BatchJobSpec._from_dataset") +def test_from_dataset(mock__from_dataset, dataset): + batch_job_spec = BatchJobSpec.from_dataset(dataset, shuffle_dataset=False) + + assert isinstance(batch_job_spec, BatchJobSpec) + + mock__from_dataset.assert_called_once_with(dataset, shuffle_dataset=False) diff --git a/tests/nv_ingest_client/cli/util/test_dataset.py b/tests/nv_ingest_client/util/test_dataset.py similarity index 94% rename from tests/nv_ingest_client/cli/util/test_dataset.py rename to tests/nv_ingest_client/util/test_dataset.py index 65b6e3ec..6edee02a 100644 --- a/tests/nv_ingest_client/cli/util/test_dataset.py +++ b/tests/nv_ingest_client/util/test_dataset.py @@ -6,8 +6,8 @@ from io import BytesIO import pytest -from nv_ingest_client.cli.util.dataset import get_dataset_files -from nv_ingest_client.cli.util.dataset import get_dataset_statistics +from nv_ingest_client.util.dataset import get_dataset_files +from nv_ingest_client.util.dataset import get_dataset_statistics @pytest.fixture diff --git a/tests/nv_ingest_client/util/test_util.py b/tests/nv_ingest_client/util/test_util.py index 1e1dab22..5f421fc4 100644 --- a/tests/nv_ingest_client/util/test_util.py +++ b/tests/nv_ingest_client/util/test_util.py @@ -2,6 +2,10 @@ # All rights reserved. # SPDX-License-Identifier: Apache-2.0 +from unittest.mock import patch + +import pytest +from nv_ingest_client.cli.util.click import generate_matching_files _MODULE_UNDER_TEST = "nv_ingest_client.util.util" @@ -66,3 +70,19 @@ # with patch(f"{_MODULE_UNDER_TEST}.os.path.splitext", return_value=("", ".pdf")): # with patch(f"{_MODULE_UNDER_TEST}.fitz.open", side_effect=Exception("Some error")): # assert estimate_page_count(file_path) == 0 + + +@pytest.mark.parametrize( + "patterns, mock_files, expected", + [ + (["*.txt"], ["test1.txt", "test2.txt"], ["test1.txt", "test2.txt"]), + (["*.txt"], [], []), + (["*.md"], ["README.md"], ["README.md"]), + (["docs/*.md"], ["docs/README.md", "docs/CHANGES.md"], ["docs/README.md", "docs/CHANGES.md"]), + ], +) +def test_generate_matching_files(patterns, mock_files, expected): + with patch( + "glob.glob", side_effect=lambda pattern, recursive: [f for f in mock_files if f.startswith(pattern[:-5])] + ), patch("os.path.isfile", return_value=True): + assert list(generate_matching_files(patterns)) == expected