Skip to content

Commit

Permalink
Properly implement PyProtocols for PyVocab.
Browse files Browse the repository at this point in the history
In order to get arbitrary keys, PyMappingProtocol::__getitem__
needs to be implemented. To get O(1) __contains__,
PySequenceProtocol::__contains__ needs to be implemented. To get
proper Iteration support, PyIterProtocol::__iter__ needs to be
implemented.

PyO3/pyo3#611

This commit adds the correct implementation of the three traits
to PyVocab.
  • Loading branch information
sebpuetz committed Oct 24, 2019
1 parent e16c867 commit 1ba091c
Show file tree
Hide file tree
Showing 3 changed files with 197 additions and 20 deletions.
34 changes: 34 additions & 0 deletions src/iter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,3 +80,37 @@ impl PyEmbedding {
self.norm
}
}

#[pyclass(name=VocabIterator)]
pub struct PyVocabIterator {
embeddings: Rc<RefCell<EmbeddingsWrap>>,
idx: usize,
}

impl PyVocabIterator {
pub fn new(embeddings: Rc<RefCell<EmbeddingsWrap>>, idx: usize) -> Self {
PyVocabIterator { embeddings, idx }
}
}

#[pyproto]
impl PyIterProtocol for PyVocabIterator {
fn __iter__(slf: PyRefMut<Self>) -> PyResult<Py<PyVocabIterator>> {
Ok(slf.into())
}

fn __next__(mut slf: PyRefMut<Self>) -> PyResult<Option<String>> {
let slf = &mut *slf;

let embeddings = slf.embeddings.borrow();
let vocab = embeddings.vocab();

if slf.idx < vocab.words_len() {
let word = vocab.words()[slf.idx].to_string();
slf.idx += 1;
Ok(Some(word))
} else {
Ok(None)
}
}
}
105 changes: 87 additions & 18 deletions src/vocab.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@ use std::rc::Rc;
use finalfusion::chunks::vocab::{NGramIndices, SubwordIndices, VocabWrap, WordIndex};
use finalfusion::prelude::*;
use pyo3::class::sequence::PySequenceProtocol;
use pyo3::exceptions;
use pyo3::prelude::*;
use pyo3::{exceptions, PyIterProtocol, PyMappingProtocol};
use pyo3::types::PyAny;

use crate::iter::PyVocabIterator;
use crate::EmbeddingsWrap;

type NGramIndex = (String, Option<usize>);
Expand All @@ -25,16 +27,21 @@ impl PyVocab {

#[pymethods]
impl PyVocab {
fn item_to_indices(&self, key: String) -> Option<PyObject> {
#[allow(clippy::option_option)]
#[args(default = "None")]
fn get(&self, key: String, default: Option<Option<&PyAny>>) -> Option<PyObject> {
let embeds = self.embeddings.borrow();

embeds.vocab().idx(key.as_str()).map(|idx| {
let gil = pyo3::Python::acquire_gil();
match idx {
WordIndex::Word(idx) => [idx].to_object(gil.python()),
WordIndex::Subword(indices) => indices.to_object(gil.python()),
}
})
let gil = pyo3::Python::acquire_gil();
let idx = embeds.vocab().idx(key.as_str()).map(|idx| match idx {
WordIndex::Word(idx) => idx.to_object(gil.python()),
WordIndex::Subword(indices) => indices.to_object(gil.python()),
});
if default.is_some() && idx.is_none() {
return default
.map(|i| i.into_py(gil.python()))
.into_py(gil.python());
}
idx
}

fn ngram_indices(&self, word: &str) -> PyResult<Option<Vec<NGramIndex>>> {
Expand Down Expand Up @@ -64,24 +71,86 @@ impl PyVocab {
}
}

#[pyproto]
impl PySequenceProtocol for PyVocab {
fn __len__(&self) -> PyResult<usize> {
let embeds = self.embeddings.borrow();
Ok(embeds.vocab().words_len())
}

fn __getitem__(&self, idx: isize) -> PyResult<String> {
impl PyVocab {
fn isize_to_str(&self, mut idx: isize) -> PyResult<String> {
let embeds = self.embeddings.borrow();
let words = embeds.vocab().words();

if idx < 0 {
idx += words.len() as isize;
}

if idx >= words.len() as isize || idx < 0 {
Err(exceptions::IndexError::py_err("list index out of range"))
} else {
Ok(words[idx as usize].clone())
}
}

fn str_to_indices(&self, query: &str) -> PyResult<WordIndex> {
let embeds = self.embeddings.borrow();
embeds
.vocab()
.idx(query)
.ok_or_else(|| exceptions::KeyError::py_err(format!("key not found: {}", query)))
}
}

#[pyproto]
impl PyMappingProtocol for PyVocab {
fn __getitem__(&self, query: &PyAny) -> PyResult<PyObject> {
let gil = Python::acquire_gil();
if let Ok(idx) = query.extract::<isize>() {
return self.isize_to_str(idx).map(|s| s.into_py(gil.python()));
}

if let Ok(indices) = query.extract::<Vec<isize>>() {
let mut words = Vec::with_capacity(indices.len());
for idx in indices {
words.push(self.isize_to_str(idx)?)
}
return Ok(words.into_py(gil.python()));
}

if let Ok(query) = query.extract::<String>() {
return self.str_to_indices(&query).map(|idx| match idx {
WordIndex::Subword(indices) => indices.into_py(gil.python()),
WordIndex::Word(idx) => idx.into_py(gil.python()),
});
}

if let Ok(queries) = query.extract::<Vec<String>>() {
let mut words = Vec::with_capacity(queries.len());
for query in queries {
let idx = self.str_to_indices(&query)?;
let obj: PyObject = match idx {
WordIndex::Subword(indices) => indices.into_py(gil.python()),
WordIndex::Word(idx) => idx.into_py(gil.python()),
};
words.push(obj)
}
return Ok(words.into_py(gil.python()));
}
Err(exceptions::KeyError::py_err(
"expected int, [int], str or [str]",
))
}
}

#[pyproto]
impl PyIterProtocol for PyVocab {
fn __iter__(slf: PyRefMut<Self>) -> PyResult<PyVocabIterator> {
Ok(PyVocabIterator::new(slf.embeddings.clone(), 0))
}
}

#[pyproto]
impl PySequenceProtocol for PyVocab {
fn __len__(&self) -> PyResult<usize> {
let embeds = self.embeddings.borrow();
Ok(embeds.vocab().words_len())
}

fn __contains__(&self, word: String) -> PyResult<bool> {
let embeds = self.embeddings.borrow();
Ok(embeds
Expand Down
78 changes: 76 additions & 2 deletions tests/test_vocab.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import pytest

TEST_NGRAM_INDICES = [
('tüb',
14),
Expand Down Expand Up @@ -53,9 +55,19 @@
1007)]


def test_embeddings_with_norms_oov(embeddings_fifu):
def test_get(embeddings_text_dims):
vocab = embeddings_text_dims.vocab()
assert vocab.get("one") is 0


def test_get_oov(embeddings_fifu):
vocab = embeddings_fifu.vocab()
assert vocab.get("Something out of vocabulary") is None


def test_get_oov_with_default(embeddings_fifu):
vocab = embeddings_fifu.vocab()
assert vocab.item_to_indices("Something out of vocabulary") is None
assert vocab.get("Something out of vocabulary", default=-1) == -1


def test_ngram_indices(subword_fifu):
Expand All @@ -72,3 +84,65 @@ def test_subword_indices(subword_fifu):
for subword_index, test_ngram_index in zip(
subword_indices, TEST_NGRAM_INDICES):
assert subword_index == test_ngram_index[1]


def test_int_idx(embeddings_text_dims):
vocab = embeddings_text_dims.vocab()
assert vocab[0] == "one"


def test_int_idx_out_of_range(embeddings_text_dims):
vocab = embeddings_text_dims.vocab()
with pytest.raises(IndexError):
_ = vocab[42]


def test_negative_int_idx(embeddings_text_dims):
vocab = embeddings_text_dims.vocab()
assert vocab[-1] == "seven"


def test_negative_int_idx_out_of_range(embeddings_text_dims):
vocab = embeddings_text_dims.vocab()
with pytest.raises(IndexError):
_ = vocab[-42]


def test_int_list_idx(embeddings_text_dims):
vocab = embeddings_text_dims.vocab()
words = vocab[[0, 2, 3, 4]]
assert words == ["one", "three", "four", "five"]


def test_int_list_idx_out_of_range(embeddings_text_dims):
vocab = embeddings_text_dims.vocab()
with pytest.raises(IndexError):
_ = vocab[[-42]]


def test_string_idx(embeddings_text_dims):
vocab = embeddings_text_dims.vocab()
assert vocab["one"] == 0


def test_string_oov(embeddings_text_dims):
vocab = embeddings_text_dims.vocab()
with pytest.raises(KeyError):
vocab["definitely in vocab"]


def test_string_list_idx(embeddings_text_dims):
vocab = embeddings_text_dims.vocab()
words = vocab[["one", "three", "four", "five"]]
assert words == [0, 2, 3, 4]


def test_string_list_idx_oov(embeddings_text_dims):
vocab = embeddings_text_dims.vocab()
with pytest.raises(KeyError):
_ = vocab[["definitely not oov"]]


def test_string_oov_subwords(subword_fifu):
vocab = subword_fifu.vocab()
assert sorted(vocab["tübingen"]) == [x[1] for x in TEST_NGRAM_INDICES]

0 comments on commit 1ba091c

Please sign in to comment.