Skip to content

Commit

Permalink
fix: Added get_documents to task interface
Browse files Browse the repository at this point in the history
  • Loading branch information
KennethEnevoldsen committed Feb 15, 2024
1 parent 24644a7 commit c4fb354
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 9 deletions.
5 changes: 1 addition & 4 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,4 @@ run_sonar.py
setup_ucloud.sh

# paper
src/scripts/paper/figure.r
src/scripts/paper/speed_filtered.csv
src/scripts/paper/speed.csv
src/scripts/paper/tmp.csv
src/scripts/paper/*
Original file line number Diff line number Diff line change
@@ -1 +1 @@
{"task_name":"LCC","task_description":"The leipzig corpora collection, annotated for sentiment","task_version":"1.1.1","time_of_run":"2024-02-11T14:09:11.351629","scores":{"da":{"accuracy":0.3846666666666666,"f1":0.3650136884557438,"accuracy_stderr":0.03664241622309678,"f1_stderr":0.03540233062350939,"main_score":0.3846666666666666}},"main_score":"accuracy"}
{"task_name":"LCC","task_description":"The leipzig corpora collection, annotated for sentiment","task_version":"1.1.1","time_of_run":"2024-02-15T10:45:54.680745","scores":{"da":{"accuracy":0.3846666666666666,"f1":0.3650136884557438,"accuracy_stderr":0.03664241622309678,"f1_stderr":0.03540233062350939,"main_score":0.3846666666666666}},"main_score":"accuracy"}
6 changes: 5 additions & 1 deletion src/seb/interfaces/mteb_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def load_data(self) -> DatasetDict:

return DatasetDict(ds)

def get_descriptive_stats(self) -> DescriptiveDatasetStats:
def get_documents(self) -> list[str]:
ds: DatasetDict = self.load_data()
texts = []
splits = self.mteb_task.description["eval_splits"]
Expand All @@ -67,6 +67,10 @@ def get_descriptive_stats(self) -> DescriptiveDatasetStats:
for text_column in self._text_columns:
texts += ds[split][text_column]

return texts

def get_descriptive_stats(self) -> DescriptiveDatasetStats:
texts = self.get_documents()
document_lengths = np.array([len(text) for text in texts])

mean = float(np.mean(document_lengths))
Expand Down
22 changes: 20 additions & 2 deletions src/seb/interfaces/task.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Literal, Protocol, TypedDict, runtime_checkable

from attr import dataclass
import nummy as np

from seb.interfaces.language import Language

Expand Down Expand Up @@ -70,9 +70,27 @@ def evaluate(self, model: Encoder) -> TaskResult:
"""
...

def get_descriptive_stats(self) -> DescriptiveDatasetStats:
def get_documents(self) -> list[str]:
"""
Get the documents for the task.
Returns:
A list of strings.
"""
...

def get_descriptive_stats(self) -> DescriptiveDatasetStats:
texts = self.get_documents()
document_lengths = np.array([len(text) for text in texts])

mean = float(np.mean(document_lengths))
std = float(np.std(document_lengths))
return DescriptiveDatasetStats(
mean_document_length=mean,
std_document_length=std,
num_documents=len(document_lengths),
)

def name_to_path(self) -> str:
"""
Convert a name to a path.
Expand Down
5 changes: 4 additions & 1 deletion src/seb/registered_tasks/multilingual.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def load_data(self) -> DatasetDict:

return DatasetDict(ds)

def get_descriptive_stats(self) -> DescriptiveDatasetStats:
def get_documents(self) -> list[str]:
ds = self.load_data()
texts = []
splits = self.get_splits()
Expand All @@ -80,7 +80,10 @@ def get_descriptive_stats(self) -> DescriptiveDatasetStats:
for split in splits:
for text_column in self._text_columns:
texts += ds[split][text_column]
return texts

def get_descriptive_stats(self) -> DescriptiveDatasetStats:
texts = self.get_documents()
document_lengths = np.array([len(text) for text in texts])

mean = np.mean(document_lengths)
Expand Down

0 comments on commit c4fb354

Please sign in to comment.