Skip to content

Commit

Permalink
audio integration (#324)
Browse files Browse the repository at this point in the history
Co-authored-by: Devin Robison <[email protected]>
Co-authored-by: edknv <[email protected]>
Co-authored-by: Edward Kim <[email protected]>
  • Loading branch information
4 people authored Mar 3, 2025
1 parent 959598f commit a89cbbd
Show file tree
Hide file tree
Showing 23 changed files with 1,224 additions and 8 deletions.
1 change: 1 addition & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ RUN apt-get update && apt-get install -y \
bzip2 \
ca-certificates \
curl \
ffmpeg \
libgl1-mesa-glx \
software-properties-common \
wget \
Expand Down
22 changes: 18 additions & 4 deletions client/src/nv_ingest_client/client/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,17 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.

import glob
import logging
import os
import shutil
import tempfile
from tqdm import tqdm
from concurrent.futures import Future
from functools import wraps
from typing import Any, Union
from typing import Any
from typing import Dict
from typing import List
from typing import Optional
from typing import Union

import fsspec
from nv_ingest_client.client.client import NvIngestClient
Expand All @@ -28,8 +29,11 @@
from nv_ingest_client.primitives.tasks import SplitTask
from nv_ingest_client.primitives.tasks import StoreEmbedTask
from nv_ingest_client.primitives.tasks import StoreTask
from nv_ingest_client.util.util import filter_function_kwargs
from nv_ingest_client.util.milvus import MilvusOperator
from nv_ingest_client.util.util import filter_function_kwargs
from tqdm import tqdm

logger = logging.getLogger(__name__)

DEFAULT_JOB_QUEUE_ID = "morpheus_task_queue"

Expand Down Expand Up @@ -397,7 +401,17 @@ def extract(self, **kwargs: Any) -> "Ingestor":
# Users have to set to True if infographic extraction is required.
extract_infographics = kwargs.pop("extract_infographics", False)

for document_type in self._job_specs.file_types:
for file_type in self._job_specs.file_types:
# Let user override document_type if user explicitly sets document_type.
if "document_type" in kwargs:
document_type = kwargs.pop("document_type")
if document_type != file_type:
logger.warning(
f"User-specified document_type '{document_type}' overrides the inferred type '{file_type}'.",
)
else:
document_type = file_type

extract_task = ExtractTask(
document_type,
extract_tables=extract_tables,
Expand Down
4 changes: 4 additions & 0 deletions client/src/nv_ingest_client/primitives/jobs/job_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

from nv_ingest_client.primitives.tasks import Task
from nv_ingest_client.primitives.tasks import ExtractTask
from nv_ingest_client.primitives.tasks.extract import _DEFAULT_EXTRACTOR_MAP
from nv_ingest_client.primitives.tasks.audio_extraction import AudioExtractionTask
from nv_ingest_client.primitives.tasks.table_extraction import TableExtractionTask
from nv_ingest_client.primitives.tasks.chart_extraction import ChartExtractionTask
from nv_ingest_client.primitives.tasks.infographic_extraction import InfographicExtractionTask
Expand Down Expand Up @@ -172,6 +174,8 @@ def add_task(self, task) -> None:
self._tasks.append(ChartExtractionTask())
if isinstance(task, ExtractTask) and (task._extract_infographics is True):
self._tasks.append(InfographicExtractionTask())
if isinstance(task, ExtractTask) and (_DEFAULT_EXTRACTOR_MAP[self._document_type] == "audio"):
self._tasks.append(AudioExtractionTask())


class BatchJobSpec:
Expand Down
90 changes: 90 additions & 0 deletions client/src/nv_ingest_client/primitives/tasks/audio_extraction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES.
# All rights reserved.
# SPDX-License-Identifier: Apache-2.0


# pylint: disable=too-few-public-methods
# pylint: disable=too-many-arguments

import logging
from typing import Dict
from typing import Optional

from pydantic import ConfigDict, BaseModel

from .task_base import Task

logger = logging.getLogger(__name__)


class AudioExtractionSchema(BaseModel):
auth_token: Optional[str] = None
grpc_endpoint: Optional[str] = None
http_endpoint: Optional[str] = None
infer_protocol: Optional[str] = None
use_ssl: Optional[bool] = None
ssl_cert: Optional[str] = None

model_config = ConfigDict(extra="forbid")
model_config["protected_namespaces"] = ()


class AudioExtractionTask(Task):
def __init__(
self,
auth_token: str = None,
grpc_endpoint: str = None,
infer_protocol: str = None,
use_ssl: bool = None,
ssl_cert: str = None,
) -> None:
super().__init__()

self._auth_token = auth_token
self._grpc_endpoint = grpc_endpoint
self._infer_protocol = infer_protocol
self._use_ssl = use_ssl
self._ssl_cert = ssl_cert

def __str__(self) -> str:
"""
Returns a string with the object's config and run time state
"""
info = ""
info += "Audio Extraction Task:\n"

if self._auth_token:
info += " auth_token: [redacted]\n"
if self._grpc_endpoint:
info += f" grpc_endpoint: {self._grpc_endpoint}\n"
if self._infer_protocol:
info += f" infer_protocol: {self._infer_protocol}\n"
if self._use_ssl:
info += f" use_ssl: {self._use_ssl}\n"
if self._ssl_cert:
info += " ssl_cert: [redacted]\n"

return info

def to_dict(self) -> Dict:
"""
Convert to a dict for submission to redis
"""
task_properties = {}

if self._auth_token:
task_properties["auth_token"] = self._auth_token

if self._grpc_endpoint:
task_properties["grpc_endpoint"] = self._grpc_endpoint

if self._infer_protocol:
task_properties["infer_protocol"] = self._infer_protocol

if self._use_ssl:
task_properties["use_ssl"] = self._use_ssl

if self._ssl_cert:
task_properties["ssl_cert"] = self._ssl_cert

return {"type": "audio_data_extract", "task_properties": task_properties}
16 changes: 16 additions & 0 deletions client/src/nv_ingest_client/primitives/tasks/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@
"svg": "image",
"tiff": "image",
"xml": "lxml",
"mp3": "audio",
"wav": "audio",
}

_Type_Extract_Method_PDF = Literal[
Expand All @@ -63,6 +65,8 @@

_Type_Extract_Method_Image = Literal["image"]

_Type_Extract_Method_Audio = Literal["audio"]

_Type_Extract_Method_Map = {
"docx": get_args(_Type_Extract_Method_DOCX),
"jpeg": get_args(_Type_Extract_Method_Image),
Expand All @@ -72,6 +76,8 @@
"pptx": get_args(_Type_Extract_Method_PPTX),
"svg": get_args(_Type_Extract_Method_Image),
"tiff": get_args(_Type_Extract_Method_Image),
"mp3": get_args(_Type_Extract_Method_Audio),
"wav": get_args(_Type_Extract_Method_Audio),
}

_Type_Extract_Tables_Method_PDF = Literal["yolox", "pdfium", "nemoretriever_parse"]
Expand Down Expand Up @@ -181,6 +187,7 @@ def __init__(
extract_images: bool = False,
extract_tables: bool = False,
extract_charts: Optional[bool] = None,
extract_audio_params: Optional[Dict[str, Any]] = None,
extract_images_method: _Type_Extract_Images_Method = "group",
extract_images_params: Optional[Dict[str, Any]] = None,
extract_tables_method: _Type_Extract_Tables_Method_PDF = "yolox",
Expand All @@ -194,6 +201,7 @@ def __init__(
super().__init__()

self._document_type = document_type
self._extract_audio_params = extract_audio_params
self._extract_images = extract_images
self._extract_method = extract_method
self._extract_tables = extract_tables
Expand Down Expand Up @@ -230,6 +238,8 @@ def __str__(self) -> str:

if self._extract_images_params:
info += f" extract images params: {self._extract_images_params}\n"
if self._extract_audio_params:
info += f" extract audio params: {self._extract_audio_params}\n"
return info

def to_dict(self) -> Dict:
Expand All @@ -253,6 +263,12 @@ def to_dict(self) -> Dict:
"extract_images_params": self._extract_images_params,
}
)
if self._extract_audio_params:
extract_params.update(
{
"extract_audio_params": self._extract_audio_params,
}
)

task_properties = {
"method": self._extract_method,
Expand Down
4 changes: 4 additions & 0 deletions client/src/nv_ingest_client/util/file_processing/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ class DocumentTypeEnum(str, Enum):
svg = "svg"
tiff = "tiff"
txt = "text"
mp3 = "mp3"
wav = "wav"


# Maps MIME types to DocumentTypeEnum
Expand Down Expand Up @@ -64,6 +66,8 @@ class DocumentTypeEnum(str, Enum):
"svg": DocumentTypeEnum.svg,
"tiff": DocumentTypeEnum.tiff,
"txt": DocumentTypeEnum.txt,
"mp3": DocumentTypeEnum.mp3,
"wav": DocumentTypeEnum.wav,
# Add more as needed
}

Expand Down
2 changes: 2 additions & 0 deletions conda/environments/nv_ingest_environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ dependencies:
- click>=8.1.7
- fastapi>=0.115.6
- fastparquet>=2024.11.0
- ffmpeg-python>=0.2.0
- fsspec>=2024.10.0
- httpx>=0.28.1
- isodate>=0.7.2
Expand Down Expand Up @@ -46,6 +47,7 @@ dependencies:
- pip
- pip:
- llama-index-embeddings-nvidia
- nvidia-riva-client
- opencv-python # For some reason conda cant solve our req set with py-opencv so we need to use pip
- pymilvus>=2.5.0
- pymilvus[bulk_writer, model]
27 changes: 27 additions & 0 deletions docker-compose.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,31 @@ services:
profiles:
- vlm

audio:
image: nvcr.io/nvidia/riva/riva-speech:2.18.0
shm_size: 2gb
ports:
- "8021:50051" # grpc
- "8022:50000" # http
user: root
environment:
- MODEL_DEPLOY_KEY=tlt_encode
- NGC_CLI_API_KEY=${RIVA_NGC_API_KEY}
- NGC_CLI_ORG=nvidia
- NGC_CLI_TEAM=riva
- CUDA_VISIBLE_DEVICES=0
deploy:
resources:
reservations:
devices:
- driver: nvidia
device_ids: ["1"]
capabilities: [gpu]
runtime: nvidia
command: bash -c "download_and_deploy_ngc_models nvidia/riva/rmir_asr_conformer_en_us_ofl:2.18.0 && start-riva"
profiles:
- audio

nv-ingest-ms-runtime:
image: nvcr.io/nvidia/nemo-microservices/nv-ingest:24.12
build:
Expand All @@ -215,6 +240,8 @@ services:
cap_add:
- sys_nice
environment:
- AUDIO_GRPC_ENDPOINT=audio:50051
- AUDIO_INFER_PROTOCOL=grpc
- CUDA_VISIBLE_DEVICES=-1
- MAX_INGEST_PROCESS_WORKERS=${MAX_PROCESS_WORKERS:-16}
- EMBEDDING_NIM_MODEL_NAME=${EMBEDDING_NIM_MODEL_NAME:-nvidia/llama-3.2-nv-embedqa-1b-v2}
Expand Down
Loading

0 comments on commit a89cbbd

Please sign in to comment.