Skip to content

Commit

Permalink
Merge pull request #4818 from jenshnielsen/ds_path_support
Browse files Browse the repository at this point in the history
Consistently allow pathlib.Path in all relevant public dataset methods
  • Loading branch information
jenshnielsen authored Nov 24, 2022
2 parents 8bf3019 + b521c21 commit 1670bcb
Show file tree
Hide file tree
Showing 7 changed files with 92 additions and 11 deletions.
2 changes: 2 additions & 0 deletions docs/changes/newsfragments/4818.improved
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
All public methods avilable in ``qcodes.dataset`` that takes a
path to a db or data file now accepts a ``pathlib.Path`` object in addition to a ``str``.
12 changes: 8 additions & 4 deletions qcodes/dataset/database_extract_runs.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import os
from pathlib import Path
from warnings import warn

import numpy as np
Expand All @@ -22,10 +23,13 @@
)


def extract_runs_into_db(source_db_path: str,
target_db_path: str, *run_ids: int,
upgrade_source_db: bool = False,
upgrade_target_db: bool = False) -> None:
def extract_runs_into_db(
source_db_path: str | Path,
target_db_path: str | Path,
*run_ids: int,
upgrade_source_db: bool = False,
upgrade_target_db: bool = False,
) -> None:
"""
Extract a selection of runs into another DB file. All runs must come from
the same experiment. They will be added to an experiment with the same name
Expand Down
5 changes: 3 additions & 2 deletions qcodes/dataset/legacy_import.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import json
from pathlib import Path

import numpy as np

Expand Down Expand Up @@ -77,7 +78,7 @@ def store_array_to_database_alt(meas: Measurement, array: DataArray) -> int:
return datasaver.run_id


def import_dat_file(location: str, exp: Experiment | None = None) -> list[int]:
def import_dat_file(location: str | Path, exp: Experiment | None = None) -> list[int]:
"""
This imports a QCoDeS legacy :class:`qcodes.data.data_set.DataSet`
into the database.
Expand All @@ -90,7 +91,7 @@ def import_dat_file(location: str, exp: Experiment | None = None) -> list[int]:
"""


loaded_data = load_data(location)
loaded_data = load_data(str(location))
meas = setup_measurement(loaded_data,
exp=exp)
run_ids = []
Expand Down
12 changes: 7 additions & 5 deletions qcodes/dataset/sqlite/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,9 @@ def connect(name: str | Path, debug: bool = False, version: int = -1) -> Connect
return conn


def get_db_version_and_newest_available_version(path_to_db: str) -> tuple[int, int]:
def get_db_version_and_newest_available_version(
path_to_db: str | Path,
) -> tuple[int, int]:
"""
Connect to a DB without performing any upgrades and get the version of
that database file along with the newest available version (the one that
Expand Down Expand Up @@ -240,7 +242,7 @@ def set_journal_mode(conn: ConnectionPlus, journal_mode: JournalMode) -> None:


def initialise_or_create_database_at(
db_file_with_abs_path: str, journal_mode: JournalMode | None = "WAL"
db_file_with_abs_path: str | Path, journal_mode: JournalMode | None = "WAL"
) -> None:
"""
This function sets up QCoDeS to refer to the given database file. If the
Expand All @@ -254,12 +256,12 @@ def initialise_or_create_database_at(
Options are DELETE, TRUNCATE, PERSIST, MEMORY, WAL and OFF. If set to None
no changes are made.
"""
qcodes.config.core.db_location = db_file_with_abs_path
qcodes.config.core.db_location = str(db_file_with_abs_path)
initialise_database(journal_mode)


@contextmanager
def initialised_database_at(db_file_with_abs_path: str) -> Iterator[None]:
def initialised_database_at(db_file_with_abs_path: str | Path) -> Iterator[None]:
"""
Initializes or creates a database and restores the 'db_location' afterwards.
Expand All @@ -277,7 +279,7 @@ def initialised_database_at(db_file_with_abs_path: str) -> Iterator[None]:


def conn_from_dbpath_or_conn(
conn: ConnectionPlus | None, path_to_db: str | None
conn: ConnectionPlus | None, path_to_db: str | Path | None
) -> ConnectionPlus:
"""
A small helper function to abstract the logic needed for functions
Expand Down
36 changes: 36 additions & 0 deletions qcodes/tests/dataset/measurement/test_load_legacy_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,39 @@ def test_load_legacy_files_1d():
"loop",
"station",
]


@pytest.mark.usefixtures("experiment")
def test_load_legacy_files_1d_pathlib_path():
full_location = (
Path(__file__).parent.parent
/ "fixtures"
/ "data_2018_01_17"
/ "data_001_testsweep_15_42_57"
)
run_ids = import_dat_file(full_location)
run_id = run_ids[0]
data = load_by_id(run_id)
assert isinstance(data, DataSet)
assert data.parameters == "dac_ch1_set,dmm_voltage"
assert data.number_of_results == 201
expected_names = ["dac_ch1_set", "dmm_voltage"]
expected_labels = ["Gate ch1", "Gate voltage"]
expected_units = ["V", "V"]
expected_depends_on = ["", "dac_ch1_set"]
for i, parameter in enumerate(data.get_parameters()):
assert parameter.name == expected_names[i]
assert parameter.label == expected_labels[i]
assert parameter.unit == expected_units[i]
assert parameter.depends_on == expected_depends_on[i]
assert parameter.type == "numeric"
snapshot = json.loads(data.get_metadata("snapshot"))
assert sorted(list(snapshot.keys())) == [
"__class__",
"arrays",
"formatter",
"io",
"location",
"loop",
"station",
]
10 changes: 10 additions & 0 deletions qcodes/tests/dataset/test_database_creation_and_upgrading.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,16 @@ def test_initialise_database_at_for_nonexisting_db(tmp_path):
assert qc.config["core"]["db_location"] == db_location


def test_initialise_database_at_for_nonexisting_db_pathlib_path(tmp_path):
db_location = tmp_path / "temp.db"
assert not db_location.exists()

initialise_or_create_database_at(db_location)

assert db_location.exists()
assert qc.config["core"]["db_location"] == str(db_location)


def test_initialise_database_at_for_existing_db(tmp_path):
# Define DB location
db_location = str(tmp_path / 'temp.db')
Expand Down
26 changes: 26 additions & 0 deletions qcodes/tests/dataset/test_database_extract_runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,32 @@ def test_real_dataset_1d(two_empty_temp_db_connections, inst):
assert_array_equal(source_data, target_data)


def test_real_dataset_1d_pathlib_path(two_empty_temp_db_connections, inst):
source_conn, target_conn = two_empty_temp_db_connections

source_path = Path(path_to_dbfile(source_conn))
target_path = Path(path_to_dbfile(target_conn))

source_exp = load_or_create_experiment(experiment_name="myexp", conn=source_conn)

source_dataset, _, _ = do1d(inst.back, 0, 1, 10, 0, inst.plunger, exp=source_exp)

extract_runs_into_db(source_path, target_path, source_dataset.run_id)

target_dataset = load_by_guid(source_dataset.guid, conn=target_conn)

assert source_dataset.the_same_dataset_as(target_dataset)
# explicit regression test for https://github.com/QCoDeS/Qcodes/issues/3953
assert source_dataset.description.shapes == {"extract_run_inst_plunger": (10,)}
assert source_dataset.description.shapes == target_dataset.description.shapes

source_data = source_dataset.get_parameter_data()["extract_run_inst_plunger"]
target_data = target_dataset.get_parameter_data()["extract_run_inst_plunger"]

for source_data, target_data in zip(source_data.values(), target_data.values()):
assert_array_equal(source_data, target_data)


def test_real_dataset_2d(two_empty_temp_db_connections, inst):
source_conn, target_conn = two_empty_temp_db_connections

Expand Down

0 comments on commit 1670bcb

Please sign in to comment.