Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix torchvision import #1064

Merged
merged 2 commits into from
May 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .github/workflows/nightly.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,6 @@ jobs:
twine upload --repository pypi dist/* --verbose -u __token__ -p ${{ secrets.PYPI_API_KEY }}
- name: Post publish import test
run: |
pip install mct-nightly
pip install mct-nightly tensorflow torch
version=$(python -c 'import model_compression_toolkit; print(model_compression_toolkit.__version__)')
echo $version
1 change: 1 addition & 0 deletions model_compression_toolkit/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
PYTORCH = 'pytorch'
FOUND_TF = importlib.util.find_spec(TENSORFLOW) is not None
FOUND_TORCH = importlib.util.find_spec("torch") is not None
FOUND_TORCHVISION = importlib.util.find_spec("torchvision") is not None
FOUND_ONNX = importlib.util.find_spec("onnx") is not None
FOUND_ONNXRUNTIME = importlib.util.find_spec("onnxruntime") is not None
FOUND_SONY_CUSTOM_LAYERS = importlib.util.find_spec('sony_custom_layers') is not None
Expand Down
4 changes: 2 additions & 2 deletions model_compression_toolkit/data_generation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@
# limitations under the License.
# ==============================================================================

from model_compression_toolkit.constants import FOUND_TORCH, FOUND_TF
from model_compression_toolkit.constants import FOUND_TORCH, FOUND_TF, FOUND_TORCHVISION
from model_compression_toolkit.data_generation.common.data_generation_config import DataGenerationConfig
from model_compression_toolkit.data_generation.common.enums import ImageGranularity, DataInitType, SchedulerType, BNLayerWeightingType, OutputLossType, BatchNormAlignemntLossType, ImagePipelineType, ImageNormalizationType

if FOUND_TF:
from model_compression_toolkit.data_generation.keras.keras_data_generation import (
keras_data_generation_experimental, get_keras_data_generation_config)

if FOUND_TORCH:
if FOUND_TORCH and FOUND_TORCHVISION:
from model_compression_toolkit.data_generation.pytorch.pytorch_data_generation import (
pytorch_data_generation_experimental, get_pytorch_data_generation_config)
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from tqdm import tqdm

from model_compression_toolkit.constants import FOUND_TORCH
from model_compression_toolkit.constants import FOUND_TORCH, FOUND_TORCHVISION
from model_compression_toolkit.core.pytorch.utils import set_model
from model_compression_toolkit.data_generation.common.constants import DEFAULT_N_ITER, DEFAULT_DATA_GEN_BS
from model_compression_toolkit.data_generation.common.data_generation import get_data_generation_classes
Expand All @@ -44,7 +44,7 @@
from model_compression_toolkit.data_generation.pytorch.optimization_utils import PytorchImagesOptimizationHandler
from model_compression_toolkit.logger import Logger

if FOUND_TORCH:
if FOUND_TORCH and FOUND_TORCHVISION:
# Importing necessary libraries
import torch
from torch import Tensor
Expand Down Expand Up @@ -354,10 +354,9 @@ def data_generation(
# If torch is not installed,
# we raise an exception when trying to use these functions.
def get_pytorch_data_generation_config(*args, **kwargs):
Logger.critical('PyTorch must be installed to use get_pytorch_data_generation_config. '
"The 'torch' package is missing.") # pragma: no cover

msg = f"PyTorch and torchvision must be installed to use get_pytorch_data_generation_config. " + ("" if FOUND_TORCH else "The 'torch' package is missing. ") + ("" if FOUND_TORCHVISION else "The 'torchvision' package is missing. ") # pragma: no cover
Logger.critical(msg) # pragma: no cover

def pytorch_data_generation_experimental(*args, **kwargs):
Logger.critical("PyTorch must be installed to use 'pytorch_data_generation_experimental'. "
"The 'torch' package is missing.") # pragma: no cover
msg = f"PyTorch and torchvision must be installed to use pytorch_data_generation_experimental. " + ("" if FOUND_TORCH else "The 'torch' package is missing. ") + ("" if FOUND_TORCHVISION else "The 'torchvision' package is missing. ") # pragma: no cover
Logger.critical(msg) # pragma: no cover
Loading