Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add documentation for criterion #297

Merged
merged 1 commit into from
May 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion docs/api/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@
* [baal.bayesian](./bayesian.md)
* [baal.active](./dataset_management.md)
* [baal.active.heuristics](./heuristics.md)
* [baal.active.stopping_criteria](./stopping_criteria.md)
* [baal.calibration](./calibration.md)
* [baal.utils](./utils.md)

### :material-file-tree: Compatibility

* [baal.utils.pytorch_lightning] (./compatibility/pytorch-lightning)
* [baal.utils.pytorch_lightning](./compatibility/pytorch-lightning)
* [baal.transformers_trainer_wrapper](./compatibility/huggingface)


Expand Down
35 changes: 35 additions & 0 deletions docs/api/stopping_criteria.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Stopping Criteria

Stopping criterion are used to determine when to stop your active learning experiment.

Their usage are simple, but best put in practice with `ActiveExperiment`.

**Example**
```python
from baal.active.stopping_criteria import LabellingBudgetStoppingCriterion
from baal.active.dataset import ActiveLearningDataset

al_dataset: ActiveLearningDataset = ... # len(al_dataset) == 10
criterion = LabellingBudgetStoppingCriterion(al_dataset, labelling_budget=100)

assert not criterion.should_stop({}, [])

# len(al_dataset) == 60
al_dataset.label_randomly(50)
assert not criterion.should_stop({}, [])

# len(al_dataset) == 110, budget exhausted! We've labelled 100 items.
al_dataset.label_randomly(50)
assert criterion.should_stop({}, [])
```


### API

### baal.active.stopping_criteria

::: baal.active.stopping_criteria.LabellingBudgetStoppingCriterion

::: baal.active.stopping_criteria.LowAverageUncertaintyStoppingCriterion

::: baal.active.stopping_criteria.EarlyStoppingCriterion
1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ nav:
- api/calibration.md
- api/dataset_management.md
- api/heuristics.md
- api/stopping_criteria.md
- api/modelwrapper.md
- api/utils.md
- Compatibility:
Expand Down
104 changes: 22 additions & 82 deletions notebooks/production/baal_prod_cls.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,6 @@
"is_executing": false
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Train: 5174, Valid: 1725, Num. classes : 8\n"
]
}
],
"source": [
"from glob import glob\n",
"import os\n",
Expand All @@ -52,7 +43,8 @@
"classes = os.listdir('/tmp/natural_images')\n",
"train, test = train_test_split(files, random_state=1337) # Split 75% train, 25% validation\n",
"print(f\"Train: {len(train)}, Valid: {len(test)}, Num. classes : {len(classes)}\")\n"
]
],
"outputs": []
},
{
"cell_type": "markdown",
Expand All @@ -79,7 +71,6 @@
"is_executing": false
}
},
"outputs": [],
"source": [
"from baal.active import FileDataset, ActiveLearningDataset\n",
"from torchvision import transforms\n",
Expand All @@ -101,7 +92,8 @@
"# We use -1 to specify that the data is unlabeled.\n",
"test_dataset = FileDataset(test, [-1] * len(test), test_transform)\n",
"active_learning_ds = ActiveLearningDataset(train_dataset, pool_specifics={'transform': test_transform})\n"
]
],
"outputs": []
},
{
"cell_type": "markdown",
Expand Down Expand Up @@ -129,7 +121,6 @@
"is_executing": false
}
},
"outputs": [],
"source": [
"import torch\n",
"from torch import nn, optim\n",
Expand All @@ -149,7 +140,8 @@
"# ModelWrapper is an object similar to keras.Model.\n",
"baal_model = ModelWrapper(model, criterion)\n",
"\n"
]
],
"outputs": []
},
{
"cell_type": "markdown",
Expand All @@ -170,11 +162,11 @@
"is_executing": false
}
},
"outputs": [],
"source": [
"from baal.active.heuristics import BALD\n",
"heuristic = BALD(shuffle_prop=0.1)\n"
]
],
"outputs": []
},
{
"cell_type": "markdown",
Expand All @@ -193,13 +185,13 @@
"is_executing": false
}
},
"outputs": [],
"source": [
"# This function would do the work that a human would do.\n",
"def get_label(img_path):\n",
" return classes.index(img_path.split('/')[-2])\n",
"\n"
]
],
"outputs": []
},
{
"cell_type": "markdown",
Expand All @@ -223,15 +215,6 @@
"is_executing": false
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Num. labeled: 100/5174\n"
]
}
],
"source": [
"import numpy as np\n",
"# 1. Label all the test set and some samples from the training set.\n",
Expand All @@ -246,7 +229,8 @@
"active_learning_ds.label(train_idxs, labels)\n",
"\n",
"print(f\"Num. labeled: {len(active_learning_ds)}/{len(train_dataset)}\")\n"
]
],
"outputs": []
},
{
"cell_type": "code",
Expand All @@ -256,56 +240,19 @@
"is_executing": false
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[103-MainThread ] [baal.modelwrapper:train_on_dataset:109] 2021-07-28T14:47:48.133213Z [\u001B[32minfo ] Starting training dataset=100 epoch=5\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/opt/conda/lib/python3.9/site-packages/torch/utils/data/dataloader.py:478: UserWarning: This DataLoader will create 4 worker processes in total. Our suggested max number of worker in current system is 1, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.\n",
" warnings.warn(_create_warning_msg(\n",
"/opt/conda/lib/python3.9/site-packages/torch/nn/functional.py:718: UserWarning: Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable. (Triggered internally at /pytorch/c10/core/TensorImpl.h:1156.)\n",
" return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[103-MainThread ] [baal.modelwrapper:train_on_dataset:119] 2021-07-28T14:48:07.477011Z [\u001B[32minfo ] Training complete train_loss=2.058176279067993\n",
"[103-MainThread ] [baal.modelwrapper:test_on_dataset:147] 2021-07-28T14:48:07.479793Z [\u001B[32minfo ] Starting evaluating dataset=1725\n",
"[103-MainThread ] [baal.modelwrapper:test_on_dataset:156] 2021-07-28T14:48:21.277716Z [\u001B[32minfo ] Evaluation complete test_loss=2.0671451091766357\n",
"Metrics: {'test_loss': 2.0671451091766357, 'train_loss': 2.058176279067993}\n"
]
}
],
"source": [
"# 2. Train the model for a few epoch on the training set.\n",
"baal_model.train_on_dataset(active_learning_ds, optimizer, batch_size=16, epoch=5, use_cuda=USE_CUDA)\n",
"baal_model.test_on_dataset(test_dataset, batch_size=16, use_cuda=USE_CUDA)\n",
"\n",
"print(\"Metrics:\", {k:v.avg for k,v in baal_model.metrics.items()})\n"
]
],
"outputs": []
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[103-MainThread ] [baal.modelwrapper:predict_on_dataset_generator:241] 2021-07-28T14:48:21.291851Z [\u001B[32minfo ] Start Predict dataset=5074\n"
]
}
],
"source": [
"# 3. Select the K-top uncertain samples according to the heuristic.\n",
"pool = active_learning_ds.pool\n",
Expand All @@ -316,29 +263,22 @@
"predictions = baal_model.predict_on_dataset(pool, batch_size=16, iterations=15, use_cuda=USE_CUDA, verbose=False)\n",
"# We will label the 10 most uncertain samples.\n",
"top_uncertainty = heuristic(predictions)[:10]\n"
]
],
"outputs": []
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[(3, 1429), (4, 2971), (2, 1309), (4, 5), (3, 3761), (4, 2708), (6, 4679), (7, 160), (7, 1638), (6, 73)]\n"
]
}
],
"source": [
"# 4. Label those samples.\n",
"oracle_indices = active_learning_ds._pool_to_oracle_index(top_uncertainty)\n",
"labels = [get_label(train_dataset.files[idx]) for idx in oracle_indices]\n",
"print(list(zip(labels, oracle_indices)))\n",
"active_learning_ds.label(top_uncertainty, labels)\n",
"\n"
]
],
"outputs": []
},
{
"cell_type": "code",
Expand All @@ -348,7 +288,6 @@
"is_executing": true
}
},
"outputs": [],
"source": [
"# 5. If not done, go back to 2.\n",
"for step in range(5): # 5 Active Learning step!\n",
Expand All @@ -372,7 +311,8 @@
" active_learning_ds.label(top_uncertainty, labels)\n",
" \n",
" "
]
],
"outputs": []
},
{
"cell_type": "markdown",
Expand All @@ -386,14 +326,14 @@
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"torch.save({\n",
" 'active_dataset': active_learning_ds.state_dict(),\n",
" 'model': baal_model.state_dict(),\n",
" 'metrics': {k:v.avg for k,v in baal_model.metrics.items()}\n",
"}, '/tmp/baal_output.pth')\n"
]
],
"outputs": []
},
{
"cell_type": "markdown",
Expand Down
Loading