diff --git a/tests/common.py b/tests/common.py new file mode 100644 index 0000000..ceeecbd --- /dev/null +++ b/tests/common.py @@ -0,0 +1,162 @@ +""" +Common strategies and utilities used across multiple test modules. + +Any real-world details or samples used as constants were correct when +taken on 2024-03-06. +""" + +import datetime as dt +import string + +from dateutil import relativedelta as rd +from hypothesis import strategies as st +from langchain_community.chat_models import ChatOllama + +from parliai_public.readers.base import BaseReader + + +class ToyReader(BaseReader): + """A toy class to allow testing our abstract base class.""" + + def retrieve_latest_entries(self): + """Allow testing with toy method.""" + + @staticmethod + def _read_metadata(url, soup): + """Allow testing with toy static method.""" + + @staticmethod + def _read_contents(soup): + """Allow testing with toy static method.""" + + def render(self, response, page): + """Allow testing with toy method.""" + + def _summary_template(self): + """Allow testing with toy method.""" + + +def where_what(reader): + """Get the right location and class for testing a reader.""" + + what = reader + if reader is ToyReader: + what = BaseReader + + where = ".".join((what.__module__, what.__name__)) + + return where, what + + +def default_llm() -> ChatOllama: + """Instantiate default LLM object for use in testing.""" + + llm = ChatOllama( + model_name="gemma", + temperature=0, + # max_output_tokens=2048, + ) + + return llm + + +MPS_SAMPLE = [ + ( + "Bob Seely", + "Conservative, Isle of Wight", + "https://www.theyworkforyou.com/mp/25645/bob_seely/isle_of_wight", + ), + ( + "Mark Logan", + "Conservative, Bolton North East", + "https://www.theyworkforyou.com/mp/25886/mark_logan/bolton_north_east", + ), + ( + "Nigel Huddleston", + "Conservative, Mid Worcestershire", + "https://www.theyworkforyou.com/mp/25381/nigel_huddleston/mid_worcestershire", + ), + ( + "Heather Wheeler", + "Conservative, South Derbyshire", + "https://www.theyworkforyou.com/mp/24769/heather_wheeler/south_derbyshire", + ), + ( + "Ian Paisley Jnr", + "DUP, North Antrim", + "https://www.theyworkforyou.com/mp/13852/ian_paisley_jnr/north_antrim", + ), + ( + "Matthew Offord", + "Conservative, Hendon", + "https://www.theyworkforyou.com/mp/24955/matthew_offord/hendon", + ), + ( + "John Howell", + "Conservative, Henley", + "https://www.theyworkforyou.com/mp/14131/john_howell/henley", + ), + ( + "Robert Goodwill", + "Conservative, Scarborough and Whitby", + "https://www.theyworkforyou.com/mp/11804/robert_goodwill/scarborough_and_whitby", + ), + ( + "Naseem Shah", + "Labour, Bradford West", + "https://www.theyworkforyou.com/mp/25385/naseem_shah/bradford_west", + ), + ( + "Amy Callaghan", + "Scottish National Party, East Dunbartonshire", + "https://www.theyworkforyou.com/mp/25863/amy_callaghan/east_dunbartonshire", + ), +] + +GOV_DEPARTMENTS = [ + "Attorney General's Office", + "Cabinet Office", + "Department for Business and Trade", + "Department for Culture, Media and Sport", + "Department for Education", + "Department for Energy Security and Net Zero", + "Department for Environment, Food and Rural Affairs", + "Department for Levelling Up, Housing and Communities", + "Department for Science, Innovation and Technology", + "Department for Transport", + "Department for Work and Pensions", + "Department of Health and Social Care", + "Export Credits Guarantee Department", + "Foreign, Commonwealth and Development Office", + "HM Treasury", + "Home Office", + "Ministry of Defence", + "Ministry of Justice", + "Northern Ireland Office", + "Office of the Advocate General for Scotland", + "Office of the Leader of the House of Commons", + "Office of the Leader of the House of Lords", + "Office of the Secretary of State for Scotland", + "Office of the Secretary of State for Wales", +] + +SEARCH_TERMS = ( + "ONS", + "Office for National Statistics", + "National Statistician", +) + +TODAY = dt.date.today() +ST_DATES = st.dates(TODAY - rd.relativedelta(years=4), TODAY) + +ST_FREE_TEXT = st.text( + string.ascii_letters + string.digits + ".:;!?-", min_size=1 +) + +MODEL_NAMES = ["llama3", "mistral", "openhermes"] + +GEMMA_PREAMBLES = [ + "Sure! Here is the text you are looking for: \nMy right honourable friend...", + "Sure - here is the quote: My right honourable friend...", + "Sure!The following contains references to your search terms:My right honourable friend...", +] diff --git a/tests/readers/base/__init__.py b/tests/readers/base/__init__.py new file mode 100644 index 0000000..99eeea0 --- /dev/null +++ b/tests/readers/base/__init__.py @@ -0,0 +1 @@ +"""Tests for the BaseReader class.""" diff --git a/tests/readers/test_creation.py b/tests/readers/test_creation.py new file mode 100644 index 0000000..667a24a --- /dev/null +++ b/tests/readers/test_creation.py @@ -0,0 +1,265 @@ +"""Unit tests for instantiation methods of our readers.""" + +import datetime as dt +from unittest import mock + +import pytest +from hypothesis import given, provisional +from hypothesis import strategies as st + +from parliai_public.readers import Debates, WrittenAnswers + +from ..common import ST_DATES, ST_FREE_TEXT, TODAY, ToyReader, where_what + +ST_OPTIONAL_STRINGS = st.one_of((st.just(None), ST_FREE_TEXT)) +YESTERDAY = TODAY - dt.timedelta(days=1) + + +@given( + st.sampled_from((ToyReader, Debates, WrittenAnswers)), + st.lists(provisional.urls(), max_size=5), + st.one_of((st.just(None), st.lists(ST_FREE_TEXT, max_size=5))), + st.one_of(st.just(None), st.lists(ST_DATES, min_size=1, max_size=5)), + ST_FREE_TEXT, + ST_OPTIONAL_STRINGS, +) +def test_init(reader_class, urls, terms, dates, outdir, prompt): + """Test instantiation occurs correctly.""" + + where, what = where_what(reader_class) + if reader_class is WrittenAnswers: + urls = reader_class._supported_urls + + config = { + "prompt": "", + "llm_name": "gemma", + } + with mock.patch(f"{where}._load_config") as load: + load.return_value = config + reader = reader_class(urls, terms, dates, outdir, prompt) + + default_terms = ["Office for National Statistics", "ONS"] + assert isinstance(reader, what) + assert reader.urls == urls + assert reader.terms == default_terms if not terms else terms + assert reader.dates == [YESTERDAY] if dates is None else dates + assert reader.outdir == outdir + assert reader.prompt == ("" if prompt is None else prompt) + assert reader.llm_name == "gemma" + + load.assert_called_once_with() + + +@pytest.mark.skip("Skipping - requires diagnostics re keywords") +@given( + st.sampled_from((ToyReader, Debates, WrittenAnswers)), + ST_OPTIONAL_STRINGS, + st.lists(provisional.urls(), max_size=5), + st.lists(ST_FREE_TEXT, max_size=5), + ST_FREE_TEXT, +) +def test_from_toml_no_dates(reader_class, path, urls, terms, text): + """ + Test that an instance can be made from a configuration file. + + In this test, we do not configure any of the date parameters, so + every reader instance should have the same `dates` attribute: + yesterday. + """ + + where, what = where_what(reader_class) + if reader_class is WrittenAnswers: + urls = reader_class._supported_urls + + with ( + mock.patch(f"{where}._load_config") as loader, + mock.patch("parliai_public.dates.list_dates") as lister, + ): + loader.return_value = { + "urls": urls, + "terms": terms, + "outdir": text, + "prompt": text, + "llm_name": "gemma", + } + reader = reader_class.from_toml(path) + + assert isinstance(reader, what) + assert reader.urls == urls + assert reader.terms == terms + assert reader.dates == [YESTERDAY] + assert reader.outdir == text + assert reader.prompt == text + assert reader.llm_name == "gemma" + + assert loader.return_value["dates"] is None + assert loader.call_count == 2 + assert loader.call_args_list == [mock.call(path), mock.call()] + + lister.assert_not_called() + + +@given( + st.sampled_from((ToyReader, Debates, WrittenAnswers)), + ST_DATES.map(dt.date.isoformat), +) +def test_from_toml_with_start(reader_class, start): + """ + Check the config constructor works with a start date. + + The actual date list construction is mocked here. + """ + where, what = where_what(reader_class) + + with ( + mock.patch(f"{where}._load_config") as loader, + mock.patch("parliai_public.dates.list_dates") as lister, + ): + loader.return_value = { + "urls": [], + "start": start, + "prompt": "", + "llm_name": "gemma", + } + lister.return_value = ["dates"] + reader = reader_class.from_toml() + + assert isinstance(reader, what) + assert reader.dates == ["dates"] + + assert "start" not in loader.return_value + assert loader.return_value.get("dates") == ["dates"] + + lister.assert_called_once_with(start, None, None, "%Y-%m-%d") + + +@given( + st.sampled_from((ToyReader, Debates, WrittenAnswers)), + ST_DATES.map(dt.date.isoformat), +) +def test_from_toml_with_end(reader_class, end): + """Check the config constructor works with an end date.""" + + where, what = where_what(reader_class) + + with ( + mock.patch(f"{where}._load_config") as loader, + mock.patch("parliai_public.dates.list_dates") as lister, + ): + loader.return_value = { + "urls": [], + "end": end, + "prompt": "", + "llm_name": "gemma", + } + lister.return_value = ["dates"] + reader = reader_class.from_toml() + + assert isinstance(reader, what) + assert reader.dates == ["dates"] + + assert "end" not in loader.return_value + assert loader.return_value.get("dates") == ["dates"] + + lister.assert_called_once_with(None, end, None, "%Y-%m-%d") + + +@given( + st.sampled_from((ToyReader, Debates, WrittenAnswers)), + st.tuples(ST_DATES, ST_DATES).map( + lambda dates: sorted(map(dt.date.isoformat, dates)) + ), +) +def test_from_toml_with_endpoints(reader_class, endpoints): + """Check the config constructor works with two endpoints.""" + + where, what = where_what(reader_class) + start, end = endpoints + + with ( + mock.patch(f"{where}._load_config") as loader, + mock.patch("parliai_public.dates.list_dates") as lister, + ): + loader.return_value = { + "urls": [], + "start": start, + "end": end, + "prompt": "", + "llm_name": "gemma", + } + lister.return_value = ["dates"] + reader = reader_class.from_toml() + + assert isinstance(reader, what) + assert reader.dates == ["dates"] + + assert "start" not in loader.return_value + assert "end" not in loader.return_value + assert loader.return_value.get("dates") == ["dates"] + + lister.assert_called_once_with(start, end, None, "%Y-%m-%d") + + +@given( + st.sampled_from((ToyReader, Debates, WrittenAnswers)), + st.integers(1, 14), +) +def test_from_toml_with_window(reader_class, window): + """Check the config constructor works with a window.""" + + where, what = where_what(reader_class) + + with ( + mock.patch(f"{where}._load_config") as loader, + mock.patch("parliai_public.dates.list_dates") as lister, + ): + loader.return_value = { + "urls": [], + "window": window, + "prompt": "", + "llm_name": "gemma", + } + lister.return_value = ["dates"] + reader = reader_class.from_toml() + + assert isinstance(reader, what) + assert reader.dates == ["dates"] + + assert "end" not in loader.return_value + assert loader.return_value.get("dates") == ["dates"] + + lister.assert_called_once_with(None, None, window, "%Y-%m-%d") + + +@given( + st.sampled_from((ToyReader, Debates, WrittenAnswers)), + ST_DATES.map(dt.date.isoformat), + st.integers(1, 14), +) +def test_from_toml_with_end_and_window(reader_class, end, window): + """Check the config constructor works with an end and a window.""" + + where, what = where_what(reader_class) + + with ( + mock.patch(f"{where}._load_config") as loader, + mock.patch("parliai_public.dates.list_dates") as lister, + ): + loader.return_value = { + "urls": [], + "end": end, + "window": window, + "prompt": "", + "llm_name": "gemma", + } + lister.return_value = ["dates"] + reader = reader_class.from_toml() + + assert isinstance(reader, what) + assert reader.dates == ["dates"] + + assert "end" not in loader.return_value + assert "window" not in loader.return_value + assert loader.return_value.get("dates") == ["dates"] + + lister.assert_called_once_with(None, end, window, "%Y-%m-%d") diff --git a/tests/readers/theyworkforyou/__init__.py b/tests/readers/theyworkforyou/__init__.py new file mode 100644 index 0000000..7b94998 --- /dev/null +++ b/tests/readers/theyworkforyou/__init__.py @@ -0,0 +1 @@ +"""Unit tests for the debates reader."""