Skip to content

Commit

Permalink
Properly implement PyProtocols for PyVocab.
Browse files Browse the repository at this point in the history
In order to get arbitrary keys, PyMappingProtocol::__getitem__
needs to be implemented. To get O(1) __contains__,
PySequenceProtocol::__contains__ needs to be implemented. To get
proper Iteration support, PyIterProtocol::__iter__ needs to be
implemented.

PyO3/pyo3#611

This commit adds the correct implementation of the three traits
to PyVocab.
  • Loading branch information
sebpuetz committed Oct 25, 2019
1 parent e16c867 commit 216c316
Show file tree
Hide file tree
Showing 3 changed files with 231 additions and 23 deletions.
34 changes: 34 additions & 0 deletions src/iter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,3 +80,37 @@ impl PyEmbedding {
self.norm
}
}

#[pyclass(name=VocabIterator)]
pub struct PyVocabIterator {
embeddings: Rc<RefCell<EmbeddingsWrap>>,
idx: usize,
}

impl PyVocabIterator {
pub fn new(embeddings: Rc<RefCell<EmbeddingsWrap>>, idx: usize) -> Self {
PyVocabIterator { embeddings, idx }
}
}

#[pyproto]
impl PyIterProtocol for PyVocabIterator {
fn __iter__(slf: PyRefMut<Self>) -> PyResult<Py<PyVocabIterator>> {
Ok(slf.into())
}

fn __next__(mut slf: PyRefMut<Self>) -> PyResult<Option<String>> {
let slf = &mut *slf;

let embeddings = slf.embeddings.borrow();
let vocab = embeddings.vocab();

if slf.idx < vocab.words_len() {
let word = vocab.words()[slf.idx].to_string();
slf.idx += 1;
Ok(Some(word))
} else {
Ok(None)
}
}
}
126 changes: 105 additions & 21 deletions src/vocab.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,12 @@ use std::rc::Rc;
use finalfusion::chunks::vocab::{NGramIndices, SubwordIndices, VocabWrap, WordIndex};
use finalfusion::prelude::*;
use pyo3::class::sequence::PySequenceProtocol;
use pyo3::exceptions;
use pyo3::exceptions::{IndexError, KeyError, ValueError};
use pyo3::prelude::*;
use pyo3::types::{PyAny, PySlice};
use pyo3::{PyIterProtocol, PyMappingProtocol};

use crate::iter::PyVocabIterator;
use crate::EmbeddingsWrap;

type NGramIndex = (String, Option<usize>);
Expand All @@ -25,16 +28,21 @@ impl PyVocab {

#[pymethods]
impl PyVocab {
fn item_to_indices(&self, key: String) -> Option<PyObject> {
#[allow(clippy::option_option)]
#[args(default = "None")]
fn get(&self, key: &str, default: Option<Option<&PyAny>>) -> Option<PyObject> {
let embeds = self.embeddings.borrow();

embeds.vocab().idx(key.as_str()).map(|idx| {
let gil = pyo3::Python::acquire_gil();
match idx {
WordIndex::Word(idx) => [idx].to_object(gil.python()),
WordIndex::Subword(indices) => indices.to_object(gil.python()),
}
})
let gil = pyo3::Python::acquire_gil();
let idx = embeds.vocab().idx(key).map(|idx| match idx {
WordIndex::Word(idx) => idx.to_object(gil.python()),
WordIndex::Subword(indices) => indices.to_object(gil.python()),
});
if default.is_some() && idx.is_none() {
return default
.map(|i| i.into_py(gil.python()))
.into_py(gil.python());
}
idx
}

fn ngram_indices(&self, word: &str) -> PyResult<Option<Vec<NGramIndex>>> {
Expand All @@ -44,7 +52,7 @@ impl PyVocab {
VocabWrap::FinalfusionSubwordVocab(inner) => inner.ngram_indices(word),
VocabWrap::FinalfusionNGramVocab(inner) => inner.ngram_indices(word),
VocabWrap::SimpleVocab(_) => {
return Err(exceptions::ValueError::py_err(
return Err(ValueError::py_err(
"querying n-gram indices is not supported for this vocabulary",
))
}
Expand All @@ -57,29 +65,105 @@ impl PyVocab {
VocabWrap::FastTextSubwordVocab(inner) => Ok(inner.subword_indices(word)),
VocabWrap::FinalfusionSubwordVocab(inner) => Ok(inner.subword_indices(word)),
VocabWrap::FinalfusionNGramVocab(inner) => Ok(inner.subword_indices(word)),
VocabWrap::SimpleVocab(_) => Err(exceptions::ValueError::py_err(
VocabWrap::SimpleVocab(_) => Err(ValueError::py_err(
"querying subwords' indices is not supported for this vocabulary",
)),
}
}
}

#[pyproto]
impl PySequenceProtocol for PyVocab {
fn __len__(&self) -> PyResult<usize> {
impl PyVocab {
fn str_to_indices(&self, query: &str) -> PyResult<WordIndex> {
let embeds = self.embeddings.borrow();
Ok(embeds.vocab().words_len())
embeds
.vocab()
.idx(query)
.ok_or_else(|| KeyError::py_err(format!("key not found: {}", query)))
}

fn __getitem__(&self, idx: isize) -> PyResult<String> {
fn maybe_convert_negative(&self, idx: isize) -> isize {
let embeds = self.embeddings.borrow();
let words = embeds.vocab().words();
let vocab = embeds.vocab();
if idx < 0 {
idx + vocab.words_len() as isize
} else {
idx
}
}

if idx >= words.len() as isize || idx < 0 {
Err(exceptions::IndexError::py_err("list index out of range"))
fn validate_and_convert_isize_idx(&self, idx: isize) -> PyResult<usize> {
let embeds = self.embeddings.borrow();
let vocab = embeds.vocab();
let idx = self.maybe_convert_negative(idx);
if idx >= vocab.words_len() as isize || idx < 0 {
Err(IndexError::py_err("list index out of range"))
} else {
Ok(words[idx as usize].clone())
Ok(idx as usize)
}
}
}

#[pyproto]
impl PyMappingProtocol for PyVocab {
fn __getitem__(&self, query: PyObject) -> PyResult<PyObject> {
let embeds = self.embeddings.borrow();
let vocab = embeds.vocab();
let gil = Python::acquire_gil();
if let Ok(idx) = query.extract::<isize>(gil.python()) {
let idx = self.validate_and_convert_isize_idx(idx)?;
return Ok(vocab.words()[idx].clone().into_py(gil.python()));
}

if let Ok(indices) = query.extract::<&PySlice>(gil.python()) {
let embeds = self.embeddings.borrow();
let indices = indices.indices(embeds.vocab().words_len() as i64)?;

let start =
(self.maybe_convert_negative(indices.start) as usize).min(vocab.words_len());
let stop = (self.maybe_convert_negative(indices.stop) as usize).min(vocab.words_len());

let words = if start > stop {
if indices.step >= 0 {
return Ok(Vec::<usize>::new().into_py(gil.python()));
}
(stop + 1..=start)
.rev()
.step_by(indices.step.abs() as usize)
.map(|idx| vocab.words()[idx].clone())
.collect::<Vec<_>>()
} else {
(start..stop)
.step_by(indices.step as usize)
.map(|idx| vocab.words()[idx].clone())
.collect::<Vec<_>>()
};

return Ok(words.into_py(gil.python()));
}

if let Ok(query) = query.extract::<String>(gil.python()) {
return self.str_to_indices(&query).map(|idx| match idx {
WordIndex::Subword(indices) => indices.into_py(gil.python()),
WordIndex::Word(idx) => idx.into_py(gil.python()),
});
}

Err(KeyError::py_err("key must be integers, slices or string"))
}
}

#[pyproto]
impl PyIterProtocol for PyVocab {
fn __iter__(slf: PyRefMut<Self>) -> PyResult<PyVocabIterator> {
Ok(PyVocabIterator::new(slf.embeddings.clone(), 0))
}
}

#[pyproto]
impl PySequenceProtocol for PyVocab {
fn __len__(&self) -> PyResult<usize> {
let embeds = self.embeddings.borrow();
Ok(embeds.vocab().words_len())
}

fn __contains__(&self, word: String) -> PyResult<bool> {
Expand Down
94 changes: 92 additions & 2 deletions tests/test_vocab.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import pytest

TEST_NGRAM_INDICES = [
('tüb',
14),
Expand Down Expand Up @@ -53,9 +55,19 @@
1007)]


def test_embeddings_with_norms_oov(embeddings_fifu):
def test_get(embeddings_text_dims):
vocab = embeddings_text_dims.vocab()
assert vocab.get("one") is 0


def test_get_oov(embeddings_fifu):
vocab = embeddings_fifu.vocab()
assert vocab.get("Something out of vocabulary") is None


def test_get_oov_with_default(embeddings_fifu):
vocab = embeddings_fifu.vocab()
assert vocab.item_to_indices("Something out of vocabulary") is None
assert vocab.get("Something out of vocabulary", default=-1) == -1


def test_ngram_indices(subword_fifu):
Expand All @@ -72,3 +84,81 @@ def test_subword_indices(subword_fifu):
for subword_index, test_ngram_index in zip(
subword_indices, TEST_NGRAM_INDICES):
assert subword_index == test_ngram_index[1]


def test_int_idx(embeddings_text_dims):
vocab = embeddings_text_dims.vocab()
assert vocab[0] == "one"


def test_int_idx_out_of_range(embeddings_text_dims):
vocab = embeddings_text_dims.vocab()
with pytest.raises(IndexError):
_ = vocab[42]


def test_negative_int_idx(embeddings_text_dims):
vocab = embeddings_text_dims.vocab()
assert vocab[-1] == "seven"


def test_negative_int_idx_out_of_range(embeddings_text_dims):
vocab = embeddings_text_dims.vocab()
with pytest.raises(IndexError):
_ = vocab[-42]


def test_slice_idx(embeddings_text_dims):
vocab = embeddings_text_dims.vocab()
assert list(vocab)[:4] == vocab[:4]


def test_slice_with_step_idx(embeddings_text_dims):
vocab = embeddings_text_dims.vocab()
assert list(vocab)[:4:2] == vocab[:4:2]


def test_slice_negative_idx(embeddings_text_dims):
vocab = embeddings_text_dims.vocab()
assert list(vocab)[:-1] == vocab[:-1]


def test_slice_negative_idx_with_positive_step(embeddings_text_dims):
vocab = embeddings_text_dims.vocab()
assert list(vocab)[:-1:2] == vocab[:-1:2]


def test_slice_negative_idx_with_negative_step(embeddings_text_dims):
vocab = embeddings_text_dims.vocab()
assert list(vocab)[:-1:-2] == vocab[:-1:-2]


def test_slice_negative_to_negative_idx(embeddings_text_dims):
vocab = embeddings_text_dims.vocab()
assert list(vocab)[-3:-1] == vocab[-3:-1]


def test_slice_out_of_range(embeddings_text_dims):
vocab = embeddings_text_dims.vocab()
assert list(vocab[:42]) == vocab[:42]


def test_slice_negative_out_of_range(embeddings_text_dims):
vocab = embeddings_text_dims.vocab()
assert list(vocab[:-42]) == vocab[:-42]


def test_string_idx(embeddings_text_dims):
vocab = embeddings_text_dims.vocab()
assert vocab["one"] == 0


def test_string_oov(embeddings_text_dims):
vocab = embeddings_text_dims.vocab()
with pytest.raises(KeyError):
vocab["definitely in vocab"]


def test_string_oov_subwords(subword_fifu):
vocab = subword_fifu.vocab()
assert sorted(vocab["tübingen"]) == [x[1] for x in TEST_NGRAM_INDICES]

0 comments on commit 216c316

Please sign in to comment.