Skip to content

Commit

Permalink
move filtering to the first step on each call
Browse files Browse the repository at this point in the history
  • Loading branch information
helpmefindaname committed Aug 7, 2022
1 parent 2e8c31c commit fa43b08
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions flair/nn/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -644,14 +644,15 @@ def multi_label_threshold(self, x): # setter method
self._multi_label_threshold = {"default": x}

def get_scores_and_labels(self, batch: List[DT]) -> Tuple[torch.Tensor, List[List[str]]]:
batch = [dp for dp in batch if self._filter_data_point(dp)]
predict_data_points = self._get_prediction_data_points(batch)
labels = [self._get_label_of_datapoint(dp) for dp in predict_data_points if self._filter_data_point(dp)]
labels = [self._get_label_of_datapoint(pdp) for pdp in predict_data_points]
embedded_tensor = self._prepare_tensors(batch)
logits = self._transform_embeddings(*embedded_tensor)
return logits, labels

def _prepare_label_tensor(self, prediction_data_points: List[DT2]) -> torch.Tensor:
labels = [self._get_label_of_datapoint(dp) for dp in prediction_data_points if self._filter_data_point(dp)]
labels = [self._get_label_of_datapoint(dp) for dp in prediction_data_points]
if self.multi_label:
return torch.tensor(
[
Expand All @@ -676,6 +677,7 @@ def _prepare_label_tensor(self, prediction_data_points: List[DT2]) -> torch.Tens
def forward_loss(self, sentences: List[DT]) -> Tuple[torch.Tensor, int]:

# make a forward pass to produce embedded data points and labels
sentences = [sentence for sentence in sentences if self._filter_data_point(sentence)]
predict_data_points = self._get_prediction_data_points(sentences)
labels = self._prepare_label_tensor(predict_data_points)

Expand Down Expand Up @@ -760,6 +762,9 @@ def predict(
overall_loss = torch.zeros(1, device=flair.device)
label_count = 0
for batch in batches:

batch = [dp for dp in batch if self._filter_data_point(dp)]

# stop if all sentences are empty
if not batch:
continue
Expand Down

0 comments on commit fa43b08

Please sign in to comment.