From 1ba091c73d659f2b660fb172b377ba9af9acab7e Mon Sep 17 00:00:00 2001 From: Sebastian Puetz Date: Thu, 24 Oct 2019 12:46:48 +0200 Subject: [PATCH] Properly implement PyProtocols for PyVocab. In order to get arbitrary keys, PyMappingProtocol::__getitem__ needs to be implemented. To get O(1) __contains__, PySequenceProtocol::__contains__ needs to be implemented. To get proper Iteration support, PyIterProtocol::__iter__ needs to be implemented. https://github.com/PyO3/pyo3/issues/611 This commit adds the correct implementation of the three traits to PyVocab. --- src/iter.rs | 34 ++++++++++++++ src/vocab.rs | 105 ++++++++++++++++++++++++++++++++++++-------- tests/test_vocab.py | 78 +++++++++++++++++++++++++++++++- 3 files changed, 197 insertions(+), 20 deletions(-) diff --git a/src/iter.rs b/src/iter.rs index 3aa2bea..cce164a 100644 --- a/src/iter.rs +++ b/src/iter.rs @@ -80,3 +80,37 @@ impl PyEmbedding { self.norm } } + +#[pyclass(name=VocabIterator)] +pub struct PyVocabIterator { + embeddings: Rc>, + idx: usize, +} + +impl PyVocabIterator { + pub fn new(embeddings: Rc>, idx: usize) -> Self { + PyVocabIterator { embeddings, idx } + } +} + +#[pyproto] +impl PyIterProtocol for PyVocabIterator { + fn __iter__(slf: PyRefMut) -> PyResult> { + Ok(slf.into()) + } + + fn __next__(mut slf: PyRefMut) -> PyResult> { + let slf = &mut *slf; + + let embeddings = slf.embeddings.borrow(); + let vocab = embeddings.vocab(); + + if slf.idx < vocab.words_len() { + let word = vocab.words()[slf.idx].to_string(); + slf.idx += 1; + Ok(Some(word)) + } else { + Ok(None) + } + } +} diff --git a/src/vocab.rs b/src/vocab.rs index 4f60daf..3888480 100644 --- a/src/vocab.rs +++ b/src/vocab.rs @@ -4,9 +4,11 @@ use std::rc::Rc; use finalfusion::chunks::vocab::{NGramIndices, SubwordIndices, VocabWrap, WordIndex}; use finalfusion::prelude::*; use pyo3::class::sequence::PySequenceProtocol; -use pyo3::exceptions; use pyo3::prelude::*; +use pyo3::{exceptions, PyIterProtocol, PyMappingProtocol}; +use pyo3::types::PyAny; +use crate::iter::PyVocabIterator; use crate::EmbeddingsWrap; type NGramIndex = (String, Option); @@ -25,16 +27,21 @@ impl PyVocab { #[pymethods] impl PyVocab { - fn item_to_indices(&self, key: String) -> Option { + #[allow(clippy::option_option)] + #[args(default = "None")] + fn get(&self, key: String, default: Option>) -> Option { let embeds = self.embeddings.borrow(); - - embeds.vocab().idx(key.as_str()).map(|idx| { - let gil = pyo3::Python::acquire_gil(); - match idx { - WordIndex::Word(idx) => [idx].to_object(gil.python()), - WordIndex::Subword(indices) => indices.to_object(gil.python()), - } - }) + let gil = pyo3::Python::acquire_gil(); + let idx = embeds.vocab().idx(key.as_str()).map(|idx| match idx { + WordIndex::Word(idx) => idx.to_object(gil.python()), + WordIndex::Subword(indices) => indices.to_object(gil.python()), + }); + if default.is_some() && idx.is_none() { + return default + .map(|i| i.into_py(gil.python())) + .into_py(gil.python()); + } + idx } fn ngram_indices(&self, word: &str) -> PyResult>> { @@ -64,17 +71,15 @@ impl PyVocab { } } -#[pyproto] -impl PySequenceProtocol for PyVocab { - fn __len__(&self) -> PyResult { - let embeds = self.embeddings.borrow(); - Ok(embeds.vocab().words_len()) - } - - fn __getitem__(&self, idx: isize) -> PyResult { +impl PyVocab { + fn isize_to_str(&self, mut idx: isize) -> PyResult { let embeds = self.embeddings.borrow(); let words = embeds.vocab().words(); + if idx < 0 { + idx += words.len() as isize; + } + if idx >= words.len() as isize || idx < 0 { Err(exceptions::IndexError::py_err("list index out of range")) } else { @@ -82,6 +87,70 @@ impl PySequenceProtocol for PyVocab { } } + fn str_to_indices(&self, query: &str) -> PyResult { + let embeds = self.embeddings.borrow(); + embeds + .vocab() + .idx(query) + .ok_or_else(|| exceptions::KeyError::py_err(format!("key not found: {}", query))) + } +} + +#[pyproto] +impl PyMappingProtocol for PyVocab { + fn __getitem__(&self, query: &PyAny) -> PyResult { + let gil = Python::acquire_gil(); + if let Ok(idx) = query.extract::() { + return self.isize_to_str(idx).map(|s| s.into_py(gil.python())); + } + + if let Ok(indices) = query.extract::>() { + let mut words = Vec::with_capacity(indices.len()); + for idx in indices { + words.push(self.isize_to_str(idx)?) + } + return Ok(words.into_py(gil.python())); + } + + if let Ok(query) = query.extract::() { + return self.str_to_indices(&query).map(|idx| match idx { + WordIndex::Subword(indices) => indices.into_py(gil.python()), + WordIndex::Word(idx) => idx.into_py(gil.python()), + }); + } + + if let Ok(queries) = query.extract::>() { + let mut words = Vec::with_capacity(queries.len()); + for query in queries { + let idx = self.str_to_indices(&query)?; + let obj: PyObject = match idx { + WordIndex::Subword(indices) => indices.into_py(gil.python()), + WordIndex::Word(idx) => idx.into_py(gil.python()), + }; + words.push(obj) + } + return Ok(words.into_py(gil.python())); + } + Err(exceptions::KeyError::py_err( + "expected int, [int], str or [str]", + )) + } +} + +#[pyproto] +impl PyIterProtocol for PyVocab { + fn __iter__(slf: PyRefMut) -> PyResult { + Ok(PyVocabIterator::new(slf.embeddings.clone(), 0)) + } +} + +#[pyproto] +impl PySequenceProtocol for PyVocab { + fn __len__(&self) -> PyResult { + let embeds = self.embeddings.borrow(); + Ok(embeds.vocab().words_len()) + } + fn __contains__(&self, word: String) -> PyResult { let embeds = self.embeddings.borrow(); Ok(embeds diff --git a/tests/test_vocab.py b/tests/test_vocab.py index 2592851..6ae3954 100644 --- a/tests/test_vocab.py +++ b/tests/test_vocab.py @@ -1,3 +1,5 @@ +import pytest + TEST_NGRAM_INDICES = [ ('tüb', 14), @@ -53,9 +55,19 @@ 1007)] -def test_embeddings_with_norms_oov(embeddings_fifu): +def test_get(embeddings_text_dims): + vocab = embeddings_text_dims.vocab() + assert vocab.get("one") is 0 + + +def test_get_oov(embeddings_fifu): + vocab = embeddings_fifu.vocab() + assert vocab.get("Something out of vocabulary") is None + + +def test_get_oov_with_default(embeddings_fifu): vocab = embeddings_fifu.vocab() - assert vocab.item_to_indices("Something out of vocabulary") is None + assert vocab.get("Something out of vocabulary", default=-1) == -1 def test_ngram_indices(subword_fifu): @@ -72,3 +84,65 @@ def test_subword_indices(subword_fifu): for subword_index, test_ngram_index in zip( subword_indices, TEST_NGRAM_INDICES): assert subword_index == test_ngram_index[1] + + +def test_int_idx(embeddings_text_dims): + vocab = embeddings_text_dims.vocab() + assert vocab[0] == "one" + + +def test_int_idx_out_of_range(embeddings_text_dims): + vocab = embeddings_text_dims.vocab() + with pytest.raises(IndexError): + _ = vocab[42] + + +def test_negative_int_idx(embeddings_text_dims): + vocab = embeddings_text_dims.vocab() + assert vocab[-1] == "seven" + + +def test_negative_int_idx_out_of_range(embeddings_text_dims): + vocab = embeddings_text_dims.vocab() + with pytest.raises(IndexError): + _ = vocab[-42] + + +def test_int_list_idx(embeddings_text_dims): + vocab = embeddings_text_dims.vocab() + words = vocab[[0, 2, 3, 4]] + assert words == ["one", "three", "four", "five"] + + +def test_int_list_idx_out_of_range(embeddings_text_dims): + vocab = embeddings_text_dims.vocab() + with pytest.raises(IndexError): + _ = vocab[[-42]] + + +def test_string_idx(embeddings_text_dims): + vocab = embeddings_text_dims.vocab() + assert vocab["one"] == 0 + + +def test_string_oov(embeddings_text_dims): + vocab = embeddings_text_dims.vocab() + with pytest.raises(KeyError): + vocab["definitely in vocab"] + + +def test_string_list_idx(embeddings_text_dims): + vocab = embeddings_text_dims.vocab() + words = vocab[["one", "three", "four", "five"]] + assert words == [0, 2, 3, 4] + + +def test_string_list_idx_oov(embeddings_text_dims): + vocab = embeddings_text_dims.vocab() + with pytest.raises(KeyError): + _ = vocab[["definitely not oov"]] + + +def test_string_oov_subwords(subword_fifu): + vocab = subword_fifu.vocab() + assert sorted(vocab["tübingen"]) == [x[1] for x in TEST_NGRAM_INDICES] \ No newline at end of file