Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expand CI tests using matrix; make dependencies less restrictive; fix ONNX tests #233

Merged
merged 9 commits into from
Dec 14, 2022
31 changes: 24 additions & 7 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,34 @@ jobs:

test_sampling:
name: Run unit tests
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ['3.7', '3.8', '3.9', '3.10']
os: [ubuntu-latest, windows-latest]
requirements: ['.[tests]', '.[compat_tests]']
fail-fast: false
runs-on: ${{ matrix.os }}
steps:
- name: Checkout code
uses: actions/checkout@v2
uses: actions/checkout@v3

- name: Setup Python environment
uses: actions/setup-python@v2
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}

- name: Try to load cached dependencies
uses: actions/cache@v3
id: restore-cache
with:
python-version: 3.7
- name: Install dependencies
path: ${{ env.pythonLocation }}
key: python-dependencies-${{ matrix.os }}-${{ matrix.python-version }}-${{ matrix.requirements }}-${{ hashFiles('setup.py') }}-${{ env.pythonLocation }}

- name: Install dependencies on cache miss
run: |
python -m pip install --upgrade pip
python -m pip install ".[tests]"
python -m pip install --no-cache-dir --upgrade pip
python -m pip install --no-cache-dir ${{ matrix.requirements }}
if: steps.restore-cache.outputs.cache-hit != 'true'

- name: Run unit tests
run: pytest -sv tests/
12 changes: 10 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,23 @@

INTEGRATIONS_REQUIRE = ["optuna"]

REQUIRED_PKGS = ["datasets==2.3.2", "sentence-transformers==2.2.2", "evaluate==0.3.0"]
REQUIRED_PKGS = ["datasets>=2.3.0", "sentence-transformers>=2.2.1", "evaluate>=0.3.0"]

QUALITY_REQUIRE = ["black", "flake8", "isort", "tabulate"]

ONNX_REQUIRE = ["onnxruntime", "onnx", "skl2onnx"]

TESTS_REQUIRE = ["pytest", "pytest-cov"] + ONNX_REQUIRE

EXTRAS_REQUIRE = {"optuna": INTEGRATIONS_REQUIRE, "quality": QUALITY_REQUIRE, "tests": TESTS_REQUIRE, "onnx": ONNX_REQUIRE}
COMPAT_TESTS_REQUIRE = [requirement.replace(">=", "==") for requirement in REQUIRED_PKGS] + TESTS_REQUIRE

EXTRAS_REQUIRE = {
"optuna": INTEGRATIONS_REQUIRE,
"quality": QUALITY_REQUIRE,
"tests": TESTS_REQUIRE,
"onnx": ONNX_REQUIRE,
"compat_tests": COMPAT_TESTS_REQUIRE,
}


def combine_requirements(base_keys):
Expand Down
6 changes: 5 additions & 1 deletion src/setfit/exporters/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,11 +87,15 @@ def export_onnx_setfit_model(setfit_model: OnnxSetFitModel, inputs, output_path,
for output_name in output_names:
dynamic_axes_output[output_name] = {0: "batch_size"}

# Move inputs to the right device
target = setfit_model.model_body.device
args = tuple(value.to(target) for value in inputs.values())

setfit_model.eval()
with torch.no_grad():
torch.onnx.export(
setfit_model,
args=tuple(inputs.values()),
args=args,
f=output_path,
opset_version=opset,
input_names=["input_ids", "attention_mask", "token_type_ids"],
Expand Down
2 changes: 2 additions & 0 deletions tests/exporters/test_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ def test_export_onnx_sklearn_head():
return_token_type_ids=True,
return_tensors="np",
)
# Map inputs to int64 from int32
inputs = {key: value.astype("int64") for key, value in inputs.items()}

session = onnxruntime.InferenceSession(output_path)

Expand Down