-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
59cf646
commit 9d73241
Showing
34 changed files
with
1,111 additions
and
170 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,11 @@ | ||
from fastapi import APIRouter | ||
|
||
from app.api.v1.endpoints import model, project, system, train_task, user | ||
from app.api.v1.endpoints import conversion_task, model, project, system, training_task, user | ||
|
||
api_router = APIRouter() | ||
api_router.include_router(user.router, prefix="/users", tags=["user"]) | ||
api_router.include_router(project.router, prefix="/projects", tags=["project"]) | ||
api_router.include_router(model.router, prefix="/models", tags=["model"]) | ||
api_router.include_router(train_task.router, prefix="/tasks", tags=["task"]) | ||
api_router.include_router(training_task.router, prefix="/tasks", tags=["training"]) | ||
api_router.include_router(conversion_task.router, prefix="/tasks", tags=["conversion"]) | ||
api_router.include_router(system.router, prefix="/system", tags=["system"]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
from fastapi import APIRouter, Depends, Query | ||
from sqlalchemy.orm import Session | ||
|
||
from app.api.deps import api_key_header | ||
from app.api.v1.schemas.task.conversion.conversion_task import ( | ||
ConversionCreate, | ||
ConversionCreateResponse, | ||
ConversionResponse, | ||
SupportedDevicesResponse, | ||
) | ||
from app.services.conversion_task import conversion_task_service | ||
from netspresso.enums.conversion import SourceFramework | ||
from netspresso.utils.db.session import get_db | ||
|
||
router = APIRouter() | ||
|
||
|
||
@router.get( | ||
"/conversions/configuration/devices", | ||
response_model=SupportedDevicesResponse, | ||
description="Get supported devices and frameworks for model conversion based on the source framework.", | ||
) | ||
def get_supported_conversion_devices( | ||
framework: SourceFramework = Query(..., description="Source framework of the model to be converted."), | ||
db: Session = Depends(get_db), | ||
api_key: str = Depends(api_key_header), | ||
) -> SupportedDevicesResponse: | ||
supported_devices = conversion_task_service.get_supported_devices(db=db, framework=framework, api_key=api_key) | ||
|
||
return SupportedDevicesResponse(data=supported_devices) | ||
|
||
|
||
@router.post("/conversions", response_model=ConversionCreateResponse, status_code=201) | ||
def create_conversions_task( | ||
request_body: ConversionCreate, | ||
db: Session = Depends(get_db), | ||
api_key: str = Depends(api_key_header), | ||
) -> ConversionCreateResponse: | ||
conversion_task = conversion_task_service.create_conversion_task(db=db, conversion_in=request_body, api_key=api_key) | ||
|
||
return ConversionCreateResponse(data=conversion_task) | ||
|
||
|
||
@router.get("/conversions/{task_id}", response_model=ConversionResponse) | ||
def get_conversions_task( | ||
task_id: str, | ||
db: Session = Depends(get_db), | ||
api_key: str = Depends(api_key_header), | ||
) -> ConversionResponse: | ||
conversion_task = conversion_task_service.get_conversion_task(db=db, task_id=task_id, api_key=api_key) | ||
|
||
return ConversionResponse(data=conversion_task) | ||
|
||
|
||
@router.post("/conversions/{task_id}/cancel", response_model=ConversionResponse) | ||
def cancel_conversion_task( | ||
task_id: str, | ||
db: Session = Depends(get_db), | ||
api_key: str = Depends(api_key_header), | ||
) -> ConversionResponse: | ||
conversion_task = conversion_task_service.cancel_conversion_task(db=db, task_id=task_id, api_key=api_key) | ||
|
||
return ConversionResponse(data=conversion_task) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
from datetime import datetime | ||
from typing import Dict, List, Optional | ||
|
||
from pydantic import BaseModel, ConfigDict, Field, model_validator | ||
|
||
from app.api.v1.schemas.base import ResponseItem | ||
from app.api.v1.schemas.task.conversion.device import ( | ||
PrecisionPayload, | ||
SoftwareVersionPayload, | ||
SupportedDevicePayload, | ||
TargetDevicePayload, | ||
) | ||
from netspresso.enums.conversion import TARGET_FRAMEWORK_DISPLAY_MAP, Precision, TargetFramework, TargetFrameworkDisplay | ||
from netspresso.enums.device import DeviceName, SoftwareVersion | ||
from netspresso.enums.model import DataType | ||
|
||
|
||
class TargetFrameworkPayload(BaseModel): | ||
name: TargetFramework = Field(description="Framework name") | ||
display_name: Optional[TargetFrameworkDisplay] = Field(default=None, description="Framework display name") | ||
|
||
@model_validator(mode="after") | ||
def set_display_name(self) -> str: | ||
self.display_name = TARGET_FRAMEWORK_DISPLAY_MAP.get(self.name) | ||
|
||
return self | ||
|
||
|
||
class SupportedDeviceResponse(BaseModel): | ||
framework: TargetFrameworkPayload | ||
devices: List[SupportedDevicePayload] | ||
|
||
|
||
class SupportedDevicesResponse(BaseModel): | ||
data: List[SupportedDeviceResponse] | ||
|
||
|
||
class ConversionCreate(BaseModel): | ||
input_model_id: str = Field(description="Input model ID") | ||
framework: TargetFramework = Field(description="Framework name") | ||
device_name: DeviceName = Field(description="Device name") | ||
software_version: Optional[SoftwareVersion] = Field(default=None, description="Software version") | ||
precision: Precision = Field(description="Precision") | ||
calibration_dataset_path: Optional[str] = Field(default=None, description="Path to the calibration dataset") | ||
|
||
|
||
class ConversionPayload(BaseModel): | ||
model_config = ConfigDict(from_attributes=True) | ||
|
||
task_id: str | ||
model_id: Optional[str] = None | ||
framework: TargetFrameworkPayload | ||
device: TargetDevicePayload | ||
software_version: Optional[SoftwareVersionPayload] = None | ||
precision: PrecisionPayload | ||
status: str | ||
is_deleted: bool | ||
error_detail: Optional[Dict] = None | ||
input_model_id: str | ||
created_at: Optional[datetime] = None | ||
updated_at: Optional[datetime] = None | ||
|
||
|
||
class ConversionCreatePayload(BaseModel): | ||
task_id: str | ||
|
||
|
||
class ConversionCreateResponse(ResponseItem): | ||
data: ConversionCreatePayload | ||
|
||
|
||
class ConversionResponse(ResponseItem): | ||
data: ConversionPayload |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
from typing import List, Optional | ||
|
||
from pydantic import BaseModel, Field, model_validator | ||
|
||
from netspresso.enums.conversion import PRECISION_DISPLAY_MAP, Precision, PrecisionDisplay | ||
from netspresso.enums.device import ( | ||
DEVICE_BRAND_MAP, | ||
DEVICE_DISPLAY_MAP, | ||
HARDWARE_TYPE_DISPLAY_MAP, | ||
SOFTWARE_VERSION_DISPLAY_MAP, | ||
DeviceBrand, | ||
DeviceDisplay, | ||
DeviceName, | ||
HardwareType, | ||
HardwareTypeDisplay, | ||
SoftwareVersion, | ||
SoftwareVersionDisplay, | ||
) | ||
|
||
|
||
class SoftwareVersionPayload(BaseModel): | ||
name: SoftwareVersion | ||
display_name: Optional[SoftwareVersionDisplay] = Field(default=None, description="Software version display name") | ||
|
||
@model_validator(mode="after") | ||
def set_display_name(self) -> str: | ||
self.display_name = SOFTWARE_VERSION_DISPLAY_MAP.get(self.name) | ||
|
||
return self | ||
|
||
|
||
class PrecisionPayload(BaseModel): | ||
name: Precision | ||
display_name: Optional[PrecisionDisplay] = Field(default=None, description="Precision display name") | ||
|
||
@model_validator(mode="after") | ||
def set_display_name(self) -> str: | ||
self.display_name = PRECISION_DISPLAY_MAP.get(self.name) | ||
|
||
return self | ||
|
||
|
||
class HardwareTypePayload(BaseModel): | ||
name: HardwareType | ||
display_name: Optional[HardwareTypeDisplay] = Field(default=None, description="Hardware type display name") | ||
|
||
@model_validator(mode="after") | ||
def set_display_name(self) -> str: | ||
self.display_name = HARDWARE_TYPE_DISPLAY_MAP.get(self.name) | ||
|
||
return self | ||
|
||
|
||
class SupportedDevicePayload(BaseModel): | ||
name: DeviceName | ||
display_name: Optional[DeviceDisplay] = Field(default=None, description="Device display name") | ||
brand_name: Optional[DeviceBrand] = Field(default=None, description="Device brand name") | ||
software_versions: List[SoftwareVersionPayload] | ||
precisions: List[PrecisionPayload] | ||
hardware_types: List[HardwareTypePayload] | ||
|
||
@model_validator(mode="after") | ||
def set_display_name(self) -> str: | ||
self.display_name = DEVICE_DISPLAY_MAP.get(self.name) | ||
self.brand_name = DEVICE_BRAND_MAP.get(self.name) | ||
|
||
return self | ||
|
||
|
||
|
||
class TargetDevicePayload(BaseModel): | ||
name: DeviceName | ||
display_name: Optional[DeviceDisplay] = Field(default=None, description="Device display name") | ||
brand_name: Optional[DeviceBrand] = Field(default=None, description="Device brand name") | ||
|
||
@model_validator(mode="after") | ||
def set_display_name(self) -> str: | ||
self.display_name = DEVICE_DISPLAY_MAP.get(self.name) | ||
self.brand_name = DEVICE_BRAND_MAP.get(self.name) | ||
|
||
return self |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.