Skip to content

Commit

Permalink
update path handling
Browse files Browse the repository at this point in the history
  • Loading branch information
rasbt committed Mar 8, 2025
1 parent 7c7e60e commit 3bfad4c
Showing 1 changed file with 33 additions and 20 deletions.
53 changes: 33 additions & 20 deletions ch02/05_bpe-from-scratch/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,34 +8,47 @@
import tiktoken


def import_definitions_from_notebook(fullname, names):
current_dir = os.getcwd()
path = os.path.join(current_dir, fullname + ".ipynb")
path = os.path.normpath(path)
def import_definitions_from_notebook(notebooks):
imported_modules = {}

if not os.path.exists(path):
raise FileNotFoundError(f"Notebook file not found at: {path}")
for fullname, names in notebooks.items():
# Get the directory of the current test file
current_dir = os.path.dirname(__file__)
path = os.path.join(current_dir, "..", fullname + ".ipynb")
path = os.path.normpath(path)

with io.open(path, "r", encoding="utf-8") as f:
nb = nbformat.read(f, as_version=4)
# Load the notebook
if not os.path.exists(path):
raise FileNotFoundError(f"Notebook file not found at: {path}")

mod = types.ModuleType(fullname)
sys.modules[fullname] = mod
with io.open(path, "r", encoding="utf-8") as f:
nb = nbformat.read(f, as_version=4)

for cell in nb.cells:
if cell.cell_type == "code":
cell_code = cell.source
for name in names:
if f"def {name}" in cell_code or f"class {name}" in cell_code:
exec(cell_code, mod.__dict__)
return mod
# Create a module to store the imported functions and classes
mod = types.ModuleType(fullname)
sys.modules[fullname] = mod

# Go through the notebook cells and only execute function or class definitions
for cell in nb.cells:
if cell.cell_type == "code":
cell_code = cell.source
for name in names:
# Check for function or class definitions
if f"def {name}" in cell_code or f"class {name}" in cell_code:
exec(cell_code, mod.__dict__)

imported_modules[fullname] = mod

return imported_modules


@pytest.fixture(scope="module")
def imported_module():
fullname = "bpe-from-scratch"
names = ["BPETokenizerSimple", "download_file_if_absent"]
return import_definitions_from_notebook(fullname, names)
notebooks = {
"bpe-from-scratch": ["BPETokenizerSimple", "download_file_if_absent"],
}

return import_definitions_from_notebook(notebooks)["bpe-from-scratch"]


@pytest.fixture(scope="module")
Expand Down

0 comments on commit 3bfad4c

Please sign in to comment.