Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

113 increase efficiency of wiki context creation #121

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions prompting/tools/datasets/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

from ..selector import Selector
from .context import Context
from prompting.utils.exceptions import MaxRetryError


class Dataset(ABC):
Expand All @@ -43,6 +44,19 @@ def get(self, name):
...

def next(self, method: str = 'random', selector: Selector = Selector(), **kwargs) -> Dict:
"""Base method for getting the next sample from the dataset.

Args:
method (str, optional): Method to use for getting the next sample; must be one of 'random', 'search', or 'get'. Defaults to 'random'.
selector (Selector, optional): Selector to use for getting the next sample. Defaults to Selector().

Raises:
ValueError: If an unknown dataset get method is used.
Exception: If the maximum number of tries is reached.

Returns:
Dict: _description_
"""
tries = 1
t0 = time.time()

Expand All @@ -62,11 +76,11 @@ def next(self, method: str = 'random', selector: Selector = Selector(), **kwargs
if info:
break

bt.logging.warning(f"Could not find an sample which meets {self.__class__.__name__} requirements after {tries} tries. Retrying... ({self.max_tries - tries} tries remaining.)")
bt.logging.debug(f"Could not find an sample which meets {self.__class__.__name__} requirements after {tries} tries. Retrying... ({self.max_tries - tries} tries remaining.)")

tries += 1
if tries == self.max_tries:
raise Exception(
raise MaxRetryError(
f"Could not find an sample which meets {self.__class__.__name__} requirements after {tries} tries."
)

Expand Down
59 changes: 43 additions & 16 deletions prompting/tools/datasets/wiki.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,12 @@

import re
import sys
import asyncio
import random
import datetime
import bittensor as bt
import wikipedia as wiki
from functools import partial
from typing import Dict, Union, List, Tuple

from functools import lru_cache
Expand Down Expand Up @@ -148,16 +150,13 @@ def get(self, name: str, selector: Selector = None, include: List = None, exclud
"""Get a specified Wikipedia page and extract a section based on the selector.

Args:
name (_type_): _description_
pageid (_type_, optional): _description_. Defaults to None.
auto_suggest (bool, optional): _description_. Defaults to True.
redirect (bool, optional): _description_. Defaults to True.
selector (Selector, optional): _description_. Defaults to None.
include (List, optional): _description_. Defaults to None.
exclude (List, optional): _description_. Defaults to None.
name (str): Title of the Wikipedia page.
selector (Selector, optional): Selector to choose a section. Defaults to None.
include (List, optional): List of section headers to include. Defaults to None.
exclude (List, optional): List of section headers to exclude. Defaults to None.

Returns:
Dict: _description_
Dict: Context dictionary
"""

page = _get_page(title=name, **kwargs)
Expand All @@ -182,21 +181,37 @@ def get(self, name: str, selector: Selector = None, include: List = None, exclud
'subtopic': section_title,
'content': content,
'internal_links': list(filter(lambda x: x not in exclude, page.sections)),
'external_links': most_relevant_links(page, num_links=self.max_links),
'external_links': most_relevant_links(page, num_links=self.max_links),
'tags': filter_categories(page.categories, exclude=self.EXCLUDE_CATEGORIES),
'source': 'Wikipedia',
'extra': {'url': page.url, 'page_length': len(page.content.split()), 'section_length': section_length},
}

def search(self, name, results=3, selector: Selector = None) -> Dict:
async def get_pages(self, titles, selector):
# asynchronouly load all pages
return await asyncio.gather(*(
asyncio.to_thread(partial(self.get, selector=selector), title) for title in titles
)
)

def search(self, name, results=3, selector: Selector = None, threaded=False) -> Dict:
titles = _wiki_search(name, results=results)
title = selector(titles)
return self.get(title, selector=selector)
if threaded:
pages = asyncio.run(self.get_pages(titles, selector=selector))
return selector(pages)
else:
title = selector(titles)
return self.get(title, selector=selector)


def random(self, pages=10, seed=None, selector: Selector = None, **kwargs) -> Dict:
def random(self, pages=10, seed=None, selector: Selector = None, threaded=False, **kwargs) -> Dict:
titles = wiki.random(pages=pages) if seed is None else _get_random_titles(pages=pages, seed=seed)
title = selector(titles)
return self.get(title, selector=selector)
if threaded:
pages = asyncio.run(self.get_pages(titles, selector=selector))
return selector(pages)
else:
title = selector(titles)
return self.get(title, selector=selector)



Expand Down Expand Up @@ -228,7 +243,19 @@ def _random_date(self, year: int = None, month: int = None) -> int:
# Step 2: Format the date for Wikipedia URL
return random_date.strftime("%B_%d") # E.g., "January_01"

def get(self, name, pageid=None, auto_suggest=True, redirect=True, selector: Selector = None) -> Dict:
def get(self, name: str, pageid: int=None, auto_suggest: bool=True, redirect: bool=True, selector: Selector = None) -> Dict:
"""Get a specified Wikipedia date page and extract an event based on the selector.

Args:
name (str): Title of the Wikipedia page.
pageid (int, optional): ID of the Wikipedia page. Defaults to None.
auto_suggest (bool, optional): Whether to use auto-suggest. Defaults to True.
redirect (bool, optional): Whether to follow redirects. Defaults to True.
selector (Selector, optional): Selector to choose a section. Defaults to None.

Returns:
Dict: Context dictionary
"""

# Check that name is correctly formatted e.g., "January_01"
date = name.split('_')
Expand Down
Loading