From bb2a2e00f38e9711269fd3a657d0c903a07ffd47 Mon Sep 17 00:00:00 2001 From: Sebastian Puetz Date: Thu, 24 Oct 2019 12:46:48 +0200 Subject: [PATCH] Properly implement PyProtocols for PyVocab. In order to get arbitrary keys, PyMappingProtocol::__getitem__ needs to be implemented. To get O(1) __contains__, PySequenceProtocol::__contains__ needs to be implemented. To get proper Iteration support, PyIterProtocol::__iter__ needs to be implemented. https://github.com/PyO3/pyo3/issues/611 This commit adds the correct implementation of the three traits to PyVocab. --- src/iter.rs | 34 ++++++++++++ src/vocab.rs | 123 ++++++++++++++++++++++++++++++++++++-------- tests/test_vocab.py | 94 ++++++++++++++++++++++++++++++++- 3 files changed, 228 insertions(+), 23 deletions(-) diff --git a/src/iter.rs b/src/iter.rs index 3aa2bea..cce164a 100644 --- a/src/iter.rs +++ b/src/iter.rs @@ -80,3 +80,37 @@ impl PyEmbedding { self.norm } } + +#[pyclass(name=VocabIterator)] +pub struct PyVocabIterator { + embeddings: Rc>, + idx: usize, +} + +impl PyVocabIterator { + pub fn new(embeddings: Rc>, idx: usize) -> Self { + PyVocabIterator { embeddings, idx } + } +} + +#[pyproto] +impl PyIterProtocol for PyVocabIterator { + fn __iter__(slf: PyRefMut) -> PyResult> { + Ok(slf.into()) + } + + fn __next__(mut slf: PyRefMut) -> PyResult> { + let slf = &mut *slf; + + let embeddings = slf.embeddings.borrow(); + let vocab = embeddings.vocab(); + + if slf.idx < vocab.words_len() { + let word = vocab.words()[slf.idx].to_string(); + slf.idx += 1; + Ok(Some(word)) + } else { + Ok(None) + } + } +} diff --git a/src/vocab.rs b/src/vocab.rs index 4f60daf..90d1a6b 100644 --- a/src/vocab.rs +++ b/src/vocab.rs @@ -4,9 +4,12 @@ use std::rc::Rc; use finalfusion::chunks::vocab::{NGramIndices, SubwordIndices, VocabWrap, WordIndex}; use finalfusion::prelude::*; use pyo3::class::sequence::PySequenceProtocol; -use pyo3::exceptions; +use pyo3::exceptions::{IndexError, KeyError, ValueError}; use pyo3::prelude::*; +use pyo3::types::PySlice; +use pyo3::{PyIterProtocol, PyMappingProtocol}; +use crate::iter::PyVocabIterator; use crate::EmbeddingsWrap; type NGramIndex = (String, Option); @@ -25,16 +28,18 @@ impl PyVocab { #[pymethods] impl PyVocab { - fn item_to_indices(&self, key: String) -> Option { + #[args(default = "Python::acquire_gil().python().None()")] + fn get(&self, key: &str, default: PyObject) -> Option { let embeds = self.embeddings.borrow(); - - embeds.vocab().idx(key.as_str()).map(|idx| { - let gil = pyo3::Python::acquire_gil(); - match idx { - WordIndex::Word(idx) => [idx].to_object(gil.python()), - WordIndex::Subword(indices) => indices.to_object(gil.python()), - } - }) + let gil = pyo3::Python::acquire_gil(); + let idx = embeds.vocab().idx(key).map(|idx| match idx { + WordIndex::Word(idx) => idx.to_object(gil.python()), + WordIndex::Subword(indices) => indices.to_object(gil.python()), + }); + if !default.is_none() && idx.is_none() { + return Some(default); + } + idx } fn ngram_indices(&self, word: &str) -> PyResult>> { @@ -44,7 +49,7 @@ impl PyVocab { VocabWrap::FinalfusionSubwordVocab(inner) => inner.ngram_indices(word), VocabWrap::FinalfusionNGramVocab(inner) => inner.ngram_indices(word), VocabWrap::SimpleVocab(_) => { - return Err(exceptions::ValueError::py_err( + return Err(ValueError::py_err( "querying n-gram indices is not supported for this vocabulary", )) } @@ -57,29 +62,105 @@ impl PyVocab { VocabWrap::FastTextSubwordVocab(inner) => Ok(inner.subword_indices(word)), VocabWrap::FinalfusionSubwordVocab(inner) => Ok(inner.subword_indices(word)), VocabWrap::FinalfusionNGramVocab(inner) => Ok(inner.subword_indices(word)), - VocabWrap::SimpleVocab(_) => Err(exceptions::ValueError::py_err( + VocabWrap::SimpleVocab(_) => Err(ValueError::py_err( "querying subwords' indices is not supported for this vocabulary", )), } } } -#[pyproto] -impl PySequenceProtocol for PyVocab { - fn __len__(&self) -> PyResult { +impl PyVocab { + fn str_to_indices(&self, query: &str) -> PyResult { let embeds = self.embeddings.borrow(); - Ok(embeds.vocab().words_len()) + embeds + .vocab() + .idx(query) + .ok_or_else(|| KeyError::py_err(format!("key not found: {}", query))) } - fn __getitem__(&self, idx: isize) -> PyResult { + fn maybe_convert_negative(&self, idx: isize) -> isize { let embeds = self.embeddings.borrow(); - let words = embeds.vocab().words(); + let vocab = embeds.vocab(); + if idx < 0 { + idx + vocab.words_len() as isize + } else { + idx + } + } - if idx >= words.len() as isize || idx < 0 { - Err(exceptions::IndexError::py_err("list index out of range")) + fn validate_and_convert_isize_idx(&self, idx: isize) -> PyResult { + let embeds = self.embeddings.borrow(); + let vocab = embeds.vocab(); + let idx = self.maybe_convert_negative(idx); + if idx >= vocab.words_len() as isize || idx < 0 { + Err(IndexError::py_err("list index out of range")) } else { - Ok(words[idx as usize].clone()) + Ok(idx as usize) + } + } +} + +#[pyproto] +impl PyMappingProtocol for PyVocab { + fn __getitem__(&self, query: PyObject) -> PyResult { + let embeds = self.embeddings.borrow(); + let vocab = embeds.vocab(); + let gil = Python::acquire_gil(); + if let Ok(idx) = query.extract::(gil.python()) { + let idx = self.validate_and_convert_isize_idx(idx)?; + return Ok(vocab.words()[idx].clone().into_py(gil.python())); } + + if let Ok(indices) = query.extract::<&PySlice>(gil.python()) { + let embeds = self.embeddings.borrow(); + let indices = indices.indices(embeds.vocab().words_len() as i64)?; + + let start = + (self.maybe_convert_negative(indices.start) as usize).min(vocab.words_len()); + let stop = (self.maybe_convert_negative(indices.stop) as usize).min(vocab.words_len()); + + let words = if start > stop { + if indices.step >= 0 { + return Ok(Vec::::new().into_py(gil.python())); + } + (stop + 1..=start) + .rev() + .step_by(indices.step.abs() as usize) + .map(|idx| vocab.words()[idx].clone()) + .collect::>() + } else { + (start..stop) + .step_by(indices.step as usize) + .map(|idx| vocab.words()[idx].clone()) + .collect::>() + }; + + return Ok(words.into_py(gil.python())); + } + + if let Ok(query) = query.extract::(gil.python()) { + return self.str_to_indices(&query).map(|idx| match idx { + WordIndex::Subword(indices) => indices.into_py(gil.python()), + WordIndex::Word(idx) => idx.into_py(gil.python()), + }); + } + + Err(KeyError::py_err("key must be integers, slices or string")) + } +} + +#[pyproto] +impl PyIterProtocol for PyVocab { + fn __iter__(slf: PyRefMut) -> PyResult { + Ok(PyVocabIterator::new(slf.embeddings.clone(), 0)) + } +} + +#[pyproto] +impl PySequenceProtocol for PyVocab { + fn __len__(&self) -> PyResult { + let embeds = self.embeddings.borrow(); + Ok(embeds.vocab().words_len()) } fn __contains__(&self, word: String) -> PyResult { diff --git a/tests/test_vocab.py b/tests/test_vocab.py index 2592851..c6f53ff 100644 --- a/tests/test_vocab.py +++ b/tests/test_vocab.py @@ -1,3 +1,5 @@ +import pytest + TEST_NGRAM_INDICES = [ ('tüb', 14), @@ -53,9 +55,19 @@ 1007)] -def test_embeddings_with_norms_oov(embeddings_fifu): +def test_get(embeddings_text_dims): + vocab = embeddings_text_dims.vocab() + assert vocab.get("one") is 0 + + +def test_get_oov(embeddings_fifu): + vocab = embeddings_fifu.vocab() + assert vocab.get("Something out of vocabulary") is None + + +def test_get_oov_with_default(embeddings_fifu): vocab = embeddings_fifu.vocab() - assert vocab.item_to_indices("Something out of vocabulary") is None + assert vocab.get("Something out of vocabulary", default=-1) == -1 def test_ngram_indices(subword_fifu): @@ -72,3 +84,81 @@ def test_subword_indices(subword_fifu): for subword_index, test_ngram_index in zip( subword_indices, TEST_NGRAM_INDICES): assert subword_index == test_ngram_index[1] + + +def test_int_idx(embeddings_text_dims): + vocab = embeddings_text_dims.vocab() + assert vocab[0] == "one" + + +def test_int_idx_out_of_range(embeddings_text_dims): + vocab = embeddings_text_dims.vocab() + with pytest.raises(IndexError): + _ = vocab[42] + + +def test_negative_int_idx(embeddings_text_dims): + vocab = embeddings_text_dims.vocab() + assert vocab[-1] == "seven" + + +def test_negative_int_idx_out_of_range(embeddings_text_dims): + vocab = embeddings_text_dims.vocab() + with pytest.raises(IndexError): + _ = vocab[-42] + + +def test_slice_idx(embeddings_text_dims): + vocab = embeddings_text_dims.vocab() + assert list(vocab)[:4] == vocab[:4] + + +def test_slice_with_step_idx(embeddings_text_dims): + vocab = embeddings_text_dims.vocab() + assert list(vocab)[:4:2] == vocab[:4:2] + + +def test_slice_negative_idx(embeddings_text_dims): + vocab = embeddings_text_dims.vocab() + assert list(vocab)[:-1] == vocab[:-1] + + +def test_slice_negative_idx_with_positive_step(embeddings_text_dims): + vocab = embeddings_text_dims.vocab() + assert list(vocab)[:-1:2] == vocab[:-1:2] + + +def test_slice_negative_idx_with_negative_step(embeddings_text_dims): + vocab = embeddings_text_dims.vocab() + assert list(vocab)[:-1:-2] == vocab[:-1:-2] + + +def test_slice_negative_to_negative_idx(embeddings_text_dims): + vocab = embeddings_text_dims.vocab() + assert list(vocab)[-3:-1] == vocab[-3:-1] + + +def test_slice_out_of_range(embeddings_text_dims): + vocab = embeddings_text_dims.vocab() + assert list(vocab[:42]) == vocab[:42] + + +def test_slice_negative_out_of_range(embeddings_text_dims): + vocab = embeddings_text_dims.vocab() + assert list(vocab[:-42]) == vocab[:-42] + + +def test_string_idx(embeddings_text_dims): + vocab = embeddings_text_dims.vocab() + assert vocab["one"] == 0 + + +def test_string_oov(embeddings_text_dims): + vocab = embeddings_text_dims.vocab() + with pytest.raises(KeyError): + vocab["definitely in vocab"] + + +def test_string_oov_subwords(subword_fifu): + vocab = subword_fifu.vocab() + assert sorted(vocab["tübingen"]) == [x[1] for x in TEST_NGRAM_INDICES] \ No newline at end of file