Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

change the indexing to support arrow dataset #183

Merged
merged 2 commits into from
Jan 7, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions baal/active/dataset/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import warnings
from typing import Union, List, Optional, Any, Iterable, Sized
from typing import Union, List, Optional, Any

import numpy as np
from sklearn.utils import check_random_state
Expand Down Expand Up @@ -30,21 +30,22 @@ def __init__(
raise ValueError("last_active_steps must be > 0 or -1 when disabled.")
self.last_active_steps = last_active_steps

def get_indices_for_active_step(self):
def get_indices_for_active_step(self) -> List[int]:
"""Returns the indices required for the active step.

Returns the indices of the labelled items. Also takes into account self.last_active_step.

Returns:
Array of the selected indices for training.
List of the selected indices for training.
"""
if self.last_active_steps == -1:
min_labelled_step = 0
else:
min_labelled_step = max(0, self.current_al_step - self.last_active_steps)
indices = np.arange(len(self.labelled_map))
bool_mask = self.labelled_map > min_labelled_step
return indices[bool_mask]

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to edit the docstring and add typing

# we need to work with lists since arrow dataset is not compatible with np.int types!
indices = [indx for indx, val in enumerate(self.labelled_map) if val > min_labelled_step]
return indices

def is_labelled(self, idx: int) -> bool:
"""Check if a datapoint is labelled."""
Expand Down
26 changes: 26 additions & 0 deletions tests/active/dataset/nlp_dataset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import pytest
from torch.utils.data import Dataset
from transformers import AutoTokenizer
from baal.active.dataset import ActiveLearningDataset
from baal.active.dataset.nlp_datasets import HuggingFaceDatasets
import datasets

Expand All @@ -27,6 +28,31 @@ def __getitem__(self, item):
return self.dataset[item]


class ActiveArrowDatasetTest(unittest.TestCase):
def setUp(self):
dataset = datasets.Dataset.from_dict({'sentence': [f'this is test number {i}' for i in range(10)],
'label': ['POS' if (i % 2) == 0 else 'NEG' for i in range(10)]},
features=datasets.Features({'sentence': datasets.Value('string'),
'label': datasets.ClassLabel(2, names=['NEG',
'POS'])}))
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')

def preprocess(example):
results = tokenizer(example['sentence'], max_length=50,
truncation=True, padding='max_length')
return results

tokenized_dataset = dataset.map(preprocess, batched=True)
self.active_dataset = ActiveLearningDataset(tokenized_dataset)

def test_dataset(self):
assert len(self.active_dataset) == 0
self.active_dataset.label_randomly(2)
assert len(self.active_dataset) == 2
print(self.active_dataset[0])
assert self.active_dataset[0]['sentence'] == 'this is test number 0'


class HuggingFaceDatasetsTest(unittest.TestCase):
def setUp(self):
dataset = MyDataset()
Expand Down