From 9d7324155ca47ae939a795df966f0fd8ed02c777 Mon Sep 17 00:00:00 2001 From: Byeongman Lee Date: Mon, 10 Feb 2025 23:56:47 +0900 Subject: [PATCH] Add create convert task api (#444) --- app/api/api.py | 5 +- app/api/v1/endpoints/conversion_task.py | 63 ++++++ .../{train_task.py => training_task.py} | 4 +- app/api/v1/schemas/project.py | 1 + .../v1/schemas/task/conversion/__init__.py | 0 .../task/conversion/conversion_task.py | 73 +++++++ app/api/v1/schemas/task/conversion/device.py | 81 +++++++ .../v1/schemas/task/train/hyperparameter.py | 9 +- app/api/v1/schemas/task/train/train_task.py | 16 +- app/services/conversion_task.py | 205 ++++++++++++++++++ app/services/model.py | 43 +++- app/services/project.py | 2 +- .../{train_task.py => training_task.py} | 78 +++---- app/worker/__init__.py | 0 app/worker/celery_app.py | 80 +++++++ .../clients/compressor/v2/schemas/common.py | 4 +- .../clients/launcher/v2/schemas/common.py | 6 +- netspresso/constant/project.py | 2 +- netspresso/converter/v2/converter.py | 141 +++++++++--- netspresso/enums/__init__.py | 5 +- netspresso/enums/conversion.py | 53 +++++ netspresso/enums/device.py | 109 +++++++++- netspresso/enums/model.py | 15 ++ netspresso/enums/project.py | 1 + netspresso/netspresso.py | 26 +++ netspresso/trainer/trainer.py | 55 +++-- netspresso/utils/db/models/__init__.py | 10 +- netspresso/utils/db/models/conversion.py | 44 ++++ netspresso/utils/db/models/model.py | 22 +- netspresso/utils/db/models/project.py | 10 +- .../utils/db/models/{train.py => training.py} | 34 +-- .../utils/db/repositories/conversion.py | 55 +++++ netspresso/utils/db/repositories/model.py | 14 +- .../db/repositories/{task.py => training.py} | 15 +- 34 files changed, 1111 insertions(+), 170 deletions(-) create mode 100644 app/api/v1/endpoints/conversion_task.py rename app/api/v1/endpoints/{train_task.py => training_task.py} (96%) create mode 100644 app/api/v1/schemas/task/conversion/__init__.py create mode 100644 app/api/v1/schemas/task/conversion/conversion_task.py create mode 100644 app/api/v1/schemas/task/conversion/device.py create mode 100644 app/services/conversion_task.py rename app/services/{train_task.py => training_task.py} (79%) create mode 100644 app/worker/__init__.py create mode 100644 app/worker/celery_app.py create mode 100644 netspresso/enums/conversion.py create mode 100644 netspresso/utils/db/models/conversion.py rename netspresso/utils/db/models/{train.py => training.py} (79%) create mode 100644 netspresso/utils/db/repositories/conversion.py rename netspresso/utils/db/repositories/{task.py => training.py} (53%) diff --git a/app/api/api.py b/app/api/api.py index 3553a3dd..6fafe00e 100644 --- a/app/api/api.py +++ b/app/api/api.py @@ -1,10 +1,11 @@ from fastapi import APIRouter -from app.api.v1.endpoints import model, project, system, train_task, user +from app.api.v1.endpoints import conversion_task, model, project, system, training_task, user api_router = APIRouter() api_router.include_router(user.router, prefix="/users", tags=["user"]) api_router.include_router(project.router, prefix="/projects", tags=["project"]) api_router.include_router(model.router, prefix="/models", tags=["model"]) -api_router.include_router(train_task.router, prefix="/tasks", tags=["task"]) +api_router.include_router(training_task.router, prefix="/tasks", tags=["training"]) +api_router.include_router(conversion_task.router, prefix="/tasks", tags=["conversion"]) api_router.include_router(system.router, prefix="/system", tags=["system"]) diff --git a/app/api/v1/endpoints/conversion_task.py b/app/api/v1/endpoints/conversion_task.py new file mode 100644 index 00000000..6bde78c1 --- /dev/null +++ b/app/api/v1/endpoints/conversion_task.py @@ -0,0 +1,63 @@ +from fastapi import APIRouter, Depends, Query +from sqlalchemy.orm import Session + +from app.api.deps import api_key_header +from app.api.v1.schemas.task.conversion.conversion_task import ( + ConversionCreate, + ConversionCreateResponse, + ConversionResponse, + SupportedDevicesResponse, +) +from app.services.conversion_task import conversion_task_service +from netspresso.enums.conversion import SourceFramework +from netspresso.utils.db.session import get_db + +router = APIRouter() + + +@router.get( + "/conversions/configuration/devices", + response_model=SupportedDevicesResponse, + description="Get supported devices and frameworks for model conversion based on the source framework.", +) +def get_supported_conversion_devices( + framework: SourceFramework = Query(..., description="Source framework of the model to be converted."), + db: Session = Depends(get_db), + api_key: str = Depends(api_key_header), +) -> SupportedDevicesResponse: + supported_devices = conversion_task_service.get_supported_devices(db=db, framework=framework, api_key=api_key) + + return SupportedDevicesResponse(data=supported_devices) + + +@router.post("/conversions", response_model=ConversionCreateResponse, status_code=201) +def create_conversions_task( + request_body: ConversionCreate, + db: Session = Depends(get_db), + api_key: str = Depends(api_key_header), +) -> ConversionCreateResponse: + conversion_task = conversion_task_service.create_conversion_task(db=db, conversion_in=request_body, api_key=api_key) + + return ConversionCreateResponse(data=conversion_task) + + +@router.get("/conversions/{task_id}", response_model=ConversionResponse) +def get_conversions_task( + task_id: str, + db: Session = Depends(get_db), + api_key: str = Depends(api_key_header), +) -> ConversionResponse: + conversion_task = conversion_task_service.get_conversion_task(db=db, task_id=task_id, api_key=api_key) + + return ConversionResponse(data=conversion_task) + + +@router.post("/conversions/{task_id}/cancel", response_model=ConversionResponse) +def cancel_conversion_task( + task_id: str, + db: Session = Depends(get_db), + api_key: str = Depends(api_key_header), +) -> ConversionResponse: + conversion_task = conversion_task_service.cancel_conversion_task(db=db, task_id=task_id, api_key=api_key) + + return ConversionResponse(data=conversion_task) diff --git a/app/api/v1/endpoints/train_task.py b/app/api/v1/endpoints/training_task.py similarity index 96% rename from app/api/v1/endpoints/train_task.py rename to app/api/v1/endpoints/training_task.py index c91b81ca..b2f6c470 100644 --- a/app/api/v1/endpoints/train_task.py +++ b/app/api/v1/endpoints/training_task.py @@ -1,5 +1,3 @@ -from typing import List - from fastapi import APIRouter, Depends from sqlalchemy.orm import Session @@ -10,7 +8,7 @@ SupportedSchedulersResponse, ) from app.api.v1.schemas.task.train.train_task import TrainingCreate, TrainingResponse -from app.services.train_task import train_task_service +from app.services.training_task import train_task_service from netspresso.utils.db.session import get_db router = APIRouter() diff --git a/app/api/v1/schemas/project.py b/app/api/v1/schemas/project.py index d34f7364..a438ee8a 100644 --- a/app/api/v1/schemas/project.py +++ b/app/api/v1/schemas/project.py @@ -28,6 +28,7 @@ class ProjectPayload(ProjectCreate): project_id: str = Field(..., description="The unique identifier for the project.") model_ids: List[str] = Field(default_factory=list, description="The list of models associated with the project.") user_id: str = Field(..., description="The unique identifier for the user associated with the project.") + project_abs_path: str = Field(..., description="The absolute path of the project.") created_at: datetime = Field(..., description="The timestamp when the project was created.") updated_at: datetime = Field(..., description="The timestamp when the project was last updated.") diff --git a/app/api/v1/schemas/task/conversion/__init__.py b/app/api/v1/schemas/task/conversion/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/app/api/v1/schemas/task/conversion/conversion_task.py b/app/api/v1/schemas/task/conversion/conversion_task.py new file mode 100644 index 00000000..7d9c0513 --- /dev/null +++ b/app/api/v1/schemas/task/conversion/conversion_task.py @@ -0,0 +1,73 @@ +from datetime import datetime +from typing import Dict, List, Optional + +from pydantic import BaseModel, ConfigDict, Field, model_validator + +from app.api.v1.schemas.base import ResponseItem +from app.api.v1.schemas.task.conversion.device import ( + PrecisionPayload, + SoftwareVersionPayload, + SupportedDevicePayload, + TargetDevicePayload, +) +from netspresso.enums.conversion import TARGET_FRAMEWORK_DISPLAY_MAP, Precision, TargetFramework, TargetFrameworkDisplay +from netspresso.enums.device import DeviceName, SoftwareVersion +from netspresso.enums.model import DataType + + +class TargetFrameworkPayload(BaseModel): + name: TargetFramework = Field(description="Framework name") + display_name: Optional[TargetFrameworkDisplay] = Field(default=None, description="Framework display name") + + @model_validator(mode="after") + def set_display_name(self) -> str: + self.display_name = TARGET_FRAMEWORK_DISPLAY_MAP.get(self.name) + + return self + + +class SupportedDeviceResponse(BaseModel): + framework: TargetFrameworkPayload + devices: List[SupportedDevicePayload] + + +class SupportedDevicesResponse(BaseModel): + data: List[SupportedDeviceResponse] + + +class ConversionCreate(BaseModel): + input_model_id: str = Field(description="Input model ID") + framework: TargetFramework = Field(description="Framework name") + device_name: DeviceName = Field(description="Device name") + software_version: Optional[SoftwareVersion] = Field(default=None, description="Software version") + precision: Precision = Field(description="Precision") + calibration_dataset_path: Optional[str] = Field(default=None, description="Path to the calibration dataset") + + +class ConversionPayload(BaseModel): + model_config = ConfigDict(from_attributes=True) + + task_id: str + model_id: Optional[str] = None + framework: TargetFrameworkPayload + device: TargetDevicePayload + software_version: Optional[SoftwareVersionPayload] = None + precision: PrecisionPayload + status: str + is_deleted: bool + error_detail: Optional[Dict] = None + input_model_id: str + created_at: Optional[datetime] = None + updated_at: Optional[datetime] = None + + +class ConversionCreatePayload(BaseModel): + task_id: str + + +class ConversionCreateResponse(ResponseItem): + data: ConversionCreatePayload + + +class ConversionResponse(ResponseItem): + data: ConversionPayload diff --git a/app/api/v1/schemas/task/conversion/device.py b/app/api/v1/schemas/task/conversion/device.py new file mode 100644 index 00000000..e3ab23ff --- /dev/null +++ b/app/api/v1/schemas/task/conversion/device.py @@ -0,0 +1,81 @@ +from typing import List, Optional + +from pydantic import BaseModel, Field, model_validator + +from netspresso.enums.conversion import PRECISION_DISPLAY_MAP, Precision, PrecisionDisplay +from netspresso.enums.device import ( + DEVICE_BRAND_MAP, + DEVICE_DISPLAY_MAP, + HARDWARE_TYPE_DISPLAY_MAP, + SOFTWARE_VERSION_DISPLAY_MAP, + DeviceBrand, + DeviceDisplay, + DeviceName, + HardwareType, + HardwareTypeDisplay, + SoftwareVersion, + SoftwareVersionDisplay, +) + + +class SoftwareVersionPayload(BaseModel): + name: SoftwareVersion + display_name: Optional[SoftwareVersionDisplay] = Field(default=None, description="Software version display name") + + @model_validator(mode="after") + def set_display_name(self) -> str: + self.display_name = SOFTWARE_VERSION_DISPLAY_MAP.get(self.name) + + return self + + +class PrecisionPayload(BaseModel): + name: Precision + display_name: Optional[PrecisionDisplay] = Field(default=None, description="Precision display name") + + @model_validator(mode="after") + def set_display_name(self) -> str: + self.display_name = PRECISION_DISPLAY_MAP.get(self.name) + + return self + + +class HardwareTypePayload(BaseModel): + name: HardwareType + display_name: Optional[HardwareTypeDisplay] = Field(default=None, description="Hardware type display name") + + @model_validator(mode="after") + def set_display_name(self) -> str: + self.display_name = HARDWARE_TYPE_DISPLAY_MAP.get(self.name) + + return self + + +class SupportedDevicePayload(BaseModel): + name: DeviceName + display_name: Optional[DeviceDisplay] = Field(default=None, description="Device display name") + brand_name: Optional[DeviceBrand] = Field(default=None, description="Device brand name") + software_versions: List[SoftwareVersionPayload] + precisions: List[PrecisionPayload] + hardware_types: List[HardwareTypePayload] + + @model_validator(mode="after") + def set_display_name(self) -> str: + self.display_name = DEVICE_DISPLAY_MAP.get(self.name) + self.brand_name = DEVICE_BRAND_MAP.get(self.name) + + return self + + + +class TargetDevicePayload(BaseModel): + name: DeviceName + display_name: Optional[DeviceDisplay] = Field(default=None, description="Device display name") + brand_name: Optional[DeviceBrand] = Field(default=None, description="Device brand name") + + @model_validator(mode="after") + def set_display_name(self) -> str: + self.display_name = DEVICE_DISPLAY_MAP.get(self.name) + self.brand_name = DEVICE_BRAND_MAP.get(self.name) + + return self diff --git a/app/api/v1/schemas/task/train/hyperparameter.py b/app/api/v1/schemas/task/train/hyperparameter.py index 937e00fa..b0ead13f 100644 --- a/app/api/v1/schemas/task/train/hyperparameter.py +++ b/app/api/v1/schemas/task/train/hyperparameter.py @@ -2,7 +2,14 @@ from pydantic import BaseModel, ConfigDict, Field, model_validator -from netspresso.enums.train import OPTIMIZER_DISPLAY_MAP, SCHEDULER_DISPLAY_MAP, Optimizer, OptimizerDisplay, Scheduler, SchedulerDisplay +from netspresso.enums.train import ( + OPTIMIZER_DISPLAY_MAP, + SCHEDULER_DISPLAY_MAP, + Optimizer, + OptimizerDisplay, + Scheduler, + SchedulerDisplay, +) class TrainerModel(BaseModel): diff --git a/app/api/v1/schemas/task/train/train_task.py b/app/api/v1/schemas/task/train/train_task.py index fcca81ac..f6e62b0c 100644 --- a/app/api/v1/schemas/task/train/train_task.py +++ b/app/api/v1/schemas/task/train/train_task.py @@ -5,19 +5,17 @@ from app.api.v1.schemas.base import ResponseItem from netspresso.enums.train import ( - Framework, - FrameworkDisplay, - Task, - TaskDisplay, - PretrainedModel, - PretrainedModelDisplay, - PretrainedModelGroup, - TASK_DISPLAY_MAP, FRAMEWORK_DISPLAY_MAP, MODEL_DISPLAY_MAP, MODEL_GROUP_MAP, TASK_DISPLAY_MAP, - FRAMEWORK_DISPLAY_MAP, + Framework, + FrameworkDisplay, + PretrainedModel, + PretrainedModelDisplay, + PretrainedModelGroup, + Task, + TaskDisplay, ) from .dataset import DatasetCreate, DatasetPayload diff --git a/app/services/conversion_task.py b/app/services/conversion_task.py new file mode 100644 index 00000000..31825d3d --- /dev/null +++ b/app/services/conversion_task.py @@ -0,0 +1,205 @@ +from pathlib import Path +from typing import List + +from sqlalchemy.orm import Session + +from app.api.v1.schemas.task.conversion.conversion_task import ( + ConversionCreate, + ConversionCreatePayload, + ConversionPayload, + SupportedDeviceResponse, + TargetFrameworkPayload, +) +from app.api.v1.schemas.task.conversion.device import ( + HardwareTypePayload, + PrecisionPayload, + SoftwareVersionPayload, + SupportedDevicePayload, + TargetDevicePayload, +) +from app.services.project import project_service +from app.services.user import user_service +from app.worker.celery_app import convert_model_task +from netspresso.clients.launcher.v2.schemas.common import DeviceInfo +from netspresso.enums import Status, TaskStatusForDisplay +from netspresso.enums.conversion import SourceFramework +from netspresso.utils.db.repositories.conversion import conversion_task_repository +from netspresso.utils.db.repositories.model import model_repository + + +class ConversionTaskService: + def get_supported_devices(self, db: Session, framework: SourceFramework, api_key: str) -> List[SupportedDeviceResponse]: + """Get supported devices for conversion tasks. + + Args: + db (Session): Database session + framework (SourceFramework): Framework to get supported devices for + api_key (str): API key for authentication + + Returns: + List[SupportedDeviceResponse]: List of supported devices grouped by framework + """ + netspresso = user_service.build_netspresso_with_api_key(db=db, api_key=api_key) + converter = netspresso.converter_v2() + supported_options = converter.get_supported_options(framework=framework) + + return [ + self._create_supported_device_response(option) + for option in supported_options + ] + + def _create_supported_device_response(self, option) -> SupportedDeviceResponse: + """Create SupportedDeviceResponse from converter option. + + Args: + option: Converter option containing framework and devices information + + Returns: + SupportedDeviceResponse: Response containing framework and supported devices + """ + return SupportedDeviceResponse( + framework=TargetFrameworkPayload(name=option.framework), + devices=[ + self._create_device_payload(device) + for device in option.devices + ] + ) + + def _create_device_payload(self, device: DeviceInfo) -> SupportedDevicePayload: + """Create SupportedDevicePayload from device information. + + Args: + device: Device information containing name, versions, precisions, and hardware types + + Returns: + SupportedDevicePayload: Payload containing device information + """ + return SupportedDevicePayload( + name=device.device_name, + software_versions=[ + SoftwareVersionPayload(name=version.software_version) + for version in device.software_versions + ], + precisions=[ + PrecisionPayload(name=precision) + for precision in device.data_types + ], + hardware_types=[ + HardwareTypePayload(name=hardware_type) + for hardware_type in device.hardware_types + ], + ) + + def create_conversion_task(self, db: Session, conversion_in: ConversionCreate, api_key: str) -> ConversionCreatePayload: + netspresso = user_service.build_netspresso_with_api_key(db=db, api_key=api_key) + + # Get model from trained models repository + model = model_repository.get_by_model_id(db=db, model_id=conversion_in.input_model_id, user_id=netspresso.user_info.user_id) + project = project_service.get_project(db=db, project_id=model.project_id, api_key=api_key) + + # Create output directory path as a 'converted' subfolder of input model path + project_abs_path = Path(project.project_abs_path) + input_model_dir = project_abs_path / model.object_path + + print(f"Input model path: {input_model_dir}") + + # Find .onnx file in the directory + onnx_files = list(input_model_dir.glob('*.onnx')) + if not onnx_files: + raise FileNotFoundError(f"No .onnx file found in directory: {input_model_dir}") + input_model_path = onnx_files[0] # Use the first .onnx file found + + output_dir = input_model_dir / 'converted' + output_dir.mkdir(exist_ok=True) + + task = convert_model_task.delay( + api_key=api_key, + input_model_path=input_model_path.as_posix(), + output_dir=output_dir.as_posix(), + target_framework=conversion_in.framework, + target_device_name=conversion_in.device_name, + target_data_type=conversion_in.precision, + target_software_version=conversion_in.software_version, + input_model_id=conversion_in.input_model_id + ) + task_id = task.get() + return ConversionCreatePayload(task_id=task_id) + + def get_conversion_task(self, db: Session, task_id: str, api_key: str): + conversion_task = conversion_task_repository.get_by_task_id(db, task_id) + + netspresso = user_service.build_netspresso_with_api_key(db=db, api_key=api_key) + converter = netspresso.converter_v2() + + if conversion_task.status == Status.NOT_STARTED or conversion_task.status == Status.IN_PROGRESS: + # Check launcher server status + launcher_status = converter.get_conversion_task(conversion_task.convert_task_uuid) + + if launcher_status.status in [TaskStatusForDisplay.FINISHED]: + conversion_task.status = Status.COMPLETED + elif launcher_status.status in [TaskStatusForDisplay.ERROR, TaskStatusForDisplay.TIMEOUT]: + conversion_task.status = Status.ERROR + conversion_task.error_detail = launcher_status.error_log + elif launcher_status.status in [TaskStatusForDisplay.USER_CANCEL]: + conversion_task.status = Status.STOPPED + + conversion_task = conversion_task_repository.save(db, conversion_task) + + framework = TargetFrameworkPayload(name=conversion_task.framework) + device_name = TargetDevicePayload(name=conversion_task.device_name) + software_version = SoftwareVersionPayload(name=conversion_task.software_version) if conversion_task.software_version else None + precision = PrecisionPayload(name=conversion_task.precision) + + conversion_payload = ConversionPayload( + task_id=conversion_task.task_id, + model_id=conversion_task.model_id, + framework=framework, + device_name=device_name, + software_version=software_version, + precision=precision, + status=conversion_task.status, + is_deleted=conversion_task.is_deleted, + error_detail=conversion_task.error_detail, + input_model_id=conversion_task.input_model_id, + created_at=conversion_task.created_at, + updated_at=conversion_task.updated_at, + ) + + return conversion_payload + + def cancel_conversion_task(self, db: Session, task_id: str, api_key: str): + netspresso = user_service.build_netspresso_with_api_key(db=db, api_key=api_key) + converter = netspresso.converter_v2() + conversion_task = conversion_task_repository.get_by_task_id(db, task_id) + convert_task = converter.cancel_conversion_task(conversion_task.convert_task_uuid) + + if convert_task.status == TaskStatusForDisplay.USER_CANCEL: + conversion_task.status = Status.STOPPED + conversion_task = conversion_task_repository.save(db, conversion_task) + else: + raise ValueError(f"Failed to cancel conversion task: {convert_task.status}") + + framework = TargetFrameworkPayload(name=conversion_task.framework) + device_name = TargetDevicePayload(name=conversion_task.device_name) + software_version = SoftwareVersionPayload(name=conversion_task.software_version) if conversion_task.software_version else None + precision = PrecisionPayload(name=conversion_task.precision) + + conversion_payload = ConversionPayload( + task_id=conversion_task.task_id, + model_id=conversion_task.model_id, + framework=framework, + device_name=device_name, + software_version=software_version, + precision=precision, + status=conversion_task.status, + is_deleted=conversion_task.is_deleted, + error_detail=conversion_task.error_detail, + input_model_id=conversion_task.input_model_id, + created_at=conversion_task.created_at, + updated_at=conversion_task.updated_at, + ) + + return conversion_payload + + +conversion_task_service = ConversionTaskService() diff --git a/app/services/model.py b/app/services/model.py index 8e124fb3..db37e6f7 100644 --- a/app/services/model.py +++ b/app/services/model.py @@ -4,19 +4,38 @@ from app.api.v1.schemas.model import ModelPayload from app.services.user import user_service -from netspresso.utils.db.repositories.model import trained_model_repository +from netspresso.utils.db.repositories.conversion import conversion_task_repository +from netspresso.utils.db.repositories.model import model_repository +from netspresso.utils.db.repositories.training import training_task_repository class ModelService: def get_models(self, db: Session, api_key: str) -> List[ModelPayload]: netspresso = user_service.build_netspresso_with_api_key(db=db, api_key=api_key) - models = trained_model_repository.get_all_by_user_id(db=db, user_id=netspresso.user_info.user_id) + models = model_repository.get_all_by_user_id(db=db, user_id=netspresso.user_info.user_id) new_models = [] for model in models: - task_status = model.train_task.status + if model.type == 'converted_models': + continue + + training_task = training_task_repository.get_by_model_id(db=db, model_id=model.model_id) + task_status = training_task.status + model.train_task_id = training_task.task_id + model = ModelPayload.model_validate(model) + + # Get conversion tasks ordered by created_at desc + conversion_tasks = conversion_task_repository.get_all_by_model_id(db=db, model_id=model.model_id) + + if conversion_tasks: + # Set latest experiment status from the most recent conversion task + model.latest_experiments.convert = conversion_tasks[0].status + # Collect all task IDs + for conversion_task in conversion_tasks: + model.convert_task_ids.append(conversion_task.task_id) + model.status = task_status new_models.append(model) @@ -25,9 +44,23 @@ def get_models(self, db: Session, api_key: str) -> List[ModelPayload]: def get_model(self, db: Session, model_id: str, api_key: str) -> ModelPayload: netspresso = user_service.build_netspresso_with_api_key(db=db, api_key=api_key) - model = trained_model_repository.get_by_model_id(db=db, model_id=model_id, user_id=netspresso.user_info.user_id) - task_status = model.train_task.status + model = model_repository.get_by_model_id(db=db, model_id=model_id, user_id=netspresso.user_info.user_id) + training_task = training_task_repository.get_by_model_id(db=db, model_id=model_id) + task_status = training_task.status + model.train_task_id = training_task.task_id + model = ModelPayload.model_validate(model) + + # Get conversion tasks ordered by created_at desc + conversion_tasks = conversion_task_repository.get_all_by_model_id(db=db, model_id=model.model_id) + + if conversion_tasks: + # Set latest experiment status from the most recent conversion task + model.latest_experiments.convert = conversion_tasks[0].status + # Collect all task IDs + for conversion_task in conversion_tasks: + model.convert_task_ids.append(conversion_task.task_id) + model.status = task_status return model diff --git a/app/services/project.py b/app/services/project.py index e8568eb0..74f4be2a 100644 --- a/app/services/project.py +++ b/app/services/project.py @@ -45,7 +45,7 @@ def count_project_by_user_id(self, *, db: Session, api_key: str) -> int: return project_repository.count_by_user_id(db=db, user_id=netspresso.user_info.user_id) def get_project(self, *, db: Session, project_id: str, api_key: str) -> Project: - netspresso = user_service.build_netspresso_with_api_key(db=db, api_key=api_key) + # netspresso = user_service.build_netspresso_with_api_key(db=db, api_key=api_key) project = project_repository.get_by_project_id(db=db, project_id=project_id) project = ProjectPayload.model_validate(project) diff --git a/app/services/train_task.py b/app/services/training_task.py similarity index 79% rename from app/services/train_task.py rename to app/services/training_task.py index 4e1e86ea..1deb833b 100644 --- a/app/services/train_task.py +++ b/app/services/training_task.py @@ -1,31 +1,28 @@ -from typing import Dict, List from pathlib import Path +from typing import Dict, List from sqlalchemy.orm import Session -from app.api.v1.schemas.task.train.hyperparameter import ( - TrainerModel, - OptimizerPayload, - SchedulerPayload -) +from app.api.v1.schemas.task.train.hyperparameter import OptimizerPayload, SchedulerPayload, TrainerModel from app.api.v1.schemas.task.train.train_task import ( - TrainingCreate, - TrainingPayload, - PretrainedModelPayload, - TaskPayload, - FrameworkPayload + FrameworkPayload, + PretrainedModelPayload, + TaskPayload, + TrainingCreate, + TrainingPayload, ) from app.services.user import user_service +from netspresso.enums.train import MODEL_DISPLAY_MAP, MODEL_GROUP_MAP from netspresso.trainer.augmentations.augmentation import Normalize, Resize, ToTensor from netspresso.trainer.models import get_all_available_models -from netspresso.enums.train import MODEL_DISPLAY_MAP, MODEL_GROUP_MAP from netspresso.trainer.optimizers.optimizer_manager import OptimizerManager from netspresso.trainer.optimizers.optimizers import get_supported_optimizers from netspresso.trainer.schedulers.scheduler_manager import SchedulerManager from netspresso.trainer.schedulers.schedulers import get_supported_schedulers from netspresso.trainer.trainer import Trainer -from netspresso.utils.db.models.train import TrainTask -from netspresso.utils.db.repositories.task import train_task_repository +from netspresso.utils.db.models.training import TrainingTask +from netspresso.utils.db.repositories.model import model_repository +from netspresso.utils.db.repositories.training import training_task_repository class TrainTaskService: @@ -102,7 +99,7 @@ def _setup_trainer(self, trainer, training_in: TrainingCreate) -> Trainer: return trainer - def _convert_to_payload_format(self, training_task: TrainTask) -> TrainingPayload: + def _convert_to_payload_format(self, training_task: TrainingTask) -> TrainingPayload: """Convert training task to payload format.""" # Set task information training_task.task = TaskPayload(name=training_task.task) @@ -121,46 +118,49 @@ def _convert_to_payload_format(self, training_task: TrainTask) -> TrainingPayloa def _generate_unique_model_name(self, db: Session, project_id: str, name: str, api_key: str) -> str: """Generate a unique model name by adding numbering if necessary. - + Args: + db (Session): Database session project_id (str): Project ID to check existing models - base_name (str): Original model name - + name (str): Original model name + api_key (str): API key for authentication + Returns: str: Unique model name with numbering if needed """ - netspresso = user_service.build_netspresso_with_api_key(db=db, api_key=api_key) - project = netspresso.get_project(project_id=project_id) - models_dir = Path(project.project_abs_path) / "trained_models" - - if not models_dir.exists(): - return name - - existing_names = [d.name for d in models_dir.iterdir() if d.is_dir()] - - if name not in existing_names: + # Get existing model names from the database for the same project + models = model_repository.get_all_by_project_id( + db=db, + project_id=project_id, + ) + + # Extract existing names from models and count occurrences of base name + base_name_count = sum( + 1 for model in models + if model.type == 'trained_models' and model.name.startswith(name) + ) + + # If no models with this name exist, return original name + if base_name_count == 0: return name - - counter = 1 - while f"{name} ({counter})" in existing_names: - counter += 1 - - return f"{name} ({counter})" + + # If models exist, return name with count + return f"{name} ({base_name_count})" def create_training_task(self, db: Session, training_in: TrainingCreate, api_key: str) -> TrainingPayload: """Create and execute a new training task.""" netspresso = user_service.build_netspresso_with_api_key(db=db, api_key=api_key) trainer = netspresso.trainer(task=training_in.task) - + trainer = self._setup_trainer(trainer, training_in) - + unique_model_name = self._generate_unique_model_name( db=db, project_id=training_in.project_id, name=training_in.name, api_key=api_key, ) - + training_task = trainer.train( gpus=training_in.environment.gpus, model_name=unique_model_name, @@ -171,8 +171,8 @@ def create_training_task(self, db: Session, training_in: TrainingCreate, api_key def get_training_task(self, db: Session, task_id: str, api_key: str) -> TrainingPayload: """Get training task by task ID.""" - netspresso = user_service.build_netspresso_with_api_key(db=db, api_key=api_key) - training_task = train_task_repository.get_by_task_id(db=db, task_id=task_id) + # netspresso = user_service.build_netspresso_with_api_key(db=db, api_key=api_key) + training_task = training_task_repository.get_by_task_id(db=db, task_id=task_id) return self._convert_to_payload_format(training_task) diff --git a/app/worker/__init__.py b/app/worker/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/app/worker/celery_app.py b/app/worker/celery_app.py new file mode 100644 index 00000000..9f1e021a --- /dev/null +++ b/app/worker/celery_app.py @@ -0,0 +1,80 @@ +from celery import Celery, chain + +from app.services.user import user_service +from netspresso.enums import Status, TaskStatusForDisplay +from netspresso.utils.db.repositories.conversion import conversion_task_repository +from netspresso.utils.db.session import get_db + +REDIS_URL = "localhost:6379" +REDIS_PASSWORD = "" +POLLING_INTERVAL = 10 # seconds + +connection_url = f"redis://:{REDIS_PASSWORD}@{REDIS_URL}" if REDIS_PASSWORD else f"redis://{REDIS_URL}" + +app = Celery('netspresso_converter', + broker=f"{connection_url}/0", + backend=f"{connection_url}/0") + +@app.task +def convert_model_task( + api_key: str, + input_model_path: str, + output_dir: str, + target_framework: str, + target_device_name: str, + target_data_type: str, + target_software_version: str = None, + input_layer = None, + dataset_path: str = None, + input_model_id: str = None, +): + db = next(get_db()) + netspresso = user_service.build_netspresso_with_api_key(db=db, api_key=api_key) + converter = netspresso.converter_v2() + task_id = converter.convert_model( + input_model_path=input_model_path, + output_dir=output_dir, + target_framework=target_framework, + target_device_name=target_device_name, + target_data_type=target_data_type, + target_software_version=target_software_version, + input_layer=input_layer, + dataset_path=dataset_path, + wait_until_done=False, + input_model_id=input_model_id + ) + # 폴링 태스크 체이닝 + chain(poll_conversion_status.s(api_key, task_id).set(countdown=POLLING_INTERVAL))() + return task_id + +@app.task +def poll_conversion_status(api_key: str, task_id: str): + db = next(get_db()) + netspresso = user_service.build_netspresso_with_api_key(db=db, api_key=api_key) + converter = netspresso.converter_v2() + conversion_task = conversion_task_repository.get_by_task_id(db, task_id) + + # launcher server에서 상태 확인 + launcher_status = converter.get_conversion_task(conversion_task.convert_task_uuid) + + status_updated = False + if launcher_status.status == TaskStatusForDisplay.FINISHED: + conversion_task.status = Status.COMPLETED + status_updated = True + elif launcher_status.status in [TaskStatusForDisplay.ERROR, TaskStatusForDisplay.TIMEOUT]: + conversion_task.status = Status.ERROR + conversion_task.error_detail = launcher_status.error_log + status_updated = True + elif launcher_status.status == TaskStatusForDisplay.USER_CANCEL: + conversion_task.status = Status.STOPPED + status_updated = True + + if status_updated: + conversion_task_repository.save(db, conversion_task) + print(f"Conversion task {task_id} status updated to {conversion_task.status}") + else: + # 아직 완료되지 않았으면 다시 폴링 예약 + poll_conversion_status.apply_async( + args=[api_key, task_id], + countdown=POLLING_INTERVAL + ) diff --git a/netspresso/clients/compressor/v2/schemas/common.py b/netspresso/clients/compressor/v2/schemas/common.py index 90e547f2..97306810 100644 --- a/netspresso/clients/compressor/v2/schemas/common.py +++ b/netspresso/clients/compressor/v2/schemas/common.py @@ -2,7 +2,7 @@ from enum import Enum from typing import List, Optional, Union -from netspresso.enums.device import DisplaySoftwareVersion, HardwareType, SoftwareVersion +from netspresso.enums.device import HardwareType, SoftwareVersion, SoftwareVersionDisplay from netspresso.enums.model import DataType, Framework from netspresso.metadata import common from netspresso.metadata.common import AvailableOption @@ -69,7 +69,7 @@ class SoftwareVersionInfo: """ """ software_version: Optional[Union[None, SoftwareVersion]] = None - display_software_version: Optional[Union[None, DisplaySoftwareVersion]] = None + display_software_version: Optional[Union[None, SoftwareVersionDisplay]] = None @dataclass diff --git a/netspresso/clients/launcher/v2/schemas/common.py b/netspresso/clients/launcher/v2/schemas/common.py index 35cd2be9..f06b643a 100644 --- a/netspresso/clients/launcher/v2/schemas/common.py +++ b/netspresso/clients/launcher/v2/schemas/common.py @@ -6,10 +6,10 @@ from netspresso.clients.utils.system import ENV_STR from netspresso.enums import ( DataType, - DisplaySoftwareVersion, Framework, HardwareType, SoftwareVersion, + SoftwareVersionDisplay, ) from netspresso.metadata import common from netspresso.metadata.common import AvailableOption, SoftwareVersions @@ -91,7 +91,7 @@ class SoftwareVersionInfo: """ """ software_version: Optional[Union[None, SoftwareVersion]] = None - display_software_version: Optional[Union[None, DisplaySoftwareVersion]] = None + display_software_version: Optional[Union[None, SoftwareVersionDisplay]] = None def to(self) -> SoftwareVersions: software_version = SoftwareVersions() @@ -109,7 +109,7 @@ class TaskInfo: display_brand_name: str display_device_name: str software_version: Optional[SoftwareVersion] - display_software_version: Optional[DisplaySoftwareVersion] + display_software_version: Optional[SoftwareVersionDisplay] data_type: DataType hardware_type: Optional[HardwareType] diff --git a/netspresso/constant/project.py b/netspresso/constant/project.py index aa617277..9e0f3979 100644 --- a/netspresso/constant/project.py +++ b/netspresso/constant/project.py @@ -1 +1 @@ -SUB_FOLDERS = ["Trained models", "Compressed models", "Pretrained models"] +SUB_FOLDERS = ["trained_models", "compressed_models", "pretrained_models"] diff --git a/netspresso/converter/v2/converter.py b/netspresso/converter/v2/converter.py index 86192046..268ca072 100644 --- a/netspresso/converter/v2/converter.py +++ b/netspresso/converter/v2/converter.py @@ -1,6 +1,6 @@ import time from pathlib import Path -from typing import Optional, Union +from typing import List, Optional, Union from urllib import request from loguru import logger @@ -10,19 +10,25 @@ from netspresso.clients.auth.response_body import UserResponse from netspresso.clients.launcher import launcher_client_v2 from netspresso.clients.launcher.v2.schemas import InputLayer -from netspresso.clients.launcher.v2.schemas.common import DeviceInfo +from netspresso.clients.launcher.v2.schemas.common import DeviceInfo, ModelOption from netspresso.clients.launcher.v2.schemas.task.convert.response_body import ConvertTask from netspresso.enums import ( DataType, DeviceName, - Framework, ServiceTask, SoftwareVersion, Status, TaskStatusForDisplay, ) +from netspresso.enums.conversion import SourceFramework, TargetFramework +from netspresso.enums.project import SubFolder from netspresso.metadata.converter import ConverterMetadata from netspresso.utils import FileHandler +from netspresso.utils.db.models.conversion import ConversionTask +from netspresso.utils.db.models.model import Model +from netspresso.utils.db.repositories.conversion import conversion_task_repository +from netspresso.utils.db.repositories.model import model_repository +from netspresso.utils.db.session import get_db_session from netspresso.utils.metadata import MetadataHandler @@ -33,6 +39,22 @@ def __init__(self, token_handler: TokenHandler, user_info: UserResponse): super().__init__(token_handler) self.user_info = user_info + def get_supported_options(self, framework: SourceFramework) -> List[ModelOption]: + options_response = launcher_client_v2.converter.read_framework_options( + access_token=self.token_handler.tokens.access_token, framework=framework, + ) + + supported_options = options_response.data + + # TODO: Will be removed when we support DLC in the future + supported_options = [ + supported_option + for supported_option in supported_options + if supported_option.framework != "dlc" + ] + + return supported_options + def create_available_options(self, target_framework, target_device, target_software_version): def filter_device(device: DeviceInfo, target_software_version: SoftwareVersion): filtered_versions = [ @@ -52,7 +74,7 @@ def filter_device(device: DeviceInfo, target_software_version: SoftwareVersion): framework=target_framework, ) - if target_framework in [Framework.TENSORRT, Framework.DRPAI]: + if target_framework in [TargetFramework.TENSORRT, TargetFramework.DRPAI]: for available_option in available_options.data: if available_option.framework == target_framework: available_option.devices = [ @@ -115,11 +137,57 @@ def _download_converted_model( logger.error(f"Download converted model failed. Error: {e}") raise e + def get_input_model(self, input_model_id: str, user_id: str) -> Model: + with get_db_session() as db: + input_model = model_repository.get_by_model_id(db=db, model_id=input_model_id, user_id=user_id) + return input_model + + def save_model(self, model_name, project_id, user_id, object_path) -> Model: + model = Model( + name=model_name, + type=SubFolder.CONVERTED_MODELS, + is_retrainable=False, + project_id=project_id, + user_id=user_id, + object_path=object_path, + ) + with get_db_session() as db: + model = model_repository.save(db=db, model=model) + return model + + def _save_conversion_task(self, conversion_task: ConversionTask) -> ConversionTask: + with get_db_session() as db: + conversion_task = conversion_task_repository.save(db=db, model=conversion_task) + + return conversion_task + + def create_conversion_task( + self, + framework: Union[str, TargetFramework], + device_name: Union[str, DeviceName], + software_version: Union[str, SoftwareVersion], + data_type: Union[str, DataType], + input_model_id: Optional[str] = None, + model_id: Optional[str] = None, + ) -> ConversionTask: + with get_db_session() as db: + conversion_task = ConversionTask( + framework=framework, + device_name=device_name, + software_version=software_version, + precision=data_type, + status=Status.NOT_STARTED, + input_model_id=input_model_id, + model_id=model_id, + ) + conversion_task = conversion_task_repository.save(db=db, model=conversion_task) + return conversion_task + def convert_model( self, input_model_path: str, output_dir: str, - target_framework: Union[str, Framework], + target_framework: Union[str, TargetFramework], target_device_name: Union[str, DeviceName], target_data_type: Union[str, DataType] = DataType.FP16, target_software_version: Optional[Union[str, SoftwareVersion]] = None, @@ -127,7 +195,8 @@ def convert_model( dataset_path: Optional[str] = None, wait_until_done: bool = True, sleep_interval: int = 30, - ) -> ConverterMetadata: + input_model_id: Optional[str] = None, + ) -> str: """Convert a model to the specified framework. Args: @@ -152,18 +221,30 @@ def convert_model( FileHandler.check_input_model_path(input_model_path) output_dir = FileHandler.create_unique_folder(folder_path=output_dir) - metadata = self.initialize_metadata( - output_dir=output_dir, - input_model_path=input_model_path, - target_framework=target_framework, - target_device=target_device_name, - target_software_version=target_software_version, + + if input_model_id: + input_model = self.get_input_model(input_model_id, self.user_info.user_id) + input_model.user_id = self.user_info.user_id + + default_model_path = FileHandler.get_default_model_path(folder_path=output_dir) + extension = FileHandler.get_extension(framework=target_framework) + converted_model_path = default_model_path.with_suffix(extension).as_posix() + model = self.save_model( + model_name=f"{input_model.name}_converted", + project_id=input_model.project_id, + user_id=self.user_info.user_id, + object_path=converted_model_path, + ) + conversion_task = self.create_conversion_task( + framework=target_framework, + device_name=target_device_name, + software_version=target_software_version, + data_type=target_data_type, + input_model_id=input_model_id, + model_id=model.model_id, ) try: - if metadata.status in [Status.ERROR, Status.STOPPED]: - return metadata - self.validate_token_and_check_credit(service_task=ServiceTask.MODEL_CONVERT) # Get presigned_model_upload_url @@ -198,9 +279,8 @@ def convert_model( dataset_path=dataset_path, ) - metadata.model_info = validate_model_response.data.to() - metadata.convert_task_info = convert_response.data.to(validate_model_response.data.uploaded_file_name) - MetadataHandler.save_metadata(data=metadata, folder_path=output_dir) + conversion_task.convert_task_uuid = convert_response.data.convert_task_id + conversion_task = self._save_conversion_task(conversion_task) if wait_until_done: while True: @@ -213,12 +293,16 @@ def convert_model( TaskStatusForDisplay.FINISHED, TaskStatusForDisplay.ERROR, TaskStatusForDisplay.TIMEOUT, + TaskStatusForDisplay.USER_CANCEL, ]: break time.sleep(sleep_interval) - if convert_response.data.status == TaskStatusForDisplay.FINISHED: + if convert_response.data.status in [TaskStatusForDisplay.IN_PROGRESS, TaskStatusForDisplay.IN_QUEUE]: + conversion_task.status = Status.IN_PROGRESS + logger.info(f"Conversion task was running. Status: {convert_response.data.status}") + elif convert_response.data.status == TaskStatusForDisplay.FINISHED: default_model_path = FileHandler.get_default_model_path(folder_path=output_dir) extension = FileHandler.get_extension(framework=target_framework) self._download_converted_model( @@ -226,20 +310,23 @@ def convert_model( local_path=str(default_model_path.with_suffix(extension)), ) self.print_remaining_credit(service_task=ServiceTask.MODEL_CONVERT) - metadata.status = Status.COMPLETED - metadata.converted_model_path = default_model_path.with_suffix(extension).as_posix() + conversion_task.status = Status.COMPLETED logger.info("Conversion task was completed successfully.") - else: - metadata = self.handle_error(metadata, ServiceTask.MODEL_CONVERT, convert_response.data.error_log) + elif convert_response.data.status in [TaskStatusForDisplay.ERROR, TaskStatusForDisplay.USER_CANCEL, TaskStatusForDisplay.TIMEOUT]: + conversion_task.status = Status.ERROR + conversion_task.error_detail = convert_response.data.error_log + conversion_task = self._save_conversion_task(conversion_task) + logger.error(f"Conversion task was failed. Error: {convert_response.data.error_log}") except Exception as e: - metadata = self.handle_error(metadata, ServiceTask.MODEL_CONVERT, e.args[0]) + conversion_task.status = Status.ERROR + conversion_task.error_detail = e.args[0] except KeyboardInterrupt: - metadata = self.handle_stop(metadata, ServiceTask.MODEL_CONVERT) + conversion_task.status = Status.STOPPED finally: - MetadataHandler.save_metadata(data=metadata, folder_path=output_dir) + conversion_task = self._save_conversion_task(conversion_task) - return metadata + return conversion_task.task_id def get_conversion_task(self, conversion_task_id: str) -> ConvertTask: """Get the conversion task information with given conversion task uuid. diff --git a/netspresso/enums/__init__.py b/netspresso/enums/__init__.py index ce666a2c..fb06dbd8 100644 --- a/netspresso/enums/__init__.py +++ b/netspresso/enums/__init__.py @@ -1,7 +1,7 @@ from .compression import CompressionMethod, GroupPolicy, LayerNorm, Policy, RecommendationMethod, StepOp from .config import EndPointProperty, EnvironmentType, ServiceModule, ServiceName from .credit import MembershipType, ServiceCredit, ServiceTask -from .device import DeviceName, DisplaySoftwareVersion, HardwareType, SoftwareVersion, TaskStatus +from .device import DeviceName, HardwareType, SoftwareVersion, SoftwareVersionDisplay, TaskStatus from .inference import Runtime from .metadata import Status, TaskType from .model import DataType, Extension, Framework, OriginFrom @@ -29,7 +29,6 @@ "DataType", "DeviceName", "SoftwareVersion", - "DisplaySoftwareVersion", "HardwareType", "TaskStatus", "Module", @@ -37,7 +36,7 @@ "ExperimentAction", "StepOp", "MembershipType", - "DisplaySoftwareVersion", + "SoftwareVersionDisplay", "LauncherTask", "TaskStatusForDisplay", "EnvironmentType", diff --git a/netspresso/enums/conversion.py b/netspresso/enums/conversion.py new file mode 100644 index 00000000..48886469 --- /dev/null +++ b/netspresso/enums/conversion.py @@ -0,0 +1,53 @@ +from enum import Enum + + +class SourceFramework(str, Enum): + ONNX = "onnx" + + +class SourceFrameworkDisplay(str, Enum): + ONNX = "ONNX" + + +SOURCE_FRAMEWORK_DISPLAY_MAP = { + SourceFramework.ONNX: SourceFrameworkDisplay.ONNX, +} + + +class TargetFramework(str, Enum): + TENSORRT = "tensorrt" + TENSORFLOW_LITE = "tensorflow_lite" + OPENVINO = "openvino" + DRPAI = "drpai" + + +class TargetFrameworkDisplay(str, Enum): + TENSORRT = "TensorRT" + TENSORFLOW_LITE = "TensorFlow Lite" + OPENVINO = "OpenVINO" + DRPAI = "DRPAI" + + +TARGET_FRAMEWORK_DISPLAY_MAP = { + TargetFramework.TENSORRT: TargetFrameworkDisplay.TENSORRT, + TargetFramework.TENSORFLOW_LITE: TargetFrameworkDisplay.TENSORFLOW_LITE, + TargetFramework.OPENVINO: TargetFrameworkDisplay.OPENVINO, + TargetFramework.DRPAI: TargetFrameworkDisplay.DRPAI, +} + + +class Precision(str, Enum): + FP16 = "FP16" + INT8 = "INT8" + + +class PrecisionDisplay(str, Enum): + FP16 = "FP16" + INT8 = "INT8" + + +PRECISION_DISPLAY_MAP = { + Precision.FP16: PrecisionDisplay.FP16, + Precision.INT8: PrecisionDisplay.INT8, +} + diff --git a/netspresso/enums/device.py b/netspresso/enums/device.py index b8c87c52..f6d94aba 100644 --- a/netspresso/enums/device.py +++ b/netspresso/enums/device.py @@ -2,6 +2,17 @@ from typing import Literal +class DeviceBrand(str, Enum): + NVIDIA = "NVIDIA" + ARM = "Arm" + RASPBERRY_PI = "RaspberryPi" + SAMSUNG = "Samsung" + ST_MICROELECTRONICS = "STMicroelectronics" + INTEL = "Intel" + RENESAS = "Renesas" + NXP = "NXP" + + class DeviceName(str, Enum): RASPBERRY_PI_5 = "RaspberryPi5" RASPBERRY_PI_4B = "RaspberryPi4B" @@ -98,6 +109,84 @@ def create_literal(cls): ] +class DeviceDisplay(str, Enum): + RASPBERRY_PI_5 = "Raspberry Pi 5 (Arm Cortex-A76)" + RASPBERRY_PI_4B = "Raspberry Pi 4B" + RASPBERRY_PI_3B_PLUS = "Raspberry Pi 3B+" + RASPBERRY_PI_3B = "Raspberry Pi 3B" + RASPBERRY_PI_2B = "Raspberry Pi 2B" + RASPBERRY_PI_ZERO_W = "Raspberry Pi Zero W" + RASPBERRY_PI_ZERO_2W = "Raspberry Pi Zero 2 W" + RENESAS_RZ_V2L = "Renesas RZ/V2L" + RENESAS_RZ_V2M = "Renesas RZ/V2M" + RENESAS_RA8D1 = "Renesas RA8D1 (Arm Cortex-M85)" + + JETSON_NANO = "NVIDIA Jetson Nano" + JETSON_TX2 = "NVIDIA Jetson TX2" + JETSON_XAVIER = "NVIDIA Jetson Xavier" + JETSON_NX = "NVIDIA Jetson Xavier NX" + JETSON_AGX_ORIN = "NVIDIA Jetson AGX Orin" + JETSON_ORIN_NANO = "NVIDIA Jetson Orin Nano" + AWS_T4 = "NVIDIA AWS T4" + INTEL_XEON_W_2233 = "Intel Xeon W-2233" + ALIF_ENSEMBLE_E7_DEVKIT_GEN2 = "Alif Ensemble DevKit-E7 Gen2 (Arm Cortex-M55+Ethos-U55)" + + ARM_ETHOS_U_SERIES = "Arm Virtual Hardware Corstone-300 (Ethos-U55/U65)" + NXP_iMX93 = "NXP i.MX 93(Arm Cortex-A55/M33+Ethos-U65)" + ARDUINO_NICLA_VISION = "Arduino Nicla Vision(Arm Cortex-M7/M4)" + + +DEVICE_DISPLAY_MAP = { + DeviceName.RASPBERRY_PI_5: DeviceDisplay.RASPBERRY_PI_5, + DeviceName.RASPBERRY_PI_4B: DeviceDisplay.RASPBERRY_PI_4B, + DeviceName.RASPBERRY_PI_3B_PLUS: DeviceDisplay.RASPBERRY_PI_3B_PLUS, + DeviceName.RASPBERRY_PI_3B: DeviceDisplay.RASPBERRY_PI_3B, + DeviceName.RASPBERRY_PI_2B: DeviceDisplay.RASPBERRY_PI_2B, + DeviceName.RASPBERRY_PI_ZERO_W: DeviceDisplay.RASPBERRY_PI_ZERO_W, + DeviceName.RASPBERRY_PI_ZERO_2W: DeviceDisplay.RASPBERRY_PI_ZERO_2W, + DeviceName.RENESAS_RZ_V2L: DeviceDisplay.RENESAS_RZ_V2L, + DeviceName.RENESAS_RZ_V2M: DeviceDisplay.RENESAS_RZ_V2M, + DeviceName.RENESAS_RA8D1: DeviceDisplay.RENESAS_RA8D1, + DeviceName.JETSON_NANO: DeviceDisplay.JETSON_NANO, + DeviceName.JETSON_TX2: DeviceDisplay.JETSON_TX2, + DeviceName.JETSON_XAVIER: DeviceDisplay.JETSON_XAVIER, + DeviceName.JETSON_NX: DeviceDisplay.JETSON_NX, + DeviceName.JETSON_AGX_ORIN: DeviceDisplay.JETSON_AGX_ORIN, + DeviceName.JETSON_ORIN_NANO: DeviceDisplay.JETSON_ORIN_NANO, + DeviceName.AWS_T4: DeviceDisplay.AWS_T4, + DeviceName.INTEL_XEON_W_2233: DeviceDisplay.INTEL_XEON_W_2233, + DeviceName.ALIF_ENSEMBLE_E7_DEVKIT_GEN2: DeviceDisplay.ALIF_ENSEMBLE_E7_DEVKIT_GEN2, + DeviceName.ARM_ETHOS_U_SERIES: DeviceDisplay.ARM_ETHOS_U_SERIES, + DeviceName.NXP_iMX93: DeviceDisplay.NXP_iMX93, + DeviceName.ARDUINO_NICLA_VISION: DeviceDisplay.ARDUINO_NICLA_VISION, +} + +DEVICE_BRAND_MAP = { + DeviceName.RASPBERRY_PI_5: DeviceBrand.RASPBERRY_PI, + DeviceName.RASPBERRY_PI_4B: DeviceBrand.RASPBERRY_PI, + DeviceName.RASPBERRY_PI_3B_PLUS: DeviceBrand.RASPBERRY_PI, + DeviceName.RASPBERRY_PI_3B: DeviceBrand.RASPBERRY_PI, + DeviceName.RASPBERRY_PI_2B: DeviceBrand.RASPBERRY_PI, + DeviceName.RASPBERRY_PI_ZERO_W: DeviceBrand.RASPBERRY_PI, + DeviceName.RASPBERRY_PI_ZERO_2W: DeviceBrand.RASPBERRY_PI, + DeviceName.RENESAS_RZ_V2L: DeviceBrand.RENESAS, + DeviceName.RENESAS_RZ_V2M: DeviceBrand.RENESAS, + DeviceName.JETSON_NANO: DeviceBrand.NVIDIA, + DeviceName.JETSON_TX2: DeviceBrand.NVIDIA, + DeviceName.JETSON_XAVIER: DeviceBrand.NVIDIA, + DeviceName.JETSON_NX: DeviceBrand.NVIDIA, + DeviceName.JETSON_AGX_ORIN: DeviceBrand.NVIDIA, + DeviceName.JETSON_ORIN_NANO: DeviceBrand.NVIDIA, + DeviceName.AWS_T4: DeviceBrand.NVIDIA, + DeviceName.INTEL_XEON_W_2233: DeviceBrand.INTEL, + DeviceName.RENESAS_RA8D1: DeviceBrand.ARM, + DeviceName.ARDUINO_NICLA_VISION: DeviceBrand.ARM, + DeviceName.ALIF_ENSEMBLE_E7_DEVKIT_GEN2: DeviceBrand.ARM, + DeviceName.ARM_ETHOS_U_SERIES: DeviceBrand.ARM, + DeviceName.NXP_iMX93: DeviceBrand.NXP, +} + + class SoftwareVersion(str, Enum): JETPACK_4_4_1 = "4.4.1-b50" JETPACK_4_6 = "4.6-b199" @@ -110,7 +199,7 @@ def create_literal(cls): return Literal["4.4.1-b50", "4.6-b199", "5.0.1-b118", "5.0.2-b231", "6.0-b52"] -class DisplaySoftwareVersion(str, Enum): +class SoftwareVersionDisplay(str, Enum): JETPACK_4_4_1 = "Jetpack 4.4.1" JETPACK_4_6 = "Jetpack 4.6" JETPACK_5_0_1 = "Jetpack 5.0.1" @@ -118,6 +207,15 @@ class DisplaySoftwareVersion(str, Enum): JETPACK_6_0 = "Jetpack 6.0" +SOFTWARE_VERSION_DISPLAY_MAP = { + SoftwareVersion.JETPACK_4_4_1: SoftwareVersionDisplay.JETPACK_4_4_1, + SoftwareVersion.JETPACK_4_6: SoftwareVersionDisplay.JETPACK_4_6, + SoftwareVersion.JETPACK_5_0_1: SoftwareVersionDisplay.JETPACK_5_0_1, + SoftwareVersion.JETPACK_5_0_2: SoftwareVersionDisplay.JETPACK_5_0_2, + SoftwareVersion.JETPACK_6_0: SoftwareVersionDisplay.JETPACK_6_0, +} + + class HardwareType(str, Enum): HELIUM = "helium" @@ -126,6 +224,15 @@ def create_literal(cls): return Literal["helium"] +class HardwareTypeDisplay(str, Enum): + HELIUM = "Helium" + + +HARDWARE_TYPE_DISPLAY_MAP = { + HardwareType.HELIUM: HardwareTypeDisplay.HELIUM, +} + + class TaskStatus(str, Enum): IN_QUEUE = "IN_QUEUE" IN_PROGRESS = "IN_PROGRESS" diff --git a/netspresso/enums/model.py b/netspresso/enums/model.py index 2ac039a7..c0400ac5 100644 --- a/netspresso/enums/model.py +++ b/netspresso/enums/model.py @@ -60,8 +60,23 @@ def create_literal(cls): return Literal["FP32", "FP16", "INT8", ""] +class DataTypeDisplay(str, Enum): + FP32 = "FP32" + FP16 = "FP16" + INT8 = "INT8" + NONE = "" + + +DATA_TYPE_DISPLAY_MAP = { + DataType.FP32: DataTypeDisplay.FP32, + DataType.FP16: DataTypeDisplay.FP16, + DataType.INT8: DataTypeDisplay.INT8, + DataType.NONE: DataTypeDisplay.NONE, +} + compressor_framework_literal = Framework.create_compressor_literal() launcher_framework_literal = Framework.create_launcher_literal() extension_literal = Extension.create_literal() originfrom_literal = OriginFrom.create_literal() datatype_literal = DataType.create_literal() + diff --git a/netspresso/enums/project.py b/netspresso/enums/project.py index 33ce3585..272f5df1 100644 --- a/netspresso/enums/project.py +++ b/netspresso/enums/project.py @@ -5,3 +5,4 @@ class SubFolder(str, Enum): TRAINED_MODELS = "trained_models" COMPRESSED_MODELS = "compressed_models" PRETRAINED_MODELS = "pretrained_models" + CONVERTED_MODELS = "converted_models" diff --git a/netspresso/netspresso.py b/netspresso/netspresso.py index 8efa8ee4..28a42b59 100644 --- a/netspresso/netspresso.py +++ b/netspresso/netspresso.py @@ -143,6 +143,32 @@ def get_projects(self) -> List[Project]: finally: db and db.close() + def get_project(self, project_id: str) -> Project: + """ + Retrieve all projects associated with the current user. + + This method fetches project information from the database for + the user identified by `self.user_info.user_id`. + + Returns: + List[Project]: A list of projects associated with the current user. + + Raises: + Exception: If an error occurs while querying the database. + """ + db = None + try: + db = SessionLocal() + project = project_repository.get_by_project_id(db=db, project_id=project_id) + + return project + + except Exception as e: + logger.error(f"Failed to get project list from the database: {e}") + raise + finally: + db and db.close() + def trainer( self, task: Optional[Union[str, Task]] = None, yaml_path: Optional[str] = None ) -> Trainer: diff --git a/netspresso/trainer/trainer.py b/netspresso/trainer/trainer.py index ba8015d4..b955435f 100644 --- a/netspresso/trainer/trainer.py +++ b/netspresso/trainer/trainer.py @@ -8,7 +8,7 @@ from netspresso.base import NetsPressoBase from netspresso.clients.auth import TokenHandler from netspresso.clients.launcher import launcher_client_v2 -from netspresso.enums import Framework, Optimizer, Scheduler, ServiceTask, Status, Task +from netspresso.enums import Framework, ServiceTask, Status, Task from netspresso.enums.project import SubFolder from netspresso.enums.train import StorageLocation from netspresso.exceptions.trainer import ( @@ -38,12 +38,18 @@ from netspresso.trainer.trainer_configs import TrainerConfigs from netspresso.trainer.training import TRAINING_CONFIG_TYPE, EnvironmentConfig, LoggingConfig, ScheduleConfig from netspresso.utils import FileHandler -from netspresso.utils.db.models.model import TrainedModel -from netspresso.utils.db.models.train import Augmentation, Dataset, Environment, Hyperparameter, Performance, TrainTask -from netspresso.utils.db.repositories.model import trained_model_repository -from netspresso.utils.db.repositories.task import train_task_repository +from netspresso.utils.db.models.model import Model +from netspresso.utils.db.models.training import ( + Augmentation, + Dataset, + Environment, + Hyperparameter, + Performance, + TrainingTask, +) +from netspresso.utils.db.repositories.model import model_repository +from netspresso.utils.db.repositories.training import training_task_repository from netspresso.utils.db.session import get_db_session -from netspresso.utils.metadata import MetadataHandler class Trainer(NetsPressoBase): @@ -315,7 +321,11 @@ def set_model_config( self.logging.sample_input_size = [img_size, img_size] if model is None: - raise NotSupportedModelException() + raise NotSupportedModelException( + available_models=self._get_available_models_w_deprecated_names(), + model_name=model_name, + task=self.task, + ) self.model = model( checkpoint=CheckpointConfig( @@ -539,24 +549,24 @@ def create_runtime_config(self, yaml_path): def _save_train_task(self, train_task): with get_db_session() as db: - train_task = train_task_repository.save(db=db, task=train_task) + train_task = training_task_repository.save(db=db, task=train_task) return train_task - def save_trained_model(self, model_name, train_task, project_id, user_id): - trained_model = TrainedModel( + def save_trained_model(self, model_name, project_id, user_id, object_path) -> Model: + model = Model( name=model_name, type=SubFolder.TRAINED_MODELS, is_retrainable=True, project_id=project_id, user_id=user_id, - train_task=train_task, + object_path=object_path, ) with get_db_session() as db: - model = trained_model_repository.save(db=db, model=trained_model) + model = model_repository.save(db=db, model=model) return model - def create_training_task(self): + def create_training_task(self, model_id) -> TrainingTask: with get_db_session() as db: dataset = Dataset( train_path="train", @@ -593,7 +603,7 @@ def create_training_task(self): num_workers=self.environment.num_workers, gpus=self.environment.gpus, ) - task = TrainTask( + task = TrainingTask( pretrained_model=self.model_name, task=self.task, framework=Framework.PYTORCH, @@ -602,12 +612,13 @@ def create_training_task(self): dataset=dataset, hyperparameter=hyperparameter, environment=environment, + model_id=model_id, ) - task = train_task_repository.save(db=db, task=task) + task = training_task_repository.save(db=db, task=task) return task - def create_performance(self, task: TrainTask, training_summary): + def create_performance(self, task: TrainingTask, training_summary): performance = Performance( train_losses=training_summary["train_losses"], valid_losses=training_summary["valid_losses"], @@ -630,7 +641,7 @@ def create_performance(self, task: TrainTask, training_summary): return task - def train(self, gpus: str, model_name: str, project_id: str, output_dir: Optional[str] = "./outputs") -> TrainTask: + def train(self, gpus: str, model_name: str, project_id: str, output_dir: Optional[str] = "./outputs") -> TrainingTask: """Train the model with the specified configuration. Args: @@ -652,9 +663,15 @@ def train(self, gpus: str, model_name: str, project_id: str, output_dir: Optiona destination_folder = Path(project_abs_path) / SubFolder.TRAINED_MODELS.value / model_name destination_folder = FileHandler.create_unique_folder(folder_path=destination_folder) + object_path = Path(SubFolder.TRAINED_MODELS.value) / model_name - train_task = self.create_training_task() - trained_model = self.save_trained_model(model_name=model_name, train_task=train_task, project_id=project.project_id, user_id=project.user_id) + model = self.save_trained_model( + model_name=model_name, + project_id=project.project_id, + user_id=project.user_id, + object_path=object_path, + ) + train_task = self.create_training_task(model_id=model.model_id) try: self.logging.output_dir = output_dir diff --git a/netspresso/utils/db/models/__init__.py b/netspresso/utils/db/models/__init__.py index 1be77b23..f932eb14 100644 --- a/netspresso/utils/db/models/__init__.py +++ b/netspresso/utils/db/models/__init__.py @@ -1,6 +1,7 @@ -from netspresso.utils.db.models.model import TrainedModel +from netspresso.utils.db.models.conversion import ConversionTask +from netspresso.utils.db.models.model import Model from netspresso.utils.db.models.project import Project -from netspresso.utils.db.models.train import TrainTask +from netspresso.utils.db.models.training import TrainingTask from netspresso.utils.db.models.user import User from netspresso.utils.db.session import Base, engine @@ -8,8 +9,9 @@ __all__ = [ - "TrainedModel", + "Model", "Project", - "TrainTask", + "TrainingTask", "User", + "ConversionTask", ] diff --git a/netspresso/utils/db/models/conversion.py b/netspresso/utils/db/models/conversion.py new file mode 100644 index 00000000..bc1ea0d9 --- /dev/null +++ b/netspresso/utils/db/models/conversion.py @@ -0,0 +1,44 @@ +from sqlalchemy import JSON, Boolean, Column, ForeignKey, Integer, String +from sqlalchemy.orm import relationship + +from netspresso.utils.db.generate_uuid import generate_uuid +from netspresso.utils.db.mixins import TimestampMixin +from netspresso.utils.db.session import Base + + +class ConversionTask(Base, TimestampMixin): + __tablename__ = "conversion_task" + + id = Column(Integer, primary_key=True, index=True, unique=True, autoincrement=True, nullable=False) + task_id = Column(String(36), index=True, unique=True, nullable=False, default=lambda: generate_uuid(entity="task")) + + # Conversion settings + framework = Column(String(30), nullable=False) + device_name = Column(String(30), nullable=False) + software_version = Column(String(30), nullable=True) + precision = Column(String(30), nullable=False) + + # Task information + convert_task_uuid = Column(String(36), nullable=True) + status = Column(String(30), nullable=False) + error_detail = Column(JSON, nullable=True) + + is_deleted = Column(Boolean, nullable=False, default=False) + + # Relationship to Model (source model) + input_model_id = Column(String(36), ForeignKey("model.model_id"), nullable=True) + input_model = relationship( + "Model", + uselist=False, + lazy='joined', + foreign_keys=[input_model_id], + ) + + # Relationship to Model (converted model) + model_id = Column(String(36), ForeignKey("model.model_id"), nullable=True) + model = relationship( + "Model", + uselist=False, + lazy='joined', + foreign_keys=[model_id], + ) diff --git a/netspresso/utils/db/models/model.py b/netspresso/utils/db/models/model.py index a1b92f2e..ececea20 100644 --- a/netspresso/utils/db/models/model.py +++ b/netspresso/utils/db/models/model.py @@ -6,31 +6,19 @@ from netspresso.utils.db.session import Base -class TrainedModel(Base, TimestampMixin): - __tablename__ = "trained_model" +class Model(Base, TimestampMixin): + __tablename__ = "model" id = Column(Integer, primary_key=True, index=True, unique=True, autoincrement=True, nullable=False) model_id = Column(String(36), index=True, unique=True, nullable=False, default=lambda: generate_uuid(entity="model")) name = Column(String(100), nullable=False) type = Column(String(30), nullable=False) is_retrainable = Column(Boolean, nullable=False, default=False) - project_id = Column(String(36), ForeignKey("project.project_id", ondelete="CASCADE"), nullable=False) - user_id = Column(String(36), nullable=False) is_deleted = Column(Boolean, nullable=False, default=False) + object_path = Column(String(255), nullable=False) - # Foreign key for 1:1 relationship - train_task_id = Column( - String(36), - ForeignKey("train_task.task_id", ondelete="CASCADE"), - unique=True, - nullable=False, - ) - train_task = relationship( - "TrainTask", - back_populates="model", - cascade="all", - uselist=False, - ) + project_id = Column(String(36), ForeignKey("project.project_id", ondelete="CASCADE"), nullable=False) + user_id = Column(String(36), nullable=False) # Back-reference to Project project = relationship("Project", back_populates="models") diff --git a/netspresso/utils/db/models/project.py b/netspresso/utils/db/models/project.py index 54a83e06..b20277a7 100644 --- a/netspresso/utils/db/models/project.py +++ b/netspresso/utils/db/models/project.py @@ -17,15 +17,15 @@ class Project(Base, TimestampMixin): project_abs_path = Column(String(500), nullable=False) is_deleted = Column(Boolean, nullable=False, default=False) - # Relationship to TrainedModel + # Relationship to Model models = relationship( - "TrainedModel", + "Model", back_populates="project", - cascade="all, delete-orphan", # Cascade options for deletion - lazy="joined", # Eager loading to fetch models with the project + cascade="all", + lazy="joined", ) # Property to get model IDs @hybrid_property def model_ids(self): - return [model.model_id for model in self.models] + return [model.model_id for model in self.models] if self.models else [] diff --git a/netspresso/utils/db/models/train.py b/netspresso/utils/db/models/training.py similarity index 79% rename from netspresso/utils/db/models/train.py rename to netspresso/utils/db/models/training.py index 76f00872..ed3a5902 100644 --- a/netspresso/utils/db/models/train.py +++ b/netspresso/utils/db/models/training.py @@ -1,4 +1,4 @@ -from sqlalchemy import JSON, Boolean, Column, Float, ForeignKey, Integer, String, BigInteger +from sqlalchemy import JSON, Boolean, Column, Float, ForeignKey, Integer, String from sqlalchemy.orm import relationship from netspresso.utils.db.generate_uuid import generate_uuid @@ -18,8 +18,8 @@ class Augmentation(Base): hyperparameter = relationship("Hyperparameter", back_populates="augmentations", lazy='joined') -class TrainTask(Base, TimestampMixin): - __tablename__ = "train_task" +class TrainingTask(Base, TimestampMixin): + __tablename__ = "training_task" id = Column(Integer, primary_key=True, index=True, unique=True, autoincrement=True, nullable=False) task_id = Column(String(36), index=True, unique=True, nullable=False, default=lambda: generate_uuid(entity="task")) @@ -37,10 +37,10 @@ class TrainTask(Base, TimestampMixin): environment = relationship("Environment", back_populates="task", uselist=False, cascade="all, delete-orphan", lazy='joined') performance = relationship("Performance", back_populates="task", uselist=False, cascade="all, delete-orphan", lazy='joined') - # Relationship to TrainedModel + # Relationship to Model + model_id = Column(String(36), ForeignKey("model.model_id"), nullable=True) model = relationship( - "TrainedModel", - back_populates="train_task", + "Model", uselist=False, lazy='joined', ) @@ -57,9 +57,9 @@ class Dataset(Base, TimestampMixin): id_mapping = Column(JSON, nullable=True) palette = Column(JSON, nullable=True) - # Relationship to TrainTask - task_id = Column(String(36), ForeignKey("train_task.task_id", ondelete="CASCADE"), unique=True, nullable=False) - task = relationship("TrainTask", back_populates="dataset") + # Relationship to TrainingTask + task_id = Column(String(36), ForeignKey("training_task.task_id", ondelete="CASCADE"), unique=True, nullable=False) + task = relationship("TrainingTask", back_populates="dataset") class Hyperparameter(Base, TimestampMixin): @@ -73,9 +73,9 @@ class Hyperparameter(Base, TimestampMixin): augmentations = relationship("Augmentation", back_populates="hyperparameter", cascade="all, delete-orphan", lazy='joined') - # Relationship to TrainTask - task_id = Column(String(36), ForeignKey("train_task.task_id", ondelete="CASCADE"), unique=True, nullable=False) - task = relationship("TrainTask", back_populates="hyperparameter") + # Relationship to TrainingTask + task_id = Column(String(36), ForeignKey("training_task.task_id", ondelete="CASCADE"), unique=True, nullable=False) + task = relationship("TrainingTask", back_populates="hyperparameter") class Environment(Base, TimestampMixin): @@ -86,9 +86,9 @@ class Environment(Base, TimestampMixin): num_workers = Column(Integer, nullable=False) gpus = Column(String(30), nullable=False) # GPUs (예: "1, 0") - # Relationship to TrainTask - task_id = Column(String(36), ForeignKey("train_task.task_id", ondelete="CASCADE"), unique=True, nullable=False) - task = relationship("TrainTask", back_populates="environment") + # Relationship to TrainingTask + task_id = Column(String(36), ForeignKey("training_task.task_id", ondelete="CASCADE"), unique=True, nullable=False) + task = relationship("TrainingTask", back_populates="environment") class Performance(Base, TimestampMixin): @@ -110,5 +110,5 @@ class Performance(Base, TimestampMixin): status = Column(String(36), nullable=True) # Relationship to TrainTask - task_id = Column(String(36), ForeignKey("train_task.task_id", ondelete="CASCADE"), unique=True, nullable=False) - task = relationship("TrainTask", back_populates="performance") + task_id = Column(String(36), ForeignKey("training_task.task_id", ondelete="CASCADE"), unique=True, nullable=False) + task = relationship("TrainingTask", back_populates="performance") diff --git a/netspresso/utils/db/repositories/conversion.py b/netspresso/utils/db/repositories/conversion.py new file mode 100644 index 00000000..c4cf6df2 --- /dev/null +++ b/netspresso/utils/db/repositories/conversion.py @@ -0,0 +1,55 @@ +from typing import List, Optional + +from sqlalchemy import func +from sqlalchemy.orm import Session + +from netspresso.utils.db.models.conversion import ConversionTask +from netspresso.utils.db.repositories.base import BaseRepository, Order + + +class ConversionTaskRepository(BaseRepository[ConversionTask]): + def get_by_task_id(self, db: Session, task_id: str) -> Optional[ConversionTask]: + task = db.query(self.model).filter( + self.model.task_id == task_id, + ).first() + + return task + + def _get_tasks( + self, + db: Session, + condition, + start: Optional[int] = None, + size: Optional[int] = None, + order: Optional[Order] = Order.DESC, + ) -> Optional[List[ConversionTask]]: + ordering_func = self.choose_order_func(order) + query = db.query(self.model).filter(condition) + + if order: + query = query.order_by(ordering_func(self.model.updated_at)) + + if start is not None and size is not None: + query = query.offset(start).limit(size) + + models = query.all() + + return models + + def get_all_by_model_id( + self, + db: Session, + model_id: str, + start: Optional[int] = None, + size: Optional[int] = None, + order: Optional[Order] = Order.DESC, + ) -> Optional[List[ConversionTask]]: + return self._get_tasks( + db=db, + condition=self.model.input_model_id == model_id, + start=start, + size=size, + order=order, + ) + +conversion_task_repository = ConversionTaskRepository(ConversionTask) diff --git a/netspresso/utils/db/repositories/model.py b/netspresso/utils/db/repositories/model.py index 0a87cf6c..585bc7ed 100644 --- a/netspresso/utils/db/repositories/model.py +++ b/netspresso/utils/db/repositories/model.py @@ -3,12 +3,12 @@ from sqlalchemy import func from sqlalchemy.orm import Session -from netspresso.utils.db.models.model import TrainedModel +from netspresso.utils.db.models.model import Model from netspresso.utils.db.repositories.base import BaseRepository, Order -class TrainedModelRepository(BaseRepository[TrainedModel]): - def get_by_model_id(self, db: Session, model_id: str, user_id: str) -> Optional[TrainedModel]: +class ModelRepository(BaseRepository[Model]): + def get_by_model_id(self, db: Session, model_id: str, user_id: str) -> Optional[Model]: model = db.query(self.model).filter( self.model.model_id == model_id, self.model.user_id == user_id, @@ -23,7 +23,7 @@ def _get_models( start: Optional[int] = None, size: Optional[int] = None, order: Optional[Order] = None, - ) -> Optional[List[TrainedModel]]: + ) -> Optional[List[Model]]: ordering_func = self.choose_order_func(order) query = db.query(self.model).filter(condition) @@ -44,7 +44,7 @@ def get_all_by_user_id( start: Optional[int] = None, size: Optional[int] = None, order: Optional[Order] = None, - ) -> Optional[List[TrainedModel]]: + ) -> Optional[List[Model]]: return self._get_models( db=db, condition=self.model.user_id == user_id, @@ -60,7 +60,7 @@ def get_all_by_project_id( start: Optional[int] = None, size: Optional[int] = None, order: Optional[Order] = Order.DESC, - ) -> Optional[List[TrainedModel]]: + ) -> Optional[List[Model]]: return self._get_models( db=db, condition=self.model.project_id == project_id, @@ -77,4 +77,4 @@ def count_by_user_id(self, db: Session, user_id: str) -> int: ) -trained_model_repository = TrainedModelRepository(TrainedModel) +model_repository = ModelRepository(Model) diff --git a/netspresso/utils/db/repositories/task.py b/netspresso/utils/db/repositories/training.py similarity index 53% rename from netspresso/utils/db/repositories/task.py rename to netspresso/utils/db/repositories/training.py index 91525fe0..b38febe1 100644 --- a/netspresso/utils/db/repositories/task.py +++ b/netspresso/utils/db/repositories/training.py @@ -3,18 +3,25 @@ from sqlalchemy import func from sqlalchemy.orm import Session -from netspresso.utils.db.models.train import TrainTask +from netspresso.utils.db.models.training import TrainingTask from netspresso.utils.db.repositories.base import BaseRepository, Order -class TrainTaskRepository(BaseRepository[TrainTask]): - def get_by_task_id(self, db: Session, task_id: str) -> Optional[TrainTask]: +class TrainingTaskRepository(BaseRepository[TrainingTask]): + def get_by_task_id(self, db: Session, task_id: str) -> Optional[TrainingTask]: task = db.query(self.model).filter( self.model.task_id == task_id, ).first() return task + def get_by_model_id(self, db: Session, model_id: str) -> Optional[TrainingTask]: + task = db.query(self.model).filter( + self.model.model_id == model_id, + ).first() + + return task + def save(self, db, task): db.add(task) db.commit() @@ -22,4 +29,4 @@ def save(self, db, task): return task -train_task_repository = TrainTaskRepository(TrainTask) +training_task_repository = TrainingTaskRepository(TrainingTask)