Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Consistently allow pathlib.Path in all relevant public dataset methods #4818

Merged
merged 3 commits into from
Nov 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/changes/newsfragments/4818.improved
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
All public methods avilable in ``qcodes.dataset`` that takes a
path to a db or data file now accepts a ``pathlib.Path`` object in addition to a ``str``.
12 changes: 8 additions & 4 deletions qcodes/dataset/database_extract_runs.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import os
from pathlib import Path
from warnings import warn

import numpy as np
Expand All @@ -22,10 +23,13 @@
)


def extract_runs_into_db(source_db_path: str,
target_db_path: str, *run_ids: int,
upgrade_source_db: bool = False,
upgrade_target_db: bool = False) -> None:
def extract_runs_into_db(
source_db_path: str | Path,
target_db_path: str | Path,
*run_ids: int,
upgrade_source_db: bool = False,
upgrade_target_db: bool = False,
) -> None:
"""
Extract a selection of runs into another DB file. All runs must come from
the same experiment. They will be added to an experiment with the same name
Expand Down
5 changes: 3 additions & 2 deletions qcodes/dataset/legacy_import.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import json
from pathlib import Path

import numpy as np

Expand Down Expand Up @@ -77,7 +78,7 @@ def store_array_to_database_alt(meas: Measurement, array: DataArray) -> int:
return datasaver.run_id


def import_dat_file(location: str, exp: Experiment | None = None) -> list[int]:
def import_dat_file(location: str | Path, exp: Experiment | None = None) -> list[int]:
"""
This imports a QCoDeS legacy :class:`qcodes.data.data_set.DataSet`
into the database.
Expand All @@ -90,7 +91,7 @@ def import_dat_file(location: str, exp: Experiment | None = None) -> list[int]:
"""


loaded_data = load_data(location)
loaded_data = load_data(str(location))
meas = setup_measurement(loaded_data,
exp=exp)
run_ids = []
Expand Down
12 changes: 7 additions & 5 deletions qcodes/dataset/sqlite/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,9 @@ def connect(name: str | Path, debug: bool = False, version: int = -1) -> Connect
return conn


def get_db_version_and_newest_available_version(path_to_db: str) -> tuple[int, int]:
def get_db_version_and_newest_available_version(
path_to_db: str | Path,
) -> tuple[int, int]:
"""
Connect to a DB without performing any upgrades and get the version of
that database file along with the newest available version (the one that
Expand Down Expand Up @@ -240,7 +242,7 @@ def set_journal_mode(conn: ConnectionPlus, journal_mode: JournalMode) -> None:


def initialise_or_create_database_at(
db_file_with_abs_path: str, journal_mode: JournalMode | None = "WAL"
db_file_with_abs_path: str | Path, journal_mode: JournalMode | None = "WAL"
) -> None:
"""
This function sets up QCoDeS to refer to the given database file. If the
Expand All @@ -254,12 +256,12 @@ def initialise_or_create_database_at(
Options are DELETE, TRUNCATE, PERSIST, MEMORY, WAL and OFF. If set to None
no changes are made.
"""
qcodes.config.core.db_location = db_file_with_abs_path
qcodes.config.core.db_location = str(db_file_with_abs_path)
initialise_database(journal_mode)


@contextmanager
def initialised_database_at(db_file_with_abs_path: str) -> Iterator[None]:
def initialised_database_at(db_file_with_abs_path: str | Path) -> Iterator[None]:
"""
Initializes or creates a database and restores the 'db_location' afterwards.

Expand All @@ -277,7 +279,7 @@ def initialised_database_at(db_file_with_abs_path: str) -> Iterator[None]:


def conn_from_dbpath_or_conn(
conn: ConnectionPlus | None, path_to_db: str | None
conn: ConnectionPlus | None, path_to_db: str | Path | None
) -> ConnectionPlus:
"""
A small helper function to abstract the logic needed for functions
Expand Down
36 changes: 36 additions & 0 deletions qcodes/tests/dataset/measurement/test_load_legacy_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,39 @@ def test_load_legacy_files_1d():
"loop",
"station",
]


@pytest.mark.usefixtures("experiment")
def test_load_legacy_files_1d_pathlib_path():
full_location = (
Path(__file__).parent.parent
/ "fixtures"
/ "data_2018_01_17"
/ "data_001_testsweep_15_42_57"
)
run_ids = import_dat_file(full_location)
run_id = run_ids[0]
data = load_by_id(run_id)
assert isinstance(data, DataSet)
assert data.parameters == "dac_ch1_set,dmm_voltage"
assert data.number_of_results == 201
expected_names = ["dac_ch1_set", "dmm_voltage"]
expected_labels = ["Gate ch1", "Gate voltage"]
expected_units = ["V", "V"]
expected_depends_on = ["", "dac_ch1_set"]
for i, parameter in enumerate(data.get_parameters()):
assert parameter.name == expected_names[i]
assert parameter.label == expected_labels[i]
assert parameter.unit == expected_units[i]
assert parameter.depends_on == expected_depends_on[i]
assert parameter.type == "numeric"
snapshot = json.loads(data.get_metadata("snapshot"))
assert sorted(list(snapshot.keys())) == [
"__class__",
"arrays",
"formatter",
"io",
"location",
"loop",
"station",
]
10 changes: 10 additions & 0 deletions qcodes/tests/dataset/test_database_creation_and_upgrading.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,16 @@ def test_initialise_database_at_for_nonexisting_db(tmp_path):
assert qc.config["core"]["db_location"] == db_location


def test_initialise_database_at_for_nonexisting_db_pathlib_path(tmp_path):
db_location = tmp_path / "temp.db"
assert not db_location.exists()

initialise_or_create_database_at(db_location)

assert db_location.exists()
assert qc.config["core"]["db_location"] == str(db_location)


def test_initialise_database_at_for_existing_db(tmp_path):
# Define DB location
db_location = str(tmp_path / 'temp.db')
Expand Down
26 changes: 26 additions & 0 deletions qcodes/tests/dataset/test_database_extract_runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,32 @@ def test_real_dataset_1d(two_empty_temp_db_connections, inst):
assert_array_equal(source_data, target_data)


def test_real_dataset_1d_pathlib_path(two_empty_temp_db_connections, inst):
source_conn, target_conn = two_empty_temp_db_connections

source_path = Path(path_to_dbfile(source_conn))
target_path = Path(path_to_dbfile(target_conn))

source_exp = load_or_create_experiment(experiment_name="myexp", conn=source_conn)

source_dataset, _, _ = do1d(inst.back, 0, 1, 10, 0, inst.plunger, exp=source_exp)

extract_runs_into_db(source_path, target_path, source_dataset.run_id)

target_dataset = load_by_guid(source_dataset.guid, conn=target_conn)

assert source_dataset.the_same_dataset_as(target_dataset)
# explicit regression test for https://github.com/QCoDeS/Qcodes/issues/3953
assert source_dataset.description.shapes == {"extract_run_inst_plunger": (10,)}
assert source_dataset.description.shapes == target_dataset.description.shapes

source_data = source_dataset.get_parameter_data()["extract_run_inst_plunger"]
target_data = target_dataset.get_parameter_data()["extract_run_inst_plunger"]

for source_data, target_data in zip(source_data.values(), target_data.values()):
assert_array_equal(source_data, target_data)


def test_real_dataset_2d(two_empty_temp_db_connections, inst):
source_conn, target_conn = two_empty_temp_db_connections

Expand Down