From 4202333d43536e3329cd44810f39b2b5763e6d24 Mon Sep 17 00:00:00 2001 From: Benedikt Fuchs Date: Fri, 25 Oct 2024 14:51:37 +0200 Subject: [PATCH 1/3] drop python 3.8 --- .github/workflows/ci.yml | 4 +- .github/workflows/publish-docs.yml | 2 +- CONTRIBUTING.md | 4 +- README.md | 2 +- docs/contributing/local_development.md | 4 +- docs/tutorial/intro.md | 2 +- flair/class_utils.py | 11 +- flair/data.py | 167 ++++++------- flair/datasets/base.py | 12 +- flair/datasets/biomedical.py | 230 +++++++++--------- flair/datasets/document_classification.py | 132 +++++----- flair/datasets/entity_linking.py | 109 +++++---- flair/datasets/ocr.py | 8 +- flair/datasets/relation_extraction.py | 30 +-- flair/datasets/sequence_labeling.py | 207 ++++++++-------- flair/datasets/text_image.py | 11 +- flair/datasets/text_text.py | 48 ++-- flair/datasets/treebanks.py | 10 +- flair/embeddings/base.py | 24 +- flair/embeddings/document.py | 62 ++--- flair/embeddings/image.py | 16 +- flair/embeddings/legacy.py | 22 +- flair/embeddings/token.py | 67 +++-- flair/embeddings/transformer.py | 70 +++--- flair/file_utils.py | 11 +- flair/inference_utils.py | 4 +- flair/models/entity_linker_model.py | 27 +- flair/models/entity_mention_linking.py | 95 ++++---- flair/models/language_model.py | 14 +- flair/models/lemmatizer_model.py | 36 +-- flair/models/multitask_model.py | 32 +-- flair/models/pairwise_classification_model.py | 5 +- flair/models/pairwise_regression_model.py | 29 +-- flair/models/prefixed_tagger.py | 36 +-- flair/models/regexp_tagger.py | 22 +- flair/models/relation_classifier_model.py | 67 +++-- flair/models/relation_extractor_model.py | 10 +- flair/models/sequence_tagger_model.py | 32 +-- flair/models/sequence_tagger_utils/viterbi.py | 12 +- flair/models/tars_model.py | 44 ++-- flair/models/text_classification_model.py | 6 +- flair/models/text_regression_model.py | 28 +-- flair/models/triple_classification_model.py | 5 +- flair/models/word_tagger_model.py | 8 +- flair/nn/decoder.py | 6 +- flair/nn/model.py | 70 +++--- flair/nn/multitask.py | 17 +- flair/samplers.py | 3 +- flair/splitter.py | 18 +- flair/tokenization.py | 34 +-- flair/trainers/language_model_trainer.py | 17 +- flair/trainers/plugins/base.py | 19 +- .../plugins/functional/anneal_on_plateau.py | 4 +- .../plugins/functional/checkpoints.py | 4 +- .../plugins/functional/linear_scheduler.py | 4 +- .../functional/reduce_transformer_vocab.py | 3 +- .../plugins/functional/weight_extractor.py | 4 +- flair/trainers/plugins/loggers/log_file.py | 4 +- flair/trainers/plugins/loggers/loss_file.py | 8 +- .../plugins/loggers/metric_history.py | 7 +- flair/trainers/plugins/loggers/tensorboard.py | 4 +- flair/trainers/plugins/loggers/wandb.py | 4 +- flair/trainers/plugins/metric_records.py | 5 +- flair/trainers/trainer.py | 32 +-- flair/training_utils.py | 18 +- flair/visual/ner_html.py | 4 +- flair/visual/training_curves.py | 8 +- flair/visual/tree_printer.py | 6 +- pyproject.toml | 4 +- requirements-dev.txt | 2 +- resources/docs/EXPERIMENTS.md | 12 +- resources/docs/HUNFLAIR2.md | 2 +- .../KOR_docs/TUTORIAL_7_TRAINING_A_MODEL.md | 2 +- .../KOR_docs/TUTORIAL_8_MODEL_OPTIMIZATION.md | 2 +- setup.py | 2 +- tests/embedding_test_utils.py | 16 +- ...test_document_transform_word_embeddings.py | 8 +- tests/embeddings/test_word_embeddings.py | 4 +- tests/model_test_utils.py | 10 +- tests/models/test_relation_classifier.py | 12 +- tests/test_datasets_biomedical.py | 4 +- tests/test_labels.py | 38 ++- tests/test_tokenize_sentence.py | 4 +- 83 files changed, 1084 insertions(+), 1118 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 5b3633d93e..5633ac850f 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -12,11 +12,11 @@ jobs: FLAIR_CACHE_ROOT: ./cache/flair steps: - uses: actions/checkout@v3 - - name: Set up Python 3.8 + - name: Set up Python 3.9 id: setup-python uses: actions/setup-python@v4 with: - python-version: 3.8 + python-version: 3.9 - name: Install Torch cpu run: pip install torch --index-url https://download.pytorch.org/whl/cpu - name: Install Flair dependencies diff --git a/.github/workflows/publish-docs.yml b/.github/workflows/publish-docs.yml index 24a424adba..f752b1324c 100644 --- a/.github/workflows/publish-docs.yml +++ b/.github/workflows/publish-docs.yml @@ -10,7 +10,7 @@ jobs: name: Build the docs using Sphinx and push to gh-pages runs-on: ubuntu-latest env: - python-version: 3.8 + python-version: 3.9 steps: - name: Checkout code uses: actions/checkout@v3 diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index b40ddfe77e..b44927f17a 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -24,8 +24,8 @@ the code should hopefully be easy. ### Setup -Flair requires python-3.8 or higher. To make sure your code also runs on the oldest supported -python version, it is recommended to use python-3.8.x for flair development. +Flair requires python-3.9 or higher. To make sure your code also runs on the oldest supported +python version, it is recommended to use python-3.9.x for flair development. Create a python environment of your preference and run: ```bash diff --git a/README.md b/README.md index 92502c972b..fdf4130124 100644 --- a/README.md +++ b/README.md @@ -52,7 +52,7 @@ In your favorite virtual environment, simply do: pip install flair ``` -Flair requires Python 3.8+. +Flair requires Python 3.9+. ### Example 1: Tag Entities in Text diff --git a/docs/contributing/local_development.md b/docs/contributing/local_development.md index 87439439f2..9c7413703e 100644 --- a/docs/contributing/local_development.md +++ b/docs/contributing/local_development.md @@ -6,8 +6,8 @@ the code should hopefully be easy. ## Setup -Flair requires python-3.8 or higher. To make sure our code also runs on the oldest supported -python version, it is recommended to use python-3.8.x for flair development. +Flair requires python-3.9 or higher. To make sure our code also runs on the oldest supported +python version, it is recommended to use python-3.9.x for flair development. Create a python environment of your preference and run: ```bash diff --git a/docs/tutorial/intro.md b/docs/tutorial/intro.md index e652583f76..b8af9b5667 100644 --- a/docs/tutorial/intro.md +++ b/docs/tutorial/intro.md @@ -16,7 +16,7 @@ In your favorite virtual environment, simply do: pip install flair ``` -Flair requires Python 3.8+. +Flair requires Python 3.9+. ## Example 1: Tag Entities in Text diff --git a/flair/class_utils.py b/flair/class_utils.py index 9aa95cd1ee..7e01f4ff42 100644 --- a/flair/class_utils.py +++ b/flair/class_utils.py @@ -1,12 +1,13 @@ import importlib import inspect +from collections.abc import Iterable from types import ModuleType -from typing import Any, Iterable, List, Optional, Type, TypeVar, Union, overload +from typing import Any, Optional, TypeVar, Union, overload T = TypeVar("T") -def get_non_abstract_subclasses(cls: Type[T]) -> Iterable[Type[T]]: +def get_non_abstract_subclasses(cls: type[T]) -> Iterable[type[T]]: for subclass in cls.__subclasses__(): yield from get_non_abstract_subclasses(subclass) if inspect.isabstract(subclass): @@ -14,7 +15,7 @@ def get_non_abstract_subclasses(cls: Type[T]) -> Iterable[Type[T]]: yield subclass -def get_state_subclass_by_name(cls: Type[T], cls_name: Optional[str]) -> Type[T]: +def get_state_subclass_by_name(cls: type[T], cls_name: Optional[str]) -> type[T]: for sub_cls in get_non_abstract_subclasses(cls): if sub_cls.__name__ == cls_name: return sub_cls @@ -26,12 +27,12 @@ def lazy_import(group: str, module: str, first_symbol: None) -> ModuleType: ... @overload -def lazy_import(group: str, module: str, first_symbol: str, *symbols: str) -> List[Any]: ... +def lazy_import(group: str, module: str, first_symbol: str, *symbols: str) -> list[Any]: ... def lazy_import( group: str, module: str, first_symbol: Optional[str] = None, *symbols: str -) -> Union[List[Any], ModuleType]: +) -> Union[list[Any], ModuleType]: try: imported_module = importlib.import_module(module) except ImportError: diff --git a/flair/data.py b/flair/data.py index 69d85baf92..56622b249c 100644 --- a/flair/data.py +++ b/flair/data.py @@ -4,10 +4,11 @@ import typing from abc import ABC, abstractmethod from collections import Counter, defaultdict +from collections.abc import Iterable from operator import itemgetter from os import PathLike from pathlib import Path -from typing import Any, DefaultDict, Dict, Iterable, List, NamedTuple, Optional, Tuple, Union, cast +from typing import Any, NamedTuple, Optional, Union, cast import torch from deprecated.sphinx import deprecated @@ -52,8 +53,8 @@ class Dictionary: def __init__(self, add_unk: bool = True) -> None: # init dictionaries - self.item2idx: Dict[bytes, int] = {} - self.idx2item: List[bytes] = [] + self.item2idx: dict[bytes, int] = {} + self.idx2item: list[bytes] = [] self.add_unk = add_unk self.multi_label = False self.span_labels = False @@ -101,7 +102,7 @@ def get_idx_for_item(self, item: str) -> int: ) raise IndexError - def get_idx_for_items(self, items: List[str]) -> List[int]: + def get_idx_for_items(self, items: list[str]) -> list[int]: """Returns the IDs for each item of the list of string, otherwise 0 if not found. Args: @@ -120,7 +121,7 @@ def get_idx_for_items(self, items: List[str]) -> List[int]: return [results] return list(results) - def get_items(self) -> List[str]: + def get_items(self) -> list[str]: items = [] for item in self.idx2item: items.append(item.decode("UTF-8")) @@ -151,7 +152,7 @@ def save(self, savefile: PathLike): mappings = {"idx2item": self.idx2item, "item2idx": self.item2idx} pickle.dump(mappings, f) - def __setstate__(self, d: Dict) -> None: + def __setstate__(self, d: dict) -> None: self.__dict__ = d # set 'add_unk' if the dictionary was created with a version of Flair older than 0.9 if "add_unk" not in self.__dict__: @@ -281,9 +282,9 @@ class DataPoint: """ def __init__(self) -> None: - self.annotation_layers: Dict[str, List[Label]] = {} - self._embeddings: Dict[str, torch.Tensor] = {} - self._metadata: Dict[str, Any] = {} + self.annotation_layers: dict[str, list[Label]] = {} + self._embeddings: dict[str, torch.Tensor] = {} + self._metadata: dict[str, Any] = {} @property @abstractmethod @@ -293,7 +294,7 @@ def embedding(self) -> torch.Tensor: def set_embedding(self, name: str, vector: torch.Tensor): self._embeddings[name] = vector - def get_embedding(self, names: Optional[List[str]] = None) -> torch.Tensor: + def get_embedding(self, names: Optional[list[str]] = None) -> torch.Tensor: # if one embedding name, directly return it if names and len(names) == 1: if names[0] in self._embeddings: @@ -308,7 +309,7 @@ def get_embedding(self, names: Optional[List[str]] = None) -> torch.Tensor: else: return torch.tensor([], device=flair.device) - def get_each_embedding(self, embedding_names: Optional[List[str]] = None) -> List[torch.Tensor]: + def get_each_embedding(self, embedding_names: Optional[list[str]] = None) -> list[torch.Tensor]: embeddings = [] for embed_name in sorted(self._embeddings.keys()): if embedding_names and embed_name not in embedding_names: @@ -325,7 +326,7 @@ def to(self, device: str, pin_memory: bool = False) -> None: else: self._embeddings[name] = vector.to(device, non_blocking=True) - def clear_embeddings(self, embedding_names: Optional[List[str]] = None) -> None: + def clear_embeddings(self, embedding_names: Optional[list[str]] = None) -> None: if embedding_names is None: self._embeddings = {} else: @@ -368,14 +369,14 @@ def get_label(self, label_type: Optional[str] = None, zero_tag_value: str = "O") return Label(self, zero_tag_value) return self.get_labels(label_type)[0] - def get_labels(self, typename: Optional[str] = None) -> List[Label]: + def get_labels(self, typename: Optional[str] = None) -> list[Label]: if typename is None: return self.labels return self.annotation_layers.get(typename, []) @property - def labels(self) -> List[Label]: + def labels(self) -> list[Label]: all_labels = [] for key in self.annotation_layers: all_labels.extend(self.annotation_layers[key]) @@ -447,8 +448,8 @@ def __init__( concept_id: str, concept_name: str, database_name: str, - additional_ids: Optional[List[str]] = None, - synonyms: Optional[List[str]] = None, + additional_ids: Optional[list[str]] = None, + synonyms: Optional[list[str]] = None, description: Optional[str] = None, ): """A Concept as part of a knowledgebase or ontology. @@ -483,7 +484,7 @@ def __str__(self) -> str: def __repr__(self) -> str: return str(self) - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> dict[str, Any]: return { "concept_id": self.concept_id, "concept_name": self.concept_name, @@ -550,8 +551,8 @@ def __init__( self._start_position = start_position - self._embeddings: Dict = {} - self.tags_proba_dist: Dict[str, List[Label]] = {} + self._embeddings: dict[str, torch.Tensor] = {} + self.tags_proba_dist: dict[str, list[Label]] = {} @property def idx(self) -> int: @@ -568,10 +569,10 @@ def text(self) -> str: def unlabeled_identifier(self) -> str: return f'Token[{self.idx - 1}]: "{self.text}"' - def add_tags_proba_dist(self, tag_type: str, tags: List[Label]) -> None: + def add_tags_proba_dist(self, tag_type: str, tags: list[Label]) -> None: self.tags_proba_dist[tag_type] = tags - def get_tags_proba_dist(self, tag_type: str) -> List[Label]: + def get_tags_proba_dist(self, tag_type: str) -> list[Label]: if tag_type in self.tags_proba_dist: return self.tags_proba_dist[tag_type] return [] @@ -617,7 +618,7 @@ def set_label(self, typename: str, value: str, score: float = 1.0, **metadata): else: DataPoint.set_label(self, typename=typename, value=value, score=score, **metadata) - def to_dict(self, tag_type: Optional[str] = None) -> Dict[str, Any]: + def to_dict(self, tag_type: Optional[str] = None) -> dict[str, Any]: return { "text": self.text, "start_pos": self.start_position, @@ -629,7 +630,7 @@ def to_dict(self, tag_type: Optional[str] = None) -> Dict[str, Any]: class Span(_PartOfSentence): """This class represents one textual span consisting of Tokens.""" - def __new__(self, tokens: List[Token]): + def __new__(self, tokens: list[Token]): # check if the span already exists. If so, return it unlabeled_identifier = self._make_unlabeled_identifier(tokens) if unlabeled_identifier in tokens[0].sentence._known_spans: @@ -643,7 +644,7 @@ def __new__(self, tokens: List[Token]): tokens[0].sentence._known_spans[unlabeled_identifier] = span return span - def __init__(self, tokens: List[Token]) -> None: + def __init__(self, tokens: list[Token]) -> None: if not self.initialized: super().__init__(tokens[0].sentence) self.tokens = tokens @@ -662,7 +663,7 @@ def text(self) -> str: return "".join([t.text + t.whitespace_after * " " for t in self.tokens]).strip() @staticmethod - def _make_unlabeled_identifier(tokens: List[Token]): + def _make_unlabeled_identifier(tokens: list[Token]): text = "".join([t.text + t.whitespace_after * " " for t in tokens]).strip() return f'Span[{tokens[0].idx - 1}:{tokens[-1].idx}]: "{text}"' @@ -769,7 +770,7 @@ class Sentence(DataPoint): def __init__( self, - text: Union[str, List[str], List[Token]], + text: Union[str, list[str], list[Token]], use_tokenizer: Union[bool, Tokenizer] = True, language_code: Optional[str] = None, start_position: int = 0, @@ -790,10 +791,10 @@ def __init__( """ super().__init__() - self.tokens: List[Token] = [] + self.tokens: list[Token] = [] # private field for all known spans - self._known_spans: Dict[str, _PartOfSentence] = {} + self._known_spans: dict[str, _PartOfSentence] = {} self.language_code: Optional[str] = language_code @@ -818,7 +819,7 @@ def __init__( self._previous_sentence: Optional[Sentence] = None self._has_context: bool = False self._next_sentence: Optional[Sentence] = None - self._position_in_dataset: Optional[typing.Tuple[Dataset, int]] = None + self._position_in_dataset: Optional[tuple[Dataset, int]] = None # if text is passed, instantiate sentence with tokens (words) if isinstance(text, str): @@ -830,7 +831,7 @@ def __init__( self.tokens[-1].whitespace_after = 0 return else: - words = cast(List[str], text) + words = cast(list[str], text) text = " ".join(words) # determine token positions and whitespace_after flag @@ -861,15 +862,15 @@ def __init__( def unlabeled_identifier(self): return f'Sentence[{len(self)}]: "{self.text}"' - def get_relations(self, label_type: Optional[str] = None) -> List[Relation]: - relations: List[Relation] = [] + def get_relations(self, label_type: Optional[str] = None) -> list[Relation]: + relations: list[Relation] = [] for label in self.get_labels(label_type): if isinstance(label.data_point, Relation): relations.append(label.data_point) return relations - def get_spans(self, label_type: Optional[str] = None) -> List[Span]: - spans: List[Span] = [] + def get_spans(self, label_type: Optional[str] = None) -> list[Span]: + spans: list[Span] = [] for potential_span in self._known_spans.values(): if isinstance(potential_span, Span) and (label_type is None or potential_span.has_label(label_type)): spans.append(potential_span) @@ -922,16 +923,16 @@ def to(self, device: str, pin_memory: bool = False): for token in self: token.to(device, pin_memory) - def clear_embeddings(self, embedding_names: Optional[List[str]] = None): + def clear_embeddings(self, embedding_names: Optional[list[str]] = None): super().clear_embeddings(embedding_names) # clear token embeddings for token in self: token.clear_embeddings(embedding_names) - def left_context(self, context_length: int, respect_document_boundaries: bool = True) -> List[Token]: + def left_context(self, context_length: int, respect_document_boundaries: bool = True) -> list[Token]: sentence = self - left_context: List[Token] = [] + left_context: list[Token] = [] while len(left_context) < context_length: sentence = sentence.previous_sentence() if sentence is None: @@ -943,9 +944,9 @@ def left_context(self, context_length: int, respect_document_boundaries: bool = left_context = sentence.tokens + left_context return left_context[-context_length:] - def right_context(self, context_length: int, respect_document_boundaries: bool = True) -> List[Token]: + def right_context(self, context_length: int, respect_document_boundaries: bool = True) -> list[Token]: sentence = self - right_context: List[Token] = [] + right_context: list[Token] = [] while len(right_context) < context_length: sentence = sentence.next_sentence() if sentence is None: @@ -1037,7 +1038,7 @@ def to_original_text(self) -> str: [t.text + t.whitespace_after * " " for t in self.tokens] ).strip() - def to_dict(self, tag_type: Optional[str] = None) -> Dict[str, Any]: + def to_dict(self, tag_type: Optional[str] = None) -> dict[str, Any]: return { "text": self.to_original_text(), "labels": [label.to_dict() for label in self.get_labels(tag_type) if label.data_point is self], @@ -1180,7 +1181,7 @@ def copy_context_from_sentence(self, sentence: "Sentence") -> None: self._position_in_dataset = sentence._position_in_dataset @classmethod - def set_context_for_sentences(cls, sentences: List["Sentence"]) -> None: + def set_context_for_sentences(cls, sentences: list["Sentence"]) -> None: previous_sentence = None for sentence in sentences: if sentence.is_context_set(): @@ -1231,7 +1232,7 @@ def to(self, device: str, pin_memory: bool = False): self.first.to(device, pin_memory) self.second.to(device, pin_memory) - def clear_embeddings(self, embedding_names: Optional[List[str]] = None): + def clear_embeddings(self, embedding_names: Optional[list[str]] = None): self.first.clear_embeddings(embedding_names) self.second.clear_embeddings(embedding_names) if self.concatenated_data is not None: @@ -1276,7 +1277,7 @@ def to(self, device: str, pin_memory: bool = False): self.second.to(device, pin_memory) self.third.to(device, pin_memory) - def clear_embeddings(self, embedding_names: Optional[List[str]] = None): + def clear_embeddings(self, embedding_names: Optional[list[str]] = None): self.first.clear_embeddings(embedding_names) self.second.clear_embeddings(embedding_names) self.third.clear_embeddings(embedding_names) @@ -1313,7 +1314,7 @@ def __init__(self, data=None, imageURL=None): super().__init__() self.data = data - self._embeddings: Dict = {} + self._embeddings: dict[str, torch.Tensor] = {} self.imageURL = imageURL @property @@ -1497,17 +1498,17 @@ def make_vocab_dictionary(self, max_tokens: int = -1, min_freq: int = 1) -> Dict return vocab_dictionary - def _get_most_common_tokens(self, max_tokens: int, min_freq: int) -> List[str]: + def _get_most_common_tokens(self, max_tokens: int, min_freq: int) -> list[str]: tokens_and_frequencies = Counter(self._get_all_tokens()) - tokens: List[str] = [] + tokens: list[str] = [] for token, freq in tokens_and_frequencies.most_common(): if (min_freq != -1 and freq < min_freq) or (max_tokens != -1 and len(tokens) == max_tokens): break tokens.append(token) return tokens - def _get_all_tokens(self) -> List[str]: + def _get_all_tokens(self) -> list[str]: assert self.train tokens = [s.tokens for s in _iter_dataset(self.train)] tokens = [token for sublist in tokens for token in sublist] @@ -1544,13 +1545,8 @@ def _obtain_statistics_for(sentences, name, tag_type) -> dict: tags_to_count = Corpus._count_token_labels(sentences, tag_type) tokens_per_sentence = Corpus._get_tokens_per_sentence(sentences) - label_size_dict = {} - for label, c in classes_to_count.items(): - label_size_dict[label] = c - - tag_size_dict = {} - for tag, c in tags_to_count.items(): - tag_size_dict[tag] = c + label_size_dict = dict(classes_to_count) + tag_size_dict = dict(tags_to_count) return { "dataset": name, @@ -1566,20 +1562,20 @@ def _obtain_statistics_for(sentences, name, tag_type) -> dict: } @staticmethod - def _get_tokens_per_sentence(sentences: Iterable[Sentence]) -> List[int]: + def _get_tokens_per_sentence(sentences: Iterable[Sentence]) -> list[int]: return [len(x.tokens) for x in sentences] @staticmethod - def _count_sentence_labels(sentences: Iterable[Sentence]) -> DefaultDict[str, int]: - label_count: DefaultDict[str, int] = defaultdict(lambda: 0) + def _count_sentence_labels(sentences: Iterable[Sentence]) -> defaultdict[str, int]: + label_count: defaultdict[str, int] = defaultdict(lambda: 0) for sent in sentences: for label in sent.labels: label_count[label.value] += 1 return label_count @staticmethod - def _count_token_labels(sentences: Iterable[Sentence], label_type: str) -> DefaultDict[str, int]: - label_count: DefaultDict[str, int] = defaultdict(lambda: 0) + def _count_token_labels(sentences: Iterable[Sentence], label_type: str) -> defaultdict[str, int]: + label_count: defaultdict[str, int] = defaultdict(lambda: 0) for sent in sentences: for token in sent.tokens: if label_type in token.annotation_layers: @@ -1623,7 +1619,7 @@ def make_label_dictionary( sentence_label_type_counter: typing.Counter[str] = Counter() label_value_counter: typing.Counter[str] = Counter() - all_sentence_labels: List[str] = [] + all_sentence_labels: list[str] = [] # first, determine the datapoint type by going through dataset until first label is found datapoint_type = None @@ -1687,10 +1683,10 @@ def make_label_dictionary( def add_label_noise( self, label_type: str, - labels: List[str], + labels: list[str], noise_share: float = 0.2, split: str = "train", - noise_transition_matrix: Optional[Dict[str, List[float]]] = None, + noise_transition_matrix: Optional[dict[str, list[float]]] = None, ): """Generates uniform label noise distribution in the chosen dataset split. @@ -1817,12 +1813,12 @@ def make_tag_dictionary(self, tag_type: str) -> Dictionary: class MultiCorpus(Corpus): def __init__( self, - corpora: List[Corpus], - task_ids: Optional[List[str]] = None, + corpora: list[Corpus], + task_ids: Optional[list[str]] = None, name: str = "multicorpus", **corpusargs, ) -> None: - self.corpora: List[Corpus] = corpora + self.corpora: list[Corpus] = corpora ids = task_ids if task_ids else [f"Task_{i}" for i in range(len(corpora))] @@ -1871,8 +1867,8 @@ class ConcatFlairDataset(Dataset): datasets (sequence): List of datasets to be concatenated """ - datasets: List[Dataset] - cumulative_sizes: List[int] + datasets: list[Dataset] + cumulative_sizes: list[int] @staticmethod def cumsum(sequence): @@ -1907,36 +1903,13 @@ def __getitem__(self, idx: int) -> Sentence: return sentence @property - def cummulative_sizes(self) -> List[int]: + def cummulative_sizes(self) -> list[int]: return self.cumulative_sizes -def iob2(tags: List) -> bool: - """Converts the tags to the IOB2 format. - - Check that tags have a valid IOB format. - Tags in IOB1 format are converted to IOB2. - """ - for i, tag in enumerate(tags): - if tag.value == "O": - continue - split = tag.value.split("-") - if len(split) != 2 or split[0] not in ["I", "B"]: - return False - if split[0] == "B": - continue - elif i == 0 or tags[i - 1].value == "O": # conversion IOB1 to IOB2 - tags[i].value = "B" + tag.value[1:] - elif tags[i - 1].value[1:] == tag.value[1:]: - continue - else: # conversion IOB1 to IOB2 - tags[i].value = "B" + tag.value[1:] - return True - - def randomly_split_into_two_datasets( dataset: Dataset, length_of_first: int, random_seed: Optional[int] = None -) -> Tuple[Subset, Subset]: +) -> tuple[Subset, Subset]: """Shuffles a dataset and splits into two subsets. The length of the first is specified and the remaining samples go into the second subset. @@ -1959,17 +1932,17 @@ def randomly_split_into_two_datasets( def get_spans_from_bio( - bioes_tags: List[str], bioes_scores: Optional[List[float]] = None -) -> List[typing.Tuple[List[int], float, str]]: + bioes_tags: list[str], bioes_scores: Optional[list[float]] = None +) -> list[tuple[list[int], float, str]]: # add a dummy "O" to close final prediction bioes_tags.append("O") # return complex list found_spans = [] # internal variables - current_tag_weights: Dict[str, float] = {} + current_tag_weights: dict[str, float] = {} previous_tag = "O-" - current_span: List[int] = [] - current_span_scores: List[float] = [] + current_span: list[int] = [] + current_span_scores: list[float] = [] for idx, bioes_tag in enumerate(bioes_tags): # non-set tags are OUT tags if bioes_tag == "" or bioes_tag == "O" or bioes_tag == "_": diff --git a/flair/datasets/base.py b/flair/datasets/base.py index a38d0b1321..0737b4660d 100644 --- a/flair/datasets/base.py +++ b/flair/datasets/base.py @@ -1,7 +1,7 @@ import logging from abc import abstractmethod from pathlib import Path -from typing import Generic, List, Optional, Union +from typing import Generic, Optional, Union import torch.utils.data.dataloader from deprecated.sphinx import deprecated @@ -41,7 +41,7 @@ def __init__( class FlairDatapointDataset(FlairDataset, Generic[DT]): """A simple Dataset object to wrap a List of Datapoints, for example Sentences.""" - def __init__(self, datapoints: Union[DT, List[DT]]) -> None: + def __init__(self, datapoints: Union[DT, list[DT]]) -> None: """Instantiate FlairDatapointDataset. Args: @@ -64,7 +64,7 @@ def __getitem__(self, index: int = 0) -> DT: class SentenceDataset(FlairDatapointDataset): @deprecated(version="0.11", reason="The 'SentenceDataset' class was renamed to 'FlairDatapointDataset'") - def __init__(self, sentences: Union[Sentence, List[Sentence]]) -> None: + def __init__(self, sentences: Union[Sentence, list[Sentence]]) -> None: super().__init__(sentences) @@ -73,7 +73,7 @@ class StringDataset(FlairDataset): def __init__( self, - texts: Union[str, List[str]], + texts: Union[str, list[str]], use_tokenizer: Union[bool, Tokenizer] = SpaceTokenizer(), ) -> None: """Instantiate StringDataset. @@ -111,7 +111,7 @@ def __init__( database: str, collection: str, text_field: str, - categories_field: Optional[List[str]] = None, + categories_field: Optional[list[str]] = None, max_tokens_per_doc: int = -1, max_chars_per_doc: int = -1, tokenizer: Tokenizer = SegtokTokenizer(), @@ -195,7 +195,7 @@ def __init__( def _parse_document_to_sentence( self, text: str, - labels: List[str], + labels: list[str], tokenizer: Union[bool, Tokenizer], ): if self.max_chars_per_doc > 0: diff --git a/flair/datasets/biomedical.py b/flair/datasets/biomedical.py index 28f4aca98b..e99a71ccf7 100644 --- a/flair/datasets/biomedical.py +++ b/flair/datasets/biomedical.py @@ -7,6 +7,7 @@ import sys from abc import ABC, abstractmethod from collections import defaultdict, deque +from collections.abc import Iterable from copy import copy from operator import attrgetter from pathlib import Path @@ -18,7 +19,7 @@ StreamError, TarError, ) -from typing import Dict, Iterable, List, NamedTuple, Optional, Tuple, Union +from typing import NamedTuple, Optional, Union from zipfile import BadZipFile, LargeZipFile import ftfy @@ -56,7 +57,7 @@ class Entity: text as well as the type of entity (e.g. Chemical, Gene, and so on). """ - def __init__(self, char_span: Tuple[int, int], entity_type: str) -> None: + def __init__(self, char_span: tuple[int, int], entity_type: str) -> None: assert char_span[0] < char_span[1] self.char_span = range(*char_span) self.type = entity_type @@ -98,9 +99,9 @@ class InternalBioNerDataset: def __init__( self, - documents: Dict[str, str], - entities_per_document: Dict[str, List[Entity]], - entity_types: List[str] = [], + documents: dict[str, str], + entities_per_document: dict[str, list[Entity]], + entity_types: list[str] = [], ): self.documents = documents self.entities_per_document = entities_per_document @@ -134,7 +135,7 @@ def merge_datasets(data_sets: Iterable[InternalBioNerDataset]): def filter_and_map_entities( - dataset: InternalBioNerDataset, entity_type_to_canonical: Dict[str, str] + dataset: InternalBioNerDataset, entity_type_to_canonical: dict[str, str] ) -> InternalBioNerDataset: mapped_entities_per_document = {} entity_types = list(entity_type_to_canonical.values()) @@ -223,7 +224,7 @@ def bioc_to_internal(bioc_file: Path): for document in Tqdm.tqdm(documents, desc="Converting to internal"): document_id = document.xpath("./id")[0].text - texts: List[str] = [] + texts: list[str] = [] entities = [] for passage in document.xpath("passage"): @@ -358,7 +359,7 @@ def __init__( """ self.sentence_splitter = sentence_splitter - def process_dataset(self, datasets: Dict[str, InternalBioNerDataset], out_dir: Path): + def process_dataset(self, datasets: dict[str, InternalBioNerDataset], out_dir: Path): if "train" in datasets: self.write_to_conll(datasets["train"], out_dir / (self.sentence_splitter.name + "_train.conll")) if "dev" in datasets: @@ -450,7 +451,7 @@ def to_internal(self, data_folder: Path) -> InternalBioNerDataset: @staticmethod @abstractmethod - def split_url() -> Union[str, List[str]]: + def split_url() -> Union[str, list[str]]: raise NotImplementedError def get_corpus_sentence_splitter(self) -> Optional[SentenceSplitter]: @@ -596,8 +597,8 @@ def download_dataset(cls, data_dir: Path) -> Path: @classmethod def parse_dataset(cls, original_file: Path): - documents: Dict[str, str] = {} - entities_per_document: Dict[str, List[Entity]] = {} + documents: dict[str, str] = {} + entities_per_document: dict[str, list[Entity]] = {} tree = etree.parse(str(original_file)) sentence_elems = tree.xpath("//sentence") @@ -647,7 +648,7 @@ def to_internal(self, data_dir: Path) -> InternalBioNerDataset: return merge_datasets([train_data, test_data]) - def get_entity_type_mapping(self) -> Optional[Dict]: + def get_entity_type_mapping(self) -> Optional[dict]: return self.entity_type_mapping @@ -726,14 +727,14 @@ def download_and_prepare_test(cls, data_folder: Path, sentence_tag: str) -> Inte @classmethod def read_file(cls, input_iob_file: Path, sentence_tag: str) -> InternalBioNerDataset: - documents: Dict[str, str] = {} - entities_per_document: Dict[str, List[Entity]] = defaultdict(list) + documents: dict[str, str] = {} + entities_per_document: dict[str, list[Entity]] = defaultdict(list) with open(str(input_iob_file), encoding="utf8") as file_reader: document_id: Optional[str] = None document_text: Optional[str] = None - entities: List[Entity] = [] + entities: list[Entity] = [] entity_type: Optional[str] = None entity_start = 0 @@ -818,7 +819,7 @@ def to_internal(self, data_dir: Path) -> InternalBioNerDataset: return merge_datasets([train_data, test_data]) - def get_entity_type_mapping(self) -> Optional[Dict]: + def get_entity_type_mapping(self) -> Optional[dict]: return self.entity_type_mapping @@ -994,7 +995,7 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @staticmethod - def split_url() -> List[str]: + def split_url() -> list[str]: split_urls = [ "https://raw.githubusercontent.com/hu-ner/huner/master/ner_scripts/splits/cellfinder_cellline", "https://raw.githubusercontent.com/hu-ner/huner/master/ner_scripts/splits/cellfinder_species", @@ -1009,7 +1010,7 @@ def to_internal(self, data_dir: Path) -> InternalBioNerDataset: return data - def get_entity_type_mapping(self) -> Optional[Dict]: + def get_entity_type_mapping(self) -> Optional[dict]: return self.entity_type_mapping @@ -1176,7 +1177,7 @@ def to_internal(self, data_dir: Path) -> InternalBioNerDataset: return merge_datasets([train_data, test_data]) - def get_entity_type_mapping(self) -> Optional[Dict]: + def get_entity_type_mapping(self) -> Optional[dict]: return self.entity_type_mapping @@ -1566,7 +1567,7 @@ def to_internal(self, data_dir: Path) -> InternalBioNerDataset: return filter_and_map_entities(dataset, self.entity_type_mapping) - def get_entity_type_mapping(self) -> Optional[Dict]: + def get_entity_type_mapping(self) -> Optional[dict]: return self.entity_type_mapping @@ -1747,8 +1748,8 @@ def download_dataset(data_dir: Path): @classmethod def parse_dataset(cls, original_file: Path): - documents: Dict[str, str] = {} - entities_per_document: Dict[str, List[Entity]] = {} + documents: dict[str, str] = {} + entities_per_document: dict[str, list[Entity]] = {} tree = etree.parse(str(original_file)) document_elems = tree.xpath("//document") @@ -1905,7 +1906,7 @@ def split_url() -> str: def to_internal(self, data_dir: Path) -> InternalBioNerDataset: return LINNEAUS.download_and_parse_dataset(data_dir) - def get_entity_type_mapping(self) -> Optional[Dict]: + def get_entity_type_mapping(self) -> Optional[dict]: return self.entity_type_mapping @@ -1995,7 +1996,7 @@ def to_internal(self, data_dir: Path) -> InternalBioNerDataset: return all_data - def get_entity_type_mapping(self) -> Optional[Dict]: + def get_entity_type_mapping(self) -> Optional[dict]: return self.entity_type_mapping @@ -2021,7 +2022,7 @@ def to_internal(self, data_dir: Path) -> InternalBioNerDataset: return all_data - def get_entity_type_mapping(self) -> Optional[Dict]: + def get_entity_type_mapping(self) -> Optional[dict]: return self.entity_type_mapping @@ -2033,7 +2034,7 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @staticmethod - def split_url() -> List[str]: + def split_url() -> list[str]: split_urls = [ "https://raw.githubusercontent.com/hu-ner/huner/master/ner_scripts/splits/CDRDisease", "https://raw.githubusercontent.com/hu-ner/huner/master/ner_scripts/splits/CDRChem", @@ -2052,7 +2053,7 @@ def to_internal(self, data_dir: Path) -> InternalBioNerDataset: return all_data - def get_entity_type_mapping(self) -> Optional[Dict]: + def get_entity_type_mapping(self) -> Optional[dict]: return self.entity_type_mapping @@ -2167,7 +2168,7 @@ def to_internal(self, data_dir: Path) -> InternalBioNerDataset: return all_data - def get_entity_type_mapping(self) -> Optional[Dict]: + def get_entity_type_mapping(self) -> Optional[dict]: return self.entity_type_mapping @@ -2190,7 +2191,7 @@ def to_internal(self, data_dir: Path) -> InternalBioNerDataset: return all_data - def get_entity_type_mapping(self) -> Optional[Dict]: + def get_entity_type_mapping(self) -> Optional[dict]: return self.entity_type_mapping @@ -2213,7 +2214,7 @@ def to_internal(self, data_dir: Path) -> InternalBioNerDataset: return all_data - def get_entity_type_mapping(self) -> Optional[Dict]: + def get_entity_type_mapping(self) -> Optional[dict]: return self.entity_type_mapping @@ -2230,7 +2231,7 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @staticmethod - def split_url() -> List[str]: + def split_url() -> list[str]: split_urls = [ "https://raw.githubusercontent.com/hu-ner/huner/master/ner_scripts/splits/variome_gene", "https://raw.githubusercontent.com/hu-ner/huner/master/ner_scripts/splits/variome_disease", @@ -2247,7 +2248,7 @@ def to_internal(self, data_dir: Path) -> InternalBioNerDataset: return all_data - def get_entity_type_mapping(self) -> Optional[Dict]: + def get_entity_type_mapping(self) -> Optional[dict]: return self.entity_type_mapping @@ -2343,7 +2344,7 @@ def parse_input_file(input_file: Path): with open(str(input_file), encoding="utf8") as file: document_id = "" document_text = "" - entities: List[Entity] = [] + entities: list[Entity] = [] c = 1 for line in file: @@ -2406,7 +2407,7 @@ def to_internal(self, data_dir: Path) -> InternalBioNerDataset: return merge_datasets([train_data, dev_data, test_data]) - def get_entity_type_mapping(self) -> Optional[Dict]: + def get_entity_type_mapping(self) -> Optional[dict]: return self.entity_type_mapping @@ -2455,13 +2456,13 @@ def download_corpus(self, data_folder: Path) -> Path: @staticmethod def parse_input_file(input_file: Path): - documents: Dict[str, str] = {} - entities_per_document: Dict[str, List[Entity]] = {} + documents: dict[str, str] = {} + entities_per_document: dict[str, list[Entity]] = {} with open(str(input_file), encoding="iso-8859-1") as file: document_id = None document_text = "" - entities: List[Entity] = [] + entities: list[Entity] = [] entity_type = None entity_start = 0 @@ -2584,7 +2585,7 @@ def to_internal(self, data_dir: Path) -> InternalBioNerDataset: return filter_and_map_entities(corpus, self.entity_type_mapping) - def get_entity_type_mapping(self) -> Optional[Dict]: + def get_entity_type_mapping(self) -> Optional[dict]: return self.entity_type_mapping @@ -2605,7 +2606,7 @@ def to_internal(self, data_dir: Path) -> InternalBioNerDataset: return filter_and_map_entities(corpus, self.entity_type_mapping) - def get_entity_type_mapping(self) -> Optional[Dict]: + def get_entity_type_mapping(self) -> Optional[dict]: return self.entity_type_mapping @@ -2628,7 +2629,7 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @staticmethod - def split_url() -> List[str]: + def split_url() -> list[str]: split_urls = [ "https://raw.githubusercontent.com/hu-ner/huner/master/ner_scripts/splits/scai_chemicals", "https://raw.githubusercontent.com/hu-ner/huner/master/ner_scripts/splits/scai_disease", @@ -2641,7 +2642,7 @@ def to_internal(self, data_dir: Path) -> InternalBioNerDataset: return filter_and_map_entities(corpus, self.entity_type_mapping) - def get_entity_type_mapping(self) -> Optional[Dict]: + def get_entity_type_mapping(self) -> Optional[dict]: return self.entity_type_mapping @@ -2763,7 +2764,7 @@ def to_internal(self, data_dir: Path) -> InternalBioNerDataset: return filter_and_map_entities(corpus, self.entity_type_mapping) - def get_entity_type_mapping(self) -> Optional[Dict]: + def get_entity_type_mapping(self) -> Optional[dict]: return self.entity_type_mapping @@ -2863,7 +2864,7 @@ def to_internal(self, data_dir: Path) -> InternalBioNerDataset: return data - def get_entity_type_mapping(self) -> Optional[Dict]: + def get_entity_type_mapping(self) -> Optional[dict]: return self.entity_type_mapping @@ -2945,7 +2946,7 @@ def download_dev_corpus(cls, data_dir) -> Path: @staticmethod def parse_input_file(text_file: Path, ann_file: Path) -> InternalBioNerDataset: documents = {} - entities_per_document: Dict[str, List[Entity]] = {} + entities_per_document: dict[str, list[Entity]] = {} document_title_length = {} @@ -3010,7 +3011,7 @@ def to_internal(self, data_dir: Path) -> InternalBioNerDataset: return merge_datasets([train_data, dev_data]) - def get_entity_type_mapping(self) -> Optional[Dict]: + def get_entity_type_mapping(self) -> Optional[dict]: return self.entity_type_mapping @@ -3071,8 +3072,8 @@ def download_corpus(cls, data_dir: Path) -> Path: @staticmethod def parse_corpus(text_dir: Path, gold_file: Path) -> InternalBioNerDataset: - documents: Dict[str, str] = {} - entities_per_document: Dict[str, List[Entity]] = {} + documents: dict[str, str] = {} + entities_per_document: dict[str, list[Entity]] = {} text_files = [file for file in os.listdir(str(text_dir)) if not file.startswith(".")] @@ -3122,7 +3123,7 @@ def to_internal(self, data_dir: Path) -> InternalBioNerDataset: return DECA.parse_corpus(text_dir, gold_file) - def get_entity_type_mapping(self) -> Optional[Dict]: + def get_entity_type_mapping(self) -> Optional[dict]: return self.entity_type_mapping @@ -3221,7 +3222,7 @@ def parse_corpus(corpus_dir: Path, sentence_separator: str) -> InternalBioNerDat akt_pos += len(words[i]) + 1 sentences += [tmp_sentence] - pre_entities: List[List[Tuple[int, int, str]]] = [[] for _ in sentences] + pre_entities: list[list[tuple[int, int, str]]] = [[] for _ in sentences] for protein in protein_tree: for span in protein.get("span").split(","): start = word_to_id[span.split("..")[0]] @@ -3287,7 +3288,7 @@ def to_internal(self, data_dir: Path) -> InternalBioNerDataset: return filter_and_map_entities(corpus, self.entity_type_mapping) - def get_entity_type_mapping(self) -> Optional[Dict]: + def get_entity_type_mapping(self) -> Optional[dict]: return self.entity_type_mapping @@ -3450,8 +3451,8 @@ def parse_dataset(data_dir: Path) -> InternalBioNerDataset: ] text_files = sorted(text_files) - documents: Dict[str, str] = {} - entities_per_document: Dict[str, List[Entity]] = {} + documents: dict[str, str] = {} + entities_per_document: dict[str, list[Entity]] = {} for text_file in sorted(text_files): document_id = os.path.basename(text_file).split("_")[0] @@ -3590,7 +3591,7 @@ def parse_test_dataset(cls, data_folder: Path) -> InternalBioNerDataset: @staticmethod def parse_dataset(text_file: Path, ann_file: Path) -> InternalBioNerDataset: documents = {} - entities_per_document: Dict[str, List[Entity]] = {} + entities_per_document: dict[str, list[Entity]] = {} with open(str(text_file), encoding="utf8") as text_file_reader: for line in text_file_reader: @@ -3733,7 +3734,7 @@ def download_dev_corpus(cls, data_dir) -> Path: @staticmethod def parse_input_file(text_file: Path, ann_file: Path) -> InternalBioNerDataset: documents = {} - entities_per_document: Dict[str, List[Entity]] = {} + entities_per_document: dict[str, list[Entity]] = {} document_abstract_length = {} with open(str(text_file), encoding="utf8") as text_reader: @@ -3806,7 +3807,7 @@ def to_internal(self, data_dir: Path) -> InternalBioNerDataset: dataset = merge_datasets([train_data, dev_data]) return filter_and_map_entities(dataset, self.entity_type_mapping) - def get_entity_type_mapping(self) -> Optional[Dict]: + def get_entity_type_mapping(self) -> Optional[dict]: return self.entity_type_mapping @@ -3945,7 +3946,7 @@ def to_internal(self, data_dir: Path, annotator: int = 0) -> InternalBioNerDatas dataset = CHEBI.parse_dataset(corpus_dir, annotator=annotator) return filter_and_map_entities(dataset, self.entity_type_mapping) - def get_entity_type_mapping(self) -> Optional[Dict]: + def get_entity_type_mapping(self) -> Optional[dict]: return self.entity_type_mapping @@ -4038,7 +4039,7 @@ def __init__( @staticmethod @abstractmethod - def download_corpus(data_folder: Path) -> Tuple[Path, Path, Path]: + def download_corpus(data_folder: Path) -> tuple[Path, Path, Path]: pass @staticmethod @@ -4083,7 +4084,7 @@ class BIONLP2013_PC(BioNLPCorpus): """ @staticmethod - def download_corpus(download_folder: Path) -> Tuple[Path, Path, Path]: + def download_corpus(download_folder: Path) -> tuple[Path, Path, Path]: train_url = "http://2013.bionlp-st.org/tasks/BioNLP-ST_2013_PC_training_data.tar.gz" dev_url = "http://2013.bionlp-st.org/tasks/BioNLP-ST_2013_PC_development_data.tar.gz" test_url = "http://2013.bionlp-st.org/tasks/BioNLP-ST_2013_PC_test_data.tar.gz" @@ -4125,7 +4126,7 @@ class BIONLP2013_CG(BioNLPCorpus): """ @staticmethod - def download_corpus(download_folder: Path) -> Tuple[Path, Path, Path]: + def download_corpus(download_folder: Path) -> tuple[Path, Path, Path]: url = "https://github.com/openbiocorpora/bionlp-st-2013-cg/archive/refs/heads/master.zip" cached_path(url, download_folder) @@ -4292,9 +4293,10 @@ def download_corpora(download_dir: Path): @staticmethod def convert_and_write(download_folder, data_folder, tag_type): data_folder.mkdir(parents=True, exist_ok=True) - with (download_folder / "train.tsv").open(encoding="utf8") as f_in, (data_folder / "train.conll").open( - "w", encoding="utf8" - ) as f_out: + with ( + (download_folder / "train.tsv").open(encoding="utf8") as f_in, + (data_folder / "train.conll").open("w", encoding="utf8") as f_out, + ): for line in f_in: if not line.strip(): f_out.write("\n") @@ -4305,9 +4307,10 @@ def convert_and_write(download_folder, data_folder, tag_type): tag = tag + "-" + tag_type f_out.write(f"{token} {tag}\n") - with (download_folder / "devel.tsv").open(encoding="utf8") as f_in, (data_folder / "dev.conll").open( - "w", encoding="utf8" - ) as f_out: + with ( + (download_folder / "devel.tsv").open(encoding="utf8") as f_in, + (data_folder / "dev.conll").open("w", encoding="utf8") as f_out, + ): for line in f_in: if not line.strip(): f_out.write("\n") @@ -4317,9 +4320,10 @@ def convert_and_write(download_folder, data_folder, tag_type): tag = tag + "-" + tag_type f_out.write(f"{token} {tag}\n") - with (download_folder / "test.tsv").open(encoding="utf8") as f_in, (data_folder / "test.conll").open( - "w", encoding="utf8" - ) as f_out: + with ( + (download_folder / "test.tsv").open(encoding="utf8") as f_in, + (data_folder / "test.conll").open("w", encoding="utf8") as f_out, + ): for line in f_in: if not line.strip(): f_out.write("\n") @@ -4638,7 +4642,7 @@ def download_corpus(cls, data_dir: Path) -> Path: @staticmethod def prepare_splits( data_dir: Path, corpus: InternalBioNerDataset - ) -> Tuple[InternalBioNerDataset, InternalBioNerDataset, InternalBioNerDataset]: + ) -> tuple[InternalBioNerDataset, InternalBioNerDataset, InternalBioNerDataset]: splits_dir = data_dir / "splits" os.makedirs(str(splits_dir), exist_ok=True) @@ -4734,7 +4738,7 @@ def to_internal(self, data_dir: Path) -> InternalBioNerDataset: return filter_and_map_entities(corpus, self.entity_type_mapping) - def get_entity_type_mapping(self) -> Optional[Dict]: + def get_entity_type_mapping(self) -> Optional[dict]: return self.entity_type_mapping @@ -4792,7 +4796,7 @@ def to_internal(self, data_dir: Path) -> InternalBioNerDataset: return filter_and_map_entities(corpus, self.entity_type_mapping) - def get_entity_type_mapping(self) -> Optional[Dict]: + def get_entity_type_mapping(self) -> Optional[dict]: return self.entity_type_mapping @@ -4896,7 +4900,7 @@ def parse_corpus(input_file: Path) -> InternalBioNerDataset: prev_sentence_id: Optional[str] = None document_text: Optional[str] = None - entities: List[Entity] = [] + entities: list[Entity] = [] offset: Optional[int] = None for line in azdz_reader: @@ -5014,7 +5018,7 @@ def to_internal(self, data_dir: Path) -> InternalBioNerDataset: return corpus_data - def get_entity_type_mapping(self) -> Optional[Dict]: + def get_entity_type_mapping(self) -> Optional[dict]: return self.entity_type_mapping @@ -5221,7 +5225,7 @@ def __init__( sample_missing_splits=True, ) - def get_entity_type_mapping(self) -> Optional[Dict]: + def get_entity_type_mapping(self) -> Optional[dict]: """Return the mapping of entity type given in the dataset to canonical types. Note, if a entity type is not present in the map it is discarded. @@ -5279,8 +5283,8 @@ def build_corpus_directory_name(self, dataset_name: str) -> str: def to_internal_dataset(self, dataset, split: str) -> InternalBioNerDataset: """Converts a dataset given in hugging datasets format to our internal corpus representation.""" - id_to_text: Dict[str, str] = {} - id_to_entities: Dict[str, list] = {} + id_to_text: dict[str, str] = {} + id_to_entities: dict[str, list] = {} entity_type_set = set() for document in dataset[split]: document_id = document["document_id"] @@ -5331,10 +5335,10 @@ def to_internal_dataset(self, dataset, split: str) -> InternalBioNerDataset: def bin_search_passage( self, - passages: List[Tuple[str, List[Tuple[int, int]]]], + passages: list[tuple[str, list[tuple[int, int]]]], low: int, high: int, - entity: Dict, + entity: dict, ): """Helper methods to find the passage to a given entity mention (incl. offset). @@ -5381,7 +5385,7 @@ def __init__( test_split_name=test_split_name, ) - def get_entity_type_mapping(self) -> Optional[Dict]: + def get_entity_type_mapping(self) -> Optional[dict]: return { "Gene": GENE_TAG, "GENERIF": GENE_TAG, @@ -5414,7 +5418,7 @@ def __init__( test_split_name=test_split_name, ) - def get_entity_type_mapping(self) -> Optional[Dict]: + def get_entity_type_mapping(self) -> Optional[dict]: return {"GENE-N": GENE_TAG, "GENE-Y": GENE_TAG} def build_corpus_directory_name(self, dataset_name: str) -> str: @@ -5441,7 +5445,7 @@ def __init__( test_split_name=test_split_name, ) - def get_entity_type_mapping(self) -> Optional[Dict]: + def get_entity_type_mapping(self) -> Optional[dict]: return {"CHEMICAL": CHEMICAL_TAG} def build_corpus_directory_name(self, dataset_name: str) -> str: @@ -5452,7 +5456,7 @@ class HUNER_ALL_DRUGPROT(BIGBIO_NER_CORPUS): def __init__(self, *args, **kwargs): super().__init__(*args, dataset_name="drugprot", **kwargs) - def get_entity_type_mapping(self) -> Optional[Dict]: + def get_entity_type_mapping(self) -> Optional[dict]: return {"GENE-N": GENE_TAG, "GENE-Y": GENE_TAG, "CHEMICAL": CHEMICAL_TAG} def build_corpus_directory_name(self, dataset_name: str) -> str: @@ -5479,7 +5483,7 @@ def __init__( test_split_name=test_split_name, ) - def get_entity_type_mapping(self) -> Optional[Dict]: + def get_entity_type_mapping(self) -> Optional[dict]: return {"GeneOrGeneProduct": GENE_TAG} def build_corpus_directory_name(self, dataset_name: str) -> str: @@ -5506,7 +5510,7 @@ def __init__( test_split_name=test_split_name, ) - def get_entity_type_mapping(self) -> Optional[Dict]: + def get_entity_type_mapping(self) -> Optional[dict]: return {"ChemicalEntity": CHEMICAL_TAG} def build_corpus_directory_name(self, dataset_name: str) -> str: @@ -5533,7 +5537,7 @@ def __init__( test_split_name=test_split_name, ) - def get_entity_type_mapping(self) -> Optional[Dict]: + def get_entity_type_mapping(self) -> Optional[dict]: return {"DiseaseOrPhenotypicFeature": DISEASE_TAG} def build_corpus_directory_name(self, dataset_name: str) -> str: @@ -5560,7 +5564,7 @@ def __init__( test_split_name=test_split_name, ) - def get_entity_type_mapping(self) -> Optional[Dict]: + def get_entity_type_mapping(self) -> Optional[dict]: return {"OrganismTaxon": SPECIES_TAG} def build_corpus_directory_name(self, dataset_name: str) -> str: @@ -5587,7 +5591,7 @@ def __init__( test_split_name=test_split_name, ) - def get_entity_type_mapping(self) -> Optional[Dict]: + def get_entity_type_mapping(self) -> Optional[dict]: return {"CellLine": CELL_LINE_TAG} def build_corpus_directory_name(self, dataset_name: str) -> str: @@ -5598,7 +5602,7 @@ class HUNER_ALL_BIORED(BIGBIO_NER_CORPUS): def __init__(self, *args, **kwargs): super().__init__(*args, dataset_name="biored", **kwargs) - def get_entity_type_mapping(self) -> Optional[Dict]: + def get_entity_type_mapping(self) -> Optional[dict]: return { "GeneOrGeneProduct": GENE_TAG, "ChemicalEntity": CHEMICAL_TAG, @@ -5631,7 +5635,7 @@ def __init__( test_split_name=test_split_name, ) - def get_entity_type_mapping(self) -> Optional[Dict]: + def get_entity_type_mapping(self) -> Optional[dict]: return {"protein": GENE_TAG} def build_corpus_directory_name(self, dataset_name: str) -> str: @@ -5658,7 +5662,7 @@ def __init__( test_split_name=test_split_name, ) - def get_entity_type_mapping(self) -> Optional[Dict]: + def get_entity_type_mapping(self) -> Optional[dict]: return {"compound": CHEMICAL_TAG} def build_corpus_directory_name(self, dataset_name: str) -> str: @@ -5669,7 +5673,7 @@ class HUNER_ALL_CPI(BIGBIO_NER_CORPUS): def __init__(self, *args, **kwargs): super().__init__(*args, dataset_name="cpi", **kwargs) - def get_entity_type_mapping(self) -> Optional[Dict]: + def get_entity_type_mapping(self) -> Optional[dict]: return {"protein": GENE_TAG, "compound": CHEMICAL_TAG} def build_corpus_directory_name(self, dataset_name: str) -> str: @@ -5696,7 +5700,7 @@ def __init__( test_split_name=test_split_name, ) - def get_entity_type_mapping(self) -> Optional[Dict]: + def get_entity_type_mapping(self) -> Optional[dict]: return {"Gene_or_gene_product": GENE_TAG, "Complex": GENE_TAG} def build_corpus_directory_name(self, dataset_name: str) -> str: @@ -5723,7 +5727,7 @@ def __init__( test_split_name=test_split_name, ) - def get_entity_type_mapping(self) -> Optional[Dict]: + def get_entity_type_mapping(self) -> Optional[dict]: return {"Simple_chemical": CHEMICAL_TAG} def build_corpus_directory_name(self, dataset_name: str) -> str: @@ -5734,7 +5738,7 @@ class HUNER_ALL_BIONLP_ST_2013_PC(BIGBIO_NER_CORPUS): def __init__(self, *args, **kwargs): super().__init__(*args, dataset_name="bionlp_st_2013_pc", **kwargs) - def get_entity_type_mapping(self) -> Optional[Dict]: + def get_entity_type_mapping(self) -> Optional[dict]: return { "Gene_or_gene_product": GENE_TAG, "Complex": GENE_TAG, @@ -5765,7 +5769,7 @@ def __init__( test_split_name=test_split_name, ) - def get_entity_type_mapping(self) -> Optional[Dict]: + def get_entity_type_mapping(self) -> Optional[dict]: return {"protein": GENE_TAG} def build_corpus_directory_name(self, dataset_name: str) -> str: @@ -5792,7 +5796,7 @@ def __init__( test_split_name=test_split_name, ) - def get_entity_type_mapping(self) -> Optional[Dict]: + def get_entity_type_mapping(self) -> Optional[dict]: return {"Protein": GENE_TAG} def build_corpus_directory_name(self, dataset_name: str) -> str: @@ -5819,7 +5823,7 @@ def __init__( test_split_name=test_split_name, ) - def get_entity_type_mapping(self) -> Optional[Dict]: + def get_entity_type_mapping(self) -> Optional[dict]: return {"Protein": GENE_TAG} def build_corpus_directory_name(self, dataset_name: str) -> str: @@ -5846,7 +5850,7 @@ def __init__( test_split_name=test_split_name, ) - def get_entity_type_mapping(self) -> Optional[Dict]: + def get_entity_type_mapping(self) -> Optional[dict]: return {"Chemical": CHEMICAL_TAG} def build_corpus_directory_name(self, dataset_name: str) -> str: @@ -5873,7 +5877,7 @@ def __init__( test_split_name=test_split_name, ) - def get_entity_type_mapping(self) -> Optional[Dict]: + def get_entity_type_mapping(self) -> Optional[dict]: return {"Organism": SPECIES_TAG} def build_corpus_directory_name(self, dataset_name: str) -> str: @@ -5884,7 +5888,7 @@ class HUNER_ALL_BIONLP_ST_2011_ID(BIGBIO_NER_CORPUS): def __init__(self, *args, **kwargs): super().__init__(*args, dataset_name="bionlp_st_2011_id", **kwargs) - def get_entity_type_mapping(self) -> Optional[Dict]: + def get_entity_type_mapping(self) -> Optional[dict]: return { "Protein": GENE_TAG, "Chemical": CHEMICAL_TAG, @@ -5915,7 +5919,7 @@ def __init__( test_split_name=test_split_name, ) - def get_entity_type_mapping(self) -> Optional[Dict]: + def get_entity_type_mapping(self) -> Optional[dict]: return {"Protein": GENE_TAG} def build_corpus_directory_name(self, dataset_name: str) -> str: @@ -5942,7 +5946,7 @@ def __init__( test_split_name=test_split_name, ) - def get_entity_type_mapping(self) -> Optional[Dict]: + def get_entity_type_mapping(self) -> Optional[dict]: return {"Protein": GENE_TAG} def build_corpus_directory_name(self, dataset_name: str) -> str: @@ -5969,7 +5973,7 @@ def __init__( test_split_name=test_split_name, ) - def get_entity_type_mapping(self) -> Optional[Dict]: + def get_entity_type_mapping(self) -> Optional[dict]: return {"Microorganism": SPECIES_TAG} def build_corpus_directory_name(self, dataset_name: str) -> str: @@ -5996,7 +6000,7 @@ def __init__( test_split_name=test_split_name, ) - def get_entity_type_mapping(self) -> Optional[Dict]: + def get_entity_type_mapping(self) -> Optional[dict]: return {"gene": GENE_TAG, "protein": GENE_TAG} def build_corpus_directory_name(self, dataset_name: str) -> str: @@ -6023,7 +6027,7 @@ def __init__( test_split_name=test_split_name, ) - def get_entity_type_mapping(self) -> Optional[Dict]: + def get_entity_type_mapping(self) -> Optional[dict]: return {"chemical": CHEMICAL_TAG} def build_corpus_directory_name(self, dataset_name: str) -> str: @@ -6050,7 +6054,7 @@ def __init__( test_split_name=test_split_name, ) - def get_entity_type_mapping(self) -> Optional[Dict]: + def get_entity_type_mapping(self) -> Optional[dict]: return {"species": SPECIES_TAG} def build_corpus_directory_name(self, dataset_name: str) -> str: @@ -6077,7 +6081,7 @@ def __init__( test_split_name=test_split_name, ) - def get_entity_type_mapping(self) -> Optional[Dict]: + def get_entity_type_mapping(self) -> Optional[dict]: # TODO whether cell or cell line is the correct tag return {"cellline": CELL_LINE_TAG} @@ -6089,7 +6093,7 @@ class HUNER_ALL_BIOID(BIGBIO_NER_CORPUS): def __init__(self, *args, **kwargs): super().__init__(*args, dataset_name="bioid", **kwargs) - def get_entity_type_mapping(self) -> Optional[Dict]: + def get_entity_type_mapping(self) -> Optional[dict]: # TODO whether cell or cell line is the correct tag return { "gene": GENE_TAG, @@ -6123,7 +6127,7 @@ def __init__( test_split_name=test_split_name, ) - def get_entity_type_mapping(self) -> Optional[Dict]: + def get_entity_type_mapping(self) -> Optional[dict]: return {"Gene": GENE_TAG, "FamilyName": GENE_TAG} def build_corpus_directory_name(self, dataset_name: str) -> str: @@ -6155,7 +6159,7 @@ def __init__( test_split_name=test_split_name, ) - def get_entity_type_mapping(self) -> Optional[Dict]: + def get_entity_type_mapping(self) -> Optional[dict]: return {"progene_text": GENE_TAG} def build_corpus_directory_name(self, dataset_name: str) -> str: @@ -6182,7 +6186,7 @@ def __init__( test_split_name=test_split_name, ) - def get_entity_type_mapping(self) -> Optional[Dict]: + def get_entity_type_mapping(self) -> Optional[dict]: return {"Chemical": CHEMICAL_TAG} def build_corpus_directory_name(self, dataset_name: str) -> str: @@ -6209,7 +6213,7 @@ def __init__( test_split_name=test_split_name, ) - def get_entity_type_mapping(self) -> Optional[Dict]: + def get_entity_type_mapping(self) -> Optional[dict]: return {"Gene": GENE_TAG} def build_corpus_directory_name(self, dataset_name: str) -> str: @@ -6224,7 +6228,7 @@ def __init__(self, *args, **kwargs): **kwargs, ) - def get_entity_type_mapping(self) -> Optional[Dict]: + def get_entity_type_mapping(self) -> Optional[dict]: return {"Gene": GENE_TAG} def build_corpus_directory_name(self, dataset_name: str) -> str: diff --git a/flair/datasets/document_classification.py b/flair/datasets/document_classification.py index 0bbc471818..363c84e561 100644 --- a/flair/datasets/document_classification.py +++ b/flair/datasets/document_classification.py @@ -2,8 +2,9 @@ import json import logging import os +import tarfile from pathlib import Path -from typing import Dict, List, Optional, Union +from typing import Optional, Union import flair from flair.data import ( @@ -36,8 +37,8 @@ def __init__( filter_if_longer_than: int = -1, tokenizer: Union[bool, Tokenizer] = SegtokTokenizer(), memory_mode: str = "partial", - label_name_map: Optional[Dict[str, str]] = None, - skip_labels: Optional[List[str]] = None, + label_name_map: Optional[dict[str, str]] = None, + skip_labels: Optional[list[str]] = None, allow_examples_without_labels=False, sample_missing_splits: bool = True, encoding: str = "utf-8", @@ -131,8 +132,8 @@ def __init__( filter_if_longer_than: int = -1, tokenizer: Union[bool, Tokenizer] = SegtokTokenizer(), memory_mode: str = "partial", - label_name_map: Optional[Dict[str, str]] = None, - skip_labels: Optional[List[str]] = None, + label_name_map: Optional[dict[str, str]] = None, + skip_labels: Optional[list[str]] = None, allow_examples_without_labels=False, encoding: str = "utf-8", ) -> None: @@ -277,11 +278,7 @@ def _parse_line_to_sentence(self, line: str, label_prefix: str, tokenizer: Union return None def is_in_memory(self) -> bool: - if self.memory_mode == "disk": - return False - if self.memory_mode == "partial": - return False - return True + return self.memory_mode not in ["disk", "partial"] def __len__(self) -> int: return self.total_sentence_count @@ -309,7 +306,7 @@ class CSVClassificationCorpus(Corpus): def __init__( self, data_folder: Union[str, Path], - column_name_map: Dict[int, str], + column_name_map: dict[int, str], label_type: str, name: str = "csv_corpus", train_file=None, @@ -404,7 +401,7 @@ class CSVClassificationDataset(FlairDataset): def __init__( self, path_to_file: Union[str, Path], - column_name_map: Dict[int, str], + column_name_map: dict[int, str], label_type: str, max_tokens_per_doc: int = -1, max_chars_per_doc: int = -1, @@ -453,8 +450,8 @@ def __init__( self.total_sentence_count: int = 0 # most data sets have the token text in the first column, if not, pass 'text' as column - self.text_columns: List[int] = [] - self.pair_columns: List[int] = [] + self.text_columns: list[int] = [] + self.pair_columns: list[int] = [] for column in column_name_map: if column_name_map[column] == "text": self.text_columns.append(column) @@ -567,7 +564,7 @@ class AMAZON_REVIEWS(ClassificationCorpus): def __init__( self, split_max: int = 30000, - label_name_map: Dict[str, str] = { + label_name_map: dict[str, str] = { "1.0": "NEGATIVE", "2.0": "NEGATIVE", "3.0": "NEGATIVE", @@ -955,9 +952,10 @@ def __init__( original_filenames = original_filenames[:-1] if not data_file.is_file(): for original_filename, new_filename in zip(original_filenames, new_filenames): - with open(data_folder / "original" / original_filename, encoding="utf-8") as open_fp, open( - data_folder / new_filename, "w", encoding="utf-8" - ) as write_fp: + with ( + open(data_folder / "original" / original_filename, encoding="utf-8") as open_fp, + open(data_folder / new_filename, "w", encoding="utf-8") as write_fp, + ): csv_reader = csv.reader( open_fp, quotechar='"', delimiter=",", quoting=csv.QUOTE_ALL, skipinitialspace=True ) @@ -1048,9 +1046,10 @@ def __init__( label_list.append(labels[int(line) - 1]) # handle data file - with (data_path / "original" / "title_StackOverflow.txt").open(encoding="latin1") as open_fp, ( - data_folder / "train.txt" - ).open("w", encoding="utf-8") as write_fp: + with ( + (data_path / "original" / "title_StackOverflow.txt").open(encoding="latin1") as open_fp, + (data_folder / "train.txt").open("w", encoding="utf-8") as write_fp, + ): for idx, line in enumerate(open_fp): line = line.rstrip() @@ -1104,9 +1103,10 @@ def __init__( os.makedirs(data_folder) # create train.txt file from CSV - with open(data_folder / "train.txt", "w") as train_file, open( - senteval_folder / "training.1600000.processed.noemoticon.csv", encoding="latin-1" - ) as csv_train: + with ( + open(data_folder / "train.txt", "w") as train_file, + open(senteval_folder / "training.1600000.processed.noemoticon.csv", encoding="latin-1") as csv_train, + ): csv_reader = csv.reader(csv_train) for row in csv_reader: @@ -1115,9 +1115,10 @@ def __init__( train_file.write(f"__label__{label} {text}\n") # create test.txt file from CSV - with (data_folder / "test.txt").open("w", encoding="utf-8") as train_file, ( - senteval_folder / "testdata.manual.2009.06.14.csv" - ).open(encoding="latin-1") as csv_train: + with ( + (data_folder / "test.txt").open("w", encoding="utf-8") as train_file, + (senteval_folder / "testdata.manual.2009.06.14.csv").open(encoding="latin-1") as csv_train, + ): csv_reader = csv.reader(csv_train) for row in csv_reader: @@ -1384,9 +1385,10 @@ def __init__( # create train dev and test files in fasttext format for new_filename, original_filename in zip(new_filenames, original_filenames): - with open(data_folder / new_filename, "a") as out_file, open( - data_folder / "raw" / original_filename - ) as in_file: + with ( + open(data_folder / new_filename, "a") as out_file, + open(data_folder / "raw" / original_filename) as in_file, + ): for line in in_file: fields = line.split("\t") label = "POSITIVE" if fields[1].rstrip() == "1" else "NEGATIVE" @@ -1437,9 +1439,10 @@ def __init__( # convert to FastText format for split in ["train", "dev", "test"]: - with (data_folder / f"{split}.txt").open("w", encoding="utf-8") as train_file, ( - data_folder / "raw" / f"stsa.fine.{split}" - ).open(encoding="latin1") as file: + with ( + (data_folder / f"{split}.txt").open("w", encoding="utf-8") as train_file, + (data_folder / "raw" / f"stsa.fine.{split}").open(encoding="latin1") as file, + ): for line in file: train_file.write(f"__label__{line[0]} {line[2:]}") @@ -1496,9 +1499,10 @@ def __init__( # create train and dev splits in fasttext format for split in ["train", "dev"]: - with open(data_folder / "CoLA" / (split + ".txt"), "a") as out_file, open( - data_folder / "CoLA" / "original" / (split + ".tsv") - ) as in_file: + with ( + open(data_folder / "CoLA" / (split + ".txt"), "a") as out_file, + open(data_folder / "CoLA" / "original" / (split + ".tsv")) as in_file, + ): for line in in_file: fields = line.rstrip().split("\t") label = int(fields[1]) @@ -1506,9 +1510,10 @@ def __init__( out_file.write(f"__label__{label_map[label]} {sentence}\n") # create eval_dataset file with no labels - with open(data_folder / "CoLA" / "eval_dataset.txt", "a") as out_file, open( - data_folder / "CoLA" / "original" / "test.tsv" - ) as in_file: + with ( + open(data_folder / "CoLA" / "eval_dataset.txt", "a") as out_file, + open(data_folder / "CoLA" / "original" / "test.tsv") as in_file, + ): for line in in_file: fields = line.rstrip().split("\t") sentence = fields[1] @@ -1702,9 +1707,10 @@ def __init__( data_path = flair.cache_root / "datasets" / dataset_name / "raw" # create correctly formated txt files for name in ["train", "test", "dev"]: - with (data_folder / (name + ".txt")).open("w", encoding="utf-8") as txt_file, ( - data_path / (name + ".tsv") - ).open(encoding="utf-8") as tsv_file: + with ( + (data_folder / (name + ".txt")).open("w", encoding="utf-8") as txt_file, + (data_path / (name + ".tsv")).open(encoding="utf-8") as tsv_file, + ): lines = tsv_file.readlines() for line in lines: row = line.split("\t") @@ -1764,9 +1770,10 @@ def __init__( if not data_file.is_file(): for original_filename, new_filename in zip(original_filenames, new_filenames): - with (data_folder / "original" / original_filename).open(encoding="latin1") as open_fp, ( - data_folder / new_filename - ).open("w", encoding="utf-8") as write_fp: + with ( + (data_folder / "original" / original_filename).open(encoding="latin1") as open_fp, + (data_folder / new_filename).open("w", encoding="utf-8") as write_fp, + ): for line in open_fp: line = line.rstrip() fields = line.split() @@ -1820,9 +1827,10 @@ def __init__( if not data_file.is_file(): for original_filename, new_filename in zip(original_filenames, new_filenames): - with (data_folder / "original" / original_filename).open(encoding="latin1") as open_fp, ( - data_folder / new_filename - ).open("w", encoding="utf-8") as write_fp: + with ( + (data_folder / "original" / original_filename).open(encoding="latin1") as open_fp, + (data_folder / new_filename).open("w", encoding="utf-8") as write_fp, + ): for line in open_fp: line = line.rstrip() fields = line.split() @@ -1887,21 +1895,20 @@ def __init__( if not (data_folder / "train.txt").is_file(): cached_path(url, original) - import tarfile - - tar = tarfile.open(original / "yahoo_answers_csv.tgz", "r:gz") - members = [] + with tarfile.open(original / "yahoo_answers_csv.tgz", "r:gz") as tar: + members = [] - for member in tar.getmembers(): - if "test.csv" in member.name or "train.csv" in member.name: - members.append(member) + for member in tar.getmembers(): + if "test.csv" in member.name or "train.csv" in member.name: + members.append(member) - tar.extractall(original, members=members) + tar.extractall(original, members=members) for name in ["train", "test"]: - with (original / "yahoo_answers_csv" / (name + ".csv")).open(encoding="utf-8") as file, ( - data_folder / (name + ".txt") - ).open("w", encoding="utf-8") as writer: + with ( + (original / "yahoo_answers_csv" / (name + ".csv")).open(encoding="utf-8") as file, + (data_folder / (name + ".txt")).open("w", encoding="utf-8") as writer, + ): reader = csv.reader(file) for row in reader: writer.write("__label__" + label_map[row[0]] + " " + row[1] + "\n") @@ -1963,9 +1970,10 @@ def __init__( if not data_file.is_file(): for original_filename, new_filename in zip(original_filenames, new_filenames): - with (data_folder / "original" / original_filename).open(encoding="utf-8") as open_fp, ( - data_folder / task_setting / new_filename - ).open("w", encoding="utf-8") as write_fp: + with ( + (data_folder / "original" / original_filename).open(encoding="utf-8") as open_fp, + (data_folder / task_setting / new_filename).open("w", encoding="utf-8") as write_fp, + ): for line in open_fp: line = line.rstrip() fields = line.split("\t") diff --git a/flair/datasets/entity_linking.py b/flair/datasets/entity_linking.py index 20f2caefdd..f74bad7092 100644 --- a/flair/datasets/entity_linking.py +++ b/flair/datasets/entity_linking.py @@ -4,8 +4,9 @@ import logging import os import re +from collections.abc import Iterable, Iterator from pathlib import Path -from typing import Any, Dict, Iterable, Iterator, List, Optional, Union +from typing import Any, Optional, Union import requests from bioc import biocxml, pubtator @@ -47,7 +48,7 @@ def __init__( self._idx_to_candidates = {candidate.concept_id: candidate for candidate in candidates} # one name can map to multiple concepts - self._text_to_index: Dict[str, List[str]] = {} + self._text_to_index: dict[str, list[str]] = {} for candidate in candidates: for text in [candidate.concept_name, *candidate.synonyms]: if text not in self._text_to_index: @@ -60,11 +61,11 @@ def database_name(self) -> str: return self._dataset_name @property - def text_to_index(self) -> Dict[str, List[str]]: + def text_to_index(self) -> dict[str, list[str]]: return self._text_to_index @property - def candidates(self) -> List[EntityCandidate]: + def candidates(self) -> list[EntityCandidate]: return list(self._idx_to_candidates.values()) def __getitem__(self, item: str) -> EntityCandidate: @@ -80,18 +81,18 @@ def to_in_memory_dictionary(self) -> "InMemoryEntityLinkingDictionary": # NOTE: EntityLinkingDictionary are lazy-loaded from a preprocessed file. # Use this class to load into memory all candidates class InMemoryEntityLinkingDictionary(EntityLinkingDictionary): - def __init__(self, candidates: List[EntityCandidate], dataset_name: str): + def __init__(self, candidates: list[EntityCandidate], dataset_name: str): self._dataset_name = dataset_name super().__init__(candidates, dataset_name=dataset_name) - def to_state(self) -> Dict[str, Any]: + def to_state(self) -> dict[str, Any]: return { "dataset_name": self._dataset_name, "candidates": [candidate.to_dict() for candidate in self._idx_to_candidates.values()], } @classmethod - def from_state(cls, state: Dict[str, Any]) -> "InMemoryEntityLinkingDictionary": + def from_state(cls, state: dict[str, Any]) -> "InMemoryEntityLinkingDictionary": return cls( dataset_name=state["dataset_name"], candidates=[EntityCandidate(**candidate) for candidate in state["candidates"]], @@ -488,7 +489,7 @@ def __init__( to point to a different folder but typically this should not be necessary. in_memory: bool If True, keeps dataset in memory giving speedups in training. - column_format: Dict[int, str] + column_format: dict[int, str] The column-format to specify which columns correspond to the text or label types. """ base_path = flair.cache_root / "datasets" if not base_path else Path(base_path) @@ -776,9 +777,10 @@ def __init__( wiki_language + "_dev.tsv", ], ): - with open(doc_path, encoding="utf-8") as read, open( - data_folder / file_name, "w", encoding="utf-8" - ) as write: + with ( + open(doc_path, encoding="utf-8") as read, + open(data_folder / file_name, "w", encoding="utf-8") as write, + ): # ignore first line read.readline() line = read.readline() @@ -1208,9 +1210,10 @@ def __init__( if not parsed_dataset.exists(): original_file_path = cached_path(f"{tweeki_gold_el_path}", Path("datasets") / dataset_name) - with open(original_file_path, encoding="utf-8") as read, open( - parsed_dataset, "w", encoding="utf-8" - ) as write: + with ( + open(original_file_path, encoding="utf-8") as read, + open(parsed_dataset, "w", encoding="utf-8") as write, + ): line = read.readline() while line: if line.startswith("#"): @@ -1274,9 +1277,10 @@ def __init__( with open(data_folder / corpus_file_name, "w", encoding="utf-8") as txtout: # First parse the post titles - with open(data_folder / "posts.tsv", encoding="utf-8") as tsvin1, open( - data_folder / "gold_post_annotations.tsv", encoding="utf-8" - ) as tsvin2: + with ( + open(data_folder / "posts.tsv", encoding="utf-8") as tsvin1, + open(data_folder / "gold_post_annotations.tsv", encoding="utf-8") as tsvin2, + ): posts = csv.reader(tsvin1, delimiter="\t") self.post_annotations = csv.reader(tsvin2, delimiter="\t") self.curr_annot = next(self.post_annotations) @@ -1312,13 +1316,14 @@ def __init__( ) # Then parse the comments - with open(data_folder / "comments.tsv", encoding="utf-8") as tsvin3, open( - data_folder / "gold_comment_annotations.tsv", encoding="utf-8" - ) as tsvin4: + with ( + open(data_folder / "comments.tsv", encoding="utf-8") as tsvin3, + open(data_folder / "gold_comment_annotations.tsv", encoding="utf-8") as tsvin4, + ): self.comments = csv.reader(tsvin3, delimiter="\t") self.comment_annotations = csv.reader(tsvin4, delimiter="\t") self.curr_annot = next(self.comment_annotations) - self.curr_row: Optional[List[str]] = next(self.comments) + self.curr_row: Optional[list[str]] = next(self.comments) self.stop_iter = False # Iterate over the comments.tsv file, until the end is reached @@ -1545,7 +1550,7 @@ def make_line(word, begin_or_inside, attributes): return line - def split_span(word_fields: List[str], datasetname: str): + def split_span(word_fields: list[str], datasetname: str): """Function that splits a word if necessary, i.e. if it is a multiple-word-span. Parameters @@ -1646,12 +1651,12 @@ def determine_tsv_file(filename: str, data_folder: Path, cut_multisense: bool = class WSD_UFSAC(MultiCorpus): def __init__( self, - filenames: Union[str, List[str]] = ["masc", "semcor"], + filenames: Union[str, list[str]] = ["masc", "semcor"], base_path: Optional[Union[str, Path]] = None, in_memory: bool = True, cut_multisense: bool = True, columns={0: "text", 3: "sense"}, - banned_sentences: Optional[List[str]] = None, + banned_sentences: Optional[list[str]] = None, sample_missing_splits_in_multicorpus: Union[bool, str] = True, sample_missing_splits_in_each_corpus: Union[bool, str] = True, use_raganato_ALL_as_test_data: bool = False, @@ -1713,7 +1718,7 @@ def __init__( if isinstance(filenames, str): filenames = [filenames] - corpora: List[Corpus] = [] + corpora: list[Corpus] = [] log.info("Transforming data into column format and creating corpora...") @@ -1784,8 +1789,8 @@ def __init__( base_path: Optional[Union[str, Path]] = None, in_memory: bool = True, columns={0: "text", 3: "sense"}, - label_name_map: Optional[Dict[str, str]] = None, - banned_sentences: Optional[List[str]] = None, + label_name_map: Optional[dict[str, str]] = None, + banned_sentences: Optional[list[str]] = None, sample_missing_splits: bool = True, cut_multisense: bool = True, ) -> None: @@ -1847,8 +1852,8 @@ def __init__( base_path: Optional[Union[str, Path]] = None, in_memory: bool = True, columns={0: "text", 3: "sense"}, - label_name_map: Optional[Dict[str, str]] = None, - banned_sentences: Optional[List[str]] = None, + label_name_map: Optional[dict[str, str]] = None, + banned_sentences: Optional[list[str]] = None, sample_missing_splits: Union[bool, str] = True, cut_multisense: bool = True, use_raganato_ALL_as_test_data: bool = False, @@ -1922,8 +1927,8 @@ def __init__( base_path: Optional[Union[str, Path]] = None, in_memory: bool = True, columns={0: "text", 3: "sense"}, - label_name_map: Optional[Dict[str, str]] = None, - banned_sentences: Optional[List[str]] = None, + label_name_map: Optional[dict[str, str]] = None, + banned_sentences: Optional[list[str]] = None, sample_missing_splits: Union[bool, str] = True, use_raganato_ALL_as_test_data: bool = False, ) -> None: @@ -1994,8 +1999,8 @@ def __init__( base_path: Optional[Union[str, Path]] = None, in_memory: bool = True, columns={0: "text", 3: "sense"}, - label_name_map: Optional[Dict[str, str]] = None, - banned_sentences: Optional[List[str]] = None, + label_name_map: Optional[dict[str, str]] = None, + banned_sentences: Optional[list[str]] = None, sample_missing_splits: Union[bool, str] = True, cut_multisense: bool = True, use_raganato_ALL_as_test_data: bool = False, @@ -2070,8 +2075,8 @@ def __init__( base_path: Optional[Union[str, Path]] = None, in_memory: bool = True, columns={0: "text", 3: "sense"}, - label_name_map: Optional[Dict[str, str]] = None, - banned_sentences: Optional[List[str]] = None, + label_name_map: Optional[dict[str, str]] = None, + banned_sentences: Optional[list[str]] = None, sample_missing_splits: Union[bool, str] = True, cut_multisense: bool = True, use_raganato_ALL_as_test_data: bool = False, @@ -2147,8 +2152,8 @@ def __init__( base_path: Optional[Union[str, Path]] = None, in_memory: bool = True, columns={0: "text", 3: "sense"}, - label_name_map: Optional[Dict[str, str]] = None, - banned_sentences: Optional[List[str]] = None, + label_name_map: Optional[dict[str, str]] = None, + banned_sentences: Optional[list[str]] = None, sample_missing_splits: Union[bool, str] = True, use_raganato_ALL_as_test_data: bool = False, ) -> None: @@ -2230,7 +2235,7 @@ def __init__( self, base_path: Optional[Union[str, Path]] = None, label_type: str = "el", - norm_keys: List[str] = ["db_name", "db_id"], + norm_keys: list[str] = ["db_name", "db_id"], **kwargs, ) -> None: self.label_type = label_type @@ -2250,14 +2255,14 @@ def __init__( ) @abc.abstractmethod - def _download_dataset(self, data_folder: Path) -> Dict[str, Union[Path, List[Path]]]: + def _download_dataset(self, data_folder: Path) -> dict[str, Union[Path, list[Path]]]: pass @abc.abstractmethod - def _file_to_dicts(self, filepath: Path) -> Iterator[Dict[str, Any]]: + def _file_to_dicts(self, filepath: Path) -> Iterator[dict[str, Any]]: pass - def _dict_to_sentences(self, entry: Dict[str, Any]) -> List[Sentence]: + def _dict_to_sentences(self, entry: dict[str, Any]) -> list[Sentence]: entities = [entity for entity in entry["entities"] if entity["normalized"]] tokenized_passages = [ @@ -2326,7 +2331,7 @@ def _dict_to_sentences(self, entry: Dict[str, Any]) -> List[Sentence]: sent_s[start_token_idx : end_token_idx + 1].add_label(self.label_type, mention_id) return passage_sentences - def _files_to_dataset(self, paths: Union[Path, List[Path]]) -> FlairDatapointDataset: + def _files_to_dataset(self, paths: Union[Path, list[Path]]) -> FlairDatapointDataset: if isinstance(paths, Path): paths = [paths] all_sentences = [] @@ -2347,7 +2352,7 @@ class BIGBIO_EL_NCBI_DISEASE(BigBioEntityLinkingCorpus): def __init__(self, base_path: Optional[Union[str, Path]] = None, label_type: str = "el-diseases", **kwargs) -> None: super().__init__(base_path, label_type, **kwargs) - def _download_dataset(self, data_folder: Path) -> Dict[str, Union[Path, List[Path]]]: + def _download_dataset(self, data_folder: Path) -> dict[str, Union[Path, list[Path]]]: download_urls = { "train": ( "NCBItrainset_corpus.txt", @@ -2362,7 +2367,7 @@ def _download_dataset(self, data_folder: Path) -> Dict[str, Union[Path, List[Pat "https://www.ncbi.nlm.nih.gov/CBBresearch/Dogan/DISEASE/NCBItestset_corpus.zip", ), } - results_files: Dict[str, Union[Path, List[Path]]] = {} + results_files: dict[str, Union[Path, list[Path]]] = {} for split, (filename, url) in download_urls.items(): result_path = data_folder / filename @@ -2376,7 +2381,7 @@ def _download_dataset(self, data_folder: Path) -> Dict[str, Union[Path, List[Pat return results_files - def _file_to_dicts(self, filepath: Path) -> Iterator[Dict[str, Any]]: + def _file_to_dicts(self, filepath: Path) -> Iterator[dict[str, Any]]: with open(filepath) as f: for doc in pubtator.iterparse(f): unified_example = { @@ -2449,7 +2454,7 @@ class BIGBIO_EL_BC5CDR_CHEMICAL(BigBioEntityLinkingCorpus): def __init__(self, base_path: Optional[Union[str, Path]] = None, label_type: str = "el-chemical", **kwargs) -> None: super().__init__(base_path, label_type, **kwargs) - def _download_dataset(self, data_folder: Path) -> Dict[str, Union[Path, List[Path]]]: + def _download_dataset(self, data_folder: Path) -> dict[str, Union[Path, list[Path]]]: url = "https://huggingface.co/datasets/bigbio/bc5cdr/resolve/main/CDR_Data.zip" path = cached_path(url, data_folder) @@ -2458,7 +2463,7 @@ def _download_dataset(self, data_folder: Path) -> Dict[str, Union[Path, List[Pat unpack_file(path, data_folder) assert data_folder.exists() - results_files: Dict[str, Union[Path, List[Path]]] = { + results_files: dict[str, Union[Path, list[Path]]] = { "train": data_path / "CDR_TrainingSet.BioC.xml", "dev": data_path / "CDR_DevelopmentSet.BioC.xml", "test": data_path / "CDR_TestSet.BioC.xml", @@ -2497,7 +2502,7 @@ def _get_bioc_entity(self, span, db_id_key="MESH"): "normalized": normalized, } - def _file_to_dicts(self, filepath: Path) -> Iterator[Dict[str, Any]]: + def _file_to_dicts(self, filepath: Path) -> Iterator[dict[str, Any]]: reader = biocxml.BioCXMLDocumentReader(str(filepath)) for i, xdoc in enumerate(reader): @@ -2542,7 +2547,7 @@ def __init__(self, base_path: Optional[Union[str, Path]] = None, label_type: str self._re_tax_id = re.compile(r"(?P\d+)\([tT]ax:(?P\d+)\)") super().__init__(base_path, label_type, norm_keys=["db_id"], **kwargs) - def _download_dataset(self, data_folder: Path) -> Dict[str, Union[Path, List[Path]]]: + def _download_dataset(self, data_folder: Path) -> dict[str, Union[Path, list[Path]]]: url = "https://www.ncbi.nlm.nih.gov/CBBresearch/Lu/Demo/tmTools/download/GNormPlus/GNormPlusCorpus.zip" path = cached_path(url, data_folder) @@ -2551,7 +2556,7 @@ def _download_dataset(self, data_folder: Path) -> Dict[str, Union[Path, List[Pat unpack_file(path, data_folder) assert data_folder.exists() - results_files: Dict[str, Union[Path, List[Path]]] = { + results_files: dict[str, Union[Path, list[Path]]] = { "train": [data_path / "BC2GNtrain.BioC.xml", data_path / "NLMIAT.BioC.xml"], "test": data_path / "BC2GNtest.BioC.xml", } @@ -2595,7 +2600,7 @@ def _parse_bioc_entity(self, span, db_id_key="NCBIGene", insert_tax_id=False): "normalized": normalized, } - def _adjust_entity_offsets(self, text: str, entities: List[Dict]): + def _adjust_entity_offsets(self, text: str, entities: list[dict]): for entity in entities: start, end = entity["offsets"][0] entity_mention = entity["text"][0] @@ -2605,7 +2610,7 @@ def _adjust_entity_offsets(self, text: str, entities: List[Dict]): elif text[start : end - 1] == entity_mention: entity["offsets"] = [(start, end - 1)] - def _file_to_dicts(self, filepath: Path) -> Iterator[Dict[str, Any]]: + def _file_to_dicts(self, filepath: Path) -> Iterator[dict[str, Any]]: with filepath.open("r") as f: collection = biocxml.load(f) diff --git a/flair/datasets/ocr.py b/flair/datasets/ocr.py index bf60b2b0d6..4a58e4e7d3 100644 --- a/flair/datasets/ocr.py +++ b/flair/datasets/ocr.py @@ -1,6 +1,6 @@ import json from pathlib import Path -from typing import Dict, Optional, Union +from typing import Optional, Union import gdown.download_folder import PIL @@ -20,7 +20,7 @@ def __init__( encoding: str = "utf-8", load_images: bool = False, normalize_coords_to_thousands: bool = True, - label_name_map: Optional[Dict[str, str]] = None, + label_name_map: Optional[dict[str, str]] = None, ) -> None: """Instantiates a Dataset from a OCR-Json format. @@ -132,7 +132,7 @@ def __init__( in_memory: bool = True, load_images: bool = False, normalize_coords_to_thousands: bool = True, - label_name_map: Optional[Dict[str, str]] = None, + label_name_map: Optional[dict[str, str]] = None, **corpusargs, ) -> None: """Instantiates a Corpus from a OCR-Json format. @@ -205,7 +205,7 @@ def __init__( in_memory: bool = True, load_images: bool = False, normalize_coords_to_thousands: bool = True, - label_name_map: Optional[Dict[str, str]] = None, + label_name_map: Optional[dict[str, str]] = None, **corpusargs, ) -> None: """Instantiates the SROIE corpus with perfect ocr boxes. diff --git a/flair/datasets/relation_extraction.py b/flair/datasets/relation_extraction.py index 30709a14c4..871811abc2 100644 --- a/flair/datasets/relation_extraction.py +++ b/flair/datasets/relation_extraction.py @@ -5,8 +5,9 @@ import os import re from collections import defaultdict +from collections.abc import Iterable from pathlib import Path -from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union +from typing import Any, Optional, Union import conllu import gdown @@ -279,7 +280,7 @@ def extract_and_convert_to_conllu(self, data_file, data_folder): token_list = self._tacred_example_to_token_list(example) target_file.write(token_list.serialize()) - def _tacred_example_to_token_list(self, example: Dict[str, Any]) -> conllu.TokenList: + def _tacred_example_to_token_list(self, example: dict[str, Any]) -> conllu.TokenList: id_ = example["id"] tokens = example["token"] ner = example["stanford_ner"] @@ -379,7 +380,7 @@ def _parse_incr(self, source_file) -> Iterable[conllu.TokenList]: } metadata_parsers = {"__fallback__": lambda k, v: tuple(k.split())} - lines: List[str] = [] + lines: list[str] = [] for index, line in enumerate(source_file): if index > 0 and line.startswith("#"): source_str = "".join(lines) @@ -416,9 +417,10 @@ def convert_to_conllu(self, source_data_folder: Path, data_folder): ] for source_filename, target_filename in zip(source_filenames, target_filenames): - with (source_data_folder / source_filename).open(encoding="utf-8") as source_file, ( - data_folder / target_filename - ).open("w", encoding="utf-8") as target_file: + with ( + (source_data_folder / source_filename).open(encoding="utf-8") as source_file, + (data_folder / target_filename).open("w", encoding="utf-8") as target_file, + ): # write CoNLL-U Plus header target_file.write("# global.columns = id form ner\n") @@ -426,7 +428,7 @@ def convert_to_conllu(self, source_data_folder: Path, data_folder): token_list = self._src_token_list_to_token_list(src_token_list) target_file.write(token_list.serialize()) - def _bio_tags_to_spans(self, tags: List[str]) -> List[Tuple[int, int]]: + def _bio_tags_to_spans(self, tags: list[str]) -> list[tuple[int, int]]: spans = [] span_start = 0 span_end = 0 @@ -590,7 +592,7 @@ def extract_and_convert_to_conllu(self, data_file, data_folder): ent2 = arg2.split(":")[1] pmid_to_relations[pmid].add((rel_type, ent1, ent2)) - tokenlists: List[conllu.TokenList] = [] + tokenlists: list[conllu.TokenList] = [] with zip_file.open( f"drugprot-gs-training-development/{split}/drugprot_{split}_abstracs.tsv" ) as abstracts_file: @@ -652,13 +654,13 @@ def has_overlap(self, a, b): def drugprot_document_to_tokenlists( self, pmid: str, - title_sentences: List[Sentence], - abstract_sentences: List[Sentence], + title_sentences: list[Sentence], + abstract_sentences: list[Sentence], abstract_offset: int, - entities: Dict[str, Tuple[str, int, int, str]], - relations: Set[Tuple[str, str, str]], - ) -> List[conllu.TokenList]: - tokenlists: List[conllu.TokenList] = [] + entities: dict[str, tuple[str, int, int, str]], + relations: set[tuple[str, str, str]], + ) -> list[conllu.TokenList]: + tokenlists: list[conllu.TokenList] = [] sentence_id = 1 for offset, sents in [ (0, title_sentences), diff --git a/flair/datasets/sequence_labeling.py b/flair/datasets/sequence_labeling.py index 6699548498..55e50723d1 100644 --- a/flair/datasets/sequence_labeling.py +++ b/flair/datasets/sequence_labeling.py @@ -4,17 +4,13 @@ import os import re import shutil +import tarfile from collections import defaultdict +from collections.abc import Iterable, Iterator from pathlib import Path from typing import ( Any, - DefaultDict, - Dict, - Iterable, - Iterator, - List, Optional, - Tuple, Union, cast, ) @@ -224,7 +220,7 @@ def __init__( self.label_type = label_type self.path_to_json_file = path_to_json_file - self.sentences: List[Sentence] = [] + self.sentences: list[Sentence] = [] with path_to_json_file.open(encoding=encoding) as jsonl_fp: for line in jsonl_fp: current_line = json.loads(line) @@ -238,7 +234,7 @@ def __init__( self.sentences.append(current_sentence) - def _add_labels_to_sentence(self, raw_text: str, sentence: Sentence, labels: List[List[Any]]): + def _add_labels_to_sentence(self, raw_text: str, sentence: Sentence, labels: list[list[Any]]): # Add tags for each annotated span for label in labels: self._add_label_to_sentence(raw_text, sentence, label[0], label[1], label[2]) @@ -288,7 +284,7 @@ def _add_label_to_sentence(self, text: str, sentence: Sentence, start: int, end: sentence[start_idx : end_idx + 1].add_label(self.label_type, label) - def _add_metadatas_to_sentence(self, sentence: Sentence, metadatas: List[Tuple[str, str]]): + def _add_metadatas_to_sentence(self, sentence: Sentence, metadatas: list[tuple[str, str]]): # Add metadatas for sentence for metadata in metadatas: self._add_metadata_to_sentence(sentence, metadata[0], metadata[1]) @@ -313,7 +309,7 @@ def __getitem__(self, index: int) -> Sentence: class MultiFileColumnCorpus(Corpus): def __init__( self, - column_format: Dict[int, str], + column_format: dict[int, str], train_files=None, test_files=None, dev_files=None, @@ -323,8 +319,8 @@ def __init__( document_separator_token: Optional[str] = None, skip_first_line: bool = False, in_memory: bool = True, - label_name_map: Optional[Dict[str, str]] = None, - banned_sentences: Optional[List[str]] = None, + label_name_map: Optional[dict[str, str]] = None, + banned_sentences: Optional[list[str]] = None, default_whitespace_after: int = 1, **corpusargs, ) -> None: @@ -424,7 +420,7 @@ class ColumnCorpus(MultiFileColumnCorpus): def __init__( self, data_folder: Union[str, Path], - column_format: Dict[int, str], + column_format: dict[int, str], train_file=None, test_file=None, dev_file=None, @@ -475,15 +471,15 @@ class ColumnDataset(FlairDataset): def __init__( self, path_to_column_file: Union[str, Path], - column_name_map: Dict[int, str], + column_name_map: dict[int, str], column_delimiter: str = r"\s+", comment_symbol: Optional[str] = None, - banned_sentences: Optional[List[str]] = None, + banned_sentences: Optional[list[str]] = None, in_memory: bool = True, document_separator_token: Optional[str] = None, encoding: str = "utf-8", skip_first_line: bool = False, - label_name_map: Optional[Dict[str, str]] = None, + label_name_map: Optional[dict[str, str]] = None, default_whitespace_after: int = 1, ) -> None: r"""Instantiates a column dataset. @@ -537,7 +533,7 @@ def __init__( # option 1: keep Sentence objects in memory if self.in_memory: - self.sentences: List[Sentence] = [] + self.sentences: list[Sentence] = [] # pointer to previous previous_sentence = None @@ -579,7 +575,7 @@ def __init__( # option 2: keep source data in memory if not self.in_memory: - self.sentences_raw: List[List[str]] = [] + self.sentences_raw: list[list[str]] = [] while True: # read lines for next sentence, but don't parse @@ -679,10 +675,10 @@ def _read_next_sentence(self, file): return lines def _convert_lines_to_sentence( - self, lines, word_level_tag_columns: Dict[int, str], span_level_tag_columns: Optional[Dict[int, str]] = None + self, lines, word_level_tag_columns: dict[int, str], span_level_tag_columns: Optional[dict[int, str]] = None ): token: Optional[Token] = None - tokens: List[Token] = [] + tokens: list[Token] = [] filtered_lines = [] comments = [] for line in lines: @@ -749,9 +745,9 @@ def _convert_lines_to_sentence( return sentence return None - def _parse_token(self, line: str, column_name_map: Dict[int, str], last_token: Optional[Token] = None) -> Token: + def _parse_token(self, line: str, column_name_map: dict[int, str], last_token: Optional[Token] = None) -> Token: # get fields from line - fields: List[str] = self.column_delimiter.split(line.rstrip()) + fields: list[str] = self.column_delimiter.split(line.rstrip()) field_count = len(fields) # get head_id if exists (only in dependency parses) head_id = int(fields[self.head_id_column]) if self.head_id_column else None @@ -855,7 +851,7 @@ def __init__( base_path: Optional[Union[str, Path]] = None, version: str = "v4", language: str = "english", - domain: Union[None, str, List[str], Dict[str, Union[None, str, List[str]]]] = None, + domain: Union[None, str, list[str], dict[str, Union[None, str, list[str]]]] = None, in_memory: bool = True, **corpusargs, ) -> None: @@ -893,7 +889,7 @@ def get_available_domains( version: str = "v4", language: str = "english", split: str = "train", - ) -> List[str]: + ) -> list[str]: processed_data_path = cls._ensure_data_processed(base_path=base_path, language=language, version=version) processed_split_path = processed_data_path / "splits" / version / language / split @@ -907,7 +903,7 @@ def _get_processed_file_paths( split: str = "train", version: str = "v4", language: str = "english", - domain: Optional[Union[str, List[str], Dict[str, Union[None, str, List[str]]]]] = None, + domain: Optional[Union[str, list[str], dict[str, Union[None, str, list[str]]]]] = None, ) -> Iterable[Path]: processed_split_path = processed_data_path / "splits" / version / language / split @@ -1009,8 +1005,8 @@ def _process_coref_span_annotations_for_word( cls, label: str, word_index: int, - clusters: DefaultDict[int, List[Tuple[int, int]]], - coref_stacks: DefaultDict[int, List[int]], + clusters: defaultdict[int, list[tuple[int, int]]], + coref_stacks: defaultdict[int, list[int]], ) -> None: """For a given coref label, add it to a currently open span(s), complete a span(s) or ignore it, if it is outside of all spans. @@ -1048,9 +1044,9 @@ def _process_coref_span_annotations_for_word( @classmethod def _process_span_annotations_for_word( cls, - annotations: List[str], - span_labels: List[List[str]], - current_span_labels: List[Optional[str]], + annotations: list[str], + span_labels: list[list[str]], + current_span_labels: list[Optional[str]], ) -> None: for annotation_index, annotation in enumerate(annotations): # strip all bracketing information to @@ -1076,33 +1072,33 @@ def _process_span_annotations_for_word( current_span_labels[annotation_index] = None @classmethod - def _conll_rows_to_sentence(cls, conll_rows: List[str]) -> Dict: + def _conll_rows_to_sentence(cls, conll_rows: list[str]) -> dict: document_id: str sentence_id: int # The words in the sentence. - sentence: List[str] = [] + sentence: list[str] = [] # The pos tags of the words in the sentence. - pos_tags: List[str] = [] + pos_tags: list[str] = [] # the pieces of the parse tree. - parse_pieces: List[Optional[str]] = [] + parse_pieces: list[Optional[str]] = [] # The lemmatised form of the words in the sentence which # have SRL or word sense information. - predicate_lemmas: List[Optional[str]] = [] + predicate_lemmas: list[Optional[str]] = [] # The FrameNet ID of the predicate. - predicate_framenet_ids: List[Optional[str]] = [] + predicate_framenet_ids: list[Optional[str]] = [] # The sense of the word, if available. - word_senses: List[Optional[float]] = [] + word_senses: list[Optional[float]] = [] # The current speaker, if available. - speakers: List[Optional[str]] = [] + speakers: list[Optional[str]] = [] - verbal_predicates: List[str] = [] - span_labels: List[List[str]] = [] - current_span_labels: List[Optional[str]] = [] + verbal_predicates: list[str] = [] + span_labels: list[list[str]] = [] + current_span_labels: list[Optional[str]] = [] # Cluster id -> List of (start_index, end_index) spans. - clusters: DefaultDict[int, List[Tuple[int, int]]] = defaultdict(list) + clusters: defaultdict[int, list[tuple[int, int]]] = defaultdict(list) # Cluster id -> List of start_indices which are open for this id. - coref_stacks: DefaultDict[int, List[int]] = defaultdict(list) + coref_stacks: defaultdict[int, list[int]] = defaultdict(list) for index, row in enumerate(conll_rows): conll_components = row.split() @@ -1178,7 +1174,7 @@ def _conll_rows_to_sentence(cls, conll_rows: List[str]) -> Dict: srl_frames = list(zip(verbal_predicates, span_labels[1:])) # this would not be reached if parse_pieces contained None, hence the cast - parse_tree = "".join(cast(List[str], parse_pieces)) if all(parse_pieces) else None + parse_tree = "".join(cast(list[str], parse_pieces)) if all(parse_pieces) else None coref_span_tuples = {(cluster_id, span) for cluster_id, span_list in clusters.items() for span in span_list} return { @@ -1197,7 +1193,7 @@ def _conll_rows_to_sentence(cls, conll_rows: List[str]) -> Dict: } @classmethod - def dataset_document_iterator(cls, file_path: Union[Path, str]) -> Iterator[List]: + def dataset_document_iterator(cls, file_path: Union[Path, str]) -> Iterator[list[dict]]: """An iterator over CONLL formatted files which yields documents, regardless of the number of document annotations in a particular file. This is useful for conll data which has been preprocessed, such @@ -1206,7 +1202,7 @@ def dataset_document_iterator(cls, file_path: Union[Path, str]) -> Iterator[List """ with open(file_path, encoding="utf8") as open_file: conll_rows = [] - document: List = [] + document: list[dict] = [] for line in open_file: line = line.strip() if line != "" and not line.startswith("#"): @@ -1456,17 +1452,22 @@ def __init__( cached_path(f"{conll_2000_path}train.txt.gz", Path("datasets") / dataset_name) cached_path(f"{conll_2000_path}test.txt.gz", Path("datasets") / dataset_name) import gzip - import shutil - with gzip.open(flair.cache_root / "datasets" / dataset_name / "train.txt.gz", "rb") as f_in, open( - flair.cache_root / "datasets" / dataset_name / "train.txt", - "wb", - ) as f_out: + with ( + gzip.open(flair.cache_root / "datasets" / dataset_name / "train.txt.gz", "rb") as f_in, + open( + flair.cache_root / "datasets" / dataset_name / "train.txt", + "wb", + ) as f_out, + ): shutil.copyfileobj(f_in, f_out) - with gzip.open(flair.cache_root / "datasets" / dataset_name / "test.txt.gz", "rb") as f_in, open( - flair.cache_root / "datasets" / dataset_name / "test.txt", - "wb", - ) as f_out: + with ( + gzip.open(flair.cache_root / "datasets" / dataset_name / "test.txt.gz", "rb") as f_in, + open( + flair.cache_root / "datasets" / dataset_name / "test.txt", + "wb", + ) as f_out, + ): shutil.copyfileobj(f_in, f_out) super().__init__( @@ -1735,8 +1736,6 @@ def __init__( data_file = data_path / "named_ent_eu.train" if not data_file.is_file(): cached_path(f"{ner_basque_path}/eiec_v1.0.tgz", Path("datasets") / dataset_name) - import shutil - import tarfile with tarfile.open( flair.cache_root / "datasets" / dataset_name / "eiec_v1.0.tgz", @@ -2247,15 +2246,13 @@ def __init__( if not base_path: base_path = Path(flair.cache_root) / "datasets" data_folder = base_path / dataset_name - import tarfile if not os.path.isfile(data_folder / "webpages_ner.txt"): # # download zip tar_file = "https://cogcomp.seas.upenn.edu/Data/NERWebpagesColumns.tgz" webpages_ner_path = cached_path(tar_file, Path("datasets") / dataset_name) - tf = tarfile.open(webpages_ner_path) - tf.extractall(data_folder) - tf.close() + with tarfile.open(webpages_ner_path) as tf: + tf.extractall(data_folder) outputfile = os.path.abspath(data_folder) # merge the files in one as the zip is containing multiples files @@ -2538,7 +2535,7 @@ def _add_IOB_tags(self, data_file: Union[str, Path], encoding: str = "utf8", ner Specifies the ner-tagged column. The default is 1 (the second column). """ - def add_I_prefix(current_line: List[str], ner: int, tag: str): + def add_I_prefix(current_line: list[str], ner: int, tag: str): for i in range(len(current_line)): if i == 0: f.write(line_list[i]) @@ -2779,9 +2776,11 @@ def _create_datasets(self, data_file: Union[str, Path], data_folder: Path): train_len = round(num_lines * 0.8) test_len = round(num_lines * 0.1) - with (data_folder / "train.txt").open("w", encoding="utf-8") as train, (data_folder / "test.txt").open( - "w", encoding="utf-8" - ) as test, (data_folder / "dev.txt").open("w", encoding="utf-8") as dev: + with ( + (data_folder / "train.txt").open("w", encoding="utf-8") as train, + (data_folder / "test.txt").open("w", encoding="utf-8") as test, + (data_folder / "dev.txt").open("w", encoding="utf-8") as dev, + ): for k, line in enumerate(file.readlines(), start=1): if k <= train_len: train.write(line) @@ -2972,7 +2971,7 @@ def __prepare_jap_wikinews_corpus(file_in: Union[str, Path], file_out: Union[str class NER_MASAKHANE(MultiCorpus): def __init__( self, - languages: Union[str, List[str]] = "luo", + languages: Union[str, list[str]] = "luo", version: str = "v2", base_path: Optional[Union[str, Path]] = None, in_memory: bool = True, @@ -3056,7 +3055,7 @@ def __init__( if languages == ["all"]: languages = list(language_to_code.values()) - corpora: List[Corpus] = [] + corpora: list[Corpus] = [] for language in languages: if language in language_to_code: language = language_to_code[language] @@ -3239,7 +3238,7 @@ def __init__( class NER_MULTI_WIKIANN(MultiCorpus): def __init__( self, - languages: Union[str, List[str]] = "en", + languages: Union[str, list[str]] = "en", base_path: Optional[Union[str, Path]] = None, in_memory: bool = False, **corpusargs, @@ -3251,7 +3250,7 @@ def __init__( Parameters ---------- - languages : Union[str, List[str]] + languages : Union[str, list[str]] Should be an abbreviation of a language ("en", "de",..) or a list of abbreviations. The datasets of all passed languages will be saved in one MultiCorpus. (Note that, even though listed on https://elisa-ie.github.io/wikiann/ some datasets are empty. @@ -3282,7 +3281,7 @@ def __init__( # this list is handed to the multicorpus # list that contains the columncopora - corpora: List[Corpus] = [] + corpora: list[Corpus] = [] google_drive_path = "https://drive.google.com/uc?id=" # download data if necessary @@ -3294,8 +3293,6 @@ def __init__( # if language not downloaded yet, download it if not language_folder.exists(): if first: - import tarfile - import gdown first = False @@ -3310,10 +3307,8 @@ def __init__( # unzip log.info("Extracting data...") - tar = tarfile.open(str(language_folder / language) + ".tar.gz", "r:gz") - # tar.extractall(language_folder,members=[tar.getmember(file_name)]) - tar.extract(file_name, str(language_folder)) - tar.close() + with tarfile.open(str(language_folder / language) + ".tar.gz", "r:gz") as tar: + tar.extract(file_name, str(language_folder)) log.info("...done.") # transform data into required format @@ -3342,9 +3337,10 @@ def __init__( ) def _silver_standard_to_simple_ner_annotation(self, data_file: Union[str, Path]): - with open(data_file, encoding="utf-8") as f_read, open( - str(data_file) + "_new", "w+", encoding="utf-8" - ) as f_write: + with ( + open(data_file, encoding="utf-8") as f_read, + open(str(data_file) + "_new", "w+", encoding="utf-8") as f_write, + ): while True: line = f_read.readline() if line: @@ -3660,7 +3656,7 @@ def _google_drive_id_from_language_name(self, language): class NER_MULTI_XTREME(MultiCorpus): def __init__( self, - languages: Union[str, List[str]] = "en", + languages: Union[str, list[str]] = "en", base_path: Optional[Union[str, Path]] = None, in_memory: bool = False, **corpusargs, @@ -3672,7 +3668,7 @@ def __init__( Parameters ---------- - languages : Union[str, List[str]], optional + languages : Union[str, list[str]], optional Specify the languages you want to load. Provide an empty list or string to select all languages. base_path : Union[str, Path], optional Default is None, meaning that corpus gets auto-downloaded and loaded. You can override this to point to a different folder but typically this should not be necessary. @@ -3743,7 +3739,7 @@ def __init__( # This list is handed to the multicorpus # list that contains the columncopora - corpora: List[Corpus] = [] + corpora: list[Corpus] = [] hu_path = "https://nlp.informatik.hu-berlin.de/resources/datasets/panx_dataset" @@ -3765,12 +3761,10 @@ def __init__( # unzip log.info("Extracting data...") - import tarfile - tar = tarfile.open(str(temp_file), "r:gz") - for part in ["train", "test", "dev"]: - tar.extract(part, str(language_folder)) - tar.close() + with tarfile.open(str(temp_file), "r:gz") as tar: + for part in ["train", "test", "dev"]: + tar.extract(part, str(language_folder)) log.info("...done.") # transform data into required format @@ -3809,7 +3803,7 @@ def _xtreme_to_simple_ner_annotation(self, data_file: Union[str, Path]): class NER_MULTI_WIKINER(MultiCorpus): def __init__( self, - languages: Union[str, List[str]] = "en", + languages: Union[str, list[str]] = "en", base_path: Optional[Union[str, Path]] = None, in_memory: bool = False, **corpusargs, @@ -3828,7 +3822,7 @@ def __init__( data_folder = base_path / dataset_name - corpora: List[Corpus] = [] + corpora: list[Corpus] = [] for language in languages: language_folder = data_folder / language @@ -3868,11 +3862,14 @@ def _download_wikiner(self, language_code: str, dataset_name: str): flair.cache_root / "datasets" / dataset_name / f"aij-wikiner-{lc}-wp3.bz2", "rb", ) - with bz_file as f, open( - flair.cache_root / "datasets" / dataset_name / f"aij-wikiner-{lc}-wp3.train", - "w", - encoding="utf-8", - ) as out: + with ( + bz_file as f, + open( + flair.cache_root / "datasets" / dataset_name / f"aij-wikiner-{lc}-wp3.train", + "w", + encoding="utf-8", + ) as out, + ): for lineb in f: line = lineb.decode("utf-8") words = line.split(" ") @@ -4740,7 +4737,7 @@ def __init__( class NER_NERMUD(MultiCorpus): def __init__( self, - domains: Union[str, List[str]] = "all", + domains: Union[str, list[str]] = "all", base_path: Optional[Union[str, Path]] = None, in_memory: bool = False, **corpusargs, @@ -4779,7 +4776,7 @@ def __init__( data_folder = base_path / dataset_name - corpora: List[Corpus] = [] + corpora: list[Corpus] = [] github_path = "https://raw.githubusercontent.com/dhfbk/KIND/main/evalita-2023" @@ -4923,7 +4920,7 @@ def _set_path(cls, base_path) -> Path: return base_path @classmethod - def _load_features(cls, base_path) -> List[List[str]]: + def _load_features(cls, base_path) -> list[list[str]]: print(base_path) unpack_file(cached_path(cls.data_url, base_path), base_path, "zip", False) with open(f"{base_path}/estner.cnll", encoding="utf-8") as in_file: @@ -4932,17 +4929,17 @@ def _load_features(cls, base_path) -> List[List[str]]: return features @classmethod - def _process_clean_labels(cls, features) -> List[List[str]]: + def _process_clean_labels(cls, features) -> list[list[str]]: preinstances = [[instance[0], instance[len(instance) - 1]] for instance in features] return preinstances @classmethod - def _rmv_clean_labels(cls, features) -> List[str]: + def _rmv_clean_labels(cls, features) -> list[str]: rdcd_features = [feature[:-1] for feature in features] return rdcd_features @classmethod - def _load_noisy_labels(cls, version, base_path) -> List[str]: + def _load_noisy_labels(cls, version, base_path) -> list[str]: file_name = f"NoisyNER_labelset{version}.labels" cached_path(f"{cls.label_url}/{file_name}", base_path) with open(f"{base_path}/{file_name}", encoding="utf-8") as in_file: @@ -4950,7 +4947,7 @@ def _load_noisy_labels(cls, version, base_path) -> List[str]: return labels @classmethod - def _process_noisy_labels(cls, rdcd_features, labels) -> List[List[str]]: + def _process_noisy_labels(cls, rdcd_features, labels) -> list[list[str]]: instances = [] label_idx = 0 for feature in rdcd_features: @@ -4965,7 +4962,7 @@ def _process_noisy_labels(cls, rdcd_features, labels) -> List[List[str]]: return instances @classmethod - def _delete_empty_labels(cls, version, preinstances) -> List[str]: + def _delete_empty_labels(cls, version, preinstances) -> list[str]: instances = [] if version == 0: for instance in preinstances: @@ -4978,7 +4975,7 @@ def _delete_empty_labels(cls, version, preinstances) -> List[str]: return instances @classmethod - def _split_data(cls, instances) -> Tuple[List[str], List[str], List[str]]: + def _split_data(cls, instances) -> tuple[list[str], list[str], list[str]]: train = instances[:185708] dev = instances[185708:208922] test = instances[208922:] @@ -4996,7 +4993,7 @@ def _write_instances(cls, version, base_path, split, data): class MASAKHA_POS(MultiCorpus): def __init__( self, - languages: Union[str, List[str]] = "bam", + languages: Union[str, list[str]] = "bam", version: str = "v1", base_path: Optional[Union[str, Path]] = None, in_memory: bool = True, @@ -5063,7 +5060,7 @@ def __init__( if languages == ["all"]: languages = supported_languages - corpora: List[Corpus] = [] + corpora: list[Corpus] = [] for language in languages: if language not in supported_languages: log.error(f"Language '{language}' is not in list of supported languages!") diff --git a/flair/datasets/text_image.py b/flair/datasets/text_image.py index f7baf72be9..676b078d7f 100644 --- a/flair/datasets/text_image.py +++ b/flair/datasets/text_image.py @@ -3,7 +3,6 @@ import os import urllib from pathlib import Path -from typing import List import numpy as np import torch.utils.data.dataloader @@ -40,13 +39,13 @@ def __init__(self, **kwargs) -> None: feidegger_dataset: Dataset = FeideggerDataset(dataset_info, **kwargs) - train_indices = list(np.where(np.in1d(feidegger_dataset.split, list(range(8))))[0]) # type: ignore[attr-defined] + train_indices = list(np.where(np.isin(feidegger_dataset.split, list(range(8))))[0]) # type: ignore[attr-defined] train = torch.utils.data.dataset.Subset(feidegger_dataset, train_indices) - dev_indices = list(np.where(np.in1d(feidegger_dataset.split, [8]))[0]) # type: ignore[attr-defined] + dev_indices = list(np.where(np.isin(feidegger_dataset.split, [8]))[0]) # type: ignore[attr-defined] dev = torch.utils.data.dataset.Subset(feidegger_dataset, dev_indices) - test_indices = list(np.where(np.in1d(feidegger_dataset.split, [9]))[0]) # type: ignore[attr-defined] + test_indices = list(np.where(np.isin(feidegger_dataset.split, [9]))[0]) # type: ignore[attr-defined] test = torch.utils.data.dataset.Subset(feidegger_dataset, test_indices) super().__init__(train, dev, test, name="feidegger") @@ -56,8 +55,8 @@ class FeideggerDataset(FlairDataset): def __init__(self, dataset_info, **kwargs) -> None: super().__init__() - self.data_points: List[DataPair] = [] - self.split: List[int] = [] + self.data_points: list[DataPair] = [] + self.split: list[int] = [] def identity(x): return x diff --git a/flair/datasets/text_text.py b/flair/datasets/text_text.py index 0bf0e91020..58a40d62c9 100644 --- a/flair/datasets/text_text.py +++ b/flair/datasets/text_text.py @@ -1,7 +1,7 @@ import logging import os from pathlib import Path -from typing import List, Optional, Union +from typing import Optional, Union import flair from flair.data import ( @@ -144,14 +144,15 @@ def __init__( self.total_sentence_count: int = 0 if self.in_memory: - self.bi_sentences: List[DataPair] = [] + self.bi_sentences: list[DataPair] = [] else: - self.source_lines: List[str] = [] - self.target_lines: List[str] = [] + self.source_lines: list[str] = [] + self.target_lines: list[str] = [] - with open(str(path_to_source), encoding="utf-8") as source_file, open( - str(path_to_target), encoding="utf-8" - ) as target_file: + with ( + open(str(path_to_source), encoding="utf-8") as source_file, + open(str(path_to_target), encoding="utf-8") as target_file, + ): source_line = source_file.readline() target_line = target_file.readline() @@ -204,7 +205,7 @@ class DataPairCorpus(Corpus): def __init__( self, data_folder: Union[str, Path], - columns: List[int] = [0, 1, 2], + columns: list[int] = [0, 1, 2], train_file=None, test_file=None, dev_file=None, @@ -318,7 +319,7 @@ class DataPairDataset(FlairDataset): def __init__( self, path_to_data: Union[str, Path], - columns: List[int] = [0, 1, 2], + columns: list[int] = [0, 1, 2], max_tokens_per_doc=-1, max_chars_per_doc=-1, use_tokenizer=True, @@ -368,11 +369,11 @@ def __init__( self.total_data_count: int = 0 if self.in_memory: - self.data_pairs: List[DataPair] = [] + self.data_pairs: list[DataPair] = [] else: - self.first_elements: List[str] = [] - self.second_elements: List[str] = [] - self.labels: List[Optional[str]] = [] + self.first_elements: list[str] = [] + self.second_elements: list[str] = [] + self.labels: list[Optional[str]] = [] with open(str(path_to_data), encoding=encoding) as source_file: source_line = source_file.readline() @@ -448,7 +449,7 @@ class DataTripleCorpus(Corpus): def __init__( self, data_folder: Union[str, Path], - columns: List[int] = [0, 1, 2, 3], + columns: list[int] = [0, 1, 2, 3], train_file=None, test_file=None, dev_file=None, @@ -563,7 +564,7 @@ class DataTripleDataset(FlairDataset): def __init__( self, path_to_data: Union[str, Path], - columns: List[int] = [0, 1, 2, 3], + columns: list[int] = [0, 1, 2, 3], max_tokens_per_doc=-1, max_chars_per_doc=-1, use_tokenizer=True, @@ -614,12 +615,12 @@ def __init__( self.total_data_count: int = 0 if self.in_memory: - self.data_triples: List[DataTriple] = [] + self.data_triples: list[DataTriple] = [] else: - self.first_elements: List[str] = [] - self.second_elements: List[str] = [] - self.third_elements: List[str] = [] - self.labels: List[Optional[str]] = [] + self.first_elements: list[str] = [] + self.second_elements: list[str] = [] + self.third_elements: list[str] = [] + self.labels: list[Optional[str]] = [] with open(str(path_to_data), encoding=encoding) as source_file: source_line = source_file.readline() @@ -828,9 +829,10 @@ def __init__( str(data_folder / "MNLI" / temp_file), ) - with open(data_folder / "MNLI" / dev_filename, "a", encoding="utf-8") as out_file, open( - data_folder / "MNLI" / temp_file, encoding="utf-8" - ) as in_file: + with ( + open(data_folder / "MNLI" / dev_filename, "a", encoding="utf-8") as out_file, + open(data_folder / "MNLI" / temp_file, encoding="utf-8") as in_file, + ): for line in in_file: fields = line.split("\t") reordered_columns = "\t".join(fields[column_id] for column_id in range(11)) diff --git a/flair/datasets/treebanks.py b/flair/datasets/treebanks.py index ed0f0135cd..21ae327691 100644 --- a/flair/datasets/treebanks.py +++ b/flair/datasets/treebanks.py @@ -1,7 +1,7 @@ import logging import re from pathlib import Path -from typing import List, Optional, Union +from typing import Optional, Union import flair from flair.data import Corpus, FlairDataset, Sentence, Token @@ -82,7 +82,7 @@ def __init__( with open(str(self.path_to_conll_file), encoding="utf-8") as file: # option 1: read only sentence boundaries as offset positions if not self.in_memory: - self.indices: List[int] = [] + self.indices: list[int] = [] line = file.readline() position = 0 @@ -97,7 +97,7 @@ def __init__( # option 2: keep everything in memory if self.in_memory: - self.sentences: List[Sentence] = [] + self.sentences: list[Sentence] = [] while True: sentence = self._read_next_sentence(file) @@ -129,7 +129,7 @@ def __getitem__(self, index: int = 0) -> Sentence: def _read_next_sentence(self, file) -> Optional[Sentence]: line = file.readline() - tokens: List[Token] = [] + tokens: list[Token] = [] # current token ID token_idx = 0 @@ -143,7 +143,7 @@ def _read_next_sentence(self, file) -> Optional[Sentence]: newline_reached = False while line: line = line.strip() - fields: List[str] = re.split("\t+", line) + fields: list[str] = re.split("\t+", line) # end of sentence if line == "": diff --git a/flair/embeddings/base.py b/flair/embeddings/base.py index 154f2600be..294b41fac8 100644 --- a/flair/embeddings/base.py +++ b/flair/embeddings/base.py @@ -1,7 +1,8 @@ import inspect import logging from abc import abstractmethod -from typing import Any, Dict, Generic, List, Sequence, Type, Union +from collections.abc import Sequence +from typing import Any, Generic, Union import torch from torch.nn import Parameter, ParameterList @@ -37,7 +38,7 @@ def embedding_length(self) -> int: def embedding_type(self) -> str: raise NotImplementedError - def embed(self, data_points: Union[DT, List[DT]]) -> List[DT]: + def embed(self, data_points: Union[DT, list[DT]]) -> list[DT]: """Add embeddings to all words in a list of sentences. If embeddings are already added, updates only if embeddings are non-static. @@ -55,10 +56,10 @@ def _everything_embedded(self, data_points: Sequence[DT]) -> bool: return all(self.name in data_point._embeddings for data_point in data_points) @abstractmethod - def _add_embeddings_internal(self, sentences: List[DT]): + def _add_embeddings_internal(self, sentences: list[DT]): """Private method for adding embeddings to all words in a list of sentences.""" - def get_names(self) -> List[str]: + def get_names(self) -> list[str]: """Returns a list of embedding names. In most cases, it is just a list with one item, namely the name of @@ -67,9 +68,6 @@ def get_names(self) -> List[str]: """ return [self.name] - def get_named_embeddings_dict(self) -> Dict: - return {self.name: self} - @staticmethod def get_instance_parameters(locals: dict) -> dict: class_definition = locals.get("__class__") @@ -84,14 +82,14 @@ def get_instance_parameters(locals: dict) -> dict: return instance_parameters @classmethod - def from_params(cls, params: Dict[str, Any]) -> "Embeddings": + def from_params(cls, params: dict[str, Any]) -> "Embeddings": raise NotImplementedError - def to_params(self) -> Dict[str, Any]: + def to_params(self) -> dict[str, Any]: raise NotImplementedError @classmethod - def load_embedding(cls, params: Dict[str, Any]): + def load_embedding(cls, params: dict[str, Any]): state_dict = params.pop("state_dict", None) embedding = cls.from_params(params) @@ -155,7 +153,7 @@ def __init__(self, mixture_size: int, trainable: bool = False) -> None: requires_grad=trainable, ) - def forward(self, tensors: List[torch.Tensor]) -> torch.Tensor: + def forward(self, tensors: list[torch.Tensor]) -> torch.Tensor: """Forward pass of scalar mix. Computes a weighted average of the ``tensors``. The input tensors an be any shape @@ -203,7 +201,7 @@ def _everything_embedded(self, data_points: Sequence[Sentence]) -> bool: return True -EMBEDDING_CLASSES: Dict[str, Type[Embeddings]] = {} +EMBEDDING_CLASSES: dict[str, type[Embeddings]] = {} def register_embeddings(*args): @@ -225,7 +223,7 @@ def _register(cls): return _register -def load_embeddings(params: Dict[str, Any]) -> Embeddings: +def load_embeddings(params: dict[str, Any]) -> Embeddings: cls_name = params.pop("__cls__") cls = EMBEDDING_CLASSES[cls_name] return cls.load_embedding(params) diff --git a/flair/embeddings/document.py b/flair/embeddings/document.py index c1e73442e6..28867d889a 100644 --- a/flair/embeddings/document.py +++ b/flair/embeddings/document.py @@ -1,5 +1,5 @@ import logging -from typing import Any, Dict, List, Optional, Union, cast +from typing import Any, Optional, Union, cast import torch from sklearn.feature_extraction.text import TfidfVectorizer @@ -67,7 +67,7 @@ def create_from_state(cls, **state): class DocumentPoolEmbeddings(DocumentEmbeddings): def __init__( self, - embeddings: Union[TokenEmbeddings, List[TokenEmbeddings]], + embeddings: Union[TokenEmbeddings, list[TokenEmbeddings]], fine_tune_mode: str = "none", pooling: str = "mean", ) -> None: @@ -114,7 +114,7 @@ def __init__( def embedding_length(self) -> int: return self.__embedding_length - def embed(self, sentences: Union[List[Sentence], Sentence]): + def embed(self, sentences: Union[list[Sentence], Sentence]): """Add embeddings to every sentence in the given list of sentences. If embeddings are already added, updates only if embeddings are non-static. @@ -146,18 +146,18 @@ def embed(self, sentences: Union[List[Sentence], Sentence]): sentence.set_embedding(self.name, pooled_embedding) - def _add_embeddings_internal(self, sentences: List[Sentence]): + def _add_embeddings_internal(self, sentences: list[Sentence]): pass def extra_repr(self): return f"fine_tune_mode={self.fine_tune_mode}, pooling={self.pooling}" @classmethod - def from_params(cls, params: Dict[str, Any]) -> "DocumentPoolEmbeddings": + def from_params(cls, params: dict[str, Any]) -> "DocumentPoolEmbeddings": embeddings = cast(StackedEmbeddings, load_embeddings(params.pop("embeddings"))).embeddings return cls(embeddings=embeddings, **params) - def to_params(self) -> Dict[str, Any]: + def to_params(self) -> dict[str, Any]: return { "pooling": self.pooling, "fine_tune_mode": self.fine_tune_mode, @@ -169,7 +169,7 @@ def to_params(self) -> Dict[str, Any]: class DocumentTFIDFEmbeddings(DocumentEmbeddings): def __init__( self, - train_dataset: List[Sentence], + train_dataset: list[Sentence], vectorizer: Optional[TfidfVectorizer] = None, **vectorizer_params, ) -> None: @@ -203,7 +203,7 @@ def __init__( def embedding_length(self) -> int: return self.__embedding_length - def embed(self, sentences: Union[List[Sentence], Sentence]): + def embed(self, sentences: Union[list[Sentence], Sentence]): """Add embeddings to every sentence in the given list of sentences.""" # if only one sentence is passed, convert to list of sentence if isinstance(sentences, Sentence): @@ -215,14 +215,14 @@ def embed(self, sentences: Union[List[Sentence], Sentence]): for sentence_id, sentence in enumerate(sentences): sentence.set_embedding(self.name, tfidf_vectors[sentence_id]) - def _add_embeddings_internal(self, sentences: List[Sentence]): + def _add_embeddings_internal(self, sentences: list[Sentence]): pass @classmethod - def from_params(cls, params: Dict[str, Any]) -> "DocumentTFIDFEmbeddings": + def from_params(cls, params: dict[str, Any]) -> "DocumentTFIDFEmbeddings": return cls(train_dataset=[], vectorizer=params["vectorizer"]) - def to_params(self) -> Dict[str, Any]: + def to_params(self) -> dict[str, Any]: return { "vectorizer": self.vectorizer, } @@ -232,7 +232,7 @@ def to_params(self) -> Dict[str, Any]: class DocumentRNNEmbeddings(DocumentEmbeddings): def __init__( self, - embeddings: List[TokenEmbeddings], + embeddings: list[TokenEmbeddings], hidden_size=128, rnn_layers=1, reproject_words: bool = True, @@ -317,7 +317,7 @@ def __init__( def embedding_length(self) -> int: return self.__embedding_length - def _add_embeddings_internal(self, sentences: List[Sentence]): + def _add_embeddings_internal(self, sentences: list[Sentence]): """Add embeddings to all sentences in the given list of sentences. If embeddings are already added, update only if embeddings are non-static. @@ -332,7 +332,7 @@ def _add_embeddings_internal(self, sentences: List[Sentence]): # embed words in the sentence self.embeddings.embed(sentences) - lengths: List[int] = [len(sentence.tokens) for sentence in sentences] + lengths: list[int] = [len(sentence.tokens) for sentence in sentences] longest_token_sequence_in_batch: int = max(lengths) pre_allocated_zero_tensor = torch.zeros( @@ -341,7 +341,7 @@ def _add_embeddings_internal(self, sentences: List[Sentence]): device=flair.device, ) - all_embs: List[torch.Tensor] = [] + all_embs: list[torch.Tensor] = [] for sentence in sentences: all_embs += [emb for token in sentence for emb in token.get_each_embedding()] nb_padding_tokens = longest_token_sequence_in_batch - len(sentence) @@ -436,7 +436,7 @@ def to_params(self): return model_state @classmethod - def from_params(cls, params: Dict[str, Any]) -> "DocumentRNNEmbeddings": + def from_params(cls, params: dict[str, Any]) -> "DocumentRNNEmbeddings": stacked_embeddings = load_embeddings(params["embeddings"]) assert isinstance(stacked_embeddings, StackedEmbeddings) return cls( @@ -484,7 +484,7 @@ def __setstate__(self, d): @register_embeddings class DocumentLMEmbeddings(DocumentEmbeddings): - def __init__(self, flair_embeddings: List[FlairEmbeddings]) -> None: + def __init__(self, flair_embeddings: list[FlairEmbeddings]) -> None: super().__init__() self.embeddings = flair_embeddings @@ -503,7 +503,7 @@ def __init__(self, flair_embeddings: List[FlairEmbeddings]) -> None: def embedding_length(self) -> int: return self._embedding_length - def _add_embeddings_internal(self, sentences: List[Sentence]): + def _add_embeddings_internal(self, sentences: list[Sentence]): for embedding in self.embeddings: embedding.embed(sentences) @@ -520,17 +520,17 @@ def _add_embeddings_internal(self, sentences: List[Sentence]): return sentences - def get_names(self) -> List[str]: + def get_names(self) -> list[str]: if "__names" not in self.__dict__: self.__names = [name for embedding in self.embeddings for name in embedding.get_names()] return self.__names - def to_params(self) -> Dict[str, Any]: + def to_params(self) -> dict[str, Any]: return {"flair_embeddings": [embedding.save_embeddings(False) for embedding in self.embeddings]} @classmethod - def from_params(cls, params: Dict[str, Any]) -> "DocumentLMEmbeddings": + def from_params(cls, params: dict[str, Any]) -> "DocumentLMEmbeddings": return cls([cast(FlairEmbeddings, load_embeddings(embedding)) for embedding in params["flair_embeddings"]]) @@ -566,7 +566,7 @@ def __init__( self.static_embeddings = True self.eval() - def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]: + def _add_embeddings_internal(self, sentences: list[Sentence]) -> list[Sentence]: sentence_batches = [ sentences[i * self.batch_size : (i + 1) * self.batch_size] for i in range((len(sentences) + self.batch_size - 1) // self.batch_size) @@ -577,7 +577,7 @@ def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]: return sentences - def _add_embeddings_to_sentences(self, sentences: List[Sentence]): + def _add_embeddings_to_sentences(self, sentences: list[Sentence]): # convert to plain strings, embedded in a list for the encode function sentences_plain_text = [sentence.to_plain_string() for sentence in sentences] @@ -591,10 +591,10 @@ def embedding_length(self) -> int: return self.model.get_sentence_embedding_dimension() @classmethod - def from_params(cls, params: Dict[str, Any]) -> "SentenceTransformerDocumentEmbeddings": + def from_params(cls, params: dict[str, Any]) -> "SentenceTransformerDocumentEmbeddings": return cls(**params) - def to_params(self) -> Dict[str, Any]: + def to_params(self) -> dict[str, Any]: return { "model": self.model_name, "batch_size": self.batch_size, @@ -605,7 +605,7 @@ def to_params(self) -> Dict[str, Any]: class DocumentCNNEmbeddings(DocumentEmbeddings): def __init__( self, - embeddings: List[TokenEmbeddings], + embeddings: list[TokenEmbeddings], kernels=((100, 3), (100, 4), (100, 5)), reproject_words: bool = True, reproject_words_dimension: Optional[int] = None, @@ -673,7 +673,7 @@ def __init__( def embedding_length(self) -> int: return self.__embedding_length - def _add_embeddings_internal(self, sentences: List[Sentence]): + def _add_embeddings_internal(self, sentences: list[Sentence]): """Add embeddings to all sentences in the given list of sentences. If embeddings are already added, update only if embeddings are non-static. @@ -689,7 +689,7 @@ def _add_embeddings_internal(self, sentences: List[Sentence]): # embed words in the sentence self.embeddings.embed(sentences) - lengths: List[int] = [len(sentence.tokens) for sentence in sentences] + lengths: list[int] = [len(sentence.tokens) for sentence in sentences] padding_length: int = max(max(lengths), self.min_sequence_length) pre_allocated_zero_tensor = torch.zeros( @@ -698,7 +698,7 @@ def _add_embeddings_internal(self, sentences: List[Sentence]): device=flair.device, ) - all_embs: List[torch.Tensor] = [] + all_embs: list[torch.Tensor] = [] for sentence in sentences: all_embs += [emb for token in sentence for emb in token.get_each_embedding()] nb_padding_tokens = padding_length - len(sentence) @@ -757,11 +757,11 @@ def _apply(self, fn): child_module._apply(fn) @classmethod - def from_params(cls, params: Dict[str, Any]) -> "DocumentCNNEmbeddings": + def from_params(cls, params: dict[str, Any]) -> "DocumentCNNEmbeddings": embeddings = cast(StackedEmbeddings, load_embeddings(params.pop("embeddings"))).embeddings return cls(embeddings=embeddings, **params) - def to_params(self) -> Dict[str, Any]: + def to_params(self) -> dict[str, Any]: return { "embeddings": self.embeddings.save_embeddings(False), "kernels": self.kernels, diff --git a/flair/embeddings/image.py b/flair/embeddings/image.py index df6d1fadd9..5d79a04390 100644 --- a/flair/embeddings/image.py +++ b/flair/embeddings/image.py @@ -1,5 +1,5 @@ import logging -from typing import Any, Dict, List, Optional +from typing import Any, Optional import torch import torch.nn.functional as F @@ -29,12 +29,12 @@ class ImageEmbeddings(Embeddings[Image]): def embedding_type(self) -> str: return "image-level" - def to_params(self) -> Dict[str, Any]: + def to_params(self) -> dict[str, Any]: # legacy pickle-like saving for image embeddings, as implementation details are not obvious return self.__getstate__() @classmethod - def from_params(cls, params: Dict[str, Any]) -> "Embeddings": + def from_params(cls, params: dict[str, Any]) -> "Embeddings": # legacy pickle-like loading for image embeddings, as implementation details are not obvious embedding = cls.__new__(cls) embedding.__setstate__(params) @@ -53,7 +53,7 @@ def __init__(self, transforms) -> None: self.static_embeddings = True super().__init__() - def _add_embeddings_internal(self, images: List[Image]): + def _add_embeddings_internal(self, images: list[Image]): for image in images: image_data = self.PIL.Image.open(image.imageURL) image_data.load() @@ -77,7 +77,7 @@ def __init__(self, url2tensor_dict, name) -> None: self.static_embeddings = True super().__init__() - def _add_embeddings_internal(self, images: List[Image]): + def _add_embeddings_internal(self, images: list[Image]): for image in images: if image.imageURL in self.url2tensor_dict: image.set_embedding(self.name, self.url2tensor_dict[image.imageURL]) @@ -137,7 +137,7 @@ def __init__(self, name, pretrained=True, transforms=None) -> None: else: raise Exception(f"Image embeddings {name} not available.") - def _add_embeddings_internal(self, images: List[Image]): + def _add_embeddings_internal(self, images: list[Image]): image_tensor = torch.stack([self.transforms(image.data) for image in images]) image_embeddings = self.features(image_tensor) image_embeddings = ( @@ -163,7 +163,7 @@ def __init__(self, feats_in, convnet_parms, posnet_parms, transformer_parms) -> adaptive_pool_func_map = {"max": AdaptiveMaxPool2d, "avg": AdaptiveAvgPool2d} - convnet_arch: List[Any] = [] if convnet_parms["dropout"][0] <= 0 else [Dropout2d(convnet_parms["dropout"][0])] + convnet_arch: list[Any] = [] if convnet_parms["dropout"][0] <= 0 else [Dropout2d(convnet_parms["dropout"][0])] convnet_arch.extend( [ Conv2d( @@ -266,7 +266,7 @@ def forward(self, x): return x - def _add_embeddings_internal(self, images: List[Image]): + def _add_embeddings_internal(self, images: list[Image]): image_tensor = torch.stack([image.data for image in images]) image_embeddings = self.forward(image_tensor) for image_id, image in enumerate(images): diff --git a/flair/embeddings/legacy.py b/flair/embeddings/legacy.py index b2658e2d2f..4b3d2a9517 100644 --- a/flair/embeddings/legacy.py +++ b/flair/embeddings/legacy.py @@ -1,7 +1,7 @@ import logging import re from pathlib import Path -from typing import List, Optional, Union +from typing import Optional, Union import torch from deprecated.sphinx import deprecated @@ -110,12 +110,12 @@ def use_layers_top(self, x): def use_layers_average(self, x): return torch.mean(torch.stack(x), 0) - def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]: + def _add_embeddings_internal(self, sentences: list[Sentence]) -> list[Sentence]: # ELMoEmbeddings before Release 0.5 did not set self.embedding_mode_fn if not getattr(self, "embedding_mode_fn", None): self.embedding_mode_fn = self.use_layers_all - sentence_words: List[List[str]] = [] + sentence_words: list[list[str]] = [] for sentence in sentences: sentence_words.append([token.text for token in sentence]) @@ -394,7 +394,7 @@ def __getstate__(self): def embedding_length(self) -> int: return self.__embedding_length - def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]: + def _add_embeddings_internal(self, sentences: list[Sentence]) -> list[Sentence]: # if cache is used, try setting embeddings from cache first if "cache" in self.__dict__ and self.cache is not None: # try populating embeddings from cache @@ -463,7 +463,7 @@ class DocumentMeanEmbeddings(DocumentEmbeddings): version="0.3.1", reason="The functionality of this class is moved to 'DocumentPoolEmbeddings'", ) - def __init__(self, token_embeddings: List[TokenEmbeddings]) -> None: + def __init__(self, token_embeddings: list[TokenEmbeddings]) -> None: """The constructor takes a list of embeddings to be combined.""" super().__init__() @@ -478,7 +478,7 @@ def __init__(self, token_embeddings: List[TokenEmbeddings]) -> None: def embedding_length(self) -> int: return self.__embedding_length - def embed(self, sentences: Union[List[Sentence], Sentence]): + def embed(self, sentences: Union[list[Sentence], Sentence]): """Add embeddings to every sentence in the given list of sentences. If embeddings are already added, updates only if embeddings are non-static. """ @@ -506,7 +506,7 @@ def embed(self, sentences: Union[List[Sentence], Sentence]): sentence.set_embedding(self.name, mean_embedding) - def _add_embeddings_internal(self, sentences: List[Sentence]): + def _add_embeddings_internal(self, sentences: list[Sentence]): pass @@ -517,7 +517,7 @@ class DocumentLSTMEmbeddings(DocumentEmbeddings): ) def __init__( self, - embeddings: List[TokenEmbeddings], + embeddings: list[TokenEmbeddings], hidden_size=128, rnn_layers=1, reproject_words: bool = True, @@ -587,7 +587,7 @@ def __init__( def embedding_length(self) -> int: return self.__embedding_length - def embed(self, sentences: Union[List[Sentence], Sentence]): + def embed(self, sentences: Union[list[Sentence], Sentence]): """Add embeddings to all sentences in the given list of sentences. If embeddings are already added, update only if embeddings are non-static. """ @@ -604,7 +604,7 @@ def embed(self, sentences: Union[List[Sentence], Sentence]): longest_token_sequence_in_batch: int = len(sentences[0]) all_sentence_tensors = [] - lengths: List[int] = [] + lengths: list[int] = [] # go through each sentence in batch for _i, sentence in enumerate(sentences): @@ -669,5 +669,5 @@ def embed(self, sentences: Union[List[Sentence], Sentence]): sentence = sentences[sentence_no] sentence.set_embedding(self.name, embedding) - def _add_embeddings_internal(self, sentences: List[Sentence]): + def _add_embeddings_internal(self, sentences: list[Sentence]): pass diff --git a/flair/embeddings/token.py b/flair/embeddings/token.py index b068305800..700eaf4c45 100644 --- a/flair/embeddings/token.py +++ b/flair/embeddings/token.py @@ -4,7 +4,7 @@ import tempfile from collections import Counter from pathlib import Path -from typing import Any, Dict, List, Optional, Union +from typing import Any, Optional, Union import numpy as np import torch @@ -64,7 +64,7 @@ def create_from_state(cls, **state): class StackedEmbeddings(TokenEmbeddings): """A stack of embeddings, used if you need to combine several different embedding types.""" - def __init__(self, embeddings: List[TokenEmbeddings], overwrite_names: bool = True) -> None: + def __init__(self, embeddings: list[TokenEmbeddings], overwrite_names: bool = True) -> None: """The constructor takes a list of embeddings to be combined.""" super().__init__() @@ -88,7 +88,7 @@ def __init__(self, embeddings: List[TokenEmbeddings], overwrite_names: bool = Tr self.__embedding_length += embedding.embedding_length self.eval() - def embed(self, sentences: Union[Sentence, List[Sentence]], static_embeddings: bool = True): + def embed(self, sentences: Union[Sentence, list[Sentence]], static_embeddings: bool = True): # if only one sentence is passed, convert to list of sentence if type(sentences) is Sentence: sentences = [sentences] @@ -104,7 +104,7 @@ def embedding_type(self) -> str: def embedding_length(self) -> int: return self.__embedding_length - def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]: + def _add_embeddings_internal(self, sentences: list[Sentence]) -> list[Sentence]: for embedding in self.embeddings: embedding._add_embeddings_internal(sentences) @@ -113,7 +113,7 @@ def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]: def __str__(self) -> str: return f'StackedEmbeddings [{",".join([str(e) for e in self.embeddings])}]' - def get_names(self) -> List[str]: + def get_names(self) -> list[str]: """Returns a list of embedding names. In most cases, it is just a list with one item, namely the name of this embedding. But in some cases, the @@ -126,13 +126,6 @@ def get_names(self) -> List[str]: return self.__names - def get_named_embeddings_dict(self) -> Dict: - named_embeddings_dict = {} - for embedding in self.embeddings: - named_embeddings_dict.update(embedding.get_named_embeddings_dict()) - - return named_embeddings_dict - @classmethod def from_params(cls, params): embeddings = [load_embeddings(p) for p in params["embeddings"]] @@ -154,7 +147,7 @@ def __init__( force_cpu: bool = True, stable: bool = False, no_header: bool = False, - vocab: Optional[Dict[str, int]] = None, + vocab: Optional[dict[str, int]] = None, embedding_length: Optional[int] = None, name: Optional[str] = None, ) -> None: @@ -334,10 +327,10 @@ def get_cached_token_index(self, word: str) -> int: else: return len(self.vocab) # token - def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]: + def _add_embeddings_internal(self, sentences: list[Sentence]) -> list[Sentence]: tokens = [token for sentence in sentences for token in sentence.tokens] - word_indices: List[int] = [] + word_indices: list[int] = [] for token in tokens: word = token.text if self.field is None else token.get_label(self.field).value word_indices.append(self.get_cached_token_index(word)) @@ -386,7 +379,7 @@ def __getattribute__(self, item): return None return super().__getattribute__(item) - def __setstate__(self, state: Dict[str, Any]): + def __setstate__(self, state: dict[str, Any]): state.pop("get_cached_vec", None) state.setdefault("embeddings", state["name"]) state.setdefault("force_cpu", True) @@ -416,10 +409,10 @@ def __setstate__(self, state: Dict[str, Any]): super().__setstate__(state) @classmethod - def from_params(cls, params: Dict[str, Any]) -> "WordEmbeddings": + def from_params(cls, params: dict[str, Any]) -> "WordEmbeddings": return cls(embeddings=None, **params) - def to_params(self) -> Dict[str, Any]: + def to_params(self) -> dict[str, Any]: return { "vocab": self.vocab, "stable": self.stable, @@ -487,7 +480,7 @@ def __init__( def embedding_length(self) -> int: return self.__embedding_length - def _add_embeddings_internal(self, sentences: List[Sentence]): + def _add_embeddings_internal(self, sentences: list[Sentence]): for sentence in sentences: tokens_char_indices = [] @@ -544,10 +537,10 @@ def __str__(self) -> str: return self.name @classmethod - def from_params(cls, params: Dict[str, Any]) -> "CharacterEmbeddings": + def from_params(cls, params: dict[str, Any]) -> "CharacterEmbeddings": return cls(**params) - def to_params(self) -> Dict[str, Any]: + def to_params(self) -> dict[str, Any]: return { "path_to_char_dict": self.char_dictionary, "char_embedding_dim": self.char_embedding_dim, @@ -793,7 +786,7 @@ def train(self, mode=True): def embedding_length(self) -> int: return self.__embedding_length - def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]: + def _add_embeddings_internal(self, sentences: list[Sentence]) -> list[Sentence]: # gradients are enable if fine-tuning is enabled gradient_context = torch.enable_grad() if self.fine_tune else torch.no_grad() @@ -885,7 +878,7 @@ def from_params(cls, params): lm = LanguageModel(**model_params) return cls(lm, **params) - def __setstate__(self, d: Dict[str, Any]): + def __setstate__(self, d: dict[str, Any]): # make compatible with old models d.setdefault("fine_tune", False) d.setdefault("chars_per_chunk", 512) @@ -920,8 +913,8 @@ def __init__( self.name = self.context_embeddings.name + "-context" # these fields are for the embedding memory - self.word_embeddings: Dict[str, torch.Tensor] = {} - self.word_count: Dict[str, int] = {} + self.word_embeddings: dict[str, torch.Tensor] = {} + self.word_count: dict[str, int] = {} # whether to add only capitalized words to memory (faster runtime and lower memory consumption) self.only_capitalized = only_capitalized @@ -940,7 +933,7 @@ def train(self, mode=True): self.word_embeddings = {} self.word_count = {} - def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]: + def _add_embeddings_internal(self, sentences: list[Sentence]) -> list[Sentence]: self.context_embeddings.embed(sentences) # if we keep a pooling, it needs to be updated continuously @@ -989,10 +982,10 @@ def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]: def embedding_length(self) -> int: return self.__embedding_length - def get_names(self) -> List[str]: + def get_names(self) -> list[str]: return [self.name, self.context_embeddings.name] - def __setstate__(self, d: Dict[str, Any]): + def __setstate__(self, d: dict[str, Any]): super().__setstate__(d) if flair.device.type != "cpu": @@ -1073,7 +1066,7 @@ def get_cached_vec(self, word: str) -> torch.Tensor: word_embedding = torch.tensor(word_embedding.tolist(), device=flair.device, dtype=torch.float) return word_embedding - def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]: + def _add_embeddings_internal(self, sentences: list[Sentence]) -> list[Sentence]: for sentence in sentences: for token in sentence.tokens: word = token.text if self.field is None else token.get_label(self.field).value @@ -1152,7 +1145,7 @@ def __init__( def embedding_length(self) -> int: return self.__embedding_length - def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]: + def _add_embeddings_internal(self, sentences: list[Sentence]) -> list[Sentence]: tokens = [t for sentence in sentences for t in sentence.tokens] if self.field == "text": @@ -1240,7 +1233,7 @@ def num_embeddings(self) -> int: def embedding_length(self) -> int: return self.__embedding_length - def _add_embeddings_internal(self, sentences: List[Sentence]): + def _add_embeddings_internal(self, sentences: list[Sentence]): def get_idx_for_item(text): hash_function = hashlib.new(self.__hash_method) hash_function.update(bytes(str(text), "utf-8")) @@ -1282,7 +1275,7 @@ def __init__( self.name: str = "muse-crosslingual" self.static_embeddings = True self.__embedding_length: int = 300 - self.language_embeddings: Dict[str, Any] = {} + self.language_embeddings: dict[str, Any] = {} (KeyedVectors,) = lazy_import("word-embeddings", "gensim.models", "KeyedVectors") self.kv = KeyedVectors super().__init__() @@ -1304,7 +1297,7 @@ def get_cached_vec(self, language_code: str, word: str) -> torch.Tensor: word_embedding = torch.tensor(word_embedding, device=flair.device, dtype=torch.float) return word_embedding - def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]: + def _add_embeddings_internal(self, sentences: list[Sentence]) -> list[Sentence]: for _i, sentence in enumerate(sentences): language_code = sentence.get_language_code() supported = [ @@ -1465,10 +1458,10 @@ def _preprocess(self, text: str) -> str: def embedding_length(self) -> int: return self.__embedding_length - def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]: + def _add_embeddings_internal(self, sentences: list[Sentence]) -> list[Sentence]: tokens = [token for sentence in sentences for token in sentence.tokens] - word_indices: List[List[int]] = [] + word_indices: list[list[int]] = [] for token in tokens: word = token.text if self.field is None else token.get_label(self.field).value @@ -1601,13 +1594,13 @@ def __init__(self, embeddings: str, model: str = "skip", size: int = 100) -> Non else: embeddings_path = embeddings - log.info("Reading embeddings from %s" % embeddings_path) + log.info("Reading embeddings from %s", embeddings_path) super().__init__( embeddings=str(extract_single_zip_file(embeddings_path, cache_dir=cache_dir)), name="NILC-" + embeddings ) @classmethod - def from_params(cls, params: Dict[str, Any]) -> "WordEmbeddings": + def from_params(cls, params: dict[str, Any]) -> "WordEmbeddings": # no need to recreate as NILCEmbeddings return WordEmbeddings(embeddings=None, **params) diff --git a/flair/embeddings/transformer.py b/flair/embeddings/transformer.py index 1e88787deb..d09ed33699 100644 --- a/flair/embeddings/transformer.py +++ b/flair/embeddings/transformer.py @@ -8,7 +8,7 @@ from abc import abstractmethod from io import BytesIO from pathlib import Path -from typing import Any, Dict, List, Literal, Optional, Tuple, Type, Union, cast +from typing import Any, Literal, Optional, Union, cast import torch import transformers @@ -44,7 +44,7 @@ @torch.jit.script_if_tracing -def pad_sequence_embeddings(all_hidden_states: List[torch.Tensor]) -> torch.Tensor: +def pad_sequence_embeddings(all_hidden_states: list[torch.Tensor]) -> torch.Tensor: embedding_length = all_hidden_states[0].shape[1] longest_token_sequence_in_batch = 0 for hidden_states in all_hidden_states: @@ -218,13 +218,13 @@ def document_max_pooling(sentence_hidden_states: torch.Tensor, sentence_lengths: def _legacy_reconstruct_word_ids( - embedding: "TransformerBaseEmbeddings", flair_tokens: List[List[str]] -) -> List[List[Optional[int]]]: + embedding: "TransformerBaseEmbeddings", flair_tokens: list[list[str]] +) -> list[list[Optional[int]]]: word_ids_list = [] max_len = 0 for tokens in flair_tokens: token_texts = embedding.tokenizer.tokenize(" ".join(tokens), is_split_into_words=True) - token_ids = cast(List[int], embedding.tokenizer.convert_tokens_to_ids(token_texts)) + token_ids = cast(list[int], embedding.tokenizer.convert_tokens_to_ids(token_texts)) expanded_token_ids = embedding.tokenizer.build_inputs_with_special_tokens(token_ids) j = 0 for _i, token_id in enumerate(token_ids): @@ -264,10 +264,10 @@ def _get_processed_token_text(tokenizer, token: str) -> str: return token_text.strip() -def _reconstruct_word_ids_from_subtokens(embedding, tokens: List[str], subtokens: List[str]): +def _reconstruct_word_ids_from_subtokens(embedding, tokens: list[str], subtokens: list[str]): word_iterator = iter(enumerate(_get_processed_token_text(embedding.tokenizer, token) for token in tokens)) token_id, token_text = next(word_iterator) - word_ids: List[Optional[int]] = [] + word_ids: list[Optional[int]] = [] reconstructed_token = "" subtoken_count = 0 processed_first_token = False @@ -504,10 +504,10 @@ def embedding_type(self) -> str: return "word-level" if self.token_embedding else "sentence-level" @abstractmethod - def _forward_tensors(self, tensors) -> Dict[str, torch.Tensor]: + def _forward_tensors(self, tensors) -> dict[str, torch.Tensor]: return self(**tensors) - def prepare_tensors(self, sentences: List[Sentence], device: Optional[torch.device] = None): + def prepare_tensors(self, sentences: list[Sentence], device: Optional[torch.device] = None): if device is None: device = flair.device flair_tokens, offsets, lengths = self.__gather_flair_tokens(sentences) @@ -535,13 +535,13 @@ def prepare_tensors(self, sentences: List[Sentence], device: Optional[torch.devi def __build_transformer_model_inputs( self, - sentences: List[Sentence], - offsets: List[int], - sentence_lengths: List[int], - flair_tokens: List[List[Token]], + sentences: list[Sentence], + offsets: list[int], + sentence_lengths: list[int], + flair_tokens: list[list[Token]], device: torch.device, ): - tokenizer_kwargs: Dict[str, Any] = {} + tokenizer_kwargs: dict[str, Any] = {} if self.tokenizer_needs_ocr_boxes: tokenizer_kwargs["boxes"] = [[t.get_metadata("bbox") for t in tokens] for tokens in flair_tokens] else: @@ -662,7 +662,7 @@ def __build_transformer_model_inputs( return model_kwargs - def __gather_flair_tokens(self, sentences: List[Sentence]) -> Tuple[List[List[Token]], List[int], List[int]]: + def __gather_flair_tokens(self, sentences: list[Sentence]) -> tuple[list[list[Token]], list[int], list[int]]: offsets = [] lengths = [] if self.context_length > 0: @@ -686,7 +686,7 @@ def __gather_flair_tokens(self, sentences: List[Sentence]) -> Tuple[List[List[To lengths.append(len(sentence)) return sentence_tokens, offsets, lengths - def _expand_sentence_with_context(self, sentence) -> Tuple[List[Token], int]: + def _expand_sentence_with_context(self, sentence) -> tuple[list[Token], int]: # fields to store left and right context left_context = [] right_context = [] @@ -722,7 +722,7 @@ def __extract_token_embeddings(self, sentence_embeddings, sentences): for token_embedding, token in zip(token_embeddings, sentence): token.set_embedding(self.name, token_embedding) - def _add_embeddings_internal(self, sentences: List[Sentence]): + def _add_embeddings_internal(self, sentences: list[Sentence]): tensors = self.prepare_tensors(sentences, device=self.force_device) gradient_context = torch.enable_grad() if (self.fine_tune and self.training) else torch.no_grad() with gradient_context: @@ -739,7 +739,7 @@ def _add_embeddings_internal(self, sentences: List[Sentence]): @register_embeddings class TransformerOnnxEmbeddings(TransformerBaseEmbeddings): - def __init__(self, onnx_model: str, providers: List = [], session_options: Optional[Dict] = None, **kwargs) -> None: + def __init__(self, onnx_model: str, providers: list = [], session_options: Optional[dict] = None, **kwargs) -> None: # onnx prepares numpy arrays, no mather if it runs on gpu or cpu, the input is on cpu first. super().__init__(**kwargs, force_device=torch.device("cpu")) self.onnx_model = onnx_model @@ -756,7 +756,7 @@ def to_params(self): return params @classmethod - def from_params(cls, params: Dict[str, Any]) -> "TransformerOnnxEmbeddings": + def from_params(cls, params: dict[str, Any]) -> "TransformerOnnxEmbeddings": params["tokenizer"] = cls._tokenizer_from_bytes(params.pop("tokenizer_data")) params["feature_extractor"] = cls._feature_extractor_from_bytes(params.pop("feature_extractor_data", None)) return cls(**params) @@ -812,7 +812,7 @@ def quantize_model(self, quantize_model_path, use_external_data_format: bool = F self.onnx_model = quantize_model_path self.create_session() - def _forward_tensors(self, tensors) -> Dict[str, torch.Tensor]: + def _forward_tensors(self, tensors) -> dict[str, torch.Tensor]: input_array = {k: v.numpy() for k, v in tensors.items()} embeddings = self.session.run([], input_array) @@ -854,9 +854,9 @@ def export_from_embedding( cls, path: Union[str, Path], embedding: "TransformerEmbeddings", - example_sentences: List[Sentence], + example_sentences: list[Sentence], opset_version: int = 14, - providers: Optional[List] = None, + providers: Optional[list] = None, session_options: Optional[dict] = None, ): path = str(path) @@ -903,7 +903,7 @@ def export_from_embedding( @register_embeddings class TransformerJitEmbeddings(TransformerBaseEmbeddings): - def __init__(self, jit_model: Union[bytes, ScriptModule], param_names: List[str], **kwargs) -> None: + def __init__(self, jit_model: Union[bytes, ScriptModule], param_names: list[str], **kwargs) -> None: super().__init__(**kwargs) if isinstance(jit_model, bytes): buffer = BytesIO(jit_model) @@ -925,12 +925,12 @@ def to_params(self): return state @classmethod - def from_params(cls, params: Dict[str, Any]) -> "Embeddings": + def from_params(cls, params: dict[str, Any]) -> "Embeddings": params["tokenizer"] = cls._tokenizer_from_bytes(params.pop("tokenizer_data")) params["feature_extractor"] = cls._feature_extractor_from_bytes(params.pop("feature_extractor_data", None)) return cls(**params) - def _forward_tensors(self, tensors) -> Dict[str, torch.Tensor]: + def _forward_tensors(self, tensors) -> dict[str, torch.Tensor]: parameters = [] for param in self.param_names: parameters.append(tensors[param]) @@ -945,13 +945,13 @@ def _forward_tensors(self, tensors) -> Dict[str, torch.Tensor]: raise ValueError("either 'token_embedding' or 'document_embedding' needs to be set.") @classmethod - def create_from_embedding(cls, module: ScriptModule, embedding: "TransformerEmbeddings", param_names: List[str]): + def create_from_embedding(cls, module: ScriptModule, embedding: "TransformerEmbeddings", param_names: list[str]): return cls(jit_model=module, param_names=param_names, **embedding.to_args()) @classmethod def parameter_to_list( - cls, embedding: "TransformerEmbeddings", wrapper: torch.nn.Module, sentences: List[Sentence] - ) -> Tuple[List[str], List[torch.Tensor]]: + cls, embedding: "TransformerEmbeddings", wrapper: torch.nn.Module, sentences: list[Sentence] + ) -> tuple[list[str], list[torch.Tensor]]: tensors = embedding.prepare_tensors(sentences) param_names = list(inspect.signature(wrapper.forward).parameters.keys()) params = [] @@ -998,7 +998,7 @@ def __init__( @register_embeddings class TransformerEmbeddings(TransformerBaseEmbeddings): - onnx_cls: Type[TransformerOnnxEmbeddings] = TransformerOnnxEmbeddings + onnx_cls: type[TransformerOnnxEmbeddings] = TransformerOnnxEmbeddings def __init__( self, @@ -1021,11 +1021,11 @@ def __init__( force_max_length: bool = False, needs_manual_ocr: Optional[bool] = None, use_context_separator: bool = True, - transformers_tokenizer_kwargs: Dict[str, Any] = {}, - transformers_config_kwargs: Dict[str, Any] = {}, - transformers_model_kwargs: Dict[str, Any] = {}, + transformers_tokenizer_kwargs: dict[str, Any] = {}, + transformers_config_kwargs: dict[str, Any] = {}, + transformers_model_kwargs: dict[str, Any] = {}, peft_config=None, - peft_gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = {}, + peft_gradient_checkpointing_kwargs: Optional[dict[str, Any]] = {}, **kwargs, ) -> None: """Instantiate transformers embeddings. @@ -1503,11 +1503,11 @@ def forward( result["token_embeddings"] = all_token_embeddings return result - def _forward_tensors(self, tensors) -> Dict[str, torch.Tensor]: + def _forward_tensors(self, tensors) -> dict[str, torch.Tensor]: return self.forward(**tensors) def export_onnx( - self, path: Union[str, Path], example_sentences: List[Sentence], **kwargs + self, path: Union[str, Path], example_sentences: list[Sentence], **kwargs ) -> TransformerOnnxEmbeddings: """Export TransformerEmbeddings to OnnxFormat. diff --git a/flair/file_utils.py b/flair/file_utils.py index 7a0118822b..518d69e809 100644 --- a/flair/file_utils.py +++ b/flair/file_utils.py @@ -12,8 +12,9 @@ import typing import warnings import zipfile +from collections.abc import Sequence from pathlib import Path -from typing import Optional, Sequence, Tuple, Union, cast +from typing import Optional, Union, cast from urllib.parse import urlparse import boto3 @@ -28,10 +29,10 @@ logger = logging.getLogger("flair") -url_proxies: Optional[typing.Dict[str, str]] = None +url_proxies: Optional[dict[str, str]] = None -def set_proxies(proxies: typing.Dict[str, str]) -> None: +def set_proxies(proxies: dict[str, str]) -> None: r"""Allows for data downloaded from urls to be forwarded to a proxy. see https://requests.readthedocs.io/en/latest/user/advanced/#proxies @@ -74,7 +75,7 @@ def url_to_filename(url: str, etag: Optional[str] = None) -> str: return decoded -def filename_to_url(filename: str) -> Tuple[str, Optional[str]]: +def filename_to_url(filename: str) -> tuple[str, Optional[str]]: """Recovers the the url from the encoded filename. Returns it and the ETag (which may be ``None``) @@ -374,7 +375,7 @@ def create_cache(self, *args, **kwargs): return decorator -def load_torch_state(model_file: str) -> typing.Dict[str, typing.Any]: +def load_torch_state(model_file: str) -> dict[str, typing.Any]: with warnings.catch_warnings(): warnings.filterwarnings("ignore") # load_big_file is a workaround byhttps://github.com/highway11git diff --git a/flair/inference_utils.py b/flair/inference_utils.py index 0310671534..c811bf39a1 100644 --- a/flair/inference_utils.py +++ b/flair/inference_utils.py @@ -126,7 +126,7 @@ def create_stores(model, backend="sqlite"): Also deletes the original vectors to save memory. """ for embedding in WordEmbeddingsStore._word_embeddings(model): - if type(embedding) == WordEmbeddings: + if isinstance(embedding, WordEmbeddings): WordEmbeddingsStore(embedding, backend) del embedding.precomputed_word_embeddings @@ -135,7 +135,7 @@ def load_stores(model, backend="sqlite"): """Loads the db versions of all word embeddings in the model.""" embeds = WordEmbeddingsStore._word_embeddings(model) for i, embedding in enumerate(embeds): - if type(embedding) == WordEmbeddings: + if isinstance(embedding, WordEmbeddings): embeds[i] = WordEmbeddingsStore(embedding, backend) @staticmethod diff --git a/flair/models/entity_linker_model.py b/flair/models/entity_linker_model.py index 9f516a703c..0f1c916bfb 100644 --- a/flair/models/entity_linker_model.py +++ b/flair/models/entity_linker_model.py @@ -2,7 +2,7 @@ import re from functools import lru_cache from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Set, Union, cast +from typing import Any, Callable, Optional, Union, cast from unicodedata import category import torch @@ -19,9 +19,9 @@ class CandidateGenerator: """Given a string, the CandidateGenerator returns possible target classes as candidates.""" - def __init__(self, candidates: Union[str, Dict[str, List[str]]], backoff: bool = True) -> None: + def __init__(self, candidates: Union[str, dict[str, list[str]]], backoff: bool = True) -> None: # internal candidate lists of generator - self.mention_to_candidates_map: Dict = {} + self.mention_to_candidates_map: dict[str, list[str]] = {} # load Zelda candidates if so passed if isinstance(candidates, str) and candidates.lower() == "zelda": @@ -39,16 +39,15 @@ def __init__(self, candidates: Union[str, Dict[str, List[str]]], backoff: bool = self.mention_to_candidates_map = candidate_lists - elif isinstance(candidates, Dict): + elif isinstance(candidates, dict): self.mention_to_candidates_map = candidates else: raise ValueError(f"'{candidates}' could not be loaded.") - self.mention_to_candidates_map = cast(Dict[str, List[str]], self.mention_to_candidates_map) # if lower casing is enabled, create candidate lists of lower cased versions self.backoff = backoff if self.backoff: # create a new dictionary for lower cased mentions - lowercased_mention_to_candidates_map: Dict = {} + lowercased_mention_to_candidates_map: dict[str, list[str]] = {} # go through each mention and its candidates for mention, candidates_list in self.mention_to_candidates_map.items(): @@ -56,8 +55,8 @@ def __init__(self, candidates: Union[str, Dict[str, List[str]]], backoff: bool = # check if backoff mention already seen. If so, add candidates. Else, create new entry. if backoff_mention in lowercased_mention_to_candidates_map: current_candidates = lowercased_mention_to_candidates_map[backoff_mention] - lowercased_mention_to_candidates_map[backoff_mention] = set(current_candidates).union( - candidates_list + lowercased_mention_to_candidates_map[backoff_mention] = list( + set(current_candidates).union(candidates_list) ) else: lowercased_mention_to_candidates_map[backoff_mention] = candidates_list @@ -72,7 +71,7 @@ def _make_backoff_string(self, mention: str) -> str: backoff_mention = re.sub(" +", " ", backoff_mention) return backoff_mention - def get_candidates(self, mention: str) -> Set[str]: + def get_candidates(self, mention: str) -> set[str]: """Given a mention, this method returns a set of candidate classes.""" if self.backoff: mention = self._make_backoff_string(mention) @@ -125,7 +124,7 @@ def __init__( self._label_type = label_type self._span_label_type = span_label_type - cases: Dict[str, Callable[[Span, List[str]], torch.Tensor]] = { + cases: dict[str, Callable[[Span, list[str]], torch.Tensor]] = { "average": self.emb_mean, "first": self.emb_first, "last": self.emb_last, @@ -155,7 +154,7 @@ def emb_firstAndLast(self, span: Span, embedding_names): def emb_mean(self, span, embedding_names): return torch.mean(torch.stack([token.get_embedding(embedding_names) for token in span], 0), 0) - def _get_data_points_from_sentence(self, sentence: Sentence) -> List[Span]: + def _get_data_points_from_sentence(self, sentence: Sentence) -> list[Span]: if self._span_label_type is not None: spans = sentence.get_spans(self._span_label_type) # only use span label type if there are predictions, otherwise search for output label type (training labels) @@ -223,7 +222,7 @@ def _init_model_with_state_dict(cls, state, **kwargs): def label_type(self): return self._label_type - def _mask_scores(self, scores: torch.Tensor, data_points: List[Span]): + def _mask_scores(self, scores: torch.Tensor, data_points: list[Span]): if not self.candidates: return scores @@ -242,9 +241,7 @@ def _mask_scores(self, scores: torch.Tensor, data_points: List[Span]): return masked_scores @classmethod - def load(cls, model_path: Union[str, Path, Dict[str, Any]]) -> "SpanClassifier": - from typing import cast - + def load(cls, model_path: Union[str, Path, dict[str, Any]]) -> "SpanClassifier": return cast("SpanClassifier", super().load(model_path=model_path)) diff --git a/flair/models/entity_mention_linking.py b/flair/models/entity_mention_linking.py index cecf2d9e57..5a1382dd60 100644 --- a/flair/models/entity_mention_linking.py +++ b/flair/models/entity_mention_linking.py @@ -4,9 +4,10 @@ import re import string from abc import ABC, abstractmethod +from collections.abc import Sequence from enum import Enum, auto from pathlib import Path -from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Type, Union, cast +from typing import Any, Optional, Union, cast import numpy as np import torch @@ -89,7 +90,7 @@ "chemical": "ctd-chemicals", } -BIOMEDICAL_DICTIONARIES: Dict[str, Type] = { +BIOMEDICAL_DICTIONARIES: dict[str, type] = { "ctd-diseases": CTD_DISEASES_DICTIONARY, "ctd-chemicals": CTD_CHEMICALS_DICTIONARY, "ncbi-gene": NCBI_GENE_HUMAN_DICTIONARY, @@ -151,7 +152,7 @@ def load_dictionary( class EntityPreprocessor(ABC): """A pre-processor used to transform / clean both entity mentions and entity names.""" - def initialize(self, sentences: List[Sentence]) -> None: + def initialize(self, sentences: list[Sentence]) -> None: """Initializes the pre-processor for a batch of sentences. This may be necessary for more sophisticated transformations. @@ -187,14 +188,14 @@ def process_entity_name(self, entity_name: str) -> str: """ @classmethod - def _from_state(cls, state_dict: Dict[str, Any]) -> "EntityPreprocessor": + def _from_state(cls, state_dict: dict[str, Any]) -> "EntityPreprocessor": if inspect.isabstract(cls): cls_name = state_dict.pop("__cls__", None) return get_state_subclass_by_name(cls, cls_name)._from_state(state_dict) else: return cls(**state_dict) - def _get_state(self) -> Dict[str, Any]: + def _get_state(self) -> dict[str, Any]: return {"__cls__": self.__class__.__name__} @@ -237,7 +238,7 @@ def process_entity_name(self, entity_name: str) -> str: return entity_name - def _get_state(self) -> Dict[str, Any]: + def _get_state(self) -> dict[str, Any]: return { **super()._get_state(), "lowercase": self.lowercase, @@ -270,9 +271,9 @@ def __init__( self.ab3p = pyab3p.Ab3p() self.preprocessor = preprocessor - self.abbreviation_dict: Dict[str, Dict[str, str]] = {} + self.abbreviation_dict: dict[str, dict[str, str]] = {} - def initialize(self, sentences: List[Sentence]) -> None: + def initialize(self, sentences: list[Sentence]) -> None: self.abbreviation_dict = self._build_abbreviation_dict(sentences) def process_mention(self, entity_mention: str, sentence: Optional[Sentence] = None) -> str: @@ -303,7 +304,7 @@ def process_entity_name(self, entity_name: str) -> str: return entity_name - def _build_abbreviation_dict(self, sentences: List[flair.data.Sentence]) -> Dict[str, Dict[str, str]]: + def _build_abbreviation_dict(self, sentences: list[flair.data.Sentence]) -> dict[str, dict[str, str]]: """Processes the given sentences with the Ab3P tool. The function returns a (nested) dictionary containing the abbreviations found for each sentence, e.g.: @@ -321,7 +322,7 @@ def _build_abbreviation_dict(self, sentences: List[flair.data.Sentence]) -> Dict Returns: abbreviation_dict: abbreviations and their resolution detected in each input sentence """ - abbreviation_dict: Dict[str, Dict[str, str]] = {} + abbreviation_dict: dict[str, dict[str, str]] = {} for sentence in sentences: sentence_text = sentence.to_original_text() @@ -331,14 +332,14 @@ def _build_abbreviation_dict(self, sentences: List[flair.data.Sentence]) -> Dict return abbreviation_dict - def _get_state(self) -> Dict[str, Any]: + def _get_state(self) -> dict[str, Any]: return { **super()._get_state(), "preprocessor": None if self.preprocessor is None else self.preprocessor._get_state(), } @classmethod - def _from_state(cls, state_dict: Dict[str, Any]) -> "EntityPreprocessor": + def _from_state(cls, state_dict: dict[str, Any]) -> "EntityPreprocessor": return cls( preprocessor=( None @@ -364,7 +365,7 @@ def index(self, dictionary: EntityLinkingDictionary, preprocessor: Optional[Enti """ @abstractmethod - def search(self, entity_mentions: List[str], top_k: int) -> List[List[Tuple[str, float]]]: + def search(self, entity_mentions: list[str], top_k: int) -> list[list[tuple[str, float]]]: """Returns the top-k entity / concept identifiers for each entity mention. Args: @@ -376,14 +377,14 @@ def search(self, entity_mentions: List[str], top_k: int) -> List[List[Tuple[str, """ @classmethod - def _from_state(cls, state_dict: Dict[str, Any]) -> "CandidateSearchIndex": + def _from_state(cls, state_dict: dict[str, Any]) -> "CandidateSearchIndex": if inspect.isabstract(cls): cls_name = state_dict.pop("__cls__", None) return get_state_subclass_by_name(cls, cls_name)._from_state(state_dict) else: return cls(**state_dict) - def _get_state(self) -> Dict[str, Any]: + def _get_state(self) -> dict[str, Any]: return {"__cls__": self.__class__.__name__} @@ -396,7 +397,7 @@ def __init__(self): Args: name_to_id_index: internal state, should only be set when loading an initialized index. """ - self.name_to_id_index: Dict[str, str] = {} + self.name_to_id_index: dict[str, str] = {} def index(self, dictionary: EntityLinkingDictionary, preprocessor: Optional[EntityPreprocessor] = None) -> None: def p(text: str) -> str: @@ -407,8 +408,8 @@ def p(text: str) -> str: for synonym in candidate.synonyms: self.name_to_id_index[p(synonym)] = candidate.concept_id - def search(self, entity_mentions: List[str], top_k: int) -> List[List[Tuple[str, float]]]: - results: List[List[Tuple[str, float]]] = [] + def search(self, entity_mentions: list[str], top_k: int) -> list[list[tuple[str, float]]]: + results: list[list[tuple[str, float]]] = [] for mention in entity_mentions: dict_entry = self.name_to_id_index.get(mention) if dict_entry is None: @@ -419,12 +420,12 @@ def search(self, entity_mentions: List[str], top_k: int) -> List[List[Tuple[str, return results @classmethod - def _from_state(cls, state_dict: Dict[str, Any]) -> "CandidateSearchIndex": + def _from_state(cls, state_dict: dict[str, Any]) -> "CandidateSearchIndex": index = cls() index.name_to_id_index = state_dict["name_to_id_index"] return index - def _get_state(self) -> Dict[str, Any]: + def _get_state(self) -> dict[str, Any]: return { **super()._get_state(), "name_to_id_index": self.name_to_id_index, @@ -436,7 +437,7 @@ class SemanticCandidateSearchIndex(CandidateSearchIndex): def __init__( self, - embeddings: Dict[str, DocumentEmbeddings], + embeddings: dict[str, DocumentEmbeddings], hybrid_search: bool, similarity_metric: SimilarityMetric = SimilarityMetric.INNER_PRODUCT, sparse_weight: float = DEFAULT_SPARSE_WEIGHT, @@ -460,8 +461,8 @@ def __init__( self.show_progress = show_progress self.batch_size = batch_size - self.ids: List[str] = [] - self._precomputed_embeddings: Dict[str, np.ndarray] = {"sparse": np.array([]), "dense": np.array([])} + self.ids: list[str] = [] + self._precomputed_embeddings: dict[str, np.ndarray] = {"sparse": np.array([]), "dense": np.array([])} @classmethod def bi_encoder( @@ -479,7 +480,7 @@ def bi_encoder( if model_name_or_path in PRETRAINED_MODELS: similarity_metric = PRETRAINED_MODEL_TO_SIMILARITY_METRIC[model_name_or_path] - embeddings: Dict[str, DocumentEmbeddings] = {"dense": TransformerDocumentEmbeddings(model_name_or_path)} + embeddings: dict[str, DocumentEmbeddings] = {"dense": TransformerDocumentEmbeddings(model_name_or_path)} if hybrid_search: if dictionary is None: @@ -515,7 +516,7 @@ def index(self, dictionary: EntityLinkingDictionary, preprocessor: Optional[Enti def p(text: str) -> str: return preprocessor.process_entity_name(text) if preprocessor is not None else text - texts: List[str] = [] + texts: list[str] = [] self.ids = [] for candidate in dictionary.candidates: texts.append(p(candidate.concept_name)) @@ -564,8 +565,8 @@ def p(text: str) -> str: sent.clear_embeddings() self._precomputed_embeddings["sparse"] = np.stack(sparse_embs, axis=0) - def embed(self, entity_mentions: List[str]) -> Dict[str, np.ndarray]: - query_embeddings: Dict[str, List] = {"dense": []} + def embed(self, entity_mentions: list[str]) -> dict[str, np.ndarray]: + query_embeddings: dict[str, list[np.ndarray]] = {"dense": []} inputs = [Sentence(name) for name in entity_mentions] @@ -600,7 +601,7 @@ def embed(self, entity_mentions: List[str]) -> Dict[str, np.ndarray]: return {k: np.stack(v, axis=0) for k, v in query_embeddings.items()} - def search(self, entity_mentions: List[str], top_k: int) -> List[List[Tuple[str, float]]]: + def search(self, entity_mentions: list[str], top_k: int) -> list[list[tuple[str, float]]]: """Returns the top-k entity / concept identifiers for each entity mention. Args: @@ -634,10 +635,10 @@ def search(self, entity_mentions: List[str], top_k: int) -> List[List[Tuple[str, return results @classmethod - def _from_state(cls, state_dict: Dict[str, Any]) -> "SemanticCandidateSearchIndex": + def _from_state(cls, state_dict: dict[str, Any]) -> "SemanticCandidateSearchIndex": index = cls( embeddings=cast( - Dict[str, DocumentEmbeddings], {k: load_embeddings(emb) for k, emb in state_dict["embeddings"].items()} + dict[str, DocumentEmbeddings], {k: load_embeddings(emb) for k, emb in state_dict["embeddings"].items()} ), similarity_metric=SimilarityMetric(state_dict["similarity_metric"]), sparse_weight=state_dict["sparse_weight"], @@ -649,7 +650,7 @@ def _from_state(cls, state_dict: Dict[str, Any]) -> "SemanticCandidateSearchInde index._precomputed_embeddings = state_dict["precomputed_embeddings"] return index - def _get_state(self) -> Dict[str, Any]: + def _get_state(self) -> dict[str, Any]: return { **super()._get_state(), "embeddings": {k: emb.save_embeddings() for k, emb in self.embeddings.items()}, @@ -670,7 +671,7 @@ def __init__( self, candidate_generator: CandidateSearchIndex, preprocessor: EntityPreprocessor, - entity_label_types: Union[str, Sequence[str], Dict[str, Set[str]]], + entity_label_types: Union[str, Sequence[str], dict[str, set[str]]], label_type: str, dictionary: EntityLinkingDictionary, batch_size: int = 1024, @@ -698,8 +699,8 @@ def __init__( super().__init__() def get_entity_label_types( - self, entity_label_types: Union[str, Sequence[str], Dict[str, Set[str]]] - ) -> Dict[str, Set[str]]: + self, entity_label_types: Union[str, Sequence[str], dict[str, set[str]]] + ) -> dict[str, set[str]]: """Find out what NER labels to extract from sentence. Args: @@ -709,9 +710,9 @@ def get_entity_label_types( To use all labels from 'ner', pass 'ner' """ if isinstance(entity_label_types, str): - entity_label_types = cast(Dict[str, Set[str]], {entity_label_types: {}}) + entity_label_types = cast(dict[str, set[str]], {entity_label_types: {}}) elif isinstance(entity_label_types, Sequence): - entity_label_types = cast(Dict[str, Set[str]], {label: {} for label in entity_label_types}) + entity_label_types = cast(dict[str, set[str]], {label: {} for label in entity_label_types}) entity_label_types = { label: {normalize_entity_type(e) for e in entity_types} @@ -728,9 +729,9 @@ def label_type(self): def dictionary(self) -> EntityLinkingDictionary: return self._dictionary - def extract_entities_mentions(self, sentence: Sentence, entity_label_types: Dict[str, Set[str]]) -> List[Label]: + def extract_entities_mentions(self, sentence: Sentence, entity_label_types: dict[str, set[str]]) -> list[Label]: """Extract tagged mentions from sentences.""" - entities_mentions: List[Label] = [] + entities_mentions: list[Label] = [] # NOTE: This is a hacky workaround for the fact that # the `label_type`s in `Classifier.load('hunflair)` are @@ -762,10 +763,10 @@ def extract_entities_mentions(self, sentence: Sentence, entity_label_types: Dict def predict( self, - sentences: Union[List[Sentence], Sentence], + sentences: Union[list[Sentence], Sentence], top_k: int = 1, pred_label_type: Optional[str] = None, - entity_label_types: Optional[Union[str, Sequence[str], Dict[str, Set[str]]]] = None, + entity_label_types: Optional[Union[str, Sequence[str], dict[str, set[str]]]] = None, batch_size: Optional[int] = None, ) -> None: """Predicts the best matching top-k entity / concept identifiers of all named entities annotated with tag input_entity_annotation_layer. @@ -859,7 +860,7 @@ def _fetch_model(model_name: str) -> str: return hf_download(model_name) @classmethod - def _init_model_with_state_dict(cls, state: Dict[str, Any], **kwargs) -> "EntityMentionLinker": + def _init_model_with_state_dict(cls, state: dict[str, Any], **kwargs) -> "EntityMentionLinker": candidate_generator = CandidateSearchIndex._from_state(state["candidate_search_index"]) preprocessor = EntityPreprocessor._from_state(state["entity_preprocessor"]) entity_label_types = state["entity_label_types"] @@ -961,7 +962,7 @@ def __get_model_path_and_entity_type( model_name_or_path: str, entity_type: Optional[str] = None, hybrid_search: bool = False, - ) -> Tuple[str, str]: + ) -> tuple[str, str]: """Try to figure out what model the user wants.""" if model_name_or_path not in MODELS and model_name_or_path not in ENTITY_TYPES: raise ValueError( @@ -1039,24 +1040,24 @@ def __get_dictionary_path( return dictionary_name_or_path - def forward_loss(self, data_points: List[DT]) -> Tuple[torch.Tensor, int]: + def forward_loss(self, data_points: list[DT]) -> tuple[torch.Tensor, int]: raise NotImplementedError("The EntityLinker cannot be trained") @classmethod - def load(cls, model_path: Union[str, Path, Dict[str, Any]]) -> "EntityMentionLinker": + def load(cls, model_path: Union[str, Path, dict[str, Any]]) -> "EntityMentionLinker": from typing import cast return cast("EntityMentionLinker", super().load(model_path=model_path)) def evaluate( self, - data_points: Union[List[Sentence], Dataset], + data_points: Union[list[Sentence], Dataset], gold_label_type: str, out_path: Optional[Union[str, Path]] = None, embedding_storage_mode: str = "none", mini_batch_size: int = 32, - main_evaluation_metric: Tuple[str, str] = ("accuracy", "f1-score"), - exclude_labels: Optional[List[str]] = None, + main_evaluation_metric: tuple[str, str] = ("accuracy", "f1-score"), + exclude_labels: Optional[list[str]] = None, gold_label_dictionary: Optional[Dictionary] = None, return_loss: bool = True, k: int = 1, diff --git a/flair/models/language_model.py b/flair/models/language_model.py index ed417f2434..d85db2fb93 100644 --- a/flair/models/language_model.py +++ b/flair/models/language_model.py @@ -1,6 +1,6 @@ import math from pathlib import Path -from typing import List, Optional, Tuple, Union +from typing import Optional, Union import torch from torch import logsumexp, nn @@ -111,7 +111,7 @@ def init_hidden(self, bsz): def get_representation( self, - strings: List[str], + strings: list[str], start_marker: str, end_marker: str, chars_per_chunk: int = 512, @@ -119,7 +119,7 @@ def get_representation( len_longest_str: int = len(max(strings, key=len)) # pad strings with whitespaces to longest sentence - padded_strings: List[str] = [] + padded_strings: list[str] = [] for string in strings: if not self.is_forward_lm: @@ -141,11 +141,11 @@ def get_representation( padding_char_index = self.dictionary.get_idx_for_item(" ") - batches: List[torch.Tensor] = [] + batches: list[torch.Tensor] = [] # push each chunk through the RNN language model for chunk in chunks: len_longest_chunk: int = len(max(chunk, key=len)) - sequences_as_char_indices: List[List[int]] = [] + sequences_as_char_indices: list[list[int]] = [] for string in chunk: char_indices = self.dictionary.get_idx_for_items(list(string)) char_indices += [padding_char_index] * (len_longest_chunk - len(string)) @@ -176,7 +176,7 @@ def get_output(self, text: str): def repackage_hidden(self, h): """Wraps hidden states in new Variables, to detach them from their history.""" - if type(h) == torch.Tensor: + if isinstance(h, torch.Tensor): return h.clone().detach() else: return tuple(self.repackage_hidden(v) for v in h) @@ -296,7 +296,7 @@ def generate_text( number_of_characters: int = 1000, temperature: float = 1.0, break_on_suffix=None, - ) -> Tuple[str, float]: + ) -> tuple[str, float]: if prefix == "": prefix = "\n" diff --git a/flair/models/lemmatizer_model.py b/flair/models/lemmatizer_model.py index 6f0854d4b5..55fa34698c 100644 --- a/flair/models/lemmatizer_model.py +++ b/flair/models/lemmatizer_model.py @@ -1,6 +1,6 @@ import logging from math import inf -from typing import List, Optional, Tuple, Union +from typing import Optional, Union import torch from torch import nn @@ -159,7 +159,7 @@ def label_type(self): def words_to_char_indices( self, - tokens: List[str], + tokens: list[str], end_symbol=True, start_symbol=False, padding_in_front=False, @@ -202,7 +202,7 @@ def words_to_char_indices( return tensor - def forward_pass(self, sentences: Union[List[Sentence], Sentence]): + def forward_pass(self, sentences: Union[list[Sentence], Sentence]): if isinstance(sentences, Sentence): sentences = [sentences] @@ -247,7 +247,7 @@ def decode(self, decoder_input_indices, initial_hidden_states, all_encoder_outpu output_vectors = self.character_decoder(output) return output_vectors, hidden - def _prepare_tensors(self, sentences: List[Sentence]) -> Tuple[Optional[torch.Tensor], ...]: + def _prepare_tensors(self, sentences: list[Sentence]) -> tuple[Optional[torch.Tensor], ...]: # get all tokens tokens = [token for sentence in sentences for token in sentence] @@ -290,7 +290,7 @@ def forward( encoder_input_indices: Optional[torch.Tensor], lengths: Optional[torch.Tensor], token_embedding_hidden: Optional[torch.Tensor], - ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: # variable to store initial hidden states for decoder initial_hidden_for_decoder = [] @@ -340,7 +340,7 @@ def forward( return initial_hidden, all_encoder_outputs - def encode(self, sentences: List[Sentence]): + def encode(self, sentences: list[Sentence]): tensors = self._prepare_tensors(sentences) return self.forward(*tensors) @@ -396,14 +396,14 @@ def _calculate_loss(self, scores, labels): return self.loss(scores_in_correct_format, target), len(labels) - def forward_loss(self, sentences: Union[List[Sentence], Sentence]) -> Tuple[torch.Tensor, int]: + def forward_loss(self, sentences: Union[list[Sentence], Sentence]) -> tuple[torch.Tensor, int]: scores, labels = self.forward_pass(sentences) return self._calculate_loss(scores, labels) def predict( self, - sentences: Union[List[Sentence], Sentence], + sentences: Union[list[Sentence], Sentence], mini_batch_size: int = 16, return_probabilities_for_all_classes: bool = False, verbose: bool = False, @@ -474,7 +474,7 @@ def predict( # option 1: greedy decoding if self.beam_size == 1: # predictions - predicted: List[List[Union[int, float]]] = [[] for _ in range(number_tokens)] + predicted: list[list[Union[int, float]]] = [[] for _ in range(number_tokens)] for _decode_step in range(max_length): # decode next character @@ -525,7 +525,7 @@ def predict( # keep track of how many hypothesis were completed for each token n_completed = [0 for _ in range(number_tokens)] # cpu - final_candidates: List[List[Tuple[torch.Tensor, float]]] = [[] for _ in range(number_tokens)] # cpu + final_candidates: list[list[tuple[torch.Tensor, float]]] = [[] for _ in range(number_tokens)] # cpu # if all_encoder_outputs returned, expand them to beam size (otherwise keep this as None) batched_encoding_output = ( @@ -552,24 +552,24 @@ def predict( # check if an end symbol has been predicted and, in that case, set hypothesis aside end_symbols = (index_candidates == self.end_index).nonzero(as_tuple=False) - for tuple in end_symbols: + for row in end_symbols: # if the sequence is already ended, do not record as candidate - if sequences[tuple[0], -1].item() == self.end_index: + if sequences[row[0], -1].item() == self.end_index: continue # index of token in in list tokens_in_batch - token_number = torch.div(tuple[0], self.beam_size, rounding_mode="trunc") + token_number = torch.div(row[0], self.beam_size, rounding_mode="trunc") # print(token_number) - seq = sequences[tuple[0], :] # hypothesis sequence + seq = sequences[row[0], :] # hypothesis sequence # hypothesis score - score = (scores[tuple[0]] + log_probabilities[tuple[0], tuple[1]]) / (len(seq) + 1) + score = (scores[row[0]] + log_probabilities[row[0], row[1]]) / (len(seq) + 1) final_candidates[token_number].append((seq, score.item())) # TODO: remove token if number of completed hypothesis exceeds given value n_completed[token_number] += 1 # set score of corresponding entry to -inf so it will not be expanded - log_probabilities[tuple[0], tuple[1]] = -inf + log_probabilities[row[0], row[1]] = -inf # get leading_indices for next expansion # find highest scoring hypothesis among beam_size*beam_size possible ones for each token @@ -594,8 +594,8 @@ def predict( # a list of length beam_size*batch_size # where the first three inidices belong to the first token, the next three to the second token, # and so on - beam_numbers: List[int] = [] - seq_numbers: List[int] = [] + beam_numbers: list[int] = [] + seq_numbers: list[int] = [] for i, row in enumerate(indices_per_token): beam_numbers.extend(i * self.beam_size + index.item() // self.beam_size for index in row) diff --git a/flair/models/multitask_model.py b/flair/models/multitask_model.py index 733751eff7..414eb46197 100644 --- a/flair/models/multitask_model.py +++ b/flair/models/multitask_model.py @@ -2,7 +2,7 @@ import random import typing from pathlib import Path -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Optional, Union import torch @@ -27,9 +27,9 @@ class MultitaskModel(flair.nn.Classifier): def __init__( self, - models: List[flair.nn.Classifier], - task_ids: Optional[List[str]] = None, - loss_factors: Optional[List[float]] = None, + models: list[flair.nn.Classifier], + task_ids: Optional[list[str]] = None, + loss_factors: Optional[list[float]] = None, use_all_tasks: bool = False, ) -> None: """Instantiates the MultiTaskModel. @@ -42,10 +42,10 @@ def __init__( """ super().__init__() - task_ids_internal: List[str] = task_ids if task_ids else [f"Task_{i}" for i in range(len(models))] + task_ids_internal: list[str] = task_ids if task_ids else [f"Task_{i}" for i in range(len(models))] - self.tasks: Dict[str, flair.nn.Classifier] = {} - self.loss_factors: Dict[str, float] = {} + self.tasks: dict[str, flair.nn.Classifier] = {} + self.loss_factors: dict[str, float] = {} self.use_all_tasks = use_all_tasks if not loss_factors: @@ -63,10 +63,10 @@ def __init__( def forward(self, *args) -> torch.Tensor: raise NotImplementedError("`forward` is not used for multitask learning") - def _prepare_tensors(self, data_points: List[DT]) -> Tuple[torch.Tensor, ...]: + def _prepare_tensors(self, data_points: list[DT]) -> tuple[torch.Tensor, ...]: raise NotImplementedError("`_prepare_tensors` is not used for multitask learning") - def forward_loss(self, sentences: Union[List[Sentence], Sentence]) -> Tuple[torch.Tensor, int]: + def forward_loss(self, sentences: Union[list[Sentence], Sentence]) -> tuple[torch.Tensor, int]: """Calls the respective forward loss of each model and sums them weighted by their loss factors. Args: @@ -92,7 +92,9 @@ def predict( task.predict(sentences, **predictargs) @staticmethod - def split_batch_to_task_ids(sentences: Union[List[Sentence], Sentence], all_tasks: bool = False) -> Dict: + def split_batch_to_task_ids( + sentences: Union[list[Sentence], Sentence], all_tasks: bool = False + ) -> dict[str, list[int]]: """Splits a batch of sentences to its respective model. If single sentence is assigned to several tasks (i.e. same corpus but different tasks), then the model @@ -104,7 +106,7 @@ def split_batch_to_task_ids(sentences: Union[List[Sentence], Sentence], all_task Returns: Key-value pairs as (task_id, list of sentences ids in batch) """ - batch_to_task_mapping: Dict[str, List[int]] = {} + batch_to_task_mapping: dict[str, list[int]] = {} for sentence_id, sentence in enumerate(sentences): if all_tasks: multitask_ids = sentence.get_labels("multitask_id") @@ -122,7 +124,7 @@ def evaluate( # type: ignore[override] data_points, gold_label_type: str, out_path: Optional[Union[str, Path]] = None, - main_evaluation_metric: Tuple[str, str] = ("micro avg", "f1-score"), + main_evaluation_metric: tuple[str, str] = ("micro avg", "f1-score"), evaluate_all: bool = True, **evalargs, ) -> Result: @@ -161,7 +163,7 @@ def evaluate( # type: ignore[override] loss = torch.tensor(0.0, device=flair.device) main_score = 0.0 all_detailed_results = "" - all_classification_report: Dict[str, Dict[str, Any]] = {} + all_classification_report: dict[str, dict[str, Any]] = {} for task_id, split in batch_split.items(): result = self.tasks[task_id].evaluate( @@ -203,7 +205,7 @@ def evaluate( # type: ignore[override] def get_used_tokens( self, corpus: Corpus, context_length: int = 0, respect_document_boundaries: bool = True - ) -> typing.Iterable[List[str]]: + ) -> typing.Iterable[list[str]]: for model in self.tasks.values(): yield from model.get_used_tokens(corpus, context_length, respect_document_boundaries) @@ -272,7 +274,7 @@ def _fetch_model(model_name) -> str: return model_name @classmethod - def load(cls, model_path: Union[str, Path, Dict[str, Any]]) -> "MultitaskModel": + def load(cls, model_path: Union[str, Path, dict[str, Any]]) -> "MultitaskModel": from typing import cast return cast("MultitaskModel", super().load(model_path=model_path)) diff --git a/flair/models/pairwise_classification_model.py b/flair/models/pairwise_classification_model.py index 262fd08cb5..6308573edc 100644 --- a/flair/models/pairwise_classification_model.py +++ b/flair/models/pairwise_classification_model.py @@ -1,5 +1,4 @@ import typing -from typing import List import torch @@ -69,7 +68,7 @@ def __init__( def label_type(self): return self._label_type - def _get_data_points_from_sentence(self, sentence: TextPair) -> List[TextPair]: + def _get_data_points_from_sentence(self, sentence: TextPair) -> list[TextPair]: return [sentence] def _get_embedding_for_data_point(self, prediction_data_point: TextPair) -> torch.Tensor: @@ -119,7 +118,7 @@ def _init_model_with_state_dict(cls, state, **kwargs): def get_used_tokens( self, corpus: Corpus, context_length: int = 0, respect_document_boundaries: bool = True - ) -> typing.Iterable[List[str]]: + ) -> typing.Iterable[list[str]]: for sentence_pair in _iter_dataset(corpus.get_all_sentences()): yield [t.text for t in sentence_pair.first] yield [t.text for t in sentence_pair.first.left_context(context_length, respect_document_boundaries)] diff --git a/flair/models/pairwise_regression_model.py b/flair/models/pairwise_regression_model.py index c3f34e0f69..9a1c2704be 100644 --- a/flair/models/pairwise_regression_model.py +++ b/flair/models/pairwise_regression_model.py @@ -1,5 +1,6 @@ +from collections.abc import Iterable from pathlib import Path -from typing import Any, Dict, Iterable, List, Optional, Tuple, Union +from typing import Any, Optional, Union import torch from torch import nn @@ -90,7 +91,7 @@ def label_type(self): def get_used_tokens( self, corpus: Corpus, context_length: int = 0, respect_document_boundaries: bool = True - ) -> Iterable[List[str]]: + ) -> Iterable[list[str]]: for sentence_pair in _iter_dataset(corpus.get_all_sentences()): yield [t.text for t in sentence_pair.first] yield [t.text for t in sentence_pair.first.left_context(context_length, respect_document_boundaries)] @@ -99,14 +100,14 @@ def get_used_tokens( yield [t.text for t in sentence_pair.second.left_context(context_length, respect_document_boundaries)] yield [t.text for t in sentence_pair.second.right_context(context_length, respect_document_boundaries)] - def forward_loss(self, pairs: List[TextPair]) -> Tuple[torch.Tensor, int]: + def forward_loss(self, pairs: list[TextPair]) -> tuple[torch.Tensor, int]: loss, num = self._forward_loss_and_scores(pairs=pairs, return_num=True, return_scores=False) assert isinstance(loss, torch.Tensor) assert isinstance(num, int) return loss, num - def _forward_loss_and_scores(self, pairs: List[TextPair], return_num=True, return_scores=True) -> Tuple: + def _forward_loss_and_scores(self, pairs: list[TextPair], return_num=True, return_scores=True) -> tuple: # make a forward pass to produce embedded data points and labels pairs = [pair for pair in pairs if self._filter_data_point(pair)] @@ -128,7 +129,7 @@ def _forward_loss_and_scores(self, pairs: List[TextPair], return_num=True, retur # calculate the loss loss, num = self._calculate_loss(scores, target_tensor) - return_value: Tuple[Any, ...] = (loss,) + return_value: tuple[Any, ...] = (loss,) if return_num: return_value += (num,) @@ -138,10 +139,10 @@ def _forward_loss_and_scores(self, pairs: List[TextPair], return_num=True, retur return return_value - def _calculate_loss(self, scores: torch.Tensor, target_tensor: torch.Tensor) -> Tuple[torch.Tensor, int]: + def _calculate_loss(self, scores: torch.Tensor, target_tensor: torch.Tensor) -> tuple[torch.Tensor, int]: return self.loss_function(scores, target_tensor), target_tensor.size(0) - def _prepare_target_tensor(self, pairs: List[TextPair]): + def _prepare_target_tensor(self, pairs: list[TextPair]): target_values = [ torch.tensor([float(label.value) for label in pair.get_labels(self.label_name)], dtype=torch.float) for pair in pairs @@ -152,7 +153,7 @@ def _prepare_target_tensor(self, pairs: List[TextPair]): def _filter_data_point(self, pair: TextPair) -> bool: return len(pair) > 0 - def _encode_data_points(self, data_points: List[TextPair]) -> torch.Tensor: + def _encode_data_points(self, data_points: list[TextPair]) -> torch.Tensor: # get a tensor of data points data_point_tensor = torch.stack([self._get_embedding_for_data_point(data_point) for data_point in data_points]) @@ -203,7 +204,7 @@ def _get_state_dict(self): return model_state @classmethod - def _init_model_with_state_dict(cls, state: Dict[str, Any], **kwargs): + def _init_model_with_state_dict(cls, state: dict[str, Any], **kwargs): """Initializes a TextPairRegressor model from a state dictionary (exported by _get_state_dict). Requires keys 'state_dict', 'document_embeddings', and 'label_type' in the state dictionary. @@ -227,12 +228,12 @@ def _init_model_with_state_dict(cls, state: Dict[str, Any], **kwargs): def predict( self, - pairs: Union[TextPair, List[TextPair]], + pairs: Union[TextPair, list[TextPair]], mini_batch_size: int = 32, verbose: bool = False, label_name: Optional[str] = None, embedding_storage_mode="none", - ) -> List[TextPair]: + ) -> list[TextPair]: if label_name is None: label_name = self.label_name if self.label_name is not None else "label" @@ -278,13 +279,13 @@ def predict( def evaluate( self, - data_points: Union[List[TextPair], Dataset], + data_points: Union[list[TextPair], Dataset], gold_label_type: str, out_path: Union[str, Path, None] = None, embedding_storage_mode: EmbeddingStorageMode = "none", mini_batch_size: int = 32, - main_evaluation_metric: Tuple[str, str] = ("correlation", "pearson"), - exclude_labels: Optional[List[str]] = None, + main_evaluation_metric: tuple[str, str] = ("correlation", "pearson"), + exclude_labels: Optional[list[str]] = None, gold_label_dictionary: Optional[Dictionary] = None, return_loss: bool = True, **kwargs, diff --git a/flair/models/prefixed_tagger.py b/flair/models/prefixed_tagger.py index 05d8fa8c34..b001653bdc 100644 --- a/flair/models/prefixed_tagger.py +++ b/flair/models/prefixed_tagger.py @@ -1,7 +1,7 @@ import logging from abc import ABC, abstractmethod from pathlib import Path -from typing import Any, Dict, List, Optional, Tuple, Union, cast +from typing import Any, Optional, Union, cast import torch from torch.utils.data import Dataset @@ -26,7 +26,7 @@ class SentenceAugmentationStrategy(ABC): @abstractmethod def augment_sentence( - self, sentence: Sentence, annotation_layers: Optional[Union[str, List[str]]] = None + self, sentence: Sentence, annotation_layers: Optional[Union[str, list[str]]] = None ) -> PrefixedSentence: """Augments the given sentence text with additional instructions for working / predicting the task on the given annotations. @@ -64,7 +64,7 @@ def _init_strategy_with_state_dict(cls, state, **kwargs): """Initializes the strategy from the given state.""" def augment_dataset( - self, dataset: Dataset[Sentence], annotation_layers: Optional[Union[str, List[str]]] = None + self, dataset: Dataset[Sentence], annotation_layers: Optional[Union[str, list[str]]] = None ) -> FlairDatapointDataset[PrefixedSentence]: """Transforms a dataset into a dataset containing augmented sentences specific to the `PrefixedSequenceTagger`. @@ -78,14 +78,14 @@ def augment_dataset( Returns: A dataset of augmented sentences specific to the `PrefixedSequenceTagger` """ data_loader: DataLoader = DataLoader(dataset, batch_size=1) - original_sentences: List[Sentence] = [batch[0] for batch in iter(data_loader)] + original_sentences: list[Sentence] = [batch[0] for batch in iter(data_loader)] augmented_sentences = [self.augment_sentence(sentence, annotation_layers) for sentence in original_sentences] return FlairDatapointDataset(augmented_sentences) def augment_corpus( - self, corpus: Corpus[Sentence], annotation_layers: Optional[Union[str, List[str]]] = None + self, corpus: Corpus[Sentence], annotation_layers: Optional[Union[str, list[str]]] = None ) -> Corpus[PrefixedSentence]: """Transforms a corpus into a corpus containing augmented sentences specific to the `PrefixedSequenceTagger`. @@ -120,7 +120,7 @@ class EntityTypeTaskPromptAugmentationStrategy(SentenceAugmentationStrategy): "[Tag gene and disease] Mutations in the TP53 tumour suppressor gene are found in ~50% of human cancers" """ - def __init__(self, entity_types: List[str]): + def __init__(self, entity_types: list[str]): if len(entity_types) <= 0: raise AssertionError @@ -128,7 +128,7 @@ def __init__(self, entity_types: List[str]): self.task_prompt = self._build_tag_prompt_prefix(entity_types) def augment_sentence( - self, sentence: Sentence, annotation_layers: Optional[Union[str, List[str]]] = None + self, sentence: Sentence, annotation_layers: Optional[Union[str, list[str]]] = None ) -> PrefixedSentence: # Prepend the task description prompt to the sentence text augmented_sentence = PrefixedSentence( @@ -182,7 +182,7 @@ def apply_predictions( ] orig_span.add_label(target_annotation_layer, label.value, label.score) - def _build_tag_prompt_prefix(self, entity_types: List[str]) -> List[str]: + def _build_tag_prompt_prefix(self, entity_types: list[str]) -> list[str]: if len(self.entity_types) == 1: prompt = f"[ Tag {entity_types[0]} ]" else: @@ -219,29 +219,29 @@ def _init_model_with_state_dict(cls, state, **kwargs): return super()._init_model_with_state_dict(state, augmentation_strategy=strategy, **kwargs) @classmethod - def load(cls, model_path: Union[str, Path, Dict[str, Any]]) -> "PrefixedSequenceTagger": + def load(cls, model_path: Union[str, Path, dict[str, Any]]) -> "PrefixedSequenceTagger": from typing import cast return cast("PrefixedSequenceTagger", super().load(model_path=model_path)) - def forward_loss(self, sentences: Union[List[Sentence], List[PrefixedSentence]]) -> Tuple[torch.Tensor, int]: + def forward_loss(self, sentences: Union[list[Sentence], list[PrefixedSentence]]) -> tuple[torch.Tensor, int]: # If all sentences are not augmented -> augment them if all(isinstance(sentence, Sentence) for sentence in sentences): # mypy does not infer the type of "sentences" restricted by the if statement - sentences = cast(List[Sentence], sentences) + sentences = cast(list[Sentence], sentences) sentences = self.augment_sentences(sentences=sentences, annotation_layers=self.tag_type) elif not all(isinstance(sentence, PrefixedSentence) for sentence in sentences): raise ValueError("All passed sentences must be either uniformly augmented or not.") # mypy does not infer the type of "sentences" restricted by code above - sentences = cast(List[Sentence], sentences) + sentences = cast(list[Sentence], sentences) return super().forward_loss(sentences) def predict( self, - sentences: Union[List[Sentence], Sentence, List[PrefixedSentence], PrefixedSentence], + sentences: Union[list[Sentence], Sentence, list[PrefixedSentence], PrefixedSentence], mini_batch_size: int = 32, return_probabilities_for_all_classes: bool = False, verbose: bool = False, @@ -260,7 +260,7 @@ def predict( # If all sentences are already augmented (i.e. compatible with this class), just forward the sentences if all(isinstance(sentence, PrefixedSentence) for sentence in sentences): # mypy does not infer the type of "sentences" restricted by the if statement - sentences = cast(List[Sentence], sentences) + sentences = cast(list[Sentence], sentences) return super().predict( sentences, @@ -280,12 +280,12 @@ def predict( for sentence in sentences: sentence.remove_labels(prediction_label_type) - sentences = cast(List[Sentence], sentences) + sentences = cast(list[Sentence], sentences) # Augment sentences - copy all annotation of the given tag type augmented_sentences = self.augment_sentences(sentences, self.tag_type) - mypy_safe_augmented_sentences = cast(List[Sentence], augmented_sentences) + mypy_safe_augmented_sentences = cast(list[Sentence], augmented_sentences) # Predict on augmented sentence and store it in an internal annotation layer / label loss_and_count = super().predict( @@ -312,8 +312,8 @@ def predict( return loss_and_count def augment_sentences( - self, sentences: Union[Sentence, List[Sentence]], annotation_layers: Optional[Union[str, List[str]]] = None - ) -> List[PrefixedSentence]: + self, sentences: Union[Sentence, list[Sentence]], annotation_layers: Optional[Union[str, list[str]]] = None + ) -> list[PrefixedSentence]: if not isinstance(sentences, list) and not isinstance(sentences, flair.data.Dataset): sentences = [sentences] diff --git a/flair/models/regexp_tagger.py b/flair/models/regexp_tagger.py index 35c244d960..e41981c899 100644 --- a/flair/models/regexp_tagger.py +++ b/flair/models/regexp_tagger.py @@ -1,7 +1,7 @@ import re import typing from dataclasses import dataclass, field -from typing import Dict, List, Tuple, Union +from typing import Union from flair.data import Sentence, Span, Token @@ -15,8 +15,8 @@ class TokenCollection: """ sentence: Sentence - __tokens_start_pos: List[int] = field(init=False, default_factory=list) - __tokens_end_pos: List[int] = field(init=False, default_factory=list) + __tokens_start_pos: list[int] = field(init=False, default_factory=list) + __tokens_end_pos: list[int] = field(init=False, default_factory=list) def __post_init__(self): for token in self.tokens: @@ -24,10 +24,10 @@ def __post_init__(self): self.__tokens_end_pos.append(token.end_position) @property - def tokens(self) -> List[Token]: + def tokens(self) -> list[Token]: return list(self.sentence) - def get_token_span(self, span: Tuple[int, int]) -> Span: + def get_token_span(self, span: tuple[int, int]) -> Span: """Find a span by the token character positions. Given an interval specified with start and end pos as tuple, this function returns a Span object @@ -45,7 +45,7 @@ def get_token_span(self, span: Tuple[int, int]) -> Span: class RegexpTagger: - def __init__(self, mapping: Union[List[Tuple[str, str]], Tuple[str, str]]) -> None: + def __init__(self, mapping: Union[list[tuple[str, str]], tuple[str, str]]) -> None: r"""This tagger is capable of tagging sentence objects with given regexp -> label mappings. I.e: The tuple (r'(["\'])(?:(?=(\\?))\2.)*?\1', 'QUOTE') maps every match of the regexp to @@ -58,14 +58,14 @@ def __init__(self, mapping: Union[List[Tuple[str, str]], Tuple[str, str]]) -> No Args: mapping: A list of tuples or a single tuple representing a mapping as regexp -> label """ - self._regexp_mapping: Dict[str, typing.Pattern] = {} + self._regexp_mapping: dict[str, typing.Pattern] = {} self.register_labels(mapping=mapping) @property def registered_labels(self): return self._regexp_mapping - def register_labels(self, mapping: Union[List[Tuple[str, str]], Tuple[str, str]]): + def register_labels(self, mapping: Union[list[tuple[str, str]], tuple[str, str]]): """Register a regexp -> label mapping. Args: @@ -81,7 +81,7 @@ def register_labels(self, mapping: Union[List[Tuple[str, str]], Tuple[str, str]] f"Couldn't compile regexp '{regexp}' for label '{label}'. Aborted with error: '{err.msg}'" ) - def remove_labels(self, labels: Union[List[str], str]): + def remove_labels(self, labels: Union[list[str], str]): """Remove a registered regexp -> label mapping given by label. Args: @@ -101,7 +101,7 @@ def _listify(element: object) -> list: else: return element - def predict(self, sentences: Union[List[Sentence], Sentence]) -> List[Sentence]: + def predict(self, sentences: Union[list[Sentence], Sentence]) -> list[Sentence]: """Predict the given sentences according to the registered mappings.""" if not isinstance(sentences, list): sentences = [sentences] @@ -122,7 +122,7 @@ def _label(self, sentence: Sentence): for label, pattern in self._regexp_mapping.items(): for match in pattern.finditer(sentence.to_original_text()): - span: Tuple[int, int] = match.span() + span: tuple[int, int] = match.span() try: token_span = collection.get_token_span(span) except ValueError: diff --git a/flair/models/relation_classifier_model.py b/flair/models/relation_classifier_model.py index 53ccabac36..9c6c69577f 100644 --- a/flair/models/relation_classifier_model.py +++ b/flair/models/relation_classifier_model.py @@ -2,17 +2,12 @@ import logging import typing from abc import ABC, abstractmethod +from collections.abc import Iterator, Sequence from pathlib import Path from typing import ( Any, - Dict, - Iterator, - List, NamedTuple, Optional, - Sequence, - Set, - Tuple, Union, cast, ) @@ -50,7 +45,7 @@ class EncodedSentence(Sentence): class EncodingStrategy(ABC): """The encoding of the head and tail entities in a sentence with a relation annotation.""" - special_tokens: Set[str] = set() + special_tokens: set[str] = set() def __init__(self, add_special_tokens: bool = False) -> None: self.add_special_tokens = add_special_tokens @@ -84,7 +79,7 @@ class EntityMask(EncodingStrategy): - "Larry Page and [TAIL] founded [HEAD]" -> Relation(head='Google', tail='Sergey Brin'). """ - special_tokens: Set[str] = {"[HEAD]", "[TAIL]"} + special_tokens: set[str] = {"[HEAD]", "[TAIL]"} def encode_head(self, head_span: Span, label: Label) -> str: return "[HEAD]" @@ -126,7 +121,7 @@ class EntityMarker(EncodingStrategy): -> Relation(head='Google', tail='Sergey Brin'). """ - special_tokens: Set[str] = {"[HEAD]", "[/HEAD]", "[TAIL]", "[/TAIL]"} + special_tokens: set[str] = {"[HEAD]", "[/HEAD]", "[TAIL]", "[/TAIL]"} def encode_head(self, head: Span, label: Label) -> str: space_tokenized_text: str = " ".join(token.text for token in head) @@ -254,8 +249,8 @@ def __init__( embeddings: DocumentEmbeddings, label_dictionary: Dictionary, label_type: str, - entity_label_types: Union[str, Sequence[str], Dict[str, Optional[Set[str]]]], - entity_pair_labels: Optional[Set[Tuple[str, str]]] = None, + entity_label_types: Union[str, Sequence[str], dict[str, Optional[set[str]]]], + entity_pair_labels: Optional[set[tuple[str, str]]] = None, entity_threshold: Optional[float] = None, cross_augmentation: bool = True, encoding_strategy: EncodingStrategy = TypedEntityMarker(), @@ -298,7 +293,7 @@ def __init__( ) if isinstance(entity_label_types, str): - self.entity_label_types: Dict[str, Optional[Set[str]]] = {entity_label_types: None} + self.entity_label_types: dict[str, Optional[set[str]]] = {entity_label_types: None} elif isinstance(entity_label_types, Sequence): self.entity_label_types = {entity_label_type: None for entity_label_type in entity_label_types} else: @@ -316,7 +311,7 @@ def __init__( and self.encoding_strategy.special_tokens and isinstance(self.embeddings, TransformerDocumentEmbeddings) ): - special_tokens: List[str] = list(self.encoding_strategy.special_tokens) + special_tokens: list[str] = list(self.encoding_strategy.special_tokens) tokenizer = self.embeddings.tokenizer tokenizer.add_special_tokens({"additional_special_tokens": special_tokens}) self.embeddings.model.resize_token_embeddings(len(tokenizer)) @@ -355,7 +350,7 @@ def _valid_entities(self, sentence: Sentence) -> Iterator[_Entity]: def _entity_pair_permutations( self, sentence: Sentence, - ) -> Iterator[Tuple[_Entity, _Entity, Optional[str]]]: + ) -> Iterator[tuple[_Entity, _Entity, Optional[str]]]: """Yields all valid entity pair permutations (relation candidates). If the passed sentence contains relation annotations, @@ -370,10 +365,10 @@ def _entity_pair_permutations( Yields: Tuples of (HEAD, TAIL, gold_label): The head and tail `_Entity`s` have span references to the passed sentence. """ - valid_entities: List[_Entity] = list(self._valid_entities(sentence)) + valid_entities: list[_Entity] = list(self._valid_entities(sentence)) # Use a dictionary to find gold relation annotations for a given entity pair - relation_to_gold_label: Dict[str, str] = { + relation_to_gold_label: dict[str, str] = { relation.unlabeled_identifier: relation.get_label(self.label_type, zero_tag_value=self.zero_tag_value).value for relation in sentence.get_relations(self.label_type) } @@ -420,13 +415,13 @@ def _encode_sentence( assert original_sentence is tail.span.sentence, "The head and tail need to come from the same sentence." # Pre-compute non-leading head and tail tokens for entity masking - non_leading_head_tokens: List[Token] = head.span.tokens[1:] - non_leading_tail_tokens: List[Token] = tail.span.tokens[1:] + non_leading_head_tokens: list[Token] = head.span.tokens[1:] + non_leading_tail_tokens: list[Token] = tail.span.tokens[1:] # We can not use the plaintext of the head/tail span in the sentence as the mask/marker # since there may be multiple occurrences of the same entity mentioned in the sentence. # Therefore, we use the span's position in the sentence. - encoded_sentence_tokens: List[str] = [] + encoded_sentence_tokens: list[str] = [] for token in original_sentence: if token is head.span[0]: encoded_sentence_tokens.append(self.encoding_strategy.encode_head(head.span, head.label)) @@ -456,7 +451,7 @@ def _encode_sentence( def _encode_sentence_for_inference( self, sentence: Sentence, - ) -> Iterator[Tuple[EncodedSentence, Relation]]: + ) -> Iterator[tuple[EncodedSentence, Relation]]: """Create Encoded Sentences and Relation pairs for Inference. Yields encoded sentences annotated with their gold relation and @@ -505,7 +500,7 @@ def _encode_sentence_for_training(self, sentence: Sentence) -> Iterator[EncodedS yield masked_sentence - def transform_sentence(self, sentences: Union[Sentence, List[Sentence]]) -> List[EncodedSentence]: + def transform_sentence(self, sentences: Union[Sentence, list[Sentence]]) -> list[EncodedSentence]: """Transforms sentences into encoded sentences specific to the `RelationClassifier`. For more information on the internal sentence transformation procedure, @@ -541,7 +536,7 @@ def transform_dataset(self, dataset: Dataset[Sentence]) -> FlairDatapointDataset Returns: A dataset of encoded sentences specific to the `RelationClassifier` """ data_loader: DataLoader = DataLoader(dataset, batch_size=1) - original_sentences: List[Sentence] = [batch[0] for batch in iter(data_loader)] + original_sentences: list[Sentence] = [batch[0] for batch in iter(data_loader)] return FlairDatapointDataset(self.transform_sentence(original_sentences)) def transform_corpus(self, corpus: Corpus[Sentence]) -> Corpus[EncodedSentence]: @@ -568,10 +563,10 @@ def transform_corpus(self, corpus: Corpus[Sentence]) -> Corpus[EncodedSentence]: ) def _get_embedding_for_data_point(self, prediction_data_point: EncodedSentence) -> torch.Tensor: - embedding_names: List[str] = self.embeddings.get_names() + embedding_names: list[str] = self.embeddings.get_names() return prediction_data_point.get_embedding(embedding_names) - def _get_data_points_from_sentence(self, sentence: EncodedSentence) -> List[EncodedSentence]: + def _get_data_points_from_sentence(self, sentence: EncodedSentence) -> list[EncodedSentence]: """Returns the encoded sentences to which labels are added. To encode sentences, use the `transform` function of the `RelationClassifier`. @@ -597,14 +592,14 @@ def _get_data_points_from_sentence(self, sentence: EncodedSentence) -> List[Enco def predict( self, - sentences: Union[List[Sentence], List[EncodedSentence], Sentence, EncodedSentence], + sentences: Union[list[Sentence], list[EncodedSentence], Sentence, EncodedSentence], mini_batch_size: int = 32, return_probabilities_for_all_classes: bool = False, verbose: bool = False, label_name: Optional[str] = None, return_loss: bool = False, embedding_storage_mode: EmbeddingStorageMode = "none", - ) -> Optional[Tuple[torch.Tensor, int]]: + ) -> Optional[tuple[torch.Tensor, int]]: """Predicts the class labels for the given sentence(s). Standard `Sentence` objects and `EncodedSentences` specific to the `RelationClassifier` are allowed as input. @@ -626,14 +621,14 @@ def predict( if not isinstance(sentences, list): sentences = [sentences] - loss: Optional[Tuple[torch.Tensor, int]] - encoded_sentences: List[EncodedSentence] + loss: Optional[tuple[torch.Tensor, int]] + encoded_sentences: list[EncodedSentence] if all(isinstance(sentence, EncodedSentence) for sentence in sentences): # Deal with the case where all sentences are encoded sentences # mypy does not infer the type of "sentences" restricted by the if statement - encoded_sentences = cast(List[EncodedSentence], sentences) + encoded_sentences = cast(list[EncodedSentence], sentences) loss = super().predict( encoded_sentences, mini_batch_size=mini_batch_size, @@ -646,8 +641,8 @@ def predict( elif all(not isinstance(sentence, EncodedSentence) for sentence in sentences): # Deal with the case where all sentences are standard (non-encoded) sentences - Sentence.set_context_for_sentences(cast(List[Sentence], sentences)) - sentences_with_relation_reference: List[Tuple[EncodedSentence, Relation]] = list( + Sentence.set_context_for_sentences(cast(list[Sentence], sentences)) + sentences_with_relation_reference: list[tuple[EncodedSentence, Relation]] = list( itertools.chain.from_iterable(self._encode_sentence_for_inference(sentence) for sentence in sentences) ) @@ -672,8 +667,8 @@ def predict( return loss if return_loss else None - def _get_state_dict(self) -> Dict[str, Any]: - model_state: Dict[str, Any] = { + def _get_state_dict(self) -> dict[str, Any]: + model_state: dict[str, Any] = { **super()._get_state_dict(), "embeddings": self.embeddings.save_embeddings(use_state_dict=False), "label_dictionary": self.label_dictionary, @@ -689,7 +684,7 @@ def _get_state_dict(self) -> Dict[str, Any]: return model_state @classmethod - def _init_model_with_state_dict(cls, state: Dict[str, Any], **kwargs): + def _init_model_with_state_dict(cls, state: dict[str, Any], **kwargs): return super()._init_model_with_state_dict( state, embeddings=state["embeddings"], @@ -719,7 +714,7 @@ def allow_unk_tag(self) -> bool: def get_used_tokens( self, corpus: Corpus, context_length: int = 0, respect_document_boundaries: bool = True - ) -> typing.Iterable[List[str]]: + ) -> typing.Iterable[list[str]]: yield from super().get_used_tokens(corpus, context_length, respect_document_boundaries) for sentence in _iter_dataset(corpus.get_all_sentences()): for span in sentence.get_spans(self.label_type): @@ -727,7 +722,7 @@ def get_used_tokens( yield self.encoding_strategy.encode_tail(span, span.get_label(self.label_type)).split(" ") @classmethod - def load(cls, model_path: Union[str, Path, Dict[str, Any]]) -> "RelationClassifier": + def load(cls, model_path: Union[str, Path, dict[str, Any]]) -> "RelationClassifier": from typing import cast return cast("RelationClassifier", super().load(model_path=model_path)) diff --git a/flair/models/relation_extractor_model.py b/flair/models/relation_extractor_model.py index 795e8a517f..0c56abf5bd 100644 --- a/flair/models/relation_extractor_model.py +++ b/flair/models/relation_extractor_model.py @@ -1,6 +1,6 @@ import logging from pathlib import Path -from typing import Any, Dict, List, Optional, Set, Tuple, Union +from typing import Any, Optional, Union import torch @@ -18,7 +18,7 @@ def __init__( embeddings: flair.embeddings.TokenEmbeddings, label_type: str, entity_label_type: str, - entity_pair_filters: Optional[List[Tuple[str, str]]] = None, + entity_pair_filters: Optional[list[tuple[str, str]]] = None, pooling_operation: str = "first_last", train_on_gold_pairs_only: bool = False, **classifierargs, @@ -56,13 +56,13 @@ def __init__( # whether to use gold entity pairs, and whether to filter entity pairs by type if entity_pair_filters is not None: - self.entity_pair_filters: Optional[Set[Tuple[str, str]]] = set(entity_pair_filters) + self.entity_pair_filters: Optional[set[tuple[str, str]]] = set(entity_pair_filters) else: self.entity_pair_filters = None self.to(flair.device) - def _get_data_points_from_sentence(self, sentence: Sentence) -> List[Relation]: + def _get_data_points_from_sentence(self, sentence: Sentence) -> list[Relation]: entity_pairs = [] entity_spans = sentence.get_spans(self.entity_label_type) @@ -172,7 +172,7 @@ def _fetch_model(model_name) -> str: return model_name @classmethod - def load(cls, model_path: Union[str, Path, Dict[str, Any]]) -> "RelationExtractor": + def load(cls, model_path: Union[str, Path, dict[str, Any]]) -> "RelationExtractor": from typing import cast return cast("RelationExtractor", super().load(model_path=model_path)) diff --git a/flair/models/sequence_tagger_model.py b/flair/models/sequence_tagger_model.py index 16f20a0ddf..2c0fc00ebb 100644 --- a/flair/models/sequence_tagger_model.py +++ b/flair/models/sequence_tagger_model.py @@ -1,7 +1,7 @@ import logging import tempfile from pathlib import Path -from typing import Any, Dict, List, Optional, Tuple, Union, cast +from typing import Any, Optional, Union, cast import torch import torch.nn @@ -40,7 +40,7 @@ def __init__( word_dropout: float = 0.05, locked_dropout: float = 0.5, train_initial_hidden_state: bool = False, - loss_weights: Optional[Dict[str, float]] = None, + loss_weights: Optional[dict[str, float]] = None, init_from_state_dict: bool = False, allow_unk_predictions: bool = False, ) -> None: @@ -204,7 +204,7 @@ def __init__( def label_type(self): return self.tag_type - def _init_loss_weights(self, loss_weights: Dict[str, float]) -> torch.Tensor: + def _init_loss_weights(self, loss_weights: dict[str, float]) -> torch.Tensor: """Initializes the loss weights based on given dictionary. Args: @@ -267,7 +267,7 @@ def RNN( return RNN - def forward_loss(self, sentences: List[Sentence]) -> Tuple[torch.Tensor, int]: + def forward_loss(self, sentences: list[Sentence]) -> tuple[torch.Tensor, int]: # if there are no sentences, there is no loss if len(sentences) == 0: return torch.tensor(0.0, dtype=torch.float, device=flair.device, requires_grad=True), 0 @@ -281,7 +281,7 @@ def forward_loss(self, sentences: List[Sentence]) -> Tuple[torch.Tensor, int]: # calculate loss given scores and labels return self._calculate_loss(scores, gold_labels) - def _prepare_tensors(self, data_points: Union[List[Sentence], Sentence]) -> Tuple[torch.Tensor, torch.LongTensor]: + def _prepare_tensors(self, data_points: Union[list[Sentence], Sentence]) -> tuple[torch.Tensor, torch.LongTensor]: sentences = [data_points] if not isinstance(data_points, list) else data_points self.embeddings.embed(sentences) @@ -331,15 +331,15 @@ def forward(self, sentence_tensor: torch.Tensor, lengths: torch.LongTensor): return scores - def _calculate_loss(self, scores: torch.Tensor, labels: torch.LongTensor) -> Tuple[torch.Tensor, int]: + def _calculate_loss(self, scores: torch.Tensor, labels: torch.LongTensor) -> tuple[torch.Tensor, int]: if labels.size(0) == 0: return torch.tensor(0.0, requires_grad=True, device=flair.device), 1 return self.loss_function(scores, labels), len(labels) - def _make_padded_tensor_for_batch(self, sentences: List[Sentence]) -> Tuple[torch.LongTensor, torch.Tensor]: + def _make_padded_tensor_for_batch(self, sentences: list[Sentence]) -> tuple[torch.LongTensor, torch.Tensor]: names = self.embeddings.get_names() - lengths: List[int] = [len(sentence.tokens) for sentence in sentences] + lengths: list[int] = [len(sentence.tokens) for sentence in sentences] longest_token_sequence_in_batch: int = max(lengths) pre_allocated_zero_tensor = torch.zeros( self.embeddings.embedding_length * longest_token_sequence_in_batch, @@ -382,7 +382,7 @@ def _get_scores_from_features(features: torch.Tensor, lengths: torch.Tensor): return scores - def _get_gold_labels(self, sentences: List[Sentence]) -> List[str]: + def _get_gold_labels(self, sentences: list[Sentence]) -> list[str]: """Extracts gold labels from each sentence. Args: @@ -419,7 +419,7 @@ def _get_gold_labels(self, sentences: List[Sentence]) -> List[str]: return labels - def _prepare_label_tensor(self, sentences: List[Sentence]): + def _prepare_label_tensor(self, sentences: list[Sentence]): gold_labels = self._get_gold_labels(sentences) labels = torch.tensor( [self.label_dictionary.get_idx_for_item(label) for label in gold_labels], @@ -430,7 +430,7 @@ def _prepare_label_tensor(self, sentences: List[Sentence]): def predict( self, - sentences: Union[List[Sentence], Sentence], + sentences: Union[list[Sentence], Sentence], mini_batch_size: int = 32, return_probabilities_for_all_classes: bool = False, verbose: bool = False, @@ -462,7 +462,7 @@ def predict( if not isinstance(sentences, list) and not isinstance(sentences, flair.data.Dataset): sentences = [sentences] - Sentence.set_context_for_sentences(cast(List[Sentence], sentences)) + Sentence.set_context_for_sentences(cast(list[Sentence], sentences)) # filter empty sentences sentences = [sentence for sentence in sentences if len(sentence) > 0] @@ -542,7 +542,7 @@ def predict( return overall_loss, label_count return None - def _standard_inference(self, features: torch.Tensor, batch: List[Sentence], probabilities_for_all_classes: bool): + def _standard_inference(self, features: torch.Tensor, batch: list[Sentence], probabilities_for_all_classes: bool): """Softmax over emission scores from forward propagation. Args: @@ -573,7 +573,7 @@ def _standard_inference(self, features: torch.Tensor, batch: List[Sentence], pro return predictions, all_tags - def _all_scores_for_token(self, sentences: List[Sentence], score_tensor: torch.Tensor, lengths: List[int]): + def _all_scores_for_token(self, sentences: list[Sentence], score_tensor: torch.Tensor, lengths: list[int]): """Returns all scores for each tag in tag dictionary.""" scores = score_tensor.numpy() tokens = [token for sentence in sentences for token in sentence] @@ -861,7 +861,7 @@ def push_to_hub( return repo_url @staticmethod - def _filter_empty_sentences(sentences: List[Sentence]) -> List[Sentence]: + def _filter_empty_sentences(sentences: list[Sentence]) -> list[Sentence]: filtered_sentences = [sentence for sentence in sentences if sentence.tokens] if len(sentences) != len(filtered_sentences): log.warning(f"Ignore {len(sentences) - len(filtered_sentences)} sentence(s) with no tokens.") @@ -919,7 +919,7 @@ def _print_predictions(self, batch, gold_label_type): return lines @classmethod - def load(cls, model_path: Union[str, Path, Dict[str, Any]]) -> "SequenceTagger": + def load(cls, model_path: Union[str, Path, dict[str, Any]]) -> "SequenceTagger": from typing import cast return cast("SequenceTagger", super().load(model_path=model_path)) diff --git a/flair/models/sequence_tagger_utils/viterbi.py b/flair/models/sequence_tagger_utils/viterbi.py index ed84d1a6b7..5c87f49f0c 100644 --- a/flair/models/sequence_tagger_utils/viterbi.py +++ b/flair/models/sequence_tagger_utils/viterbi.py @@ -1,5 +1,3 @@ -from typing import Tuple - import numpy as np import torch import torch.nn @@ -7,7 +5,7 @@ from torch.nn.utils.rnn import pack_padded_sequence import flair -from flair.data import Dictionary, Label, List, Sentence +from flair.data import Dictionary, Label, Sentence START_TAG: str = "" STOP_TAG: str = "" @@ -141,8 +139,8 @@ def __init__(self, tag_dictionary: Dictionary) -> None: self.stop_tag = tag_dictionary.get_idx_for_item(STOP_TAG) def decode( - self, features_tuple: tuple, probabilities_for_all_classes: bool, sentences: List[Sentence] - ) -> Tuple[List, List]: + self, features_tuple: tuple, probabilities_for_all_classes: bool, sentences: list[Sentence] + ) -> tuple[list[list[tuple[str, float]]], list[list[list[Label]]]]: """Decoding function returning the most likely sequence of tags. Args: @@ -211,7 +209,7 @@ def decode( scores = softmax(scores_upto_t, dim=2) confidences = torch.max(scores, dim=2) - tags = [] + tags: list[list[tuple[str, float]]] = [] for tag_seq, tag_seq_conf, length_seq in zip(decoded, confidences.values, lengths): tags.append( [ @@ -230,7 +228,7 @@ def _all_scores_for_token( score_tensor: torch.Tensor, tag_sequences: torch.Tensor, lengths: torch.IntTensor, - sentences: List[Sentence], + sentences: list[Sentence], ): """Returns all scores for each tag in tag dictionary.""" scores = score_tensor.numpy() diff --git a/flair/models/tars_model.py b/flair/models/tars_model.py index a7a41bdb5d..4f5cb85731 100644 --- a/flair/models/tars_model.py +++ b/flair/models/tars_model.py @@ -3,7 +3,7 @@ from abc import ABC from collections import OrderedDict from pathlib import Path -from typing import Any, Dict, List, Optional, Set, Tuple, Union +from typing import Any, Optional, Union import numpy as np import torch @@ -30,14 +30,14 @@ class FewshotClassifier(flair.nn.Classifier[Sentence], ABC): def __init__(self) -> None: self._current_task = None - self._task_specific_attributes: Dict[str, Dict[str, Any]] = {} + self._task_specific_attributes: dict[str, dict[str, Any]] = {} self.label_nearest_map = None self.tars_model: flair.nn.Classifier[Sentence] self.separator: str super().__init__() - def forward_loss(self, data_points: Union[List[Sentence], Sentence]) -> Tuple[torch.Tensor, int]: + def forward_loss(self, data_points: Union[list[Sentence], Sentence]) -> tuple[torch.Tensor, int]: if not isinstance(data_points, list): data_points = [data_points] @@ -54,7 +54,7 @@ def tars_embeddings(self): def _get_tars_formatted_sentence(self, label, sentence): raise NotImplementedError - def _get_tars_formatted_sentences(self, sentences: List[Sentence]): + def _get_tars_formatted_sentences(self, sentences: list[Sentence]): label_text_pairs = [] all_labels = [label.decode("utf-8") for label in self.get_current_label_dictionary().idx2item] for sentence in sentences: @@ -173,7 +173,7 @@ def is_current_task_multi_label(self): def add_and_switch_to_new_task( self, task_name: str, - label_dictionary: Union[List, Set, Dictionary, str], + label_dictionary: Union[list, set, Dictionary, str], label_type: str, multi_label: bool = True, force_switch: bool = False, @@ -219,7 +219,7 @@ def add_and_switch_to_new_task( self.switch_to_task(task_name) - def list_existing_tasks(self) -> Set[str]: + def list_existing_tasks(self) -> set[str]: """Lists existing tasks in the loaded TARS model on the console.""" return set(self._task_specific_attributes.keys()) @@ -246,7 +246,7 @@ def _drop_task(self, task_name): log.warning("No task exists with the name `%s`.", task_name) @staticmethod - def _filter_empty_sentences(sentences: List[Sentence]) -> List[Sentence]: + def _filter_empty_sentences(sentences: list[Sentence]) -> list[Sentence]: filtered_sentences = [sentence for sentence in sentences if sentence.tokens] if len(sentences) != len(filtered_sentences): log.warning(f"Ignore {len(sentences) - len(filtered_sentences)} sentence(s) with no tokens.") @@ -258,8 +258,8 @@ def label_type(self): def predict_zero_shot( self, - sentences: Union[List[Sentence], Sentence], - candidate_label_set: Union[List[str], Set[str], str], + sentences: Union[list[Sentence], Sentence], + candidate_label_set: Union[list[str], set[str], str], multi_label: bool = True, ): """Make zero shot predictions from the TARS model. @@ -307,14 +307,14 @@ def predict_zero_shot( def get_used_tokens( self, corpus: Corpus, context_length: int = 0, respect_document_boundaries: bool = True - ) -> typing.Iterable[List[str]]: + ) -> typing.Iterable[list[str]]: yield from super().get_used_tokens(corpus, context_length, respect_document_boundaries) for label in self.get_current_label_dictionary().idx2item: yield [label.decode("utf-8")] yield [self.separator] @classmethod - def load(cls, model_path: Union[str, Path, Dict[str, Any]]) -> "FewshotClassifier": + def load(cls, model_path: Union[str, Path, dict[str, Any]]) -> "FewshotClassifier": from typing import cast return cast("FewshotClassifier", super().load(model_path=model_path)) @@ -472,7 +472,7 @@ def tars_embeddings(self): def predict( self, - sentences: Union[List[Sentence], Sentence], + sentences: Union[list[Sentence], Sentence], mini_batch_size=32, return_probabilities_for_all_classes: bool = False, verbose: bool = False, @@ -532,12 +532,12 @@ def predict( if not batch: continue - tars_sentences: List[Sentence] = [] - all_labels_to_sentence: List[Dict[str, Sentence]] = [] + tars_sentences: list[Sentence] = [] + all_labels_to_sentence: list[dict[str, Sentence]] = [] for sentence in batch: # always remove tags first sentence.remove_labels(label_name) - labels_to_sentence: Dict[str, Sentence] = {} + labels_to_sentence: dict[str, Sentence] = {} for label in all_labels: tars_sentence = self._get_tars_formatted_sentence(label, sentence) tars_sentences.append(tars_sentence) @@ -570,7 +570,7 @@ def predict( if most_probable_first: import operator - already_set_indices: List[int] = [] + already_set_indices: list[int] = [] sorted_x = sorted(all_detected.items(), key=operator.itemgetter(1)) sorted_x.reverse() @@ -648,7 +648,7 @@ def _print_predictions(self, batch, gold_label_type): return lines @classmethod - def load(cls, model_path: Union[str, Path, Dict[str, Any]]) -> "TARSTagger": + def load(cls, model_path: Union[str, Path, dict[str, Any]]) -> "TARSTagger": from typing import cast return cast("TARSTagger", super().load(model_path=model_path)) @@ -832,7 +832,7 @@ def tars_embeddings(self): def predict( self, - sentences: Union[List[Sentence], Sentence], + sentences: Union[list[Sentence], Sentence], mini_batch_size=32, return_probabilities_for_all_classes: bool = False, verbose: bool = False, @@ -907,12 +907,12 @@ def predict( if not batch: continue - tars_sentences: List[Sentence] = [] - all_labels_to_sentence: List[Dict[str, Sentence]] = [] + tars_sentences: list[Sentence] = [] + all_labels_to_sentence: list[dict[str, Sentence]] = [] for sentence in batch: # always remove tags first sentence.remove_labels(label_name) - labels_to_sentence: Dict[str, Sentence] = {} + labels_to_sentence: dict[str, Sentence] = {} for label in all_labels: tars_sentence = self._get_tars_formatted_sentence(label, sentence) tars_sentences.append(tars_sentence) @@ -972,7 +972,7 @@ def predict( return None @classmethod - def load(cls, model_path: Union[str, Path, Dict[str, Any]]) -> "TARSClassifier": + def load(cls, model_path: Union[str, Path, dict[str, Any]]) -> "TARSClassifier": from typing import cast return cast("TARSClassifier", super().load(model_path=model_path)) diff --git a/flair/models/text_classification_model.py b/flair/models/text_classification_model.py index 1b330a0da3..7f4e00d2c4 100644 --- a/flair/models/text_classification_model.py +++ b/flair/models/text_classification_model.py @@ -1,6 +1,6 @@ import logging from pathlib import Path -from typing import Any, Dict, List, Union +from typing import Any, Union import torch @@ -56,7 +56,7 @@ def _get_embedding_for_data_point(self, prediction_data_point: Sentence) -> torc embedding_names = self.embeddings.get_names() return prediction_data_point.get_embedding(embedding_names) - def _get_data_points_from_sentence(self, sentence: Sentence) -> List[Sentence]: + def _get_data_points_from_sentence(self, sentence: Sentence) -> list[Sentence]: return [sentence] def _get_state_dict(self): @@ -133,7 +133,7 @@ def label_type(self): return self._label_type @classmethod - def load(cls, model_path: Union[str, Path, Dict[str, Any]]) -> "TextClassifier": + def load(cls, model_path: Union[str, Path, dict[str, Any]]) -> "TextClassifier": from typing import cast return cast("TextClassifier", super().load(model_path=model_path)) diff --git a/flair/models/text_regression_model.py b/flair/models/text_regression_model.py index 894ce3087e..d1ad98d4e0 100644 --- a/flair/models/text_regression_model.py +++ b/flair/models/text_regression_model.py @@ -1,7 +1,7 @@ import logging import typing from pathlib import Path -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Optional, Union import torch from torch import nn @@ -43,7 +43,7 @@ def __init__( def label_type(self): return self.label_name - def _prepare_tensors(self, sentences: List[Sentence]) -> Tuple[torch.Tensor]: + def _prepare_tensors(self, sentences: list[Sentence]) -> tuple[torch.Tensor]: self.document_embeddings.embed(sentences) embedding_names = self.document_embeddings.get_names() text_embedding_list = [sentence.get_embedding(embedding_names).unsqueeze(0) for sentence in sentences] @@ -55,14 +55,14 @@ def forward(self, *args: torch.Tensor) -> torch.Tensor: label_scores = self.decoder(text_embedding_tensor) return label_scores - def forward_loss(self, sentences: List[Sentence]) -> Tuple[torch.Tensor, int]: + def forward_loss(self, sentences: list[Sentence]) -> tuple[torch.Tensor, int]: labels = self._labels_to_tensor(sentences) text_embedding_tensor = self._prepare_tensors(sentences) scores = self.forward(*text_embedding_tensor) return self.loss_function(scores.squeeze(1), labels), len(sentences) - def _labels_to_tensor(self, sentences: List[Sentence]): + def _labels_to_tensor(self, sentences: list[Sentence]): indices = [ torch.tensor([float(label.value) for label in sentence.get_labels(self.label_name)], dtype=torch.float) for sentence in sentences @@ -74,12 +74,12 @@ def _labels_to_tensor(self, sentences: List[Sentence]): def predict( self, - sentences: Union[Sentence, List[Sentence]], + sentences: Union[Sentence, list[Sentence]], mini_batch_size: int = 32, verbose: bool = False, label_name: Optional[str] = None, embedding_storage_mode: EmbeddingStorageMode = "none", - ) -> List[Sentence]: + ) -> list[Sentence]: if label_name is None: label_name = self.label_name if self.label_name is not None else "label" @@ -123,7 +123,7 @@ def predict( return sentences - def forward_labels_and_loss(self, sentences: List[Sentence]) -> Tuple[torch.Tensor, torch.Tensor]: + def forward_labels_and_loss(self, sentences: list[Sentence]) -> tuple[torch.Tensor, torch.Tensor]: labels = self._labels_to_tensor(sentences) text_embedding_tensor = self._prepare_tensors(sentences) scores = self.forward(*text_embedding_tensor) @@ -132,13 +132,13 @@ def forward_labels_and_loss(self, sentences: List[Sentence]) -> Tuple[torch.Tens def evaluate( self, - data_points: Union[List[Sentence], Dataset], + data_points: Union[list[Sentence], Dataset], gold_label_type: str, out_path: Optional[Union[str, Path]] = None, embedding_storage_mode: EmbeddingStorageMode = "none", mini_batch_size: int = 32, - main_evaluation_metric: Tuple[str, str] = ("micro avg", "f1-score"), - exclude_labels: Optional[List[str]] = None, + main_evaluation_metric: tuple[str, str] = ("micro avg", "f1-score"), + exclude_labels: Optional[list[str]] = None, gold_label_dictionary: Optional[Dictionary] = None, return_loss: bool = True, **kwargs, @@ -154,7 +154,7 @@ def evaluate( metric = MetricRegression("Evaluation") - lines: List[str] = [] + lines: list[str] = [] total_count = 0 for batch in data_loader: if isinstance(batch, Sentence): @@ -227,21 +227,21 @@ def _init_model_with_state_dict(cls, state, **kwargs): ) @staticmethod - def _filter_empty_sentences(sentences: List[Sentence]) -> List[Sentence]: + def _filter_empty_sentences(sentences: list[Sentence]) -> list[Sentence]: filtered_sentences = [sentence for sentence in sentences if sentence.tokens] if len(sentences) != len(filtered_sentences): log.warning(f"Ignore {len(sentences) - len(filtered_sentences)} sentence(s) with no tokens.") return filtered_sentences @classmethod - def load(cls, model_path: Union[str, Path, Dict[str, Any]]) -> "TextRegressor": + def load(cls, model_path: Union[str, Path, dict[str, Any]]) -> "TextRegressor": from typing import cast return cast("TextRegressor", super().load(model_path=model_path)) def get_used_tokens( self, corpus: Corpus, context_length: int = 0, respect_document_boundaries: bool = True - ) -> typing.Iterable[List[str]]: + ) -> typing.Iterable[list[str]]: for sentence in _iter_dataset(corpus.get_all_sentences()): yield [t.text for t in sentence] yield [t.text for t in sentence.left_context(context_length, respect_document_boundaries)] diff --git a/flair/models/triple_classification_model.py b/flair/models/triple_classification_model.py index 1c1337f9b0..9f1a57a23e 100644 --- a/flair/models/triple_classification_model.py +++ b/flair/models/triple_classification_model.py @@ -1,5 +1,4 @@ import typing -from typing import List import torch @@ -69,7 +68,7 @@ def __init__( def label_type(self): return self._label_type - def _get_data_points_from_sentence(self, sentence: TextTriple) -> List[TextTriple]: + def _get_data_points_from_sentence(self, sentence: TextTriple) -> list[TextTriple]: return [sentence] def _get_embedding_for_data_point(self, prediction_data_point: TextTriple) -> torch.Tensor: @@ -121,7 +120,7 @@ def _init_model_with_state_dict(cls, state, **kwargs): def get_used_tokens( self, corpus: Corpus, context_length: int = 0, respect_document_boundaries: bool = True - ) -> typing.Iterable[List[str]]: + ) -> typing.Iterable[list[str]]: for sentence_pair in _iter_dataset(corpus.get_all_sentences()): yield [t.text for t in sentence_pair.first] yield [t.text for t in sentence_pair.first.left_context(context_length, respect_document_boundaries)] diff --git a/flair/models/word_tagger_model.py b/flair/models/word_tagger_model.py index 2d32a54b06..5040a63728 100644 --- a/flair/models/word_tagger_model.py +++ b/flair/models/word_tagger_model.py @@ -1,6 +1,6 @@ import logging from pathlib import Path -from typing import Any, Dict, List, Union +from typing import Any, Union import torch from deprecated.sphinx import deprecated @@ -99,7 +99,7 @@ def _get_embedding_for_data_point(self, prediction_data_point: Token) -> torch.T names = self.embeddings.get_names() return prediction_data_point.get_embedding(names) - def _get_data_points_from_sentence(self, sentence: Sentence) -> List[Token]: + def _get_data_points_from_sentence(self, sentence: Sentence) -> list[Token]: # special handling during training if this is a span prediction problem if self.training and self.span_prediction_problem: for token in sentence.tokens: @@ -125,7 +125,7 @@ def _post_process_batch_after_prediction(self, batch, label_name): for sentence in batch: # internal variables previous_tag = "O-" - current_span: List[Token] = [] + current_span: list[Token] = [] for token in sentence: bioes_tag = token.get_label(label_name).value @@ -222,7 +222,7 @@ def _print_predictions(self, batch, gold_label_type): return lines @classmethod - def load(cls, model_path: Union[str, Path, Dict[str, Any]]) -> "TokenClassifier": + def load(cls, model_path: Union[str, Path, dict[str, Any]]) -> "TokenClassifier": from typing import cast return cast("TokenClassifier", super().load(model_path=model_path)) diff --git a/flair/nn/decoder.py b/flair/nn/decoder.py index 65f802148a..48cdbf39b0 100644 --- a/flair/nn/decoder.py +++ b/flair/nn/decoder.py @@ -1,5 +1,5 @@ import logging -from typing import List, Optional +from typing import Optional import torch @@ -151,11 +151,11 @@ class LabelVerbalizerDecoder(torch.nn.Module): def __init__(self, label_embedding: Embeddings, label_dictionary: Dictionary): super().__init__() self.label_embedding = label_embedding - self.verbalized_labels: List[Sentence] = self.verbalize_labels(label_dictionary) + self.verbalized_labels: list[Sentence] = self.verbalize_labels(label_dictionary) self.to(flair.device) @staticmethod - def verbalize_labels(label_dictionary: Dictionary) -> List[Sentence]: + def verbalize_labels(label_dictionary: Dictionary) -> list[Sentence]: """Takes a label dictionary and returns a list of sentences with verbalized labels. Args: diff --git a/flair/nn/model.py b/flair/nn/model.py index eeb5b7c84a..bf13baf2f1 100644 --- a/flair/nn/model.py +++ b/flair/nn/model.py @@ -5,7 +5,7 @@ from abc import ABC, abstractmethod from collections import Counter from pathlib import Path -from typing import Any, Dict, List, Optional, Set, Tuple, Union +from typing import Any, Optional, Union import torch.nn from torch import Tensor @@ -31,7 +31,7 @@ class Model(torch.nn.Module, typing.Generic[DT], ABC): Every new type of model must implement these methods. """ - model_card: Optional[Dict[str, Any]] = None + model_card: Optional[dict[str, Any]] = None @property @abstractmethod @@ -40,7 +40,7 @@ def label_type(self) -> str: raise NotImplementedError @abstractmethod - def forward_loss(self, data_points: List[DT]) -> Tuple[torch.Tensor, int]: + def forward_loss(self, data_points: list[DT]) -> tuple[torch.Tensor, int]: """Performs a forward pass and returns a loss tensor for backpropagation. Implement this to enable training. @@ -50,13 +50,13 @@ def forward_loss(self, data_points: List[DT]) -> Tuple[torch.Tensor, int]: @abstractmethod def evaluate( self, - data_points: Union[List[DT], Dataset], + data_points: Union[list[DT], Dataset], gold_label_type: str, out_path: Optional[Union[str, Path]] = None, embedding_storage_mode: EmbeddingStorageMode = "none", mini_batch_size: int = 32, - main_evaluation_metric: Tuple[str, str] = ("micro avg", "f1-score"), - exclude_labels: Optional[List[str]] = None, + main_evaluation_metric: tuple[str, str] = ("micro avg", "f1-score"), + exclude_labels: Optional[list[str]] = None, gold_label_dictionary: Optional[Dictionary] = None, return_loss: bool = True, **kwargs, @@ -84,7 +84,7 @@ def evaluate( exclude_labels = exclude_labels if exclude_labels is not None else [] raise NotImplementedError - def _get_state_dict(self) -> Dict: + def _get_state_dict(self) -> dict: """Returns the state dictionary for this model.""" # Always include the name of the Model class for which the state dict holds state_dict = {"state_dict": self.state_dict(), "__cls__": self.__class__.__name__} @@ -92,7 +92,7 @@ def _get_state_dict(self) -> Dict: return state_dict @classmethod - def _init_model_with_state_dict(cls, state: Dict[str, Any], **kwargs): + def _init_model_with_state_dict(cls, state: dict[str, Any], **kwargs): """Initialize the model from a state dictionary.""" if "embeddings" in kwargs: embeddings = kwargs.pop("embeddings") @@ -128,7 +128,7 @@ def save(self, model_file: Union[str, Path], checkpoint: bool = False) -> None: torch.save(model_state, str(model_file), pickle_protocol=4) @classmethod - def load(cls, model_path: Union[str, Path, Dict[str, Any]]) -> "Model": + def load(cls, model_path: Union[str, Path, dict[str, Any]]) -> "Model": """Loads the model from the given file. Args: @@ -238,7 +238,7 @@ class ReduceTransformerVocabMixin(ABC): @abstractmethod def get_used_tokens( self, corpus: Corpus, context_lenth: int = 0, respect_document_boundaries: bool = True - ) -> typing.Iterable[List[str]]: + ) -> typing.Iterable[list[str]]: pass @@ -251,13 +251,13 @@ class Classifier(Model[DT], typing.Generic[DT], ReduceTransformerVocabMixin, ABC def evaluate( self, - data_points: Union[List[DT], Dataset], + data_points: Union[list[DT], Dataset], gold_label_type: str, out_path: Optional[Union[str, Path]] = None, embedding_storage_mode: EmbeddingStorageMode = "none", mini_batch_size: int = 32, - main_evaluation_metric: Tuple[str, str] = ("micro avg", "f1-score"), - exclude_labels: Optional[List[str]] = None, + main_evaluation_metric: tuple[str, str] = ("micro avg", "f1-score"), + exclude_labels: Optional[list[str]] = None, gold_label_dictionary: Optional[Dictionary] = None, return_loss: bool = True, **kwargs, @@ -281,10 +281,10 @@ def evaluate( average_over = 0 # variables for printing - lines: List[str] = [] + lines: list[str] = [] # variables for computing scores - all_spans: Set[str] = set() + all_spans: set[str] = set() all_true_values = {} all_predicted_values = {} @@ -476,7 +476,7 @@ def evaluate( ) # Create and populate score object for logging with all evaluation values, plus the loss - scores: Dict[Union[Tuple[str, ...], str], Any] = {} + scores: dict[Union[tuple[str, ...], str], Any] = {} for avg_type in ("micro avg", "macro avg"): for metric_type in ("f1-score", "precision", "recall"): @@ -514,7 +514,7 @@ def evaluate( @abstractmethod def predict( self, - sentences: Union[List[DT], DT], + sentences: Union[list[DT], DT], mini_batch_size: int = 32, return_probabilities_for_all_classes: bool = False, verbose: bool = False, @@ -537,7 +537,7 @@ def predict( """ raise NotImplementedError - def _print_predictions(self, batch: List[DT], gold_label_type: str) -> List[str]: + def _print_predictions(self, batch: list[DT], gold_label_type: str) -> list[str]: lines = [] for datapoint in batch: # check if there is a label mismatch @@ -557,14 +557,14 @@ def _print_predictions(self, batch: List[DT], gold_label_type: str) -> List[str] def get_used_tokens( self, corpus: Corpus, context_length: int = 0, respect_document_boundaries: bool = True - ) -> typing.Iterable[List[str]]: + ) -> typing.Iterable[list[str]]: for sentence in _iter_dataset(corpus.get_all_sentences()): yield [t.text for t in sentence] yield [t.text for t in sentence.left_context(context_length, respect_document_boundaries)] yield [t.text for t in sentence.right_context(context_length, respect_document_boundaries)] @classmethod - def load(cls, model_path: Union[str, Path, Dict[str, Any]]) -> "Classifier": + def load(cls, model_path: Union[str, Path, dict[str, Any]]) -> "Classifier": from typing import cast return cast("Classifier", super().load(model_path=model_path)) @@ -589,7 +589,7 @@ def __init__( word_dropout: float = 0.0, multi_label: bool = False, multi_label_threshold: float = 0.5, - loss_weights: Optional[Dict[str, float]] = None, + loss_weights: Optional[dict[str, float]] = None, decoder: Optional[torch.nn.Module] = None, inverse_model: bool = False, train_on_gold_pairs_only: bool = False, @@ -663,21 +663,21 @@ def _get_embedding_for_data_point(self, prediction_data_point: DT2) -> torch.Ten raise NotImplementedError @abstractmethod - def _get_data_points_from_sentence(self, sentence: DT) -> List[DT2]: + def _get_data_points_from_sentence(self, sentence: DT) -> list[DT2]: """Returns the data_points to which labels are added. The results should be of any type that inherits from DataPoint (Sentence, Span, Token, ... objects). """ raise NotImplementedError - def _get_data_points_for_batch(self, sentences: List[DT]) -> List[DT2]: + def _get_data_points_for_batch(self, sentences: list[DT]) -> list[DT2]: """Returns the data_points to which labels are added. The results should be of any type that inherits from DataPoint (Sentence, Span, Token, ... objects). """ return [data_point for sentence in sentences for data_point in self._get_data_points_from_sentence(sentence)] - def _get_label_of_datapoint(self, data_point: DT2) -> List[str]: + def _get_label_of_datapoint(self, data_point: DT2) -> list[str]: """Extracts the labels from the data points. Each data point might return a list of strings, representing multiple labels. @@ -701,7 +701,7 @@ def multi_label_threshold(self, x): # setter method else: self._multi_label_threshold = {"default": x} - def _prepare_label_tensor(self, prediction_data_points: List[DT2]) -> torch.Tensor: + def _prepare_label_tensor(self, prediction_data_points: list[DT2]) -> torch.Tensor: labels = [self._get_label_of_datapoint(dp) for dp in prediction_data_points] if self.multi_label: return torch.tensor( @@ -726,7 +726,7 @@ def _prepare_label_tensor(self, prediction_data_points: List[DT2]) -> torch.Tens device=flair.device, ) - def _encode_data_points(self, sentences: List[DT], data_points: List[DT2]) -> Tensor: + def _encode_data_points(self, sentences: list[DT], data_points: list[DT2]) -> Tensor: # embed sentences if self.should_embed_sentence: self.embeddings.embed(sentences) @@ -747,7 +747,7 @@ def _mask_scores(self, scores: Tensor, data_points) -> Tensor: """Classes that inherit from DefaultClassifier may optionally mask scores.""" return scores - def forward_loss(self, sentences: List[DT]) -> Tuple[torch.Tensor, int]: + def forward_loss(self, sentences: list[DT]) -> tuple[torch.Tensor, int]: # make a forward pass to produce embedded data points and labels sentences = [sentence for sentence in sentences if self._filter_data_point(sentence)] @@ -773,10 +773,10 @@ def forward_loss(self, sentences: List[DT]) -> Tuple[torch.Tensor, int]: # calculate the loss return self._calculate_loss(scores, label_tensor) - def _calculate_loss(self, scores: torch.Tensor, labels: torch.Tensor) -> Tuple[torch.Tensor, int]: + def _calculate_loss(self, scores: torch.Tensor, labels: torch.Tensor) -> tuple[torch.Tensor, int]: return self.loss_function(scores, labels), labels.size(0) - def _sort_data(self, data_points: List[DT]) -> List[DT]: + def _sort_data(self, data_points: list[DT]) -> list[DT]: if len(data_points) == 0: return [] @@ -784,16 +784,16 @@ def _sort_data(self, data_points: List[DT]) -> List[DT]: return data_points # filter empty sentences - sentences = [sentence for sentence in typing.cast(List[Sentence], data_points) if len(sentence) > 0] + sentences = [sentence for sentence in typing.cast(list[Sentence], data_points) if len(sentence) > 0] # reverse sort all sequences by their length reordered_sentences = sorted(sentences, key=len, reverse=True) - return typing.cast(List[DT], reordered_sentences) + return typing.cast(list[DT], reordered_sentences) def predict( self, - sentences: Union[List[DT], DT], + sentences: Union[list[DT], DT], mini_batch_size: int = 32, return_probabilities_for_all_classes: bool = False, verbose: bool = False, @@ -824,7 +824,7 @@ def predict( sentences = [sentences] if isinstance(sentences[0], Sentence): - Sentence.set_context_for_sentences(typing.cast(List[Sentence], sentences)) + Sentence.set_context_for_sentences(typing.cast(list[Sentence], sentences)) reordered_sentences = self._sort_data(sentences) @@ -832,7 +832,7 @@ def predict( return sentences if len(reordered_sentences) > mini_batch_size: - batches: Union[DataLoader, List[List[DT]]] = DataLoader( + batches: Union[DataLoader, list[list[DT]]] = DataLoader( dataset=FlairDatapointDataset(reordered_sentences), batch_size=mini_batch_size, ) @@ -990,7 +990,7 @@ def _get_state_dict(self): return state @classmethod - def load(cls, model_path: Union[str, Path, Dict[str, Any]]) -> "DefaultClassifier": + def load(cls, model_path: Union[str, Path, dict[str, Any]]) -> "DefaultClassifier": from typing import cast return cast("DefaultClassifier", super().load(model_path=model_path)) diff --git a/flair/nn/multitask.py b/flair/nn/multitask.py index 6fa2f20c02..42c5665141 100644 --- a/flair/nn/multitask.py +++ b/flair/nn/multitask.py @@ -1,4 +1,5 @@ -from typing import Iterable, Tuple, Union +from collections.abc import Iterable +from typing import Union from flair.data import Corpus, MultiCorpus from flair.models import MultitaskModel @@ -6,18 +7,18 @@ def make_multitask_model_and_corpus( - mapping: Iterable[Union[Tuple[Classifier, Corpus], Tuple[Classifier, Corpus, float]]] -) -> Tuple[Model, Corpus]: + mapping: Iterable[Union[tuple[Classifier, Corpus], tuple[Classifier, Corpus, float]]] +) -> tuple[Model, Corpus]: models = [] corpora = [] loss_factors = [] ids = [] - for task_id, map in enumerate(mapping): - models.append(map[0]) - corpora.append(map[1]) - if len(map) == 3: - loss_factors.append(map[2]) + for task_id, _map in enumerate(mapping): + models.append(_map[0]) + corpora.append(_map[1]) + if len(_map) == 3: + loss_factors.append(_map[2]) else: loss_factors.append(1.0) diff --git a/flair/samplers.py b/flair/samplers.py index 135dfb3310..53ad40c4c5 100644 --- a/flair/samplers.py +++ b/flair/samplers.py @@ -1,7 +1,6 @@ import logging import random from collections import defaultdict -from typing import Dict import torch from torch.utils.data.sampler import Sampler @@ -36,7 +35,7 @@ def set_dataset(self, data_source): self.indices = list(range(len(data_source))) # first determine the distribution of classes in the dataset - label_count: Dict[str, int] = defaultdict(int) + label_count: dict[str, int] = defaultdict(int) for sentence in data_source: for label in sentence.labels: label_count[label.value] += 1 diff --git a/flair/splitter.py b/flair/splitter.py index 9f7e502c87..2b6c90cd7f 100644 --- a/flair/splitter.py +++ b/flair/splitter.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Any, List, Optional, Union +from typing import Any, Optional, Union from segtok.segmenter import split_multi @@ -25,7 +25,7 @@ class SentenceSplitter(ABC): the sentence splitter's configuration. """ - def split(self, text: str, link_sentences: Optional[bool] = True) -> List[Sentence]: + def split(self, text: str, link_sentences: Optional[bool] = True) -> list[Sentence]: sentences = self._perform_split(text) if not link_sentences: return sentences @@ -34,7 +34,7 @@ def split(self, text: str, link_sentences: Optional[bool] = True) -> List[Senten return sentences @abstractmethod - def _perform_split(self, text: str) -> List[Sentence]: + def _perform_split(self, text: str) -> list[Sentence]: raise NotImplementedError @property @@ -62,11 +62,11 @@ def __init__(self, tokenizer: Tokenizer = SegtokTokenizer()) -> None: super().__init__() self._tokenizer = tokenizer - def _perform_split(self, text: str) -> List[Sentence]: - plain_sentences: List[str] = split_multi(text) + def _perform_split(self, text: str) -> list[Sentence]: + plain_sentences: list[str] = split_multi(text) sentence_offset = 0 - sentences: List[Sentence] = [] + sentences: list[Sentence] = [] for sentence in plain_sentences: try: sentence_offset = text.index(sentence, sentence_offset) @@ -133,7 +133,7 @@ def __init__(self, model: Union[Any, str], tokenizer: Optional[Tokenizer] = None else: self._tokenizer = tokenizer - def _perform_split(self, text: str) -> List[Sentence]: + def _perform_split(self, text: str) -> list[Sentence]: document = self.model(text) sentences = [ @@ -192,7 +192,7 @@ def __init__(self, tag: str, tokenizer: Tokenizer = SegtokTokenizer()) -> None: self._tokenizer = tokenizer self.tag = tag - def _perform_split(self, text: str) -> List[Sentence]: + def _perform_split(self, text: str) -> list[Sentence]: plain_sentences = text.split(self.tag) sentences = [] @@ -252,7 +252,7 @@ def __init__(self, tokenizer: Tokenizer = SegtokTokenizer()) -> None: super().__init__() self._tokenizer = tokenizer - def _perform_split(self, text: str) -> List[Sentence]: + def _perform_split(self, text: str) -> list[Sentence]: return [Sentence(text=text, use_tokenizer=self._tokenizer, start_position=0)] @property diff --git a/flair/tokenization.py b/flair/tokenization.py index 185e944d3b..b377c419e6 100644 --- a/flair/tokenization.py +++ b/flair/tokenization.py @@ -1,7 +1,7 @@ import logging import sys from abc import ABC, abstractmethod -from typing import Callable, List +from typing import Callable from segtok.segmenter import split_single from segtok.tokenizer import split_contractions, word_tokenizer @@ -20,7 +20,7 @@ class Tokenizer(ABC): """ @abstractmethod - def tokenize(self, text: str) -> List[str]: + def tokenize(self, text: str) -> list[str]: raise NotImplementedError @property @@ -57,11 +57,11 @@ def __init__(self, model) -> None: "spacy model or the name of the model to load." ) - def tokenize(self, text: str) -> List[str]: + def tokenize(self, text: str) -> list[str]: from spacy.tokens.doc import Doc doc: Doc = self.model.make_doc(text) - words: List[str] = [] + words: list[str] = [] for word in doc: if len(word.text.strip()) == 0: continue @@ -82,12 +82,12 @@ class SegtokTokenizer(Tokenizer): def __init__(self) -> None: super().__init__() - def tokenize(self, text: str) -> List[str]: + def tokenize(self, text: str) -> list[str]: return SegtokTokenizer.run_tokenize(text) @staticmethod - def run_tokenize(text: str) -> List[str]: - words: List[str] = [] + def run_tokenize(text: str) -> list[str]: + words: list[str] = [] sentences = split_single(text) for sentence in sentences: @@ -105,12 +105,12 @@ class SpaceTokenizer(Tokenizer): def __init__(self) -> None: super().__init__() - def tokenize(self, text: str) -> List[str]: + def tokenize(self, text: str) -> list[str]: return SpaceTokenizer.run_tokenize(text) @staticmethod - def run_tokenize(text: str) -> List[str]: - tokens: List[str] = [] + def run_tokenize(text: str) -> list[str]: + tokens: list[str] = [] word = "" index = -1 for index, char in enumerate(text): @@ -166,8 +166,8 @@ def __init__(self, tokenizer: str, sudachi_mode: str = "A") -> None: self.sentence_tokenizer = konoha.SentenceTokenizer() self.word_tokenizer = konoha.WordTokenizer(tokenizer, mode=sudachi_mode) - def tokenize(self, text: str) -> List[str]: - words: List[str] = [] + def tokenize(self, text: str) -> list[str]: + words: list[str] = [] sentences = self.sentence_tokenizer.tokenize(text) for sentence in sentences: @@ -184,11 +184,11 @@ def name(self) -> str: class TokenizerWrapper(Tokenizer): """Helper class to wrap tokenizer functions to the class-based tokenizer interface.""" - def __init__(self, tokenizer_func: Callable[[str], List[str]]) -> None: + def __init__(self, tokenizer_func: Callable[[str], list[str]]) -> None: super().__init__() self.tokenizer_func = tokenizer_func - def tokenize(self, text: str) -> List[str]: + def tokenize(self, text: str) -> list[str]: return self.tokenizer_func(text) @property @@ -225,7 +225,7 @@ def __init__(self) -> None: " Note that the scispacy version and the version of the model must match to work properly!" ) - def combined_rule_prefixes() -> List[str]: + def combined_rule_prefixes() -> list[str]: """Helper function that returns the prefix pattern for the tokenizer. It is a helper function to accommodate spacy tests that only test prefixes. @@ -270,9 +270,9 @@ def combined_rule_prefixes() -> List[str]: self.model.tokenizer.prefix_search = prefix_re.search self.model.tokenizer.infix_finditer = infix_re.finditer - def tokenize(self, text: str) -> List[str]: + def tokenize(self, text: str) -> list[str]: sentence = self.model(text) - words: List[str] = [] + words: list[str] = [] for word in sentence: words.append(word.text) return words diff --git a/flair/trainers/language_model_trainer.py b/flair/trainers/language_model_trainer.py index eb374ed75d..341cead776 100644 --- a/flair/trainers/language_model_trainer.py +++ b/flair/trainers/language_model_trainer.py @@ -3,8 +3,9 @@ import math import random import time +from collections.abc import Iterable from pathlib import Path -from typing import Any, Dict, Iterable, Optional, Type, Union +from typing import Any, Optional, Union import torch from torch import cuda @@ -155,16 +156,16 @@ def __init__( self, model: LanguageModel, corpus: TextCorpus, - optimizer: Type[Optimizer] = SGD, + optimizer: type[Optimizer] = SGD, test_mode: bool = False, epoch: int = 0, split: int = 0, loss: float = 10000, - optimizer_state: Optional[Dict[str, Any]] = None, - scaler_state: Optional[Dict[str, Any]] = None, + optimizer_state: Optional[dict[str, Any]] = None, + scaler_state: Optional[dict[str, Any]] = None, ) -> None: self.model: LanguageModel = model - self.optimizer: Type[Optimizer] = optimizer + self.optimizer: type[Optimizer] = optimizer self.corpus: TextCorpus = corpus self.test_mode: bool = test_mode @@ -362,7 +363,7 @@ def train( ) with open(loss_txt, "a") as myfile: - myfile.write("%s\n" % summary) + myfile.write(f"{summary}\n") log.info(summary) log.info("-" * 89) @@ -386,7 +387,7 @@ def train( summary = f"TEST: valid loss {test_loss:5.4f} | valid ppl {math.exp(test_loss):8.4f}" with open(loss_txt, "a") as myfile: - myfile.write("%s\n" % summary) + myfile.write(f"{summary}\n") log.info(summary) log.info("-" * 89) @@ -440,7 +441,7 @@ def _repackage_hidden(h): def load_checkpoint( checkpoint_file: Union[str, Path], corpus: TextCorpus, - optimizer: Type[Optimizer] = SGD, + optimizer: type[Optimizer] = SGD, ): if isinstance(checkpoint_file, str): checkpoint_file = Path(checkpoint_file) diff --git a/flair/trainers/plugins/base.py b/flair/trainers/plugins/base.py index 958d57b785..663a78d6d3 100644 --- a/flair/trainers/plugins/base.py +++ b/flair/trainers/plugins/base.py @@ -1,19 +1,14 @@ import logging from collections import defaultdict +from collections.abc import Iterator, Sequence from inspect import isclass, signature from itertools import count from queue import Queue from typing import ( Any, Callable, - Dict, - Iterator, - List, NewType, Optional, - Sequence, - Set, - Type, Union, cast, ) @@ -21,7 +16,7 @@ log = logging.getLogger("flair") -PluginArgument = Union["BasePlugin", Type["BasePlugin"]] +PluginArgument = Union["BasePlugin", type["BasePlugin"]] HookHandleId = NewType("HookHandleId", int) EventIdenifier = str @@ -34,7 +29,7 @@ class TrainingInterrupt(Exception): class Pluggable: """Dispatches events which attached plugins can react to.""" - valid_events: Optional[Set[EventIdenifier]] = None + valid_events: Optional[set[EventIdenifier]] = None def __init__(self, *, plugins: Sequence[PluginArgument] = []) -> None: """Initialize a `Pluggable`. @@ -42,11 +37,11 @@ def __init__(self, *, plugins: Sequence[PluginArgument] = []) -> None: Args: plugins: Plugins which should be attached to this `Pluggable`. """ - self._hook_handles: Dict[EventIdenifier, Dict[HookHandleId, HookHandle]] = defaultdict(dict) + self._hook_handles: dict[EventIdenifier, dict[HookHandleId, HookHandle]] = defaultdict(dict) self._hook_handle_id_counter = count() - self._plugins: List[BasePlugin] = [] + self._plugins: list[BasePlugin] = [] # This flag tracks, whether an event is currently being processed (otherwise it is added to the queue) self._processing_events = False @@ -181,7 +176,7 @@ class BasePlugin: def __init__(self) -> None: """Initialize the base plugin.""" - self._hook_handles: List[HookHandle] = [] + self._hook_handles: list[HookHandle] = [] self._pluggable: Optional[Pluggable] = None def attach_to(self, pluggable: Pluggable): @@ -260,7 +255,7 @@ def pluggable(self) -> Optional[Pluggable]: def __str__(self) -> str: return self.__class__.__name__ - def get_state(self) -> Dict[str, Any]: + def get_state(self) -> dict[str, Any]: return {"__cls__": f"{self.__module__}.{self.__class__.__name__}"} diff --git a/flair/trainers/plugins/functional/anneal_on_plateau.py b/flair/trainers/plugins/functional/anneal_on_plateau.py index d62b21fba1..ccd330bf0f 100644 --- a/flair/trainers/plugins/functional/anneal_on_plateau.py +++ b/flair/trainers/plugins/functional/anneal_on_plateau.py @@ -1,6 +1,6 @@ import logging import os -from typing import Any, Dict +from typing import Any from flair.trainers.plugins.base import TrainerPlugin, TrainingInterrupt from flair.trainers.plugins.metric_records import MetricRecord @@ -108,7 +108,7 @@ def __str__(self) -> str: f"min_learning_rate: '{self.min_learning_rate}'" ) - def get_state(self) -> Dict[str, Any]: + def get_state(self) -> dict[str, Any]: return { **super().get_state(), "base_path": str(self.base_path), diff --git a/flair/trainers/plugins/functional/checkpoints.py b/flair/trainers/plugins/functional/checkpoints.py index 75ecb9bd98..1936177835 100644 --- a/flair/trainers/plugins/functional/checkpoints.py +++ b/flair/trainers/plugins/functional/checkpoints.py @@ -1,5 +1,5 @@ import logging -from typing import Any, Dict +from typing import Any from flair.trainers.plugins.base import TrainerPlugin @@ -29,7 +29,7 @@ def after_training_epoch(self, epoch, **kw): model_name = "model_epoch_" + str(epoch) + ".pt" self.model.save(self.base_path / model_name, checkpoint=self.save_optimizer_state) - def get_state(self) -> Dict[str, Any]: + def get_state(self) -> dict[str, Any]: return { **super().get_state(), "base_path": str(self.base_path), diff --git a/flair/trainers/plugins/functional/linear_scheduler.py b/flair/trainers/plugins/functional/linear_scheduler.py index 1000be6dd7..2258844129 100644 --- a/flair/trainers/plugins/functional/linear_scheduler.py +++ b/flair/trainers/plugins/functional/linear_scheduler.py @@ -1,5 +1,5 @@ import logging -from typing import Any, Dict +from typing import Any from flair.optim import LinearSchedulerWithWarmup from flair.trainers.plugins.base import TrainerPlugin @@ -62,7 +62,7 @@ def after_training_batch(self, optimizer_was_run: bool, **kwargs): def __str__(self) -> str: return f"LinearScheduler | warmup_fraction: '{self.warmup_fraction}'" - def get_state(self) -> Dict[str, Any]: + def get_state(self) -> dict[str, Any]: return { **super().get_state(), "warmup_fraction": self.warmup_fraction, diff --git a/flair/trainers/plugins/functional/reduce_transformer_vocab.py b/flair/trainers/plugins/functional/reduce_transformer_vocab.py index 162c667f88..86759c2fe2 100644 --- a/flair/trainers/plugins/functional/reduce_transformer_vocab.py +++ b/flair/trainers/plugins/functional/reduce_transformer_vocab.py @@ -1,6 +1,5 @@ import logging from pathlib import Path -from typing import List from transformer_smaller_training_vocab import reduce_train_vocab @@ -57,7 +56,7 @@ def save_model_at_the_end(self, **kw): self.model.save(self.base_path / "final-model.pt", checkpoint=self.save_optimizer_state) -def get_transformer_embeddings(model: Model) -> List[TransformerEmbeddings]: +def get_transformer_embeddings(model: Model) -> list[TransformerEmbeddings]: embeddings = model.tars_embeddings if isinstance(model, FewshotClassifier) else getattr(model, "embeddings", None) if embeddings is None: diff --git a/flair/trainers/plugins/functional/weight_extractor.py b/flair/trainers/plugins/functional/weight_extractor.py index ef5afe081e..4ba7c07621 100644 --- a/flair/trainers/plugins/functional/weight_extractor.py +++ b/flair/trainers/plugins/functional/weight_extractor.py @@ -1,4 +1,4 @@ -from typing import Any, Dict +from typing import Any from flair.trainers.plugins.base import TrainerPlugin from flair.training_utils import WeightExtractor @@ -21,7 +21,7 @@ def after_training_batch(self, batch_no, epoch, total_number_of_batches, **kw): if (iteration + 1) % modulo == 0: self.weight_extractor.extract_weights(self.model.state_dict(), iteration) - def get_state(self) -> Dict[str, Any]: + def get_state(self) -> dict[str, Any]: return { **super().get_state(), "base_path": str(self.base_path), diff --git a/flair/trainers/plugins/loggers/log_file.py b/flair/trainers/plugins/loggers/log_file.py index a9b7453a09..21a8c54632 100644 --- a/flair/trainers/plugins/loggers/log_file.py +++ b/flair/trainers/plugins/loggers/log_file.py @@ -1,6 +1,6 @@ import logging from pathlib import Path -from typing import Any, Dict +from typing import Any from flair.trainers.plugins.base import TrainerPlugin from flair.training_utils import add_file_handler @@ -21,5 +21,5 @@ def close_file_handler(self, **kw): self.log_handler.close() log.removeHandler(self.log_handler) - def get_state(self) -> Dict[str, Any]: + def get_state(self) -> dict[str, Any]: return {**super().get_state(), "base_path": str(self.base_path)} diff --git a/flair/trainers/plugins/loggers/loss_file.py b/flair/trainers/plugins/loggers/loss_file.py index 29c42fc930..b53a23a956 100644 --- a/flair/trainers/plugins/loggers/loss_file.py +++ b/flair/trainers/plugins/loggers/loss_file.py @@ -1,5 +1,5 @@ from datetime import datetime -from typing import Any, Dict, Optional, Tuple, Union +from typing import Any, Optional, Union from flair.trainers.plugins.base import TrainerPlugin from flair.trainers.plugins.metric_records import MetricName @@ -10,7 +10,7 @@ class LossFilePlugin(TrainerPlugin): """Plugin that manages the loss.tsv file output.""" def __init__( - self, base_path, epoch: int, metrics_to_collect: Optional[Dict[Union[Tuple, str], str]] = None + self, base_path, epoch: int, metrics_to_collect: Optional[dict[Union[tuple, str], str]] = None ) -> None: super().__init__() @@ -56,9 +56,9 @@ def __init__( self.headers[metric_name] = f"{prefix.upper()}_{header}" # initialize the first log line - self.current_row: Optional[Dict[MetricName, str]] = None + self.current_row: Optional[dict[MetricName, str]] = None - def get_state(self) -> Dict[str, Any]: + def get_state(self) -> dict[str, Any]: return { **super().get_state(), "base_path": str(self.base_path), diff --git a/flair/trainers/plugins/loggers/metric_history.py b/flair/trainers/plugins/loggers/metric_history.py index 8d7c946e8d..a22cf1b0e7 100644 --- a/flair/trainers/plugins/loggers/metric_history.py +++ b/flair/trainers/plugins/loggers/metric_history.py @@ -1,5 +1,6 @@ import logging -from typing import Any, Dict, Mapping +from collections.abc import Mapping +from typing import Any from flair.trainers.plugins.base import TrainerPlugin @@ -17,7 +18,7 @@ class MetricHistoryPlugin(TrainerPlugin): def __init__(self, metrics_to_collect: Mapping = default_metrics_to_collect) -> None: super().__init__() - self.metric_history: Dict[str, list] = {} + self.metric_history: dict[str, list] = {} self.metrics_to_collect: Mapping = metrics_to_collect for target in self.metrics_to_collect.values(): self.metric_history[target] = [] @@ -33,7 +34,7 @@ def after_training(self, **kw): """Returns metric history.""" self.trainer.return_values.update(self.metric_history) - def get_state(self) -> Dict[str, Any]: + def get_state(self) -> dict[str, Any]: return { **super().get_state(), "metrics_to_collect": dict(self.metrics_to_collect), diff --git a/flair/trainers/plugins/loggers/tensorboard.py b/flair/trainers/plugins/loggers/tensorboard.py index 59bba9f2e9..a7af50a521 100644 --- a/flair/trainers/plugins/loggers/tensorboard.py +++ b/flair/trainers/plugins/loggers/tensorboard.py @@ -1,6 +1,6 @@ import logging import os -from typing import Any, Dict +from typing import Any from flair.trainers.plugins.base import TrainerPlugin from flair.training_utils import log_line @@ -59,7 +59,7 @@ def _training_finally(self, **kw): assert self.writer is not None self.writer.close() - def get_state(self) -> Dict[str, Any]: + def get_state(self) -> dict[str, Any]: return { **super().get_state(), "log_dir": str(self.log_dir) if self.log_dir is not None else None, diff --git a/flair/trainers/plugins/loggers/wandb.py b/flair/trainers/plugins/loggers/wandb.py index 8608fcdbd9..0f8dc89f73 100644 --- a/flair/trainers/plugins/loggers/wandb.py +++ b/flair/trainers/plugins/loggers/wandb.py @@ -1,5 +1,5 @@ import logging -from typing import Any, Dict +from typing import Any from flair.trainers.plugins.base import TrainerPlugin @@ -72,7 +72,7 @@ def metric_recorded(self, record): def _training_finally(self, **kw): self.writer.close() - def get_state(self) -> Dict[str, Any]: + def get_state(self) -> dict[str, Any]: return { **super().get_state(), "emit_alerts": self.emit_alerts, diff --git a/flair/trainers/plugins/metric_records.py b/flair/trainers/plugins/metric_records.py index 034c021854..548b54fccd 100644 --- a/flair/trainers/plugins/metric_records.py +++ b/flair/trainers/plugins/metric_records.py @@ -1,14 +1,15 @@ import time +from collections.abc import Iterable, Iterator from dataclasses import dataclass from enum import Enum -from typing import Any, Iterable, Iterator, Optional, Tuple, Union +from typing import Any, Optional, Union RecordType = Enum("RecordType", ["scalar", "image", "histogram", "string", "scalar_list"]) class MetricName: def __init__(self, name) -> None: - self.parts: Tuple[str, ...] + self.parts: tuple[str, ...] if isinstance(name, str): self.parts = tuple(name.split("/")) diff --git a/flair/trainers/trainer.py b/flair/trainers/trainer.py index 03e6edc083..2f32b54c01 100644 --- a/flair/trainers/trainer.py +++ b/flair/trainers/trainer.py @@ -7,7 +7,7 @@ import warnings from inspect import signature from pathlib import Path -from typing import List, Optional, Tuple, Type, Union +from typing import Optional, Union import torch from torch.optim.sgd import SGD @@ -128,7 +128,7 @@ def train( base_path, anneal_factor: float = 0.5, patience: int = 3, - min_learning_rate: Union[float, List[float]] = 0.0001, + min_learning_rate: Union[float, list[float]] = 0.0001, initial_extra_patience: int = 0, anneal_with_restarts: bool = False, learning_rate: float = 0.1, @@ -137,17 +137,17 @@ def train( eval_batch_size: int = 64, mini_batch_chunk_size: Optional[int] = None, max_epochs: int = 100, - optimizer: Type[torch.optim.Optimizer] = torch.optim.SGD, + optimizer: type[torch.optim.Optimizer] = torch.optim.SGD, train_with_dev: bool = False, train_with_test: bool = False, reduce_transformer_vocab: bool = False, # evaluation and monitoring - main_evaluation_metric: Tuple[str, str] = ("micro avg", "f1-score"), + main_evaluation_metric: tuple[str, str] = ("micro avg", "f1-score"), monitor_test: bool = False, monitor_train_sample: float = 0.0, use_final_model_for_eval: bool = False, gold_label_dictionary_for_eval: Optional[Dictionary] = None, - exclude_labels: Optional[List[str]] = None, + exclude_labels: Optional[list[str]] = None, # sampling and shuffling sampler=None, shuffle: bool = True, @@ -164,7 +164,7 @@ def train( create_loss_file: bool = True, write_weights: bool = False, # plugins - plugins: Optional[List[TrainerPlugin]] = None, + plugins: Optional[list[TrainerPlugin]] = None, attach_default_scheduler: bool = True, **kwargs, ): @@ -211,17 +211,17 @@ def fine_tune( eval_batch_size: int = 16, mini_batch_chunk_size: Optional[int] = None, max_epochs: int = 10, - optimizer: Type[torch.optim.Optimizer] = torch.optim.AdamW, + optimizer: type[torch.optim.Optimizer] = torch.optim.AdamW, train_with_dev: bool = False, train_with_test: bool = False, reduce_transformer_vocab: bool = False, # evaluation and monitoring - main_evaluation_metric: Tuple[str, str] = ("micro avg", "f1-score"), + main_evaluation_metric: tuple[str, str] = ("micro avg", "f1-score"), monitor_test: bool = False, monitor_train_sample: float = 0.0, use_final_model_for_eval: bool = True, gold_label_dictionary_for_eval: Optional[Dictionary] = None, - exclude_labels: Optional[List[str]] = None, + exclude_labels: Optional[list[str]] = None, # sampling and shuffling sampler=None, shuffle: bool = True, @@ -240,7 +240,7 @@ def fine_tune( # amp use_amp: bool = False, # plugins - plugins: Optional[List[TrainerPlugin]] = None, + plugins: Optional[list[TrainerPlugin]] = None, attach_default_scheduler: bool = True, **kwargs, ): @@ -304,20 +304,20 @@ def train_custom( eval_batch_size: int = 64, mini_batch_chunk_size: Optional[int] = None, max_epochs: int = 100, - optimizer: Type[torch.optim.Optimizer] = SGD, + optimizer: type[torch.optim.Optimizer] = SGD, train_with_dev: bool = False, train_with_test: bool = False, max_grad_norm: Optional[float] = 5.0, reduce_transformer_vocab: bool = False, # evaluation and monitoring - main_evaluation_metric: Tuple[str, str] = ("micro avg", "f1-score"), + main_evaluation_metric: tuple[str, str] = ("micro avg", "f1-score"), monitor_test: bool = False, monitor_train_sample: float = 0.0, use_final_model_for_eval: bool = False, gold_label_dictionary_for_eval: Optional[Dictionary] = None, - exclude_labels: Optional[List[str]] = None, + exclude_labels: Optional[list[str]] = None, # sampling and shuffling - sampler: Optional[FlairSampler] = None, + sampler: Optional[Union[FlairSampler, type[FlairSampler]]] = None, shuffle: bool = True, shuffle_first_epoch: bool = True, # evaluation and monitoring @@ -334,7 +334,7 @@ def train_custom( # amp use_amp: bool = False, # plugins - plugins: Optional[List[TrainerPlugin]] = None, + plugins: Optional[list[TrainerPlugin]] = None, **kwargs, ) -> dict: """Trains any class that implements the flair.nn.Model interface. @@ -475,7 +475,7 @@ def train_custom( # initialize sampler if provided if sampler is not None: # init with default values if only class is provided - if inspect.isclass(sampler): + if isinstance(sampler, type): sampler = sampler() # set dataset to sample from sampler.set_dataset(train_data) diff --git a/flair/training_utils.py b/flair/training_utils.py index 0b4ef91cbf..9b38ec1ddb 100644 --- a/flair/training_utils.py +++ b/flair/training_utils.py @@ -5,7 +5,7 @@ from functools import reduce from math import inf from pathlib import Path -from typing import Dict, List, Literal, Optional, Union +from typing import Literal, Optional, Union from scipy.stats import pearsonr, spearmanr from sklearn.metrics import mean_absolute_error, mean_squared_error @@ -25,7 +25,7 @@ def __init__( main_score: float, detailed_results: str, classification_report: Optional[dict] = None, - scores: Optional[Dict] = None, + scores: Optional[dict] = None, ) -> None: classification_report = classification_report if classification_report is not None else {} assert scores is not None and "loss" in scores, "No loss provided." @@ -47,8 +47,8 @@ class MetricRegression: def __init__(self, name) -> None: self.name = name - self.true: List[float] = [] - self.pred: List[float] = [] + self.true: list[float] = [] + self.pred: list[float] = [] def mean_squared_error(self): return mean_squared_error(self.true, self.pred) @@ -98,7 +98,7 @@ def __init__(self, directory: Union[str, Path], number_of_weights: int = 10) -> if isinstance(directory, str): directory = Path(directory) self.weights_file = init_output_file(directory, "weights.txt") - self.weights_dict: Dict[str, Dict[int, List[float]]] = defaultdict(lambda: defaultdict(list)) + self.weights_dict: dict[str, dict[int, list[float]]] = defaultdict(lambda: defaultdict(list)) self.number_of_weights = number_of_weights def extract_weights(self, state_dict, iteration): @@ -338,7 +338,7 @@ def init_output_file(base_path: Union[str, Path], file_name: str) -> Path: return file -def convert_labels_to_one_hot(label_list: List[List[str]], label_dict: Dictionary) -> List[List[int]]: +def convert_labels_to_one_hot(label_list: list[list[str]], label_dict: Dictionary) -> list[list[int]]: """Convert list of labels to a one hot list. Args: @@ -365,9 +365,9 @@ def add_file_handler(log, output_file): def store_embeddings( - data_points: Union[List[DT], Dataset], + data_points: Union[list[DT], Dataset], storage_mode: EmbeddingStorageMode, - dynamic_embeddings: Optional[List[str]] = None, + dynamic_embeddings: Optional[list[str]] = None, ): if isinstance(data_points, Dataset): data_points = list(_iter_dataset(data_points)) @@ -391,7 +391,7 @@ def store_embeddings( data_point.to("cpu", pin_memory=pin_memory) -def identify_dynamic_embeddings(data_points: List[DT]) -> Optional[List]: +def identify_dynamic_embeddings(data_points: list[DT]) -> Optional[list[str]]: dynamic_embeddings = [] all_embeddings = [] for data_point in data_points: diff --git a/flair/visual/ner_html.py b/flair/visual/ner_html.py index c71e108379..5b691a9e60 100644 --- a/flair/visual/ner_html.py +++ b/flair/visual/ner_html.py @@ -1,5 +1,5 @@ import html -from typing import List, Union +from typing import Union from flair.data import Sentence @@ -41,7 +41,7 @@ def split_to_spans(s: Sentence, label_name="ner"): def render_ner_html( - sentences: Union[List[Sentence], Sentence], + sentences: Union[list[Sentence], Sentence], title: str = "Flair", colors={ "PER": "#F7FF53", diff --git a/flair/visual/training_curves.py b/flair/visual/training_curves.py index 1fd856b669..32947c3348 100644 --- a/flair/visual/training_curves.py +++ b/flair/visual/training_curves.py @@ -3,7 +3,7 @@ import math from collections import defaultdict from pathlib import Path -from typing import Dict, List, Union +from typing import Union import matplotlib.pyplot as plt import numpy as np @@ -27,7 +27,7 @@ class Plotter: def _extract_evaluation_data(file_name: Union[str, Path], score: str = "F1") -> dict: file_name = Path(file_name) - training_curves: Dict[str, Dict[str, List[float]]] = { + training_curves: dict[str, dict[str, list[float]]] = { "train": {"loss": [], "score": []}, "test": {"loss": [], "score": []}, "dev": {"loss": [], "score": []}, @@ -70,7 +70,7 @@ def _extract_weight_data(file_name: Union[str, Path]) -> dict: if isinstance(file_name, str): file_name = Path(file_name) - weights: Dict[str, Dict[str, List[float]]] = defaultdict(lambda: defaultdict(list)) + weights: dict[str, dict[str, list[float]]] = defaultdict(lambda: defaultdict(list)) with open(file_name) as f: tsvin = csv.reader(f, delimiter="\t") @@ -151,7 +151,7 @@ def plot_weights(self, file_name: Union[str, Path]): log.info(f"Weights plots are saved in {path}") # to let user know the path of the save plots plt.close(fig) - def plot_training_curves(self, file_name: Union[str, Path], plot_values: List[str] = ["loss", "F1"]): + def plot_training_curves(self, file_name: Union[str, Path], plot_values: list[str] = ["loss", "F1"]): file_name = Path(file_name) fig = plt.figure(figsize=(15, 10)) diff --git a/flair/visual/tree_printer.py b/flair/visual/tree_printer.py index fc461d9f81..9753a37a09 100644 --- a/flair/visual/tree_printer.py +++ b/flair/visual/tree_printer.py @@ -1,5 +1,3 @@ -from typing import List - from pptree import print_tree from flair.data import Sentence, Token @@ -9,7 +7,7 @@ class NodeToken: def __init__(self, token: Token, tag_type: str) -> None: self.token: Token = token self.tag_type: str = tag_type - self.children: List[NodeToken] = [] + self.children: list[NodeToken] = [] def set_haed(self, parent): parent.children.append(self) @@ -19,7 +17,7 @@ def __str__(self) -> str: def tree_printer(sentence: Sentence, tag_type: str): - tree: List[NodeToken] = [NodeToken(token, tag_type) for token in sentence] + tree: list[NodeToken] = [NodeToken(token, tag_type) for token in sentence] for x in tree: if x.token.head_id != 0: head_token = x.token.get_head() diff --git a/pyproject.toml b/pyproject.toml index 78d1692a09..9f4c5a7535 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.black] line-length = 120 -target-version = ['py37'] +target-version = ['py39'] exclude = ''' ( /( @@ -49,7 +49,7 @@ ignore_errors = true [tool.ruff] line-length = 120 -target-version = "py38" +target-version = "py39" [tool.ruff.lint] #select = ["ALL"] # Uncommit to autofix all the things diff --git a/requirements-dev.txt b/requirements-dev.txt index 61d45acf8c..3b8fbde79c 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -6,7 +6,7 @@ pytest-black-ng==0.4.* pytest-github-actions-annotate-failures>=0.1.8 pytest-mypy>=0.10.3 pytest-ruff==0.3.* -ruff==0.3.* +ruff==0.7.* types-dataclasses>=0.6.6 types-Deprecated>=1.2.9.2 types-requests>=2.28.11.17 diff --git a/resources/docs/EXPERIMENTS.md b/resources/docs/EXPERIMENTS.md index 69f1a5fbf2..c6bbe72a1c 100644 --- a/resources/docs/EXPERIMENTS.md +++ b/resources/docs/EXPERIMENTS.md @@ -55,7 +55,7 @@ tag_type = 'ner' tag_dictionary = corpus.make_tag_dictionary(tag_type=tag_type) # initialize embeddings -embedding_types: List[TokenEmbeddings] = [ +embedding_types: list[TokenEmbeddings] = [ # GloVe embeddings WordEmbeddings('glove'), @@ -124,7 +124,7 @@ tag_type = 'ner' tag_dictionary = corpus.make_tag_dictionary(tag_type=tag_type) # initialize embeddings -embedding_types: List[TokenEmbeddings] = [ +embedding_types: list[TokenEmbeddings] = [ WordEmbeddings('de'), PooledFlairEmbeddings('german-forward'), PooledFlairEmbeddings('german-backward'), @@ -225,7 +225,7 @@ tag_type = 'ner' tag_dictionary = corpus.make_tag_dictionary(tag_type=tag_type) # initialize embeddings -embedding_types: List[TokenEmbeddings] = [ +embedding_types: list[TokenEmbeddings] = [ WordEmbeddings('crawl'), WordEmbeddings('twitter'), FlairEmbeddings('news-forward'), @@ -292,7 +292,7 @@ tag_type = 'ner' tag_dictionary = corpus.make_tag_dictionary(tag_type=tag_type) # initialize embeddings -embedding_types: List[TokenEmbeddings] = [ +embedding_types: list[TokenEmbeddings] = [ WordEmbeddings('crawl'), FlairEmbeddings('news-forward'), FlairEmbeddings('news-backward'), @@ -361,7 +361,7 @@ tag_type = 'pos' tag_dictionary = corpus.make_tag_dictionary(tag_type=tag_type) # initialize embeddings -embedding_types: List[TokenEmbeddings] = [ +embedding_types: list[TokenEmbeddings] = [ WordEmbeddings('extvec'), FlairEmbeddings('news-forward'), FlairEmbeddings('news-backward'), @@ -416,7 +416,7 @@ tag_type = 'np' tag_dictionary = corpus.make_tag_dictionary(tag_type=tag_type) # initialize embeddings -embedding_types: List[TokenEmbeddings] = [ +embedding_types: list[TokenEmbeddings] = [ WordEmbeddings('extvec'), FlairEmbeddings('news-forward'), FlairEmbeddings('news-backward'), diff --git a/resources/docs/HUNFLAIR2.md b/resources/docs/HUNFLAIR2.md index 6f2c1474b3..032b4fe075 100644 --- a/resources/docs/HUNFLAIR2.md +++ b/resources/docs/HUNFLAIR2.md @@ -14,7 +14,7 @@ NER tools on unseen corpora. ## Quick Start #### Requirements and Installation -*HunFlair2* is based on Flair 0.14+ and Python 3.8+. If you do not have Python 3.8, install it first. +*HunFlair2* is based on Flair 0.14+ and Python 3.9+. If you do not have Python 3.9, install it first. Then, in your favorite virtual environment, simply do: ``` pip install flair diff --git a/resources/docs/KOR_docs/TUTORIAL_7_TRAINING_A_MODEL.md b/resources/docs/KOR_docs/TUTORIAL_7_TRAINING_A_MODEL.md index cd839acc15..382600d5be 100644 --- a/resources/docs/KOR_docs/TUTORIAL_7_TRAINING_A_MODEL.md +++ b/resources/docs/KOR_docs/TUTORIAL_7_TRAINING_A_MODEL.md @@ -313,7 +313,7 @@ label_type = 'ner' # 3. 말뭉치에서 레이블 사전 만들기 label_dict = corpus.make_label_dictionary(label_type=label_type, add_unk=False) # 4. 임베딩 초기화하기 -embedding_types: List[TokenEmbeddings] = [ +embedding_types: list[TokenEmbeddings] = [ WordEmbeddings('glove') ] embeddings: StackedEmbeddings = StackedEmbeddings(embeddings=embedding_types) diff --git a/resources/docs/KOR_docs/TUTORIAL_8_MODEL_OPTIMIZATION.md b/resources/docs/KOR_docs/TUTORIAL_8_MODEL_OPTIMIZATION.md index e821dc2d17..ec3affe947 100644 --- a/resources/docs/KOR_docs/TUTORIAL_8_MODEL_OPTIMIZATION.md +++ b/resources/docs/KOR_docs/TUTORIAL_8_MODEL_OPTIMIZATION.md @@ -95,7 +95,7 @@ tag_type = 'ner' tag_dictionary = corpus.make_label_dictionary(label_type=tag_type, add_unk=False) print(tag_dictionary.idx2item) # 4. 임베딩 초기화하기 -embedding_types: List[TokenEmbeddings] = [ +embedding_types: list[TokenEmbeddings] = [ WordEmbeddings('glove'), ] embeddings: StackedEmbeddings = StackedEmbeddings(embeddings=embedding_types) diff --git a/setup.py b/setup.py index 3c1cc06018..0573896c19 100644 --- a/setup.py +++ b/setup.py @@ -20,5 +20,5 @@ "word-embeddings": ["gensim>=4.2.0", "bpemb>=0.3.5"], }, include_package_data=True, - python_requires=">=3.8", + python_requires=">=3.9", ) diff --git a/tests/embedding_test_utils.py b/tests/embedding_test_utils.py index 554ef32777..c1a0b1a791 100644 --- a/tests/embedding_test_utils.py +++ b/tests/embedding_test_utils.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional, Type +from typing import Any, Optional import pytest import torch @@ -9,15 +9,15 @@ class BaseEmbeddingsTest: - embedding_cls: Type[Embeddings[Sentence]] + embedding_cls: type[Embeddings[Sentence]] is_token_embedding: bool is_document_embedding: bool - default_args: Dict[str, Any] - valid_args: List[Dict[str, Any]] = [] - invalid_args: List[Dict[str, Any]] = [] - invalid_names: List[str] = [] + default_args: dict[str, Any] + valid_args: list[dict[str, Any]] = [] + invalid_args: list[dict[str, Any]] = [] + invalid_names: list[str] = [] name_field: Optional[str] = None - weired_texts: List[str] = [ + weired_texts: list[str] = [ "Hybrid mesons , qq ̄ states with an admixture", "typical proportionalities of \u223C 1nmV \u2212 1 [ 3,4 ] .", "🤟 🤟 🤟 hüllo", @@ -33,7 +33,7 @@ def create_embedding_from_name(self, name: str): kwargs.pop(self.name_field) return self.embedding_cls(name, **kwargs) # type: ignore[call-arg] - def create_embedding_with_args(self, args: Dict[str, Any]): + def create_embedding_with_args(self, args: dict[str, Any]): kwargs = dict(self.default_args) for k, v in args.items(): kwargs[k] = v diff --git a/tests/embeddings/test_document_transform_word_embeddings.py b/tests/embeddings/test_document_transform_word_embeddings.py index 6a06372723..73567ffbeb 100644 --- a/tests/embeddings/test_document_transform_word_embeddings.py +++ b/tests/embeddings/test_document_transform_word_embeddings.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List +from typing import Any from flair.embeddings import ( DocumentCNNEmbeddings, @@ -19,7 +19,7 @@ class BaseDocumentsViaWordEmbeddingsTest(BaseEmbeddingsTest): is_document_embedding = True is_token_embedding = False - base_embeddings: List[TokenEmbeddings] = [word, flair_embedding] + base_embeddings: list[TokenEmbeddings] = [word, flair_embedding] def create_embedding_from_name(self, name: str): """Overwrite this method if it is more complex to load an embedding by name.""" @@ -28,7 +28,7 @@ def create_embedding_from_name(self, name: str): kwargs.pop(self.name_field) return self.embedding_cls(name, **kwargs) # type: ignore[call-arg] - def create_embedding_with_args(self, args: Dict[str, Any]): + def create_embedding_with_args(self, args: dict[str, Any]): kwargs = dict(self.default_args) for k, v in args.items(): kwargs[k] = v @@ -63,4 +63,4 @@ class TestDocumentCNNEmbeddings(BaseDocumentsViaWordEmbeddingsTest): class TestDocumentLMEmbeddings(BaseDocumentsViaWordEmbeddingsTest): embedding_cls = DocumentLMEmbeddings base_embeddings = [flair_embedding, flair_embedding_back] - default_args: Dict[str, Any] = {} + default_args: dict[str, Any] = {} diff --git a/tests/embeddings/test_word_embeddings.py b/tests/embeddings/test_word_embeddings.py index 34d0b3b9f7..87f56fec4f 100644 --- a/tests/embeddings/test_word_embeddings.py +++ b/tests/embeddings/test_word_embeddings.py @@ -1,4 +1,4 @@ -from typing import Any, Dict +from typing import Any from flair.embeddings import MuseCrosslingualEmbeddings, NILCEmbeddings, WordEmbeddings from tests.embedding_test_utils import BaseEmbeddingsTest @@ -18,7 +18,7 @@ class TestMuseCrosslingualEmbeddings(BaseEmbeddingsTest): embedding_cls = MuseCrosslingualEmbeddings is_token_embedding = True is_document_embedding = False - default_args: Dict[str, Any] = {} + default_args: dict[str, Any] = {} class TestNILCEmbeddings(BaseEmbeddingsTest): diff --git a/tests/model_test_utils.py b/tests/model_test_utils.py index 10aab0831f..b5afd81bfe 100644 --- a/tests/model_test_utils.py +++ b/tests/model_test_utils.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional, Type +from typing import Any, Optional import pytest @@ -11,13 +11,13 @@ class BaseModelTest: - model_cls: Type[Model] + model_cls: type[Model] pretrained_model: Optional[str] = None empty_sentence = Sentence(" ") train_label_type: str - multiclass_prediction_labels: List[str] - model_args: Dict[str, Any] = {} - training_args: Dict[str, Any] = {} + multiclass_prediction_labels: list[str] + model_args: dict[str, Any] = {} + training_args: dict[str, Any] = {} finetune_instead_of_train: bool = False @pytest.fixture() diff --git a/tests/models/test_relation_classifier.py b/tests/models/test_relation_classifier.py index c0ca34bce5..da4de52bfc 100644 --- a/tests/models/test_relation_classifier.py +++ b/tests/models/test_relation_classifier.py @@ -1,5 +1,5 @@ from operator import itemgetter -from typing import Dict, List, Optional, Set, Tuple +from typing import Optional import pytest from torch.utils.data import Dataset @@ -20,7 +20,7 @@ ) from tests.model_test_utils import BaseModelTest -encoding_strategies: Dict[EncodingStrategy, List[Tuple[str, str]]] = { +encoding_strategies: dict[EncodingStrategy, list[tuple[str, str]]] = { EntityMask(): [("[HEAD]", "[TAIL]") for _ in range(7)], TypedEntityMask(): [ ("[HEAD-ORG]", "[TAIL-PER]"), @@ -140,7 +140,7 @@ def train_test_sentence(self): return sentence def assert_training_example(self, predicted_training_example): - relations: List[Relation] = predicted_training_example.get_relations("relation") + relations: list[Relation] = predicted_training_example.get_relations("relation") assert len(relations) == 2 # Intel ----founded_by---> Gordon Moore @@ -164,7 +164,7 @@ def assert_training_example(self, predicted_training_example): @staticmethod def check_transformation_correctness( split: Optional[Dataset], - ground_truth: Set[Tuple[str, Tuple[str, ...]]], + ground_truth: set[tuple[str, tuple[str, ...]]], ) -> None: # Ground truth is a set of tuples of (, ) assert split is not None @@ -190,7 +190,7 @@ def test_transform_corpus( embeddings: TransformerDocumentEmbeddings, cross_augmentation: bool, encoding_strategy: EncodingStrategy, - encoded_entity_pairs: List[Tuple[str, str]], + encoded_entity_pairs: list[tuple[str, str]], ) -> None: label_dictionary = corpus.make_label_dictionary("relation") model: RelationClassifier = self.build_model( @@ -200,7 +200,7 @@ def test_transform_corpus( # Check sentence masking and relation label annotation on # training, validation and test dataset (in this test the splits are the same) - ground_truth: Set[Tuple[str, Tuple[str, ...]]] = { + ground_truth: set[tuple[str, tuple[str, ...]]] = { # Entity pair permutations of: "Larry Page and Sergey Brin founded Google ." (f"{encoded_entity_pairs[0][1]} and Sergey Brin founded {encoded_entity_pairs[0][0]} .", ("founded_by",)), (f"Larry Page and {encoded_entity_pairs[1][1]} founded {encoded_entity_pairs[1][0]} .", ("founded_by",)), diff --git a/tests/test_datasets_biomedical.py b/tests/test_datasets_biomedical.py index 0264b08394..c15674eb6b 100644 --- a/tests/test_datasets_biomedical.py +++ b/tests/test_datasets_biomedical.py @@ -2,7 +2,7 @@ import os import tempfile from pathlib import Path -from typing import List, Optional +from typing import Optional from flair.datasets.biomedical import ( CoNLLWriter, @@ -84,7 +84,7 @@ def test_conll_writer_one_token_multiple_entities2(): def assert_conll_writer_output( dataset: InternalBioNerDataset, - expected_output: List[str], + expected_output: list[str], sentence_splitter: Optional[SentenceSplitter] = None, ): fd, outfile_path = tempfile.mkstemp() diff --git a/tests/test_labels.py b/tests/test_labels.py index 210a215889..099484162c 100644 --- a/tests/test_labels.py +++ b/tests/test_labels.py @@ -1,5 +1,3 @@ -from typing import List - from flair.data import Label, Relation, Sentence, Span @@ -14,7 +12,7 @@ def test_token_tags(): sentence[0].add_label("pos", "pronoun") # check if there are three POS labels with correct text and values - labels: List[Label] = sentence.get_labels("pos") + labels: list[Label] = sentence.get_labels("pos") assert len(labels) == 3 assert labels[0].data_point.text == "I" assert labels[0].value == "pronoun" @@ -24,7 +22,7 @@ def test_token_tags(): assert labels[2].value == "proper noun" # check if there are is one SENTIMENT label with correct text and values - labels: List[Label] = sentence.get_labels("sentiment") + labels: list[Label] = sentence.get_labels("sentiment") assert len(labels) == 1 assert labels[0].data_point.text == "love" assert labels[0].value == "positive" @@ -45,7 +43,7 @@ def test_token_tags(): # remove the pos label from the last word sentence[2].remove_labels("pos") # there should be 2 POS labels left - labels: List[Label] = sentence.get_labels("pos") + labels: list[Label] = sentence.get_labels("pos") assert len(labels) == 2 assert len(sentence[0].get_labels("pos")) == 1 assert len(sentence[1].get_labels("pos")) == 1 @@ -72,7 +70,7 @@ def test_span_tags(): sentence[7:8].add_label("ner", "City") # check if there are three labels with correct text and values - labels: List[Label] = sentence.get_labels("ner") + labels: list[Label] = sentence.get_labels("ner") assert len(labels) == 3 assert labels[0].data_point.text == "Humboldt Universität zu Berlin" assert labels[0].value == "Organization" @@ -82,7 +80,7 @@ def test_span_tags(): assert labels[2].value == "City" # check if there are two spans with correct text and values - spans: List[Span] = sentence.get_spans("ner") + spans: list[Span] = sentence.get_spans("ner") assert len(spans) == 2 assert spans[0].text == "Humboldt Universität zu Berlin" assert len(spans[0].get_labels("ner")) == 2 @@ -92,12 +90,12 @@ def test_span_tags(): # now delete the NER tags of "Humboldt-Universität zu Berlin" sentence[0:4].remove_labels("ner") # should be only one NER label left - labels: List[Label] = sentence.get_labels("ner") + labels: list[Label] = sentence.get_labels("ner") assert len(labels) == 1 assert labels[0].data_point.text == "Berlin" assert labels[0].value == "City" # and only one NER span - spans: List[Span] = sentence.get_spans("ner") + spans: list[Span] = sentence.get_spans("ner") assert len(spans) == 1 assert spans[0].text == "Berlin" assert spans[0].get_label("ner").value == "City" @@ -111,7 +109,7 @@ def test_different_span_tags(): sentence[7:8].add_label("ner", "City") # check if there are three labels with correct text and values - labels: List[Label] = sentence.get_labels("ner") + labels: list[Label] = sentence.get_labels("ner") assert len(labels) == 2 assert labels[0].data_point.text == "Humboldt Universität zu Berlin" assert labels[0].value == "Organization" @@ -119,7 +117,7 @@ def test_different_span_tags(): assert labels[1].value == "City" # check if there are two spans with correct text and values - spans: List[Span] = sentence.get_spans("ner") + spans: list[Span] = sentence.get_spans("ner") assert len(spans) == 2 assert spans[0].text == "Humboldt Universität zu Berlin" assert spans[0].get_label("ner").value == "Organization" @@ -131,22 +129,22 @@ def test_different_span_tags(): # now delete the NER tags of "Humboldt-Universität zu Berlin" sentence[0:4].remove_labels("ner") # should be only one NER label left - labels: List[Label] = sentence.get_labels("ner") + labels: list[Label] = sentence.get_labels("ner") assert len(labels) == 1 assert labels[0].data_point.text == "Berlin" assert labels[0].value == "City" # and only one NER span - spans: List[Span] = sentence.get_spans("ner") + spans: list[Span] = sentence.get_spans("ner") assert len(spans) == 1 assert spans[0].text == "Berlin" assert spans[0].get_label("ner").value == "City" # but there is also one orgtype span and label - labels: List[Label] = sentence.get_labels("orgtype") + labels: list[Label] = sentence.get_labels("orgtype") assert len(labels) == 1 assert labels[0].data_point.text == "Humboldt Universität zu Berlin" assert labels[0].value == "University" # and only one NER span - spans: List[Span] = sentence.get_spans("orgtype") + spans: list[Span] = sentence.get_spans("orgtype") assert len(spans) == 1 assert spans[0].text == "Humboldt Universität zu Berlin" assert spans[0].get_label("orgtype").value == "University" @@ -154,7 +152,7 @@ def test_different_span_tags(): # let's add the NER tag back sentence[0:4].add_label("ner", "Organization") # check if there are three labels with correct text and values - labels: List[Label] = sentence.get_labels("ner") + labels: list[Label] = sentence.get_labels("ner") print(labels) assert len(labels) == 2 assert labels[0].data_point.text == "Humboldt Universität zu Berlin" @@ -163,7 +161,7 @@ def test_different_span_tags(): assert labels[1].value == "City" # check if there are two spans with correct text and values - spans: List[Span] = sentence.get_spans("ner") + spans: list[Span] = sentence.get_spans("ner") assert len(spans) == 2 assert spans[0].text == "Humboldt Universität zu Berlin" assert spans[0].get_label("ner").value == "Organization" @@ -194,17 +192,17 @@ def test_relation_tags(): Relation(sentence[0:2], sentence[3:4]).add_label("syntactic", "apposition") # there should be two relation labels - labels: List[Label] = sentence.get_labels("rel") + labels: list[Label] = sentence.get_labels("rel") assert len(labels) == 2 assert labels[0].value == "located in" assert labels[1].value == "university of" # there should be one syntactic labels - labels: List[Label] = sentence.get_labels("syntactic") + labels: list[Label] = sentence.get_labels("syntactic") assert len(labels) == 1 # there should be two relations, one with two and one with one label - relations: List[Relation] = sentence.get_relations("rel") + relations: list[Relation] = sentence.get_relations("rel") assert len(relations) == 2 assert len(relations[0].labels) == 1 assert len(relations[1].labels) == 2 diff --git a/tests/test_tokenize_sentence.py b/tests/test_tokenize_sentence.py index fd049b642e..7fd03ac6ba 100644 --- a/tests/test_tokenize_sentence.py +++ b/tests/test_tokenize_sentence.py @@ -1,5 +1,3 @@ -from typing import List - import pytest import flair @@ -492,5 +490,5 @@ def test_line_separator_is_ignored(): assert Sentence(with_separator).to_original_text() == Sentence(without_separator).to_original_text() -def no_op_tokenizer(text: str) -> List[str]: +def no_op_tokenizer(text: str) -> list[str]: return [text] From a30fd78bfbcdd2eb8d203629dd7b827e700bdbd6 Mon Sep 17 00:00:00 2001 From: Benedikt Fuchs Date: Fri, 25 Oct 2024 15:44:04 +0200 Subject: [PATCH 2/3] fix unrequred type ignore statements --- flair/embeddings/document.py | 2 +- flair/embeddings/token.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/flair/embeddings/document.py b/flair/embeddings/document.py index 28867d889a..8f66a198ed 100644 --- a/flair/embeddings/document.py +++ b/flair/embeddings/document.py @@ -371,7 +371,7 @@ def _add_embeddings_internal(self, sentences: list[Sentence]): sentence_tensor = self.word_reprojection_map(sentence_tensor) # push through RNN - packed = pack_padded_sequence(sentence_tensor, lengths, enforce_sorted=False, batch_first=True) # type: ignore[arg-type] + packed = pack_padded_sequence(sentence_tensor, lengths, enforce_sorted=False, batch_first=True) rnn_out, hidden = self.rnn(packed) outputs, output_lengths = pad_packed_sequence(rnn_out, batch_first=True) diff --git a/flair/embeddings/token.py b/flair/embeddings/token.py index 700eaf4c45..3d95c8ee0b 100644 --- a/flair/embeddings/token.py +++ b/flair/embeddings/token.py @@ -513,7 +513,7 @@ def _add_embeddings_internal(self, sentences: list[Sentence]): character_embeddings = self.char_embedding(chars).transpose(0, 1) - packed = torch.nn.utils.rnn.pack_padded_sequence(character_embeddings, chars2_length) # type: ignore[arg-type] + packed = torch.nn.utils.rnn.pack_padded_sequence(character_embeddings, chars2_length) lstm_out, self.hidden = self.char_rnn(packed) From 783ebc7c5db43f2d0e3ca0d6473556da34a4e70e Mon Sep 17 00:00:00 2001 From: Benedikt Fuchs Date: Fri, 25 Oct 2024 19:00:09 +0200 Subject: [PATCH 3/3] fix transformers switching attention implementation --- flair/embeddings/transformer.py | 5 +++++ pyproject.toml | 1 + 2 files changed, 6 insertions(+) diff --git a/flair/embeddings/transformer.py b/flair/embeddings/transformer.py index d09ed33699..245d528e5d 100644 --- a/flair/embeddings/transformer.py +++ b/flair/embeddings/transformer.py @@ -1353,6 +1353,11 @@ def from_params(cls, params): def to_params(self): config_dict = self.model.config.to_dict() + + # do not switch the attention implementation upon reload. + config_dict["attn_implementation"] = self.model.config._attn_implementation + del config_dict["_attn_implementation_autoset"] + super_params = super().to_params() # those parameters are only from the super class and will be recreated in the constructor. diff --git a/pyproject.toml b/pyproject.toml index 9f4c5a7535..9711794abb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,6 +32,7 @@ filterwarnings = [ 'ignore:`resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.', # transformers calls deprecated hf_hub "ignore:`torch.cuda.amp.GradScaler", # GradScaler changes in torch 2.3.0 but we want to be backwards compatible. "ignore:`clean_up_tokenization_spaces` was not set", # Default behavior changes in transformers v4.45, raising irrelevant FutureWarning for serialized models. + "ignore:1Torch was not compiled with flash attention", # You might want to install flash attention, but you don't have to. ] markers = [ "integration",