Skip to content

Commit

Permalink
Add create convert task api (#444)
Browse files Browse the repository at this point in the history
  • Loading branch information
Only-bottle authored Feb 10, 2025
1 parent 59cf646 commit 9d73241
Show file tree
Hide file tree
Showing 34 changed files with 1,111 additions and 170 deletions.
5 changes: 3 additions & 2 deletions app/api/api.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from fastapi import APIRouter

from app.api.v1.endpoints import model, project, system, train_task, user
from app.api.v1.endpoints import conversion_task, model, project, system, training_task, user

api_router = APIRouter()
api_router.include_router(user.router, prefix="/users", tags=["user"])
api_router.include_router(project.router, prefix="/projects", tags=["project"])
api_router.include_router(model.router, prefix="/models", tags=["model"])
api_router.include_router(train_task.router, prefix="/tasks", tags=["task"])
api_router.include_router(training_task.router, prefix="/tasks", tags=["training"])
api_router.include_router(conversion_task.router, prefix="/tasks", tags=["conversion"])
api_router.include_router(system.router, prefix="/system", tags=["system"])
63 changes: 63 additions & 0 deletions app/api/v1/endpoints/conversion_task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from fastapi import APIRouter, Depends, Query
from sqlalchemy.orm import Session

from app.api.deps import api_key_header
from app.api.v1.schemas.task.conversion.conversion_task import (
ConversionCreate,
ConversionCreateResponse,
ConversionResponse,
SupportedDevicesResponse,
)
from app.services.conversion_task import conversion_task_service
from netspresso.enums.conversion import SourceFramework
from netspresso.utils.db.session import get_db

router = APIRouter()


@router.get(
"/conversions/configuration/devices",
response_model=SupportedDevicesResponse,
description="Get supported devices and frameworks for model conversion based on the source framework.",
)
def get_supported_conversion_devices(
framework: SourceFramework = Query(..., description="Source framework of the model to be converted."),
db: Session = Depends(get_db),
api_key: str = Depends(api_key_header),
) -> SupportedDevicesResponse:
supported_devices = conversion_task_service.get_supported_devices(db=db, framework=framework, api_key=api_key)

return SupportedDevicesResponse(data=supported_devices)


@router.post("/conversions", response_model=ConversionCreateResponse, status_code=201)
def create_conversions_task(
request_body: ConversionCreate,
db: Session = Depends(get_db),
api_key: str = Depends(api_key_header),
) -> ConversionCreateResponse:
conversion_task = conversion_task_service.create_conversion_task(db=db, conversion_in=request_body, api_key=api_key)

return ConversionCreateResponse(data=conversion_task)


@router.get("/conversions/{task_id}", response_model=ConversionResponse)
def get_conversions_task(
task_id: str,
db: Session = Depends(get_db),
api_key: str = Depends(api_key_header),
) -> ConversionResponse:
conversion_task = conversion_task_service.get_conversion_task(db=db, task_id=task_id, api_key=api_key)

return ConversionResponse(data=conversion_task)


@router.post("/conversions/{task_id}/cancel", response_model=ConversionResponse)
def cancel_conversion_task(
task_id: str,
db: Session = Depends(get_db),
api_key: str = Depends(api_key_header),
) -> ConversionResponse:
conversion_task = conversion_task_service.cancel_conversion_task(db=db, task_id=task_id, api_key=api_key)

return ConversionResponse(data=conversion_task)
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from typing import List

from fastapi import APIRouter, Depends
from sqlalchemy.orm import Session

Expand All @@ -10,7 +8,7 @@
SupportedSchedulersResponse,
)
from app.api.v1.schemas.task.train.train_task import TrainingCreate, TrainingResponse
from app.services.train_task import train_task_service
from app.services.training_task import train_task_service
from netspresso.utils.db.session import get_db

router = APIRouter()
Expand Down
1 change: 1 addition & 0 deletions app/api/v1/schemas/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class ProjectPayload(ProjectCreate):
project_id: str = Field(..., description="The unique identifier for the project.")
model_ids: List[str] = Field(default_factory=list, description="The list of models associated with the project.")
user_id: str = Field(..., description="The unique identifier for the user associated with the project.")
project_abs_path: str = Field(..., description="The absolute path of the project.")
created_at: datetime = Field(..., description="The timestamp when the project was created.")
updated_at: datetime = Field(..., description="The timestamp when the project was last updated.")

Expand Down
Empty file.
73 changes: 73 additions & 0 deletions app/api/v1/schemas/task/conversion/conversion_task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
from datetime import datetime
from typing import Dict, List, Optional

from pydantic import BaseModel, ConfigDict, Field, model_validator

from app.api.v1.schemas.base import ResponseItem
from app.api.v1.schemas.task.conversion.device import (
PrecisionPayload,
SoftwareVersionPayload,
SupportedDevicePayload,
TargetDevicePayload,
)
from netspresso.enums.conversion import TARGET_FRAMEWORK_DISPLAY_MAP, Precision, TargetFramework, TargetFrameworkDisplay
from netspresso.enums.device import DeviceName, SoftwareVersion
from netspresso.enums.model import DataType


class TargetFrameworkPayload(BaseModel):
name: TargetFramework = Field(description="Framework name")
display_name: Optional[TargetFrameworkDisplay] = Field(default=None, description="Framework display name")

@model_validator(mode="after")
def set_display_name(self) -> str:
self.display_name = TARGET_FRAMEWORK_DISPLAY_MAP.get(self.name)

return self


class SupportedDeviceResponse(BaseModel):
framework: TargetFrameworkPayload
devices: List[SupportedDevicePayload]


class SupportedDevicesResponse(BaseModel):
data: List[SupportedDeviceResponse]


class ConversionCreate(BaseModel):
input_model_id: str = Field(description="Input model ID")
framework: TargetFramework = Field(description="Framework name")
device_name: DeviceName = Field(description="Device name")
software_version: Optional[SoftwareVersion] = Field(default=None, description="Software version")
precision: Precision = Field(description="Precision")
calibration_dataset_path: Optional[str] = Field(default=None, description="Path to the calibration dataset")


class ConversionPayload(BaseModel):
model_config = ConfigDict(from_attributes=True)

task_id: str
model_id: Optional[str] = None
framework: TargetFrameworkPayload
device: TargetDevicePayload
software_version: Optional[SoftwareVersionPayload] = None
precision: PrecisionPayload
status: str
is_deleted: bool
error_detail: Optional[Dict] = None
input_model_id: str
created_at: Optional[datetime] = None
updated_at: Optional[datetime] = None


class ConversionCreatePayload(BaseModel):
task_id: str


class ConversionCreateResponse(ResponseItem):
data: ConversionCreatePayload


class ConversionResponse(ResponseItem):
data: ConversionPayload
81 changes: 81 additions & 0 deletions app/api/v1/schemas/task/conversion/device.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
from typing import List, Optional

from pydantic import BaseModel, Field, model_validator

from netspresso.enums.conversion import PRECISION_DISPLAY_MAP, Precision, PrecisionDisplay
from netspresso.enums.device import (
DEVICE_BRAND_MAP,
DEVICE_DISPLAY_MAP,
HARDWARE_TYPE_DISPLAY_MAP,
SOFTWARE_VERSION_DISPLAY_MAP,
DeviceBrand,
DeviceDisplay,
DeviceName,
HardwareType,
HardwareTypeDisplay,
SoftwareVersion,
SoftwareVersionDisplay,
)


class SoftwareVersionPayload(BaseModel):
name: SoftwareVersion
display_name: Optional[SoftwareVersionDisplay] = Field(default=None, description="Software version display name")

@model_validator(mode="after")
def set_display_name(self) -> str:
self.display_name = SOFTWARE_VERSION_DISPLAY_MAP.get(self.name)

return self


class PrecisionPayload(BaseModel):
name: Precision
display_name: Optional[PrecisionDisplay] = Field(default=None, description="Precision display name")

@model_validator(mode="after")
def set_display_name(self) -> str:
self.display_name = PRECISION_DISPLAY_MAP.get(self.name)

return self


class HardwareTypePayload(BaseModel):
name: HardwareType
display_name: Optional[HardwareTypeDisplay] = Field(default=None, description="Hardware type display name")

@model_validator(mode="after")
def set_display_name(self) -> str:
self.display_name = HARDWARE_TYPE_DISPLAY_MAP.get(self.name)

return self


class SupportedDevicePayload(BaseModel):
name: DeviceName
display_name: Optional[DeviceDisplay] = Field(default=None, description="Device display name")
brand_name: Optional[DeviceBrand] = Field(default=None, description="Device brand name")
software_versions: List[SoftwareVersionPayload]
precisions: List[PrecisionPayload]
hardware_types: List[HardwareTypePayload]

@model_validator(mode="after")
def set_display_name(self) -> str:
self.display_name = DEVICE_DISPLAY_MAP.get(self.name)
self.brand_name = DEVICE_BRAND_MAP.get(self.name)

return self



class TargetDevicePayload(BaseModel):
name: DeviceName
display_name: Optional[DeviceDisplay] = Field(default=None, description="Device display name")
brand_name: Optional[DeviceBrand] = Field(default=None, description="Device brand name")

@model_validator(mode="after")
def set_display_name(self) -> str:
self.display_name = DEVICE_DISPLAY_MAP.get(self.name)
self.brand_name = DEVICE_BRAND_MAP.get(self.name)

return self
9 changes: 8 additions & 1 deletion app/api/v1/schemas/task/train/hyperparameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,14 @@

from pydantic import BaseModel, ConfigDict, Field, model_validator

from netspresso.enums.train import OPTIMIZER_DISPLAY_MAP, SCHEDULER_DISPLAY_MAP, Optimizer, OptimizerDisplay, Scheduler, SchedulerDisplay
from netspresso.enums.train import (
OPTIMIZER_DISPLAY_MAP,
SCHEDULER_DISPLAY_MAP,
Optimizer,
OptimizerDisplay,
Scheduler,
SchedulerDisplay,
)


class TrainerModel(BaseModel):
Expand Down
16 changes: 7 additions & 9 deletions app/api/v1/schemas/task/train/train_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,17 @@

from app.api.v1.schemas.base import ResponseItem
from netspresso.enums.train import (
Framework,
FrameworkDisplay,
Task,
TaskDisplay,
PretrainedModel,
PretrainedModelDisplay,
PretrainedModelGroup,
TASK_DISPLAY_MAP,
FRAMEWORK_DISPLAY_MAP,
MODEL_DISPLAY_MAP,
MODEL_GROUP_MAP,
TASK_DISPLAY_MAP,
FRAMEWORK_DISPLAY_MAP,
Framework,
FrameworkDisplay,
PretrainedModel,
PretrainedModelDisplay,
PretrainedModelGroup,
Task,
TaskDisplay,
)

from .dataset import DatasetCreate, DatasetPayload
Expand Down
Loading

0 comments on commit 9d73241

Please sign in to comment.