diff --git a/transforms/language/gneissweb_classification/Makefile b/transforms/language/gneissweb_classification/Makefile index f3a088ff33..4e8112adf1 100644 --- a/transforms/language/gneissweb_classification/Makefile +++ b/transforms/language/gneissweb_classification/Makefile @@ -21,8 +21,9 @@ run-cli-sample: $(PYTHON) -m dpk_$(TRANSFORM_NAME).transform_python \ --data_local_config "{ 'input_folder' : 'test-data/input', 'output_folder' : 'output'}" \ --gcls_model_credential "PUT YOUR OWN HUGGINGFACE CREDENTIAL" \ - --gcls_model_file_name "model.bin" \ - --gcls_model_url "facebook/fasttext-language-identification" \ + --gcls_model_file_name "['fasttext_medical.bin']" \ + --gcls_model_url "['ibm-granite/GneissWeb.Med_classifier']"\ + --gcls_output_label_column_name "['label']" \ --gcls_content_column_name "text" run-cli-ray-sample: @@ -31,6 +32,7 @@ run-cli-ray-sample: $(PYTHON) -m dpk_$(TRANSFORM_NAME).ray.transform \ --run_locally True --data_local_config "{ 'input_folder' : 'test-data/input', 'output_folder' : 'output'}" \ --gcls_model_credential "PUT YOUR OWN HUGGINGFACE CREDENTIAL" \ - --gcls_model_file_name "model.bin" \ - --gcls_model_url "facebook/fasttext-language-identification" \ + --gcls_model_file_name "['fasttext_medical.bin']" \ + --gcls_model_url "['ibm-granite/GneissWeb.Med_classifier']"\ + --gcls_output_label_column_name "['label']" \ --gcls_content_column_name "text" diff --git a/transforms/language/gneissweb_classification/README.md b/transforms/language/gneissweb_classification/README.md index 4c22cf033a..dbe73a5cc7 100644 --- a/transforms/language/gneissweb_classification/README.md +++ b/transforms/language/gneissweb_classification/README.md @@ -3,8 +3,17 @@ The Gneissweb Classification transform serves as a simple exemplar to demonstrat of a simple 1:1 transform. Please see the set of [transform project conventions](../../README.md#transform-project-conventions) for details on general project conventions, transform configuration, testing and IDE set up. +## Contributors + +- Ran Iwamoto (ran.iwamoto1@ibm.com) + ## Summary -This transform will classify each text with confidence score with fasttext classification model such as [ref](https://huggingface.co/facebook/fasttext-language-identification). +This transform will classify each text with confidence score with fasttext classification model such as: +- [ibm-granite/GneissWeb.Quality_annotator](https://huggingface.co/ibm-granite/GneissWeb.Quality_annotator) +- [ibm-granite/GneissWeb.Sci_classifier](https://huggingface.co/ibm-granite/GneissWeb.Sci_classifier) +- [ibm-granite/GneissWeb.Tech_classifier](https://huggingface.co/ibm-granite/GneissWeb.Tech_classifier) +- [ibm-granite/GneissWeb.Edu_classifier](https://huggingface.co/ibm-granite/GneissWeb.Edu_classifier) +- [ibm-granite/GneissWeb.Med_classifier](https://huggingface.co/ibm-granite/GneissWeb.Med_classifier) ## Configuration and command line Options @@ -13,12 +22,13 @@ configuration for values are as follows: | Configuration Parameters | Default | Description | |------------|----------|--------------| -| gcls_model_credential | _unset_ | specifies the credential you use to get model. This will be huggingface token. [Guide to get huggingface token](https://huggingface.co/docs/hub/security-tokens) | -| gcls_model_file_name | _unset_ | specifies what filename of model you use to get model, like `model.bin` | -| gcls_model_url | _unset_ | specifies url that model locates. For fasttext, this will be repo name of the model, like `facebook/fasttext-language-identification` | +| gcls_model_credential | _unset_ | specifies the credential you use to get models. This will be huggingface token. [Guide to get huggingface token](https://huggingface.co/docs/hub/security-tokens) | +| gcls_model_file_name | _unset_ | specifies what filename of models you use to get models, like [`fasttext_gneissweb_quality_annotator.bin`,`fasttext_science.bin`,`fasttext_technology_computing.bin`,`fasttext_education.bin`,`fasttext_medical.bin`] | +| gcls_model_url | _unset_ | specifies urls that models locate. For fasttext, this will be repo name of the models, like [`ibm-granite/GneissWeb.Quality_annotator`,`ibm-granite/GneissWeb.Sci_classifier`,`ibm-granite/GneissWeb.Tech_classifier`,`ibm-granite/GneissWeb.Edu_classifier`,`ibm-granite/GneissWeb.Med_classifier`] | +| gcls_n_processes | 1 | number of processes. Must be a positive integer | | gcls_content_column_name | `contents` | specifies name of the column containing documents | -| gcls_output_lablel_column_name | `label` | specifies name of the output column to hold predicted classes| -| gcls_output_score_column_name | `score` | specifies name of the output column to hold score of prediction | +| gcls_output_lablel_column_name | [`label_quality`,`label_sci`,`label_tech`,`label_edu`,`label_med`] | specifies name of the output column to hold predicted classes| +| gcls_output_score_column_name | [`score_quality`,`score_sci`,`score_tech`,`score_edu`,`score_med`] | specifies name of the output column to hold score of prediction | ## Running @@ -28,12 +38,13 @@ the options provided by the [launcher](../../../data-processing-lib/doc/launcher-options.md). The prefix gcls is short name for Gneissweb CLaSsification. ``` - --gcls_model_credential GCLS_MODEL_CREDENTIAL the credential you use to get model. This will be huggingface token. - --gcls_model_file_name GCLS_MODEL_KIND filename of model you use to get model. Currently,like `model.bin` - --gcls_model_url GCLS_MODEL_URL url that model locates. For fasttext, this will be repo name of the model, like `facebook/fasttext-language-identification` + --gcls_model_credential GCLS_MODEL_CREDENTIAL the credential you use to get models. This will be huggingface token. + --gcls_model_file_name GCLS_MODEL_KIND filename of models you use to get models. Currently,like [`fasttext_gneissweb_quality_annotator.bin`,`fasttext_science.bin`,`fasttext_technology_computing.bin`,`fasttext_education.bin`,`fasttext_medical.bin`] + --gcls_model_url GCLS_MODEL_URL urls that models locate. For fasttext, this will be repo name of the models, like [`ibm-granite/GneissWeb.Quality_annotator`,`ibm-granite/GneissWeb.Sci_classifier`,`ibm-granite/GneissWeb.Tech_classifier`,`ibm-granite/GneissWeb.Edu_classifier`,`ibm-granite/GneissWeb.Med_classifier`] --gcls_content_column_name GCLS_CONTENT_COLUMN_NAME A name of the column containing documents - --gcls_output_lable_column_name GCLS_OUTPUT_LABEL_COLUMN_NAME Column name to store classification results - --gcls_output_score_column_name GCLS_OUTPUT_SCORE_COLUMN_NAME Column name to store the score of prediction + --gcls_output_lable_column_name GCLS_OUTPUT_LABEL_COLUMN_NAME Column names to store classification results, like [`label_quality`,`label_sci`,`label_tech`,`label_edu`,`label_med`] + --gcls_output_score_column_name GCLS_OUTPUT_SCORE_COLUMN_NAME Column names to store the score of prediction, like [`score_quality`,`score_sci`,`score_tech`,`score_edu`,`score_med`] + --gcls_n_processes NUMBER_OF_PROCESSES number of processes, an integer value. Larger value will give a better throughput in compensation for memory consumption ``` These correspond to the configuration keys described above. diff --git a/transforms/language/gneissweb_classification/dpk_gneissweb_classification/local.py b/transforms/language/gneissweb_classification/dpk_gneissweb_classification/local.py index c5de1a4d4a..0ae0c4cc2c 100644 --- a/transforms/language/gneissweb_classification/dpk_gneissweb_classification/local.py +++ b/transforms/language/gneissweb_classification/dpk_gneissweb_classification/local.py @@ -15,10 +15,13 @@ from data_processing.data_access import DataAccessLocal from dpk_gneissweb_classification.transform import ( ClassificationTransform, - content_column_name_key, - model_credential_key, - model_file_name_key, - model_url_key, + content_column_name_cli_param, + model_credential_cli_param, + model_file_name_cli_param, + model_url_cli_param, + n_processes_cli_param, + output_label_column_name_cli_param, + output_score_column_name_cli_param ) @@ -26,11 +29,16 @@ input_folder = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "test-data", "input")) classification_params = { - model_credential_key: "PUT YOUR OWN HUGGINGFACE CREDENTIAL", - model_file_name_key: "model.bin", - model_url_key:"facebook/fasttext-language-identification", - content_column_name_key: "text", + model_credential_cli_param: "PUT YOUR OWN HUGGINGFACE CREDENTIAL", + model_file_name_cli_param: ["['fasttext_medical.bin']"], + model_url_cli_param:["['ibm-granite/GneissWeb.Med_classifier']"], + output_label_column_name_cli_param:["['label_med']"], + output_score_column_name_cli_param:["['score']"], + content_column_name_cli_param: "text", + n_processes_cli_param: 1, } + + if __name__ == "__main__": # Here we show how to run outside of the runtime # Create and configure the transform. diff --git a/transforms/language/gneissweb_classification/dpk_gneissweb_classification/local_python.py b/transforms/language/gneissweb_classification/dpk_gneissweb_classification/local_python.py index bc2845d9ef..e7d262ac03 100644 --- a/transforms/language/gneissweb_classification/dpk_gneissweb_classification/local_python.py +++ b/transforms/language/gneissweb_classification/dpk_gneissweb_classification/local_python.py @@ -20,6 +20,9 @@ model_credential_cli_param, model_file_name_cli_param, model_url_cli_param, + n_processes_cli_param, + output_label_column_name_cli_param, + output_score_column_name_cli_param ) from dpk_gneissweb_classification.transform_python import ClassificationPythonTransformConfiguration @@ -41,9 +44,12 @@ "runtime_code_location": ParamsUtils.convert_to_ast(code_location), # classification params model_credential_cli_param: "PUT YOUR OWN HUGGINGFACE CREDENTIAL", - model_file_name_cli_param: "model.bin", - model_url_cli_param: "facebook/fasttext-language-identification", + model_file_name_cli_param:["fasttext_medical.bin"], + model_url_cli_param: ["ibm-granite/GneissWeb.Med_classifier"], + output_label_column_name_cli_param:["label_med"], + output_score_column_name_cli_param:["score"], content_column_name_cli_param: "text", + n_processes_cli_param: 1, } if __name__ == "__main__": # Set the simulated command line args @@ -52,3 +58,4 @@ launcher = PythonTransformLauncher(runtime_config=ClassificationPythonTransformConfiguration()) # Launch the ray actor(s) to process the input launcher.launch() + diff --git a/transforms/language/gneissweb_classification/dpk_gneissweb_classification/nlp_parallel.py b/transforms/language/gneissweb_classification/dpk_gneissweb_classification/nlp_parallel.py new file mode 100644 index 0000000000..d80b7937df --- /dev/null +++ b/transforms/language/gneissweb_classification/dpk_gneissweb_classification/nlp_parallel.py @@ -0,0 +1,83 @@ +# (C) Copyright IBM Corp. 2024. +# Licensed under the Apache License, Version 2.0 (the “License”); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an “AS IS” BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +from typing import Any +from functools import partial + +import pyarrow as pa +import multiprocessing + +from data_processing.utils import TransformUtils, get_logger +from dpk_gneissweb_classification.classification_models import ClassificationModel +from dpk_gneissweb_classification.classification_models import ClassificationModelFactory + +logger = get_logger(__name__) + +global_model: ClassificationModel = None + +def init_global_model(url: str, file_name: str, credential: str): + global global_model + global_model = ClassificationModelFactory.create_model(url, file_name, credential) + + +def _process(text_list): + return [global_model.detect_label(r) for r in text_list] + + +def split_lists(text_list: list[str] | tuple[str, ...], num_chunks: int) -> list[list[str]]: + num_rows = len(text_list) + chunk_size = num_rows // num_chunks + + chunks: list[list[str]] = [] + i = 0 + while i < num_chunks: + if i == num_chunks - 1: + remainder = num_rows % num_chunks + else: + remainder = 0 + chunk = text_list[i * chunk_size : i * chunk_size + chunk_size + remainder] + chunks.append(list(chunk)) + i += 1 + + return chunks + + +def get_label_ds_pa_parallel( + table: pa.table, + content_column_name: str, + output_label_column_name: str, + output_score_column_name: str, + n_processes: int = 4, + url: str = None, + file_name: str = None, + credential: str = None +) -> tuple[pa.table, dict[str, Any]]: + + table_chunks = split_lists(table[content_column_name].to_pylist(), n_processes) + + with multiprocessing.get_context("spawn").Pool(n_processes, initializer=init_global_model, initargs=(url, file_name, credential)) as p: + pool_results = p.map(_process, table_chunks) + classification_results = [] + for result in pool_results: + classification_results += result + labels, scores = zip(*classification_results) + detected_label = {"label": list(labels), "score": list(scores)} + + stats = pa.table([detected_label["label"]], names=["label"]).group_by("label").aggregate([("label", "count")]) + stats_dict = {} + for batch in stats.to_batches(): + d = batch.to_pydict() + for label, count in zip(d["label"], d["label_count"]): + stats_dict[label] = count + result = TransformUtils.add_column(table=table, name=output_label_column_name, content=detected_label["label"]) + result = TransformUtils.add_column(table=result, name=output_score_column_name, content=detected_label["score"]) + return result, stats_dict diff --git a/transforms/language/gneissweb_classification/dpk_gneissweb_classification/ray/local.py b/transforms/language/gneissweb_classification/dpk_gneissweb_classification/ray/local.py index a77a6bc765..ec9194f6f9 100644 --- a/transforms/language/gneissweb_classification/dpk_gneissweb_classification/ray/local.py +++ b/transforms/language/gneissweb_classification/dpk_gneissweb_classification/ray/local.py @@ -23,6 +23,7 @@ model_url_cli_param, output_label_column_name_cli_param, output_score_column_name_cli_param, + n_processes_cli_param ) @@ -49,11 +50,13 @@ "runtime_code_location": ParamsUtils.convert_to_ast(code_location), # classification params model_credential_cli_param: "PUT YOUR OWN HUGGINGFACE CREDENTIAL", - model_file_name_cli_param: "model.bin", - model_url_cli_param:"facebook/fasttext-language-identification", + model_file_name_cli_param: ["fasttext_medical.bin"], + model_url_cli_param:["ibm-granite/GneissWeb.Med_classifier"], content_column_name_cli_param: "text", - output_label_column_name_cli_param: "ft_label", - output_score_column_name_cli_param: "ft_score", + output_label_column_name_cli_param: ["ft_label"], + output_score_column_name_cli_param: ["ft_score"], + n_processes_cli_param: 1, + } if __name__ == "__main__": # Set the simulated command line args @@ -62,3 +65,5 @@ launcher = RayTransformLauncher(ClassificationRayTransformConfiguration()) # Launch the ray actor(s) to process the input launcher.launch() + + diff --git a/transforms/language/gneissweb_classification/dpk_gneissweb_classification/ray/s3.py b/transforms/language/gneissweb_classification/dpk_gneissweb_classification/ray/s3.py index af91ca4c0b..b96740d9cf 100644 --- a/transforms/language/gneissweb_classification/dpk_gneissweb_classification/ray/s3.py +++ b/transforms/language/gneissweb_classification/dpk_gneissweb_classification/ray/s3.py @@ -23,6 +23,7 @@ model_url_cli_param, output_label_column_name_cli_param, output_score_column_name_cli_param, + n_processes_cli_param ) @@ -59,11 +60,12 @@ "runtime_code_location": ParamsUtils.convert_to_ast(code_location), # classification params model_credential_cli_param: "PUT YOUR OWN HUGGINGFACE CREDENTIAL", - model_file_name_cli_param: "model.bin", - model_url_cli_param:"facebook/fasttext-language-identification", - content_column_name_cli_param: "text", - output_label_column_name_cli_param: "ft_label", - output_score_column_name_cli_param: "ft_score", + model_file_name_cli_param: ["fasttext_medical.bin"], + model_url_cli_param:["ibm-granite/GneissWeb.Med_classifier"], + content_column_name_cli_param: ["text"], + output_label_column_name_cli_param: ["ft_label"], + output_score_column_name_cli_param: ["ft_score"], + n_processes_cli_param: 1, } sys.argv = ParamsUtils.dict_to_req(d=params) # for arg in sys.argv: diff --git a/transforms/language/gneissweb_classification/dpk_gneissweb_classification/transform.py b/transforms/language/gneissweb_classification/dpk_gneissweb_classification/transform.py index 4825d16cdd..0812273cf0 100644 --- a/transforms/language/gneissweb_classification/dpk_gneissweb_classification/transform.py +++ b/transforms/language/gneissweb_classification/dpk_gneissweb_classification/transform.py @@ -15,10 +15,14 @@ import pyarrow as pa +import ast + from data_processing.transform import AbstractTableTransform, TransformConfiguration from data_processing.utils import CLIArgumentProvider, TransformUtils from dpk_gneissweb_classification.classification_models import ClassificationModelFactory, ClassificationModel from dpk_gneissweb_classification.nlp import get_label_ds_pa +from dpk_gneissweb_classification.nlp_parallel import get_label_ds_pa_parallel + short_name = "gcls" @@ -29,16 +33,19 @@ content_column_name_key = "content_column_name" output_label_column_name_key = "output_label_column_name" output_score_column_name_key = "output_score_column_name" +n_processes_key = "n_processes" model_credential_cli_param = f"{cli_prefix}{model_credential_key}" model_file_name_cli_param = f"{cli_prefix}{model_file_name_key}" model_url_cli_param = f"{cli_prefix}{model_url_key}" content_column_name_cli_param = f"{cli_prefix}{content_column_name_key}" output_label_column_name_cli_param = f"{cli_prefix}{output_label_column_name_key}" output_score_column_name_cli_param = f"{cli_prefix}{output_score_column_name_key}" +n_processes_cli_param = f"{cli_prefix}{n_processes_key}" default_content_column_name = "contents" -default_output_label_column_name = "lang" -default_output_score_column_name = "score" +default_output_label_column_name = ["['lang']"] +default_output_score_column_name = ["['score']"] +default_n_processes = 1 class ClassificationTransform(AbstractTableTransform): @@ -61,27 +68,14 @@ def __init__(self, config: dict[str, Any]): # Make sure that the param name corresponds to the name used in apply_input_params method # of ClassificationTransformConfiguration class super().__init__(config) - self.nlp_classfication = self._get_nlp_classfication(config) + + self.model_credential = config.get(model_credential_cli_param) + self.model_file_name = ast.literal_eval(config.get(model_file_name_cli_param)[0]) + self.model_url = ast.literal_eval(config.get(model_url_cli_param)[0]) + self.n_processes = config.get(n_processes_cli_param, default_n_processes) self.content_column_name = config.get(content_column_name_cli_param, default_content_column_name) - self.output_label_column_name = config.get(output_label_column_name_cli_param, default_output_label_column_name) - self.output_score_column_name = config.get(output_score_column_name_cli_param, default_output_score_column_name) - - @staticmethod - def _get_nlp_classfication(config) -> ClassificationModel: - nlp_classfication: ClassificationModel - - model_credential = config.get(model_credential_cli_param) - model_file_name = config.get(model_file_name_cli_param) - model_url = config.get(model_url_cli_param) - - if model_credential is None or len(model_credential) == 0: - raise ValueError("model_credential_cli_param is not specified.") - elif model_file_name is None or len(model_credential) == 0: - raise ValueError("model_file_name_cli_param is not specified.") - else: - nlp_classfication = ClassificationModelFactory.create_model(url=model_url, file_name = model_file_name, credential=model_credential) - - return nlp_classfication + self.output_label_column_name = ast.literal_eval(config.get(output_label_column_name_cli_param, default_output_label_column_name)[0]) + self.output_score_column_name = ast.literal_eval(config.get(output_score_column_name_cli_param, default_output_score_column_name)[0]) def transform(self, table: pa.Table, file_name: str | None = None) -> tuple[list[pa.Table], dict[str, Any]]: # pylint:disable=unused-argument """ @@ -90,21 +84,42 @@ def transform(self, table: pa.Table, file_name: str | None = None) -> tuple[list This implementation makes no modifications so effectively implements a copy of the input parquet to the output folder, without modification. """ - TransformUtils.validate_columns(table, [self.content_column_name]) - if self.output_label_column_name in table.schema.names: - raise Exception(f"column to store label ({self.output_label_column_name}) already exist") - if self.output_score_column_name in table.schema.names: - raise Exception( - f"column to store score of label ({self.output_score_column_name}) already exist" - ) + + for label_column_name, score_column_name in zip(self.output_label_column_name,self.output_score_column_name): + TransformUtils.validate_columns(table, [self.content_column_name]) + if label_column_name in table.schema.names: + raise Exception(f"column to store label ({label_column_name}) already exist") + if score_column_name in table.schema.names: + raise Exception( + f"column to store score of label ({score_column_name}) already exist" + ) self.logger.debug(f"Transforming one table with {len(table)} rows") - table, stats = get_label_ds_pa( - table, - self.nlp_classfication, - self.content_column_name, - self.output_label_column_name, - self.output_score_column_name, - ) + for url, file_name, label_column_name, score_column_name in zip(self.model_url, self.model_file_name,self.output_label_column_name,self.output_score_column_name): + if self.n_processes <= 1: + nlp_classfication = ClassificationModelFactory.create_model(url=url, file_name=file_name, credential=self.model_credential) + else: + # Suppress memory consumption as the main process does not actually use this model when multiprocessing + nlp_classfication = None + if self.n_processes <= 1: + table, stats = get_label_ds_pa( + table, + nlp_classfication, + self.content_column_name, + label_column_name, + score_column_name, + ) + else: + table, stats = get_label_ds_pa_parallel( + table, + self.content_column_name, + label_column_name, + score_column_name, + self.n_processes, + url, + file_name, + self.model_credential, + ) + self.logger.debug(f"Transformed one table with {len(table)} rows") return [table], stats @@ -139,10 +154,17 @@ def add_input_params(self, parser: ArgumentParser) -> None: parser.add_argument( f"--{model_file_name_cli_param}", type=str, + nargs="+", default="", help="filename of model", ) - parser.add_argument(f"--{model_url_cli_param}", help="Url to model") + parser.add_argument( + f"--{model_url_cli_param}", + type=str, + nargs="+", + default="", + help="Url to model" + ) parser.add_argument( f"--{content_column_name_cli_param}", default=default_content_column_name, @@ -151,13 +173,23 @@ def add_input_params(self, parser: ArgumentParser) -> None: parser.add_argument( f"--{output_label_column_name_cli_param}", default=default_output_label_column_name, + type=str, + nargs="+", help="Column name to store label", ) parser.add_argument( f"--{output_score_column_name_cli_param}", default=default_output_score_column_name, + type=str, + nargs="+", help="Column name to store the score", ) + parser.add_argument( + f"--{n_processes_cli_param}", + type=int, + default=default_n_processes, + help="number of processes. Must be a positive integer.", + ) def apply_input_params(self, args: Namespace) -> bool: """ diff --git a/transforms/language/gneissweb_classification/gneissweb_classification-ray.ipynb b/transforms/language/gneissweb_classification/gneissweb_classification-ray.ipynb index a22ebae54f..ab1e8ab3ae 100644 --- a/transforms/language/gneissweb_classification/gneissweb_classification-ray.ipynb +++ b/transforms/language/gneissweb_classification/gneissweb_classification-ray.ipynb @@ -10,12 +10,13 @@ "make venv \n", "source venv/bin/activate \n", "pip install jupyterlab\n", + "venv/bin/jupyter lab\n", "```" ] }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "id": "4c45c3c6-e4d7-4e61-8de6-32d61f2ce695", "metadata": {}, "outputs": [], @@ -23,8 +24,7 @@ "%%capture\n", "## This is here as a reference only\n", "# Users and application developers must use the right tag for the latest from pypi\n", - "%pip install 'data-prep-toolkit[ray]'\n", - "%pip install 'data-prep-toolkit-transforms[gneissweb_classification]'" + "%pip install 'data-prep-toolkit-transforms[ray,gneissweb_classification]'" ] }, { @@ -37,12 +37,13 @@ "##### **** Configure the transform parameters. The set of dictionary keys holding DocIDTransform configuration for values are as follows: \n", "| Configuration Parameters | Default | Description |\n", "|------------|----------|--------------|\n", - "| gcls_model_credential | _unset_ | specifies the credential you use to get model. This will be huggingface token. [Guide to get huggingface token](https://huggingface.co/docs/hub/security-tokens) |\n", - "| gcls_model_file_name | _unset_ | specifies what filename of model you use to get model, like `model.bin` |\n", - "| gcls_model_url | _unset_ | specifies url that model locates. For fasttext, this will be repo nme of the model, like `facebook/fasttext-language-identification` |\n", + "| gcls_model_credential | _unset_ | specifies the credential you use to get modela. This will be huggingface token. [Guide to get huggingface token](https://huggingface.co/docs/hub/security-tokens) |\n", + "| gcls_model_file_name | _unset_ | specifies what filename of models you use to get models, like [`fasttext_medical.bin`] |\n", + "| gcls_model_url | _unset_ | specifies url that models locate. For fasttext, this will be repo name of the models, like [`ibm-granite/GneissWeb.Med_classifier`] |\n", + "| gcls_n_processes | 1 | number of processes. Must be a positive integer |\n", "| gcls_content_column_name | `contents` | specifies name of the column containing documents |\n", - "| gcls_output_label_column_name | `label` | specifies name of the output column to hold predicted classes |\n", - "| gcls_output_score_column_name | `score` | specifies name of the output column to hold score of prediction |" + "| gcls_output_label_column_name | [`label`] | specifies name of the output columns to hold predicted classes |\n", + "| gcls_output_score_column_name | [`score`] | specifies name of the output columns to hold score of prediction |" ] }, { @@ -55,7 +56,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 1, "id": "9669273a-8fcc-4b40-9b20-8df658e2ab58", "metadata": {}, "outputs": [], @@ -73,7 +74,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 2, "id": "badafb96-64d2-4bb8-9f3e-b23713fd5c3f", "metadata": {}, "outputs": [ @@ -81,28 +82,35 @@ "name": "stderr", "output_type": "stream", "text": [ - "09:56:06 INFO - parameters are : {'model_credential': 'PUT YOUR OWN HUGGINGFACE CREDENTIAL', 'model_file_name': 'model.bin', 'model_url': 'facebook/fasttext-language-identification', 'content_column_name': 'text', 'output_label_column_name': 'lang', 'output_score_column_name': 'score'}\n", - "09:56:06 INFO - pipeline id pipeline_id\n", - "09:56:06 INFO - code location None\n", - "09:56:06 INFO - number of workers 1 worker options {'num_cpus': 0.8, 'max_restarts': -1}\n", - "09:56:06 INFO - actor creation delay 0\n", - "09:56:06 INFO - job details {'job category': 'preprocessing', 'job name': 'gcls', 'job type': 'ray', 'job id': 'job_id'}\n", - "09:56:06 INFO - data factory data_ is using local data access: input_folder - test-data/input output_folder - output\n", - "09:56:06 INFO - data factory data_ max_files -1, n_sample -1\n", - "09:56:06 INFO - data factory data_ Not using data sets, checkpointing False, max files -1, random samples -1, files to use ['.parquet'], files to checkpoint ['.parquet']\n", - "09:56:06 INFO - Running locally\n", - "2025-01-27 09:56:08,919\tINFO worker.py:1777 -- Started a local Ray instance. View the dashboard at \u001b[1m\u001b[32m127.0.0.1:8265 \u001b[39m\u001b[22m\n", - "\u001b[36m(orchestrate pid=97043)\u001b[0m 09:56:09 INFO - orchestrator started at 2025-01-27 09:56:09\n", - "\u001b[36m(orchestrate pid=97043)\u001b[0m 09:56:09 INFO - Number of files is 3, source profile {'max_file_size': 0.3023223876953125, 'min_file_size': 0.037346839904785156, 'total_file_size': 0.4433746337890625}\n", - "\u001b[36m(orchestrate pid=97043)\u001b[0m 09:56:09 INFO - Cluster resources: {'cpus': 10, 'gpus': 0, 'memory': 28.60002746619284, 'object_store': 2.0}\n", - "\u001b[36m(orchestrate pid=97043)\u001b[0m 09:56:09 INFO - Number of workers - 1 with {'num_cpus': 0.8, 'max_restarts': -1} each\n", - "\u001b[36m(RayTransformFileProcessor pid=97047)\u001b[0m Warning : `load_model` does not return WordVectorModel or SupervisedModel any more, but a `FastText` object which is very similar.\n", - "\u001b[36m(orchestrate pid=97043)\u001b[0m 09:56:12 INFO - Completed 1 files in 0.004 min\n", - "\u001b[36m(orchestrate pid=97043)\u001b[0m 09:56:12 INFO - Completed 2 files in 0.006 min\n", - "\u001b[36m(orchestrate pid=97043)\u001b[0m 09:56:12 INFO - Completed 2 files (66.667%) in 0.006 min. Waiting for completion\n", - "\u001b[36m(orchestrate pid=97043)\u001b[0m 09:56:12 INFO - Completed processing 3 files in 0.008 min\n", - "\u001b[36m(orchestrate pid=97043)\u001b[0m 09:56:12 INFO - done flushing in 0.001 sec\n", - "09:56:22 INFO - Completed execution in 0.26 min, execution result 0\n" + "10:36:20 INFO - parameters are : {'gcls_model_credential': 'PUT YOUR OWN HUGGINGFACE CREDENTIAL', 'gcls_model_file_name': [\"['fasttext_medical.bin']\"], 'gcls_model_url': [\"['ibm-granite/GneissWeb.Med_classifier']\"], 'gcls_content_column_name': 'text', 'gcls_output_label_column_name': [\"['label']\"], 'gcls_output_score_column_name': [\"['score']\"], 'gcls_n_processes': 2}\n", + "10:36:20 INFO - pipeline id pipeline_id\n", + "10:36:20 INFO - code location None\n", + "10:36:20 INFO - number of workers 1 worker options {'num_cpus': 0.8, 'max_restarts': -1}\n", + "10:36:20 INFO - actor creation delay 0\n", + "10:36:20 INFO - job details {'job category': 'preprocessing', 'job name': 'gcls', 'job type': 'ray', 'job id': 'job_id'}\n", + "10:36:20 INFO - data factory data_ is using local data access: input_folder - test-data/input output_folder - output\n", + "10:36:20 INFO - data factory data_ max_files -1, n_sample -1\n", + "10:36:20 INFO - data factory data_ Not using data sets, checkpointing False, max files -1, random samples -1, files to use ['.parquet'], files to checkpoint ['.parquet']\n", + "10:36:20 INFO - Running locally\n", + "2025-02-21 10:36:22,064\tINFO worker.py:1777 -- Started a local Ray instance. View the dashboard at \u001b[1m\u001b[32m127.0.0.1:8265 \u001b[39m\u001b[22m\n", + "\u001b[36m(orchestrate pid=99531)\u001b[0m 10:36:23 INFO - orchestrator started at 2025-02-21 10:36:23\n", + "\u001b[36m(orchestrate pid=99531)\u001b[0m 10:36:23 INFO - Number of files is 1, source profile {'max_file_size': 0.04273414611816406, 'min_file_size': 0.04273414611816406, 'total_file_size': 0.04273414611816406}\n", + "\u001b[36m(orchestrate pid=99531)\u001b[0m 10:36:23 INFO - Cluster resources: {'cpus': 10, 'gpus': 0, 'memory': 29.46222076471895, 'object_store': 2.0}\n", + "\u001b[36m(orchestrate pid=99531)\u001b[0m 10:36:23 INFO - Number of workers - 1 with {'num_cpus': 0.8, 'max_restarts': -1} each\n", + "\u001b[36m(orchestrate pid=99531)\u001b[0m 10:36:24 INFO - Completed 0 files (0.0%) in 0.0 min. Waiting for completion\n", + "\u001b[36m(RayTransformFileProcessor pid=99535)\u001b[0m Warning : `load_model` does not return WordVectorModel or SupervisedModel any more, but a `FastText` object which is very similar.\n", + "\u001b[36m(RayTransformFileProcessor pid=99535)\u001b[0m Warning : `load_model` does not return WordVectorModel or SupervisedModel any more, but a `FastText` object which is very similar.\n", + "\u001b[36m(orchestrate pid=99531)\u001b[0m 10:36:28 INFO - Completed processing 1 files in 0.073 min\n", + "\u001b[36m(orchestrate pid=99531)\u001b[0m 10:36:28 INFO - done flushing in 0.001 sec\n", + "10:36:38 INFO - Completed execution in 0.308 min, execution result 0\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CPU times: user 134 ms, sys: 115 ms, total: 249 ms\n", + "Wall time: 20 s\n" ] }, { @@ -111,7 +119,7 @@ "0" ] }, - "execution_count": 5, + "execution_count": 2, "metadata": {}, "output_type": "execute_result" } @@ -121,8 +129,10 @@ "Classification(input_folder= \"test-data/input\",\n", " output_folder= \"output\",\n", " gcls_model_credential= \"PUT YOUR OWN HUGGINGFACE CREDENTIAL\",\n", - " gcls_model_file_name= \"model.bin\",\n", - " gcls_model_url= \"facebook/fasttext-language-identification\",\n", + " gcls_model_file_name= [\"fasttext_medical.bin\"],\n", + " gcls_model_url= [\"ibm-granite/GneissWeb.Med_classifier\"],\n", + " gcls_n_processes=2,\n", + " gcls_output_label_column_name=[\"label\"],\n", " run_locally= True,\n", " gcls_content_column_name= \"text\").transform()" ] @@ -137,20 +147,17 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 3, "id": "7276fe84-6512-4605-ab65-747351e13a7c", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "['output/test_03.parquet',\n", - " 'output/test_02.parquet',\n", - " 'output/metadata.json',\n", - " 'output/test_01.parquet']" + "['output/metadata.json', 'output/test_01.parquet']" ] }, - "execution_count": 6, + "execution_count": 3, "metadata": {}, "output_type": "execute_result" } diff --git a/transforms/language/gneissweb_classification/gneissweb_classification.ipynb b/transforms/language/gneissweb_classification/gneissweb_classification.ipynb index 17a5a2e7b0..d9b0017390 100644 --- a/transforms/language/gneissweb_classification/gneissweb_classification.ipynb +++ b/transforms/language/gneissweb_classification/gneissweb_classification.ipynb @@ -10,6 +10,7 @@ "make venv \n", "source venv/bin/activate \n", "pip install jupyterlab\n", + "venv/bin/jupyter lab\n", "```" ] }, @@ -23,8 +24,7 @@ "%%capture\n", "## This is here as a reference only\n", "# Users and application developers must use the right tag for the latest from pypi\n", - "%pip install data-prep-toolkit\n", - "%pip install 'data-prep-toolkit-transforms[gneissweb_classificationo]'\n", + "%pip install 'data-prep-toolkit-transforms[gneissweb_classification]'\n", "%pip install pandas" ] }, @@ -39,11 +39,12 @@ "| Configuration Parameters | Default | Description |\n", "|------------|----------|--------------|\n", "| gcls_model_credential | _unset_ | specifies the credential you use to get model. This will be huggingface token. [Guide to get huggingface token](https://huggingface.co/docs/hub/security-tokens) |\n", - "| gcls_model_file_name | _unset_ | specifies what filename of model you use to get model, like `model.bin` |\n", - "| gcls_model_url | _unset_ | specifies url that model locates. For fasttext, this will be repo nme of the model, like `facebook/fasttext-language-identification` |\n", + "| gcls_model_file_name | _unset_ | specifies what filename of models you use to get models, like [`fasttext_science.bin`] |\n", + "| gcls_model_url | _unset_ | specifies url that models locate. For fasttext, this will be repo name of the models, like [`ibm-granite/GneissWeb.Sci_classifier`] |\n", + "| gcls_n_processes | 1 | number of processes. Must be a positive integer |\n", "| gcls_content_column_name | `contents` | specifies name of the column containing documents |\n", - "| gcls_output_label_column_name | `label` | specifies name of the output column to hold predicted classes |\n", - "| gcls_output_score_column_name | `score` | specifies name of the output column to hold score of prediction |" + "| gcls_output_label_column_name | [`label`] | specifies name of the output columns to hold predicted classes |\n", + "| gcls_output_score_column_name | [`score`] | specifies name of the output columns to hold score of prediction |" ] }, { @@ -74,7 +75,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 3, "id": "badafb96-64d2-4bb8-9f3e-b23713fd5c3f", "metadata": {}, "outputs": [ @@ -82,21 +83,30 @@ "name": "stderr", "output_type": "stream", "text": [ - "09:52:55 INFO - parameters are : {'model_credential': 'PUT YOUR OWN HUGGINGFACE CREDENTIAL', 'model_file_name': 'model.bin', 'model_url': 'facebook/fasttext-language-identification', 'content_column_name': 'text', 'output_label_column_name': 'lang', 'output_score_column_name': 'score'}\n", - "09:52:55 INFO - pipeline id pipeline_id\n", - "09:52:55 INFO - code location None\n", - "09:52:55 INFO - data factory data_ is using local data access: input_folder - test-data/input output_folder - output\n", - "09:52:55 INFO - data factory data_ max_files -1, n_sample -1\n", - "09:52:55 INFO - data factory data_ Not using data sets, checkpointing False, max files -1, random samples -1, files to use ['.parquet'], files to checkpoint ['.parquet']\n", - "09:52:55 INFO - orchestrator gcls started at 2025-01-27 09:52:55\n", - "09:52:55 INFO - Number of files is 3, source profile {'max_file_size': 0.3023223876953125, 'min_file_size': 0.037346839904785156, 'total_file_size': 0.4433746337890625}\n", + "10:36:03 INFO - parameters are : {'gcls_model_credential': 'PUT YOUR OWN HUGGINGFACE CREDENTIAL', 'gcls_model_file_name': [\"['fasttext_gneissweb_quality_annotator.bin', 'fasttext_medical.bin']\"], 'gcls_model_url': [\"['ibm-granite/GneissWeb.Quality_annotator', 'ibm-granite/GneissWeb.Med_classifier']\"], 'gcls_content_column_name': 'text', 'gcls_output_label_column_name': [\"['label_quality', 'label_med']\"], 'gcls_output_score_column_name': [\"['score_quality', 'score_med']\"], 'gcls_n_processes': 2}\n", + "10:36:03 INFO - pipeline id pipeline_id\n", + "10:36:03 INFO - code location None\n", + "10:36:03 INFO - data factory data_ is using local data access: input_folder - test-data/input output_folder - output\n", + "10:36:03 INFO - data factory data_ max_files -1, n_sample -1\n", + "10:36:03 INFO - data factory data_ Not using data sets, checkpointing False, max files -1, random samples -1, files to use ['.parquet'], files to checkpoint ['.parquet']\n", + "10:36:03 INFO - orchestrator gcls started at 2025-02-21 10:36:03\n", + "10:36:03 INFO - Number of files is 1, source profile {'max_file_size': 0.04273414611816406, 'min_file_size': 0.04273414611816406, 'total_file_size': 0.04273414611816406}\n", + "Warning : `load_model` does not return WordVectorModel or SupervisedModel any more, but a `FastText` object which is very similar.\n", + "Warning : `load_model` does not return WordVectorModel or SupervisedModel any more, but a `FastText` object which is very similar.\n", + "Warning : `load_model` does not return WordVectorModel or SupervisedModel any more, but a `FastText` object which is very similar.\n", "Warning : `load_model` does not return WordVectorModel or SupervisedModel any more, but a `FastText` object which is very similar.\n", - "09:52:57 INFO - Completed 1 files (33.33%) in 0.01 min\n", - "09:52:57 INFO - Completed 2 files (66.67%) in 0.011 min\n", - "09:52:57 INFO - Completed 3 files (100.0%) in 0.014 min\n", - "09:52:57 INFO - Done processing 3 files, waiting for flush() completion.\n", - "09:52:57 INFO - done flushing in 0.0 sec\n", - "09:52:57 INFO - Completed execution in 0.029 min, execution result 0\n" + "10:36:16 INFO - Completed 1 files (100.0%) in 0.22 min\n", + "10:36:16 INFO - Done processing 1 files, waiting for flush() completion.\n", + "10:36:16 INFO - done flushing in 0.0 sec\n", + "10:36:16 INFO - Completed execution in 0.22 min, execution result 0\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CPU times: user 163 ms, sys: 68.1 ms, total: 231 ms\n", + "Wall time: 13.2 s\n" ] }, { @@ -105,7 +115,7 @@ "0" ] }, - "execution_count": 4, + "execution_count": 3, "metadata": {}, "output_type": "execute_result" } @@ -115,8 +125,11 @@ "Classification(input_folder= \"test-data/input\",\n", " output_folder= \"output\",\n", " gcls_model_credential= \"PUT YOUR OWN HUGGINGFACE CREDENTIAL\",\n", - " gcls_model_file_name= \"model.bin\",\n", - " gcls_model_url= \"facebook/fasttext-language-identification\",\n", + " gcls_model_file_name= [\"fasttext_gneissweb_quality_annotator.bin\",\"fasttext_medical.bin\"],\n", + " gcls_model_url= [\"ibm-granite/GneissWeb.Quality_annotator\",\"ibm-granite/GneissWeb.Med_classifier\"],\n", + " gcls_n_processes=2,\n", + " gcls_output_label_column_name=[\"label_quality\",\"label_med\"],\n", + " gcls_output_score_column_name=[\"score_quality\",\"score_med\"],\n", " gcls_content_column_name= \"text\").transform()" ] }, @@ -130,20 +143,17 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 4, "id": "7276fe84-6512-4605-ab65-747351e13a7c", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "['output/test_03.parquet',\n", - " 'output/test_02.parquet',\n", - " 'output/metadata.json',\n", - " 'output/test_01.parquet']" + "['output/metadata.json', 'output/test_01.parquet']" ] }, - "execution_count": 5, + "execution_count": 4, "metadata": {}, "output_type": "execute_result" } @@ -155,7 +165,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 5, "id": "845a75cf-f4a9-467d-87fa-ccbac1c9beb8", "metadata": {}, "outputs": [ @@ -181,125 +191,535 @@ " \n", " \n", " text\n", - " count()\n", - " lang\n", - " score\n", + " id\n", + " dump\n", + " url\n", + " date\n", + " file_path\n", + " language\n", + " language_score\n", + " token_count\n", + " watsonnlp_top_category0\n", + " ...\n", + " avg_grade_level\n", + " mcalpine_eflaw_textstat\n", + " dclm_fasttext_label\n", + " dclm_fasttext_score\n", + " cosmo_10k_edu_fasttext_label\n", + " cosmo_10k_edu_fasttext_score\n", + " label_quality\n", + " score_quality\n", + " label_med\n", + " score_med\n", " \n", " \n", " \n", " \n", " 0\n", - " - Notice of name-email change.doc\n", - " 6\n", + " A staffer sells cars via livestream at a deale...\n", + " <urn:uuid:567e2e87-397a-4119-93e9-d72d59b61f90>\n", + " CC-MAIN-2023-14\n", + " https://peoplesdaily.pdnews.cn/business/vehicl...\n", + " 2023-03-27T23:11:21Z\n", + " s3://commoncrawl/crawl-data/CC-MAIN-2023-14/se...\n", " en\n", - " 0.858\n", + " 0.967074\n", + " 1239\n", + " automotive\n", + " ...\n", + " 9.436667\n", + " 22.1\n", + " cc\n", + " 0.002249\n", + " cc\n", + " 0.012263\n", + " cc\n", + " 0.987\n", + " cc\n", + " 0.994\n", " \n", " \n", " 1\n", - " - Nov13ENAOnly.doc\n", - " 2\n", - " de\n", - " 0.264\n", + " The May 1st submission deadline may feel like ...\n", + " <urn:uuid:3330ddd2-9c19-4da4-8feb-c41d0ba3b65f>\n", + " CC-MAIN-2023-14\n", + " https://performancein.com/news/2019/01/29/all-...\n", + " 2023-03-27T23:08:19Z\n", + " s3://commoncrawl/crawl-data/CC-MAIN-2023-14/se...\n", + " en\n", + " 0.944369\n", + " 418\n", + " news and politics\n", + " ...\n", + " 11.670000\n", + " 30.3\n", + " cc\n", + " 0.000050\n", + " cc\n", + " 0.000067\n", + " cc\n", + " 0.999\n", + " cc\n", + " 0.997\n", " \n", " \n", " 2\n", - " - OHIO_C~1.XLS\n", - " 2\n", - " de\n", - " 0.603\n", + " Yes! Cinnamon Oil is a great way to deter mice...\n", + " <urn:uuid:e8d2ac4f-cde2-4c45-afd9-3a19cfb86d4c>\n", + " CC-MAIN-2023-14\n", + " https://peskylittlecritters.com/does-cinnamon-...\n", + " 2023-03-27T23:15:19Z\n", + " s3://commoncrawl/crawl-data/CC-MAIN-2023-14/se...\n", + " en\n", + " 0.906198\n", + " 490\n", + " food & drink\n", + " ...\n", + " 8.980000\n", + " 26.2\n", + " cc\n", + " 0.009224\n", + " cc\n", + " 0.021643\n", + " cc\n", + " 0.978\n", + " cc\n", + " 0.844\n", " \n", " \n", " 3\n", - " - Oneok(5-30)final.doc\n", - " 1\n", - " vi\n", - " 0.152\n", + " Rosemary Oil can be used to deter cockroaches....\n", + " <urn:uuid:bd5c2a03-9a9b-43e2-872a-f7123213bea9>\n", + " CC-MAIN-2023-14\n", + " https://peskylittlecritters.com/does-rosemary-...\n", + " 2023-03-27T23:18:25Z\n", + " s3://commoncrawl/crawl-data/CC-MAIN-2023-14/se...\n", + " en\n", + " 0.916242\n", + " 513\n", + " science\n", + " ...\n", + " 9.370000\n", + " 23.8\n", + " cc\n", + " 0.007073\n", + " cc\n", + " 0.005885\n", + " cc\n", + " 0.994\n", + " cc\n", + " 0.876\n", " \n", " \n", " 4\n", - " - OpeningBrief.doc\n", - " 6\n", - " ko-Hang\n", - " 0.365\n", + " A cat might have discovered an insect crawling...\n", + " <urn:uuid:1922dc93-9fb8-4775-b147-88c589a7bd65>\n", + " CC-MAIN-2023-14\n", + " https://petcatty.com/why-does-my-cat-stare-at-...\n", + " 2023-03-27T23:28:27Z\n", + " s3://commoncrawl/crawl-data/CC-MAIN-2023-14/se...\n", + " en\n", + " 0.967236\n", + " 1172\n", + " pets\n", + " ...\n", + " 6.396667\n", + " 20.6\n", + " hq\n", + " 0.960727\n", + " hq\n", + " 0.881134\n", + " hq\n", + " 0.881\n", + " cc\n", + " 0.974\n", " \n", " \n", - " ...\n", + " 5\n", + " A staffer sells cars via livestream at a deale...\n", + " <urn:uuid:567e2e87-397a-4119-93e9-d72d59b61f90>\n", + " CC-MAIN-2023-14\n", + " https://peoplesdaily.pdnews.cn/business/vehicl...\n", + " 2023-03-27T23:11:21Z\n", + " s3://commoncrawl/crawl-data/CC-MAIN-2023-14/se...\n", + " en\n", + " 0.967074\n", + " 1239\n", + " automotive\n", " ...\n", + " 9.436667\n", + " 22.1\n", + " cc\n", + " 0.002249\n", + " cc\n", + " 0.012263\n", + " cc\n", + " 0.987\n", + " cc\n", + " 0.994\n", + " \n", + " \n", + " 6\n", + " The May 1st submission deadline may feel like ...\n", + " <urn:uuid:3330ddd2-9c19-4da4-8feb-c41d0ba3b65f>\n", + " CC-MAIN-2023-14\n", + " https://performancein.com/news/2019/01/29/all-...\n", + " 2023-03-27T23:08:19Z\n", + " s3://commoncrawl/crawl-data/CC-MAIN-2023-14/se...\n", + " en\n", + " 0.944369\n", + " 418\n", + " news and politics\n", + " ...\n", + " 11.670000\n", + " 30.3\n", + " cc\n", + " 0.000050\n", + " cc\n", + " 0.000067\n", + " cc\n", + " 0.999\n", + " cc\n", + " 0.997\n", + " \n", + " \n", + " 7\n", + " Yes! Cinnamon Oil is a great way to deter mice...\n", + " <urn:uuid:e8d2ac4f-cde2-4c45-afd9-3a19cfb86d4c>\n", + " CC-MAIN-2023-14\n", + " https://peskylittlecritters.com/does-cinnamon-...\n", + " 2023-03-27T23:15:19Z\n", + " s3://commoncrawl/crawl-data/CC-MAIN-2023-14/se...\n", + " en\n", + " 0.906198\n", + " 490\n", + " food & drink\n", " ...\n", + " 8.980000\n", + " 26.2\n", + " cc\n", + " 0.009224\n", + " cc\n", + " 0.021643\n", + " cc\n", + " 0.978\n", + " cc\n", + " 0.844\n", + " \n", + " \n", + " 8\n", + " Rosemary Oil can be used to deter cockroaches....\n", + " <urn:uuid:bd5c2a03-9a9b-43e2-872a-f7123213bea9>\n", + " CC-MAIN-2023-14\n", + " https://peskylittlecritters.com/does-rosemary-...\n", + " 2023-03-27T23:18:25Z\n", + " s3://commoncrawl/crawl-data/CC-MAIN-2023-14/se...\n", + " en\n", + " 0.916242\n", + " 513\n", + " science\n", " ...\n", + " 9.370000\n", + " 23.8\n", + " cc\n", + " 0.007073\n", + " cc\n", + " 0.005885\n", + " cc\n", + " 0.994\n", + " cc\n", + " 0.876\n", + " \n", + " \n", + " 9\n", + " A cat might have discovered an insect crawling...\n", + " <urn:uuid:1922dc93-9fb8-4775-b147-88c589a7bd65>\n", + " CC-MAIN-2023-14\n", + " https://petcatty.com/why-does-my-cat-stare-at-...\n", + " 2023-03-27T23:28:27Z\n", + " s3://commoncrawl/crawl-data/CC-MAIN-2023-14/se...\n", + " en\n", + " 0.967236\n", + " 1172\n", + " pets\n", " ...\n", + " 6.396667\n", + " 20.6\n", + " hq\n", + " 0.960727\n", + " hq\n", + " 0.881134\n", + " hq\n", + " 0.881\n", + " cc\n", + " 0.974\n", " \n", " \n", - " 195\n", - " - invite.doc\n", - " 2\n", - " ro\n", - " 0.717\n", + " 10\n", + " Ham came to the Kennebec Valley Humane Society...\n", + " <urn:uuid:a7e185ac-84fb-4059-9a2c-36b914368a46>\n", + " CC-MAIN-2023-14\n", + " https://pethavenlane.org/hanks-hams-story/\n", + " 2023-03-27T23:07:21Z\n", + " s3://commoncrawl/crawl-data/CC-MAIN-2023-14/se...\n", + " en\n", + " 0.976228\n", + " 400\n", + " pets\n", + " ...\n", + " 9.170000\n", + " 26.6\n", + " cc\n", + " 0.001376\n", + " cc\n", + " 0.056280\n", + " cc\n", + " 0.943\n", + " cc\n", + " 0.762\n", " \n", " \n", - " 196\n", - " - issues wrt portland and calgary signing shor...\n", - " 2\n", + " 11\n", + " In this post, I told you I was making a dress ...\n", + " <urn:uuid:6ad36d37-6e01-4313-97f8-e38615928efe>\n", + " CC-MAIN-2023-14\n", + " https://petitmainsauvage.blogspot.com/2010/04/...\n", + " 2023-03-27T22:21:51Z\n", + " s3://commoncrawl/crawl-data/CC-MAIN-2023-14/se...\n", " en\n", + " 0.973702\n", + " 818\n", + " style & fashion\n", + " ...\n", + " 7.690000\n", + " 23.3\n", + " cc\n", + " 0.006613\n", + " cc\n", + " 0.004423\n", + " cc\n", + " 0.995\n", + " cc\n", " 0.997\n", " \n", " \n", - " 197\n", - " - jan3102.XLS\n", - " 2\n", - " de\n", - " 0.399\n", + " 12\n", + " Fitted with new strimmer spool. 2 x Minor crac...\n", + " <urn:uuid:f286d586-4f92-444a-9e0b-b35244f6e03b>\n", + " CC-MAIN-2023-14\n", + " https://petrolbrushcutter.com/en/makita_em4340...\n", + " 2023-03-27T23:13:34Z\n", + " s3://commoncrawl/crawl-data/CC-MAIN-2023-14/se...\n", + " en\n", + " 0.922969\n", + " 143\n", + " business and finance\n", + " ...\n", + " 6.030000\n", + " 14.4\n", + " cc\n", + " -0.000007\n", + " cc\n", + " -0.000009\n", + " cc\n", + " 1.000\n", + " cc\n", + " 0.999\n", " \n", " \n", - " 198\n", - " - job market.gif\n", - " 2\n", + " 13\n", + " Who are Amerpetrelocator.com?http://Amerpetrel...\n", + " <urn:uuid:9d03b16f-40e8-418e-996e-1db33cc175aa>\n", + " CC-MAIN-2023-14\n", + " https://petscams.com/pet-delivery-scam/amerpet...\n", + " 2023-03-27T23:39:39Z\n", + " s3://commoncrawl/crawl-data/CC-MAIN-2023-14/se...\n", " en\n", - " 0.791\n", + " 0.934744\n", + " 577\n", + " news and politics\n", + " ...\n", + " 10.726667\n", + " 20.0\n", + " cc\n", + " 0.000184\n", + " cc\n", + " 0.001286\n", + " cc\n", + " 0.998\n", + " cc\n", + " 0.990\n", " \n", " \n", - " 199\n", - " - kick~1.mpe\n", - " 4\n", - " eo\n", - " 0.253\n", + " 14\n", + " Who are Mainecoonkittens4rehoming.com?http://M...\n", + " <urn:uuid:cbbc566e-3387-4a3a-bf24-1c6c334bdebb>\n", + " CC-MAIN-2023-14\n", + " https://petscams.com/puppy-scammer-list/mainec...\n", + " 2023-03-27T22:57:03Z\n", + " s3://commoncrawl/crawl-data/CC-MAIN-2023-14/se...\n", + " en\n", + " 0.901691\n", + " 635\n", + " pets\n", + " ...\n", + " 10.990000\n", + " 22.4\n", + " cc\n", + " 0.005841\n", + " cc\n", + " 0.000099\n", + " cc\n", + " 0.999\n", + " cc\n", + " 0.988\n", " \n", " \n", "\n", - "

200 rows × 4 columns

\n", + "

15 rows × 28 columns

\n", "" ], "text/plain": [ - " text count() lang \\\n", - "0 - Notice of name-email change.doc 6 en \n", - "1 - Nov13ENAOnly.doc 2 de \n", - "2 - OHIO_C~1.XLS 2 de \n", - "3 - Oneok(5-30)final.doc 1 vi \n", - "4 - OpeningBrief.doc 6 ko-Hang \n", - ".. ... ... ... \n", - "195 - invite.doc 2 ro \n", - "196 - issues wrt portland and calgary signing shor... 2 en \n", - "197 - jan3102.XLS 2 de \n", - "198 - job market.gif 2 en \n", - "199 - kick~1.mpe 4 eo \n", + " text \\\n", + "0 A staffer sells cars via livestream at a deale... \n", + "1 The May 1st submission deadline may feel like ... \n", + "2 Yes! Cinnamon Oil is a great way to deter mice... \n", + "3 Rosemary Oil can be used to deter cockroaches.... \n", + "4 A cat might have discovered an insect crawling... \n", + "5 A staffer sells cars via livestream at a deale... \n", + "6 The May 1st submission deadline may feel like ... \n", + "7 Yes! Cinnamon Oil is a great way to deter mice... \n", + "8 Rosemary Oil can be used to deter cockroaches.... \n", + "9 A cat might have discovered an insect crawling... \n", + "10 Ham came to the Kennebec Valley Humane Society... \n", + "11 In this post, I told you I was making a dress ... \n", + "12 Fitted with new strimmer spool. 2 x Minor crac... \n", + "13 Who are Amerpetrelocator.com?http://Amerpetrel... \n", + "14 Who are Mainecoonkittens4rehoming.com?http://M... \n", "\n", - " score \n", - "0 0.858 \n", - "1 0.264 \n", - "2 0.603 \n", - "3 0.152 \n", - "4 0.365 \n", - ".. ... \n", - "195 0.717 \n", - "196 0.997 \n", - "197 0.399 \n", - "198 0.791 \n", - "199 0.253 \n", + " id dump \\\n", + "0 CC-MAIN-2023-14 \n", + "1 CC-MAIN-2023-14 \n", + "2 CC-MAIN-2023-14 \n", + "3 CC-MAIN-2023-14 \n", + "4 CC-MAIN-2023-14 \n", + "5 CC-MAIN-2023-14 \n", + "6 CC-MAIN-2023-14 \n", + "7 CC-MAIN-2023-14 \n", + "8 CC-MAIN-2023-14 \n", + "9 CC-MAIN-2023-14 \n", + "10 CC-MAIN-2023-14 \n", + "11 CC-MAIN-2023-14 \n", + "12 CC-MAIN-2023-14 \n", + "13 CC-MAIN-2023-14 \n", + "14 CC-MAIN-2023-14 \n", "\n", - "[200 rows x 4 columns]" + " url date \\\n", + "0 https://peoplesdaily.pdnews.cn/business/vehicl... 2023-03-27T23:11:21Z \n", + "1 https://performancein.com/news/2019/01/29/all-... 2023-03-27T23:08:19Z \n", + "2 https://peskylittlecritters.com/does-cinnamon-... 2023-03-27T23:15:19Z \n", + "3 https://peskylittlecritters.com/does-rosemary-... 2023-03-27T23:18:25Z \n", + "4 https://petcatty.com/why-does-my-cat-stare-at-... 2023-03-27T23:28:27Z \n", + "5 https://peoplesdaily.pdnews.cn/business/vehicl... 2023-03-27T23:11:21Z \n", + "6 https://performancein.com/news/2019/01/29/all-... 2023-03-27T23:08:19Z \n", + "7 https://peskylittlecritters.com/does-cinnamon-... 2023-03-27T23:15:19Z \n", + "8 https://peskylittlecritters.com/does-rosemary-... 2023-03-27T23:18:25Z \n", + "9 https://petcatty.com/why-does-my-cat-stare-at-... 2023-03-27T23:28:27Z \n", + "10 https://pethavenlane.org/hanks-hams-story/ 2023-03-27T23:07:21Z \n", + "11 https://petitmainsauvage.blogspot.com/2010/04/... 2023-03-27T22:21:51Z \n", + "12 https://petrolbrushcutter.com/en/makita_em4340... 2023-03-27T23:13:34Z \n", + "13 https://petscams.com/pet-delivery-scam/amerpet... 2023-03-27T23:39:39Z \n", + "14 https://petscams.com/puppy-scammer-list/mainec... 2023-03-27T22:57:03Z \n", + "\n", + " file_path language \\\n", + "0 s3://commoncrawl/crawl-data/CC-MAIN-2023-14/se... en \n", + "1 s3://commoncrawl/crawl-data/CC-MAIN-2023-14/se... en \n", + "2 s3://commoncrawl/crawl-data/CC-MAIN-2023-14/se... en \n", + "3 s3://commoncrawl/crawl-data/CC-MAIN-2023-14/se... en \n", + "4 s3://commoncrawl/crawl-data/CC-MAIN-2023-14/se... en \n", + "5 s3://commoncrawl/crawl-data/CC-MAIN-2023-14/se... en \n", + "6 s3://commoncrawl/crawl-data/CC-MAIN-2023-14/se... en \n", + "7 s3://commoncrawl/crawl-data/CC-MAIN-2023-14/se... en \n", + "8 s3://commoncrawl/crawl-data/CC-MAIN-2023-14/se... en \n", + "9 s3://commoncrawl/crawl-data/CC-MAIN-2023-14/se... en \n", + "10 s3://commoncrawl/crawl-data/CC-MAIN-2023-14/se... en \n", + "11 s3://commoncrawl/crawl-data/CC-MAIN-2023-14/se... en \n", + "12 s3://commoncrawl/crawl-data/CC-MAIN-2023-14/se... en \n", + "13 s3://commoncrawl/crawl-data/CC-MAIN-2023-14/se... en \n", + "14 s3://commoncrawl/crawl-data/CC-MAIN-2023-14/se... en \n", + "\n", + " language_score token_count watsonnlp_top_category0 ... avg_grade_level \\\n", + "0 0.967074 1239 automotive ... 9.436667 \n", + "1 0.944369 418 news and politics ... 11.670000 \n", + "2 0.906198 490 food & drink ... 8.980000 \n", + "3 0.916242 513 science ... 9.370000 \n", + "4 0.967236 1172 pets ... 6.396667 \n", + "5 0.967074 1239 automotive ... 9.436667 \n", + "6 0.944369 418 news and politics ... 11.670000 \n", + "7 0.906198 490 food & drink ... 8.980000 \n", + "8 0.916242 513 science ... 9.370000 \n", + "9 0.967236 1172 pets ... 6.396667 \n", + "10 0.976228 400 pets ... 9.170000 \n", + "11 0.973702 818 style & fashion ... 7.690000 \n", + "12 0.922969 143 business and finance ... 6.030000 \n", + "13 0.934744 577 news and politics ... 10.726667 \n", + "14 0.901691 635 pets ... 10.990000 \n", + "\n", + " mcalpine_eflaw_textstat dclm_fasttext_label dclm_fasttext_score \\\n", + "0 22.1 cc 0.002249 \n", + "1 30.3 cc 0.000050 \n", + "2 26.2 cc 0.009224 \n", + "3 23.8 cc 0.007073 \n", + "4 20.6 hq 0.960727 \n", + "5 22.1 cc 0.002249 \n", + "6 30.3 cc 0.000050 \n", + "7 26.2 cc 0.009224 \n", + "8 23.8 cc 0.007073 \n", + "9 20.6 hq 0.960727 \n", + "10 26.6 cc 0.001376 \n", + "11 23.3 cc 0.006613 \n", + "12 14.4 cc -0.000007 \n", + "13 20.0 cc 0.000184 \n", + "14 22.4 cc 0.005841 \n", + "\n", + " cosmo_10k_edu_fasttext_label cosmo_10k_edu_fasttext_score label_quality \\\n", + "0 cc 0.012263 cc \n", + "1 cc 0.000067 cc \n", + "2 cc 0.021643 cc \n", + "3 cc 0.005885 cc \n", + "4 hq 0.881134 hq \n", + "5 cc 0.012263 cc \n", + "6 cc 0.000067 cc \n", + "7 cc 0.021643 cc \n", + "8 cc 0.005885 cc \n", + "9 hq 0.881134 hq \n", + "10 cc 0.056280 cc \n", + "11 cc 0.004423 cc \n", + "12 cc -0.000009 cc \n", + "13 cc 0.001286 cc \n", + "14 cc 0.000099 cc \n", + "\n", + " score_quality label_med score_med \n", + "0 0.987 cc 0.994 \n", + "1 0.999 cc 0.997 \n", + "2 0.978 cc 0.844 \n", + "3 0.994 cc 0.876 \n", + "4 0.881 cc 0.974 \n", + "5 0.987 cc 0.994 \n", + "6 0.999 cc 0.997 \n", + "7 0.978 cc 0.844 \n", + "8 0.994 cc 0.876 \n", + "9 0.881 cc 0.974 \n", + "10 0.943 cc 0.762 \n", + "11 0.995 cc 0.997 \n", + "12 1.000 cc 0.999 \n", + "13 0.998 cc 0.990 \n", + "14 0.999 cc 0.988 \n", + "\n", + "[15 rows x 28 columns]" ] }, - "execution_count": 8, + "execution_count": 5, "metadata": {}, "output_type": "execute_result" } diff --git a/transforms/language/gneissweb_classification/test-data/expected/metadata.json b/transforms/language/gneissweb_classification/test-data/expected/metadata.json index 04bc26ccec..73919d2bd9 100644 --- a/transforms/language/gneissweb_classification/test-data/expected/metadata.json +++ b/transforms/language/gneissweb_classification/test-data/expected/metadata.json @@ -1,84 +1,66 @@ { - "pipeline": "pipeline_id", - "job details": { - "job category": "preprocessing", - "job name": "LangIdentification", - "job type": "ray", - "job id": "job_id", - "start_time": "2024-05-28 18:53:20", - "end_time": "2024-05-28 18:53:23", - "status": "success" - }, - "code": null, - "job_input_params": { - "lang_id_model_kind": "fasttext", - "lang_id_model_url": "facebook/fasttext-language-identification", - "lang_id_content_column_name": "text", - "checkpointing": false, - "max_files": -1, - "random_samples": -1, - "files_to_use": [".parquet"], - "number of workers": 1, - "worker options": { - "num_cpus": 0.8 - }, - "actor creation delay": 0 - }, - "execution_stats": { - "cpus": 4, - "gpus": 0, - "memory": 3.495941162109375, - "object_store": 1.7479705810546875 - }, - "job_output_stats": { - "source_files": 3, - "source_size": 464912, - "result_files": 3, - "result_size": 470832, - "processing_time": 0.44434642791748047, - "en": 357, - "de": 38, - "vi": 16, - "ko-Hang": 49, - "lb": 2, - "ca": 5, - "rm": 2, - "lt": 1, - "yue-Hant": 58, - "hu": 1, - "sv": 3, - "it": 2, - "vec-Latn": 2, - "azb-Arab": 7, - "tr": 1, - "fr": 4, - "ro": 4, - "pl": 12, - "cs": 7, - "es": 3, - "ast-Latn": 4, - "eo": 2, - "oc-Latn": 3, - "lmo-Latn": 1, - "da": 1, - "eu": 1, - "nl": 4, - "source_doc_count": 600, - "result_doc_count": 600, - "sk": 1, - "lvs-Latn": 1, - "li-Latn": 1, - "nn": 1, - "bo-Tibt": 4, - "af": 1, - "nb": 1 - }, - "source": { - "name": "/home/kind/data-prep-kit-inner/transforms/language/language_id/test-data/input", - "type": "path" - }, - "target": { - "name": "/tmp/LangIdentificationp6jsp6zh", - "type": "path" - } -} + "pipeline": "pipeline_id", + "job details": { + "job category": "preprocessing", + "job name": "gcls", + "job type": "pure python", + "job id": "job_id", + "start_time": "2025-02-17 11:05:56", + "end_time": "2025-02-17 11:05:59", + "status": "success" + }, + "code": { + "github": "github", + "commit_hash": "12345", + "path": "path" + }, + "job_input_params": { + "gcls_model_credential": "PUT YOUR OWN HUGGINGFACE CREDENTIAL", + "gcls_model_file_name": [ + "['fasttext_medical.bin']" + ], + "gcls_model_url": [ + "['ibm-granite/GneissWeb.Med_classifier']" + ], + "gcls_content_column_name": "text", + "gcls_output_label_column_name": [ + "['label_med']" + ], + "gcls_output_score_column_name": [ + "['score']" + ], + "gcls_n_processes": 1, + "checkpointing": false, + "max_files": -1, + "random_samples": -1, + "files_to_use": [ + ".parquet" + ], + "num_processors": 0 + }, + "execution_stats": { + "cpus": 35.6, + "gpus": 0, + "memory": 26.88, + "object_store": 0, + "execution time, min": 0.049 + }, + "job_output_stats": { + "source_files": 1, + "source_size": 44810, + "result_files": 1, + "result_size": 40631, + "processing_time": 2.938, + "cc": 15, + "source_doc_count": 15, + "result_doc_count": 15 + }, + "source": { + "name": "/home/kind/data-prep-kit/transforms/language/gneissweb_classification/test-data/input", + "type": "path" + }, + "target": { + "name": "/tmp/gneissweb_classification/output", + "type": "path" + } +} \ No newline at end of file diff --git a/transforms/language/gneissweb_classification/test-data/expected/test_01.parquet b/transforms/language/gneissweb_classification/test-data/expected/test_01.parquet index 6006954576..46eb4bd43d 100644 Binary files a/transforms/language/gneissweb_classification/test-data/expected/test_01.parquet and b/transforms/language/gneissweb_classification/test-data/expected/test_01.parquet differ diff --git a/transforms/language/gneissweb_classification/test-data/expected/test_02.parquet b/transforms/language/gneissweb_classification/test-data/expected/test_02.parquet deleted file mode 100644 index 710a508483..0000000000 Binary files a/transforms/language/gneissweb_classification/test-data/expected/test_02.parquet and /dev/null differ diff --git a/transforms/language/gneissweb_classification/test-data/expected/test_03.parquet b/transforms/language/gneissweb_classification/test-data/expected/test_03.parquet deleted file mode 100644 index 0942231a04..0000000000 Binary files a/transforms/language/gneissweb_classification/test-data/expected/test_03.parquet and /dev/null differ diff --git a/transforms/language/gneissweb_classification/test-data/input/test_01.parquet b/transforms/language/gneissweb_classification/test-data/input/test_01.parquet index ea7714a375..450a4ac5eb 100644 Binary files a/transforms/language/gneissweb_classification/test-data/input/test_01.parquet and b/transforms/language/gneissweb_classification/test-data/input/test_01.parquet differ diff --git a/transforms/language/gneissweb_classification/test-data/input/test_02.parquet b/transforms/language/gneissweb_classification/test-data/input/test_02.parquet deleted file mode 100644 index 2162b66412..0000000000 Binary files a/transforms/language/gneissweb_classification/test-data/input/test_02.parquet and /dev/null differ diff --git a/transforms/language/gneissweb_classification/test-data/input/test_03.parquet b/transforms/language/gneissweb_classification/test-data/input/test_03.parquet deleted file mode 100644 index 9d78e3ee4a..0000000000 Binary files a/transforms/language/gneissweb_classification/test-data/input/test_03.parquet and /dev/null differ diff --git a/transforms/language/gneissweb_classification/test/test_gneissweb_classification.py b/transforms/language/gneissweb_classification/test/test_gneissweb_classification.py index 905ff1ab5c..e7e515d320 100644 --- a/transforms/language/gneissweb_classification/test/test_gneissweb_classification.py +++ b/transforms/language/gneissweb_classification/test/test_gneissweb_classification.py @@ -26,14 +26,13 @@ class TestLangIdentificationTransform(AbstractTableTransformTest): def get_test_transform_fixtures(self) -> list[tuple]: config = { "gcls_model_credential": "PUT YOUR OWN HUGGINGFACE CREDENTIAL", - "gcls_model_file_name": "model.bin", - "gcls_model_url": "facebook/fasttext-language-identification", + "gcls_model_file_name": ["['fasttext_medical.bin']"], + "gcls_model_url": ["['ibm-granite/GneissWeb.Med_classifier']"], "gcls_content_column_name": "contents", - "gcls_output_label_column_name": "l", - "gcls_output_score_column_name": "s", + "gcls_output_label_column_name": ["['l']"], + "gcls_output_score_column_name": ["['s']"], } - table = pa.Table.from_arrays( [ pa.array( @@ -69,14 +68,14 @@ def get_test_transform_fixtures(self) -> list[tuple]: "hija de Forbante y nieta de Lápites. ", ] ), - pa.array(["de", "pt", "ja", "fr", "es"]), + pa.array(["cc", "cc", "cc", "cc", "cc"]), pa.array( [ - 0.998, - 1.000, - 0.930, - 0.998, - 0.998, + 0.966, + 0.988, + 1, + 0.996, + 0.892, ] ), ], @@ -102,7 +101,7 @@ def get_test_transform_fixtures(self) -> list[tuple]: ), pa.array( [ - "en", + "cc", ] ), pa.array([1.000]), @@ -118,7 +117,7 @@ def get_test_transform_fixtures(self) -> list[tuple]: ), pa.array( [ - "en", + "cc", ] ), pa.array([1.000]), @@ -136,7 +135,7 @@ def get_test_transform_fixtures(self) -> list[tuple]: invalid_output_score_column_name_table, ], [expected_table], - [{"de": 1, "es": 1, "fr": 1, "ja": 1, "pt": 1}, {}], + [{"cc": 5}, {}], ) ] diff --git a/transforms/language/gneissweb_classification/test/test_gneissweb_classification_python.py b/transforms/language/gneissweb_classification/test/test_gneissweb_classification_python.py index eccf847eb6..85c1a84229 100644 --- a/transforms/language/gneissweb_classification/test/test_gneissweb_classification_python.py +++ b/transforms/language/gneissweb_classification/test/test_gneissweb_classification_python.py @@ -28,11 +28,11 @@ class TestPythonClassificationTransform(AbstractTransformLauncherTest): def get_test_transform_fixtures(self) -> list[tuple]: cli_params = { "gcls_model_credential": "PUT YOUR OWN HUGGINGFACE CREDENTIAL", - "gcls_model_file_name": "model.bin", - "gcls_model_url":"facebook/fasttext-language-identification", + "gcls_model_file_name": ["fasttext_medical.bin"], + "gcls_model_url":["ibm-granite/GneissWeb.Med_classifier"], "gcls_content_column_name": "text", - "gcls_output_label_column_name": "ft_lang", - "gcls_output_score_column_name": "ft_score", + "gcls_output_label_column_name": ["label_med"], + "gcls_output_score_column_name": ["score"], } @@ -41,3 +41,4 @@ def get_test_transform_fixtures(self) -> list[tuple]: launcher = PythonTransformLauncher(ClassificationPythonTransformConfiguration()) fixtures.append((launcher, cli_params, basedir + "/input", basedir + "/expected")) return fixtures + diff --git a/transforms/language/gneissweb_classification/test/test_gneissweb_classification_ray.py b/transforms/language/gneissweb_classification/test/test_gneissweb_classification_ray.py index 2009c15244..cf1f126d8b 100644 --- a/transforms/language/gneissweb_classification/test/test_gneissweb_classification_ray.py +++ b/transforms/language/gneissweb_classification/test/test_gneissweb_classification_ray.py @@ -36,11 +36,11 @@ def get_test_transform_fixtures(self) -> list[tuple]: basedir = os.path.abspath(os.path.join(os.path.dirname(__file__), basedir)) config = { model_credential_cli_param: "PUT YOUR OWN HUGGINGFACE CREDENTIAL", - model_file_name_cli_param: "model.bin", - model_url_cli_param:"facebook/fasttext-language-identification", + model_file_name_cli_param: ["fasttext_medical.bin"], + model_url_cli_param:["ibm-granite/GneissWeb.Med_classifier"], content_column_name_cli_param: "text", - output_label_column_name_cli_param: "ft_lang", - output_score_column_name_cli_param: "ft_score", + output_label_column_name_cli_param: ["label_med"], + output_score_column_name_cli_param: ["score"], "run_locally": True, } @@ -52,3 +52,4 @@ def get_test_transform_fixtures(self) -> list[tuple]: basedir + "/expected", ) ] + diff --git a/transforms/language/gneissweb_classification/test/test_nlp.py b/transforms/language/gneissweb_classification/test/test_nlp.py index f563cb982a..c20b3926b1 100644 --- a/transforms/language/gneissweb_classification/test/test_nlp.py +++ b/transforms/language/gneissweb_classification/test/test_nlp.py @@ -17,7 +17,7 @@ def test_classification(): nlp_langid = ClassificationModelFactory.create_model( - "facebook/fasttext-language-identification", "model.bin","YOUR_HUGGINGFACE_ACCESS_TOKEN" + "ibm-granite/GneissWeb.Med_classifier", "fasttext_medical.bin","YOUR_HUGGINGFACE_ACCESS_TOKEN" ) documents = pa.array( @@ -36,7 +36,7 @@ def test_classification(): ) table = pa.Table.from_arrays([documents], names=["contents"]) table, stats = get_label_ds_pa(table, nlp_langid, "contents", "label", "score") - assert table["label"].to_pylist() == ["de", "pt", "ja", "fr", "es"] + assert table["label"].to_pylist() == ["cc", "cc", "cc", "cc", "cc"] assert len(table["score"].to_pylist()) == len(table["label"].to_pylist()) - assert "ft_lang" not in table.column_names + assert "ft_label" not in table.column_names assert "ft_score" not in table.column_names