diff --git a/.gitignore b/.gitignore index 5417be0f..89dda433 100644 --- a/.gitignore +++ b/.gitignore @@ -58,7 +58,4 @@ run_sonar.py setup_ucloud.sh # paper -src/scripts/paper/figure.r -src/scripts/paper/speed_filtered.csv -src/scripts/paper/speed.csv -src/scripts/paper/tmp.csv +src/scripts/paper/* diff --git a/src/seb/cache/sentence-transformers__all-MiniLM-L6-v2/LCC.json b/src/seb/cache/sentence-transformers__all-MiniLM-L6-v2/LCC.json index db25875d..79c2c00b 100644 --- a/src/seb/cache/sentence-transformers__all-MiniLM-L6-v2/LCC.json +++ b/src/seb/cache/sentence-transformers__all-MiniLM-L6-v2/LCC.json @@ -1 +1 @@ -{"task_name":"LCC","task_description":"The leipzig corpora collection, annotated for sentiment","task_version":"1.1.1","time_of_run":"2024-02-11T14:09:11.351629","scores":{"da":{"accuracy":0.3846666666666666,"f1":0.3650136884557438,"accuracy_stderr":0.03664241622309678,"f1_stderr":0.03540233062350939,"main_score":0.3846666666666666}},"main_score":"accuracy"} \ No newline at end of file +{"task_name":"LCC","task_description":"The leipzig corpora collection, annotated for sentiment","task_version":"1.1.1","time_of_run":"2024-02-15T10:45:54.680745","scores":{"da":{"accuracy":0.3846666666666666,"f1":0.3650136884557438,"accuracy_stderr":0.03664241622309678,"f1_stderr":0.03540233062350939,"main_score":0.3846666666666666}},"main_score":"accuracy"} \ No newline at end of file diff --git a/src/seb/interfaces/mteb_task.py b/src/seb/interfaces/mteb_task.py index e30b3ffe..6df486cb 100644 --- a/src/seb/interfaces/mteb_task.py +++ b/src/seb/interfaces/mteb_task.py @@ -48,7 +48,7 @@ def load_data(self) -> DatasetDict: return DatasetDict(ds) - def get_descriptive_stats(self) -> DescriptiveDatasetStats: + def get_documents(self) -> list[str]: ds: DatasetDict = self.load_data() texts = [] splits = self.mteb_task.description["eval_splits"] @@ -67,6 +67,10 @@ def get_descriptive_stats(self) -> DescriptiveDatasetStats: for text_column in self._text_columns: texts += ds[split][text_column] + return texts + + def get_descriptive_stats(self) -> DescriptiveDatasetStats: + texts = self.get_documents() document_lengths = np.array([len(text) for text in texts]) mean = float(np.mean(document_lengths)) diff --git a/src/seb/interfaces/task.py b/src/seb/interfaces/task.py index a5de49e9..4b16cab7 100644 --- a/src/seb/interfaces/task.py +++ b/src/seb/interfaces/task.py @@ -1,6 +1,6 @@ from typing import Literal, Protocol, TypedDict, runtime_checkable -from attr import dataclass +import nummy as np from seb.interfaces.language import Language @@ -70,9 +70,27 @@ def evaluate(self, model: Encoder) -> TaskResult: """ ... - def get_descriptive_stats(self) -> DescriptiveDatasetStats: + def get_documents(self) -> list[str]: + """ + Get the documents for the task. + + Returns: + A list of strings. + """ ... + def get_descriptive_stats(self) -> DescriptiveDatasetStats: + texts = self.get_documents() + document_lengths = np.array([len(text) for text in texts]) + + mean = float(np.mean(document_lengths)) + std = float(np.std(document_lengths)) + return DescriptiveDatasetStats( + mean_document_length=mean, + std_document_length=std, + num_documents=len(document_lengths), + ) + def name_to_path(self) -> str: """ Convert a name to a path. diff --git a/src/seb/registered_tasks/multilingual.py b/src/seb/registered_tasks/multilingual.py index 72d0a2e4..b08d1531 100644 --- a/src/seb/registered_tasks/multilingual.py +++ b/src/seb/registered_tasks/multilingual.py @@ -71,7 +71,7 @@ def load_data(self) -> DatasetDict: return DatasetDict(ds) - def get_descriptive_stats(self) -> DescriptiveDatasetStats: + def get_documents(self) -> list[str]: ds = self.load_data() texts = [] splits = self.get_splits() @@ -80,7 +80,10 @@ def get_descriptive_stats(self) -> DescriptiveDatasetStats: for split in splits: for text_column in self._text_columns: texts += ds[split][text_column] + return texts + def get_descriptive_stats(self) -> DescriptiveDatasetStats: + texts = self.get_documents() document_lengths = np.array([len(text) for text in texts]) mean = np.mean(document_lengths)