From a54843f0546933f824366f81000aa342df4d3994 Mon Sep 17 00:00:00 2001 From: Thomas Bernard Date: Thu, 14 Mar 2024 23:24:00 +0100 Subject: [PATCH] Reorganize backend files to make it futurproof --- .github/workflows/pr.yml | 2 - backend/Dockerfile | 10 +- backend/requirements.txt | 5 +- backend/src/config.py | 114 +++++++ backend/src/constants.py | 24 -- backend/src/main.py | 309 +----------------- .../{model.pt => src/ml/models/typology.pt} | Bin .../src/{model.py => ml/utils/typology.py} | 17 +- backend/src/router.py | 144 ++++++++ backend/src/utils.py | 25 ++ backend/tests/test_api.py | 4 +- backend/tests/test_model.py | 8 +- docker-compose.yml | 1 - 13 files changed, 301 insertions(+), 362 deletions(-) create mode 100644 backend/src/config.py delete mode 100644 backend/src/constants.py rename backend/{model.pt => src/ml/models/typology.pt} (100%) rename backend/src/{model.py => ml/utils/typology.py} (68%) create mode 100644 backend/src/router.py create mode 100644 backend/src/utils.py diff --git a/.github/workflows/pr.yml b/.github/workflows/pr.yml index cbe0cc61..51d53f7b 100644 --- a/.github/workflows/pr.yml +++ b/.github/workflows/pr.yml @@ -35,7 +35,6 @@ jobs: context: ./backend push: true tags: ghcr.io/datalab-mi/basegun/basegun-backend:${{ github.head_ref }} - target: dev build-frontend: name: Build Frontend @@ -55,7 +54,6 @@ jobs: context: ./frontend push: true tags: ghcr.io/datalab-mi/basegun/basegun-frontend:${{ github.head_ref }} - target: prod test-backend: name: Test Backend diff --git a/backend/Dockerfile b/backend/Dockerfile index cbe18cbb..2c65f120 100644 --- a/backend/Dockerfile +++ b/backend/Dockerfile @@ -22,12 +22,4 @@ RUN pip --default-timeout=300 install --upgrade pip \ ARG VERSION ENV SSL_CERT_FILE=$CACERT_LOCATION -COPY src/ src/ -COPY model.pt . - -FROM base as dev -COPY tests/ tests/ - -FROM base as prod -RUN pip install --extra-index-url https://download.pytorch.org/whl/cpu \ - torch==2.1.1+cpu torchvision==0.16.1+cpu && rm -r /root/.cache \ No newline at end of file +COPY . . \ No newline at end of file diff --git a/backend/requirements.txt b/backend/requirements.txt index 89afa3cf..8da5c3f7 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -8,10 +8,9 @@ gelf-formatter==0.2.1 pyyaml>=5.4.1 user-agents==2.2.0 boto3==1.28.39 -torch==2.1.1 -torchvision==0.16.1 -ultralytics==8.1.2 autodynatrace==2.0.0 +# ML +ultralytics==8.1.2 # Dev pytest==7.4.3 coverage==7.3.2 \ No newline at end of file diff --git a/backend/src/config.py b/backend/src/config.py new file mode 100644 index 00000000..cdd3755a --- /dev/null +++ b/backend/src/config.py @@ -0,0 +1,114 @@ +import os +from datetime import datetime + +import boto3 +from gelfformatter import GelfFormatter + +CURRENT_DIR = os.path.dirname(os.path.abspath(__file__)) + +PATH_LOGS = os.environ.get("PATH_LOGS", "/tmp/logs") + +LOGS_CONFIG = { + "version": 1, + "disable_existing_loggers": False, + "formatters": {"standard": {"()": lambda: GelfFormatter()}}, + "handlers": { + "default": { + "class": "logging.StreamHandler", + "formatter": "standard", + "level": "INFO", + "stream": "ext://sys.stdout", + }, + "file": { + "class": "logging.handlers.TimedRotatingFileHandler", + "when": "midnight", + "utc": True, + "backupCount": 5, + "level": "INFO", + "filename": f"{PATH_LOGS}/log.json", + "formatter": "standard", + }, + }, + "loggers": {"": {"handlers": ["default", "file"], "level": "DEBUG"}}, +} + +HEADERS = [ + {"name": "Cache-Control", "value": "no-store, max-age=0"}, + {"name": "Clear-Site-Data", "value": '"cache","cookies","storage"'}, + { + "name": "Content-Security-Policy", + "value": "default-src 'self'; form-action 'self'; object-src 'none'; frame-ancestors 'none'; upgrade-insecure-requests; block-all-mixed-content", + }, + {"name": "Cross-Origin-Embedder-Policy", "value": "require-corp"}, + {"name": "Cross-Origin-Opener-Policy", "value": "same-origin"}, + {"name": "Cross-Origin-Resource-Policy", "value": "same-origin"}, + { + "name": "Permissions-Policy", + "value": "accelerometer=(),ambient-light-sensor=(),autoplay=(),battery=(),camera=(),display-capture=(),document-domain=(),encrypted-media=(),fullscreen=(),gamepad=(),geolocation=(),gyroscope=(),layout-animations=(self),legacy-image-formats=(self),magnetometer=(),microphone=(),midi=(),oversized-images=(self),payment=(),picture-in-picture=(),publickey-credentials-get=(),speaker-selection=(),sync-xhr=(self),unoptimized-images=(self),unsized-media=(self),usb=(),screen-wake-lock=(),web-share=(),xr-spatial-tracking=()", + }, + {"name": "Pragma", "value": "no-cache"}, + {"name": "Referrer-Policy", "value": "no-referrer"}, + { + "name": "Strict-Transport-Security", + "value": "max-age=31536000 ; includeSubDomains", + }, + {"name": "X-Content-Type-Options", "value": "nosniff"}, + {"name": "X-Frame-Options", "value": "deny"}, + {"name": "X-Permitted-Cross-Domain-Policies", "value": "none"}, +] + + +def get_device(user_agent) -> str: + """Explicitly gives the device of a user-agent object + + Args: + user_agent: info given by the user browser + + Returns: + str: mobile, pc, tablet or other + """ + if user_agent.is_mobile: + return "mobile" + elif user_agent.is_pc: + return "pc" + elif user_agent.is_tablet: + return "tablet" + else: + return "other" + + +def get_base_logs(user_agent, user_id: str) -> dict: + """Generates the common information for custom logs in basegun. + Each function can add some info specific to the current process, + then we insert these custom logs as extra + + Args: + user_agent: user agent object + user_id (str): UUID identifying a unique user + + Returns: + dict: the base custom information + """ + extras_logging = { + "bg_date": datetime.now().isoformat(), + "bg_user_id": user_id, + "bg_version": APP_VERSION, + "bg_model": MODEL_VERSION, + "bg_device": get_device(user_agent), + "bg_device_family": user_agent.device.family, + "bg_device_os": user_agent.os.family, + "bg_device_browser": user_agent.browser.family, + } + return extras_logging + + +# Object storage +S3_URL_ENDPOINT = os.environ["S3_URL_ENDPOINT"] +S3_BUCKET_NAME = os.environ["S3_BUCKET_NAME"] +S3_PREFIX = os.path.join("uploaded-images/", os.environ["WORKSPACE"]) + +S3 = boto3.resource("s3", endpoint_url=S3_URL_ENDPOINT, verify=False) + +# Versions +APP_VERSION = "-1" +MODEL_VERSION = "-1" diff --git a/backend/src/constants.py b/backend/src/constants.py deleted file mode 100644 index 20cc489e..00000000 --- a/backend/src/constants.py +++ /dev/null @@ -1,24 +0,0 @@ -HEADERS = [ - {"name": "Cache-Control", "value": "no-store, max-age=0"}, - {"name": "Clear-Site-Data", "value": '"cache","cookies","storage"'}, - { - "name": "Content-Security-Policy", - "value": "default-src 'self'; form-action 'self'; object-src 'none'; frame-ancestors 'none'; upgrade-insecure-requests; block-all-mixed-content", - }, - {"name": "Cross-Origin-Embedder-Policy", "value": "require-corp"}, - {"name": "Cross-Origin-Opener-Policy", "value": "same-origin"}, - {"name": "Cross-Origin-Resource-Policy", "value": "same-origin"}, - { - "name": "Permissions-Policy", - "value": "accelerometer=(),ambient-light-sensor=(),autoplay=(),battery=(),camera=(),display-capture=(),document-domain=(),encrypted-media=(),fullscreen=(),gamepad=(),geolocation=(),gyroscope=(),layout-animations=(self),legacy-image-formats=(self),magnetometer=(),microphone=(),midi=(),oversized-images=(self),payment=(),picture-in-picture=(),publickey-credentials-get=(),speaker-selection=(),sync-xhr=(self),unoptimized-images=(self),unsized-media=(self),usb=(),screen-wake-lock=(),web-share=(),xr-spatial-tracking=()", - }, - {"name": "Pragma", "value": "no-cache"}, - {"name": "Referrer-Policy", "value": "no-referrer"}, - { - "name": "Strict-Transport-Security", - "value": "max-age=31536000 ; includeSubDomains", - }, - {"name": "X-Content-Type-Options", "value": "nosniff"}, - {"name": "X-Frame-Options", "value": "deny"}, - {"name": "X-Permitted-Cross-Domain-Policies", "value": "none"}, -] diff --git a/backend/src/main.py b/backend/src/main.py index cc207791..26b9dd27 100644 --- a/backend/src/main.py +++ b/backend/src/main.py @@ -1,156 +1,14 @@ -import json -import sys import logging import os -import time -from datetime import datetime -from contextlib import asynccontextmanager -from typing import Union -from uuid import uuid4 -import boto3 -from fastapi import ( - APIRouter, - BackgroundTasks, - Cookie, - FastAPI, - File, - Form, - HTTPException, - Request, - Response, - UploadFile, -) +from fastapi import FastAPI, Request from fastapi.middleware.cors import CORSMiddleware -from fastapi.responses import PlainTextResponse -from gelfformatter import GelfFormatter -from src.constants import HEADERS -from src.model import load_model_inference, predict_image -from user_agents import parse - -CURRENT_DIR = os.path.dirname(os.path.abspath(__file__)) - - -def setup_logs(log_dir: str): - os.makedirs(log_dir, exist_ok=True) - - logging_config = { - "version": 1, - 'disable_existing_loggers': False, - "formatters": { - 'standard': { - '()': lambda: GelfFormatter() - } - }, - "handlers": { - 'default': { - 'class': 'logging.StreamHandler', - 'formatter': 'standard', - 'level': "INFO", - 'stream': 'ext://sys.stdout' - }, - 'file': { - 'class': 'logging.handlers.TimedRotatingFileHandler', - 'when': 'midnight', - 'utc': True, - 'backupCount': 5, - 'level': "INFO", - 'filename': f'{log_dir}/log.json', - 'formatter': 'standard', - }, - }, - "loggers": { - "": { - 'handlers': ['default', 'file'], - 'level': "DEBUG" - } - } - } - - logging.config.dictConfig(logging_config) - - -def get_device(user_agent) -> str: - """Explicitly gives the device of a user-agent object - - Args: - user_agent: info given by the user browser - - Returns: - str: mobile, pc, tablet or other - """ - if user_agent.is_mobile: - return "mobile" - elif user_agent.is_pc: - return "pc" - elif user_agent.is_tablet: - return "tablet" - else: - return "other" - - -def get_base_logs(user_agent, user_id: str) -> dict: - """Generates the common information for custom logs in basegun. - Each function can add some info specific to the current process, - then we insert these custom logs as extra - - Args: - user_agent: user agent object - user_id (str): UUID identifying a unique user - Returns: - dict: the base custom information - """ - extras_logging = { - "bg_date": datetime.now().isoformat(), - "bg_user_id": user_id, - "bg_version": APP_VERSION, - "bg_model": MODEL_VERSION, - "bg_device": get_device(user_agent), - "bg_device_family": user_agent.device.family, - "bg_device_os": user_agent.os.family, - "bg_device_browser": user_agent.browser.family, - } - return extras_logging +from .config import HEADERS, LOGS_CONFIG, PATH_LOGS +from .router import router - -def upload_image(content: bytes, image_key: str): - """Uploads an image to s3 bucket - path uploaded-images/WORKSPACE/img_name - where WORKSPACE is dev, preprod or prod - - Args: - content (bytes): file content - image_key (str): path we want to have - """ - start = time.time() - object = s3.Object(S3_BUCKET_NAME, image_key) - object.put(Body=content) - extras_logging = { - "bg_date": datetime.now().isoformat(), - "bg_upload_time": time.time() - start, - "bg_image_url": image_key, - } - logging.info("Upload successful", extra=extras_logging) - - -#################### -# SETUP # -#################### - -# FastAPI Setup app = FastAPI(docs_url="/api/docs") -router = APIRouter(prefix="/api") -origins = [ # allow requests from front-end - "http://basegun.fr", - "https://basegun.fr", - "http://preprod.basegun.fr", - "https://preprod.basegun.fr", - "http://localhost", - "http://localhost:8080", - "http://localhost:3000", -] app.add_middleware( CORSMiddleware, allow_origins=["*"], @@ -169,164 +27,7 @@ async def add_owasp_middleware(request: Request, call_next): # Logs -PATH_LOGS = os.environ.get("PATH_LOGS", "/tmp/logs") -setup_logs(PATH_LOGS) - -# Load model -app.model = load_model_inference("./model.pt") - -# Object storage -S3_URL_ENDPOINT = os.environ["S3_URL_ENDPOINT"] -S3_BUCKET_NAME = os.environ["S3_BUCKET_NAME"] -S3_PREFIX = os.path.join("uploaded-images/", os.environ["WORKSPACE"]) - -s3 = boto3.resource("s3", endpoint_url=S3_URL_ENDPOINT, verify=False) - -# Versions -if "versions.json" in os.listdir(os.path.dirname(CURRENT_DIR)): - with open("versions.json", "r") as f: - versions = json.load(f) - APP_VERSION = versions["app"] - MODEL_VERSION = versions["model"] -else: - logging.warn("File versions.json not found") - APP_VERSION = "-1" - MODEL_VERSION = "-1" - - -#################### -# ROUTES # -#################### -@router.get("/", response_class=PlainTextResponse) -def home(): - return "Basegun backend" - - -@router.get("/version", response_class=PlainTextResponse) -def version(): - return APP_VERSION - - -@router.post("/upload") -async def imageupload( - request: Request, - response: Response, - background_tasks: BackgroundTasks, - image: UploadFile = File(...), - date: float = Form(...), - user_id: Union[str, None] = Cookie(None), -): - - # prepare content logs - user_agent = parse(request.headers.get("user-agent")) - extras_logging = get_base_logs(user_agent, user_id) - extras_logging["bg_upload_time"] = round(time.time() - date, 2) - - try: - img_key = os.path.join( - S3_PREFIX, str(uuid4()) + os.path.splitext(image.filename)[1].lower() - ) - img_bytes = image.file.read() - - # upload image to OVH Cloud - background_tasks.add_task(upload_image, img_bytes, img_key) - extras_logging["bg_image_url"] = img_key - - # set user id - if not user_id: - user_id = uuid4() - response.set_cookie(key="user_id", value=user_id) - extras_logging["bg_user_id"] = user_id - - # send image to model for prediction - start = time.time() - label, confidence = predict_image(app.model, img_bytes) - extras_logging["bg_label"] = label - extras_logging["bg_confidence"] = confidence - extras_logging["bg_model_time"] = round(time.time() - start, 2) - if confidence < 0.76: - extras_logging["bg_confidence_level"] = "low" - elif confidence < 0.98: - extras_logging["bg_confidence_level"] = "medium" - else: - extras_logging["bg_confidence_level"] = "high" - - logging.info("Identification request", extra=extras_logging) - - return { - "path": img_key, - "label": label, - "confidence": confidence, - "confidence_level": extras_logging["bg_confidence_level"], - } - - except Exception as e: - extras_logging["bg_error_type"] = e.__class__.__name__ - logging.exception(e, extra=extras_logging) - raise HTTPException(status_code=500, detail=str(e)) - - -@router.post("/identification-feedback") -async def log_feedback(request: Request, user_id: Union[str, None] = Cookie(None)): - res = await request.json() - - user_agent = parse(request.headers.get("user-agent")) - extras_logging = get_base_logs(user_agent, user_id) - - extras_logging["bg_feedback_bool"] = res["feedback"] - for key in ["image_url", "label", "confidence", "confidence_level"]: - extras_logging["bg_" + key] = res[key] - - logging.info("Identification feedback", extra=extras_logging) - return - - -@router.post("/tutorial-feedback") -async def log_tutorial_feedback( - request: Request, user_id: Union[str, None] = Cookie(None) -): - res = await request.json() - - user_agent = parse(request.headers.get("user-agent")) - extras_logging = get_base_logs(user_agent, user_id) - - for key in [ - "image_url", - "label", - "confidence", - "confidence_level", - "tutorial_feedback", - "tutorial_option", - "route_name", - ]: - extras_logging["bg_" + key] = res[key] - - logging.info("Tutorial feedback", extra=extras_logging) - return - - -@router.post("/identification-dummy") -async def log_identification_dummy( - request: Request, user_id: Union[str, None] = Cookie(None) -): - res = await request.json() - - user_agent = parse(request.headers.get("user-agent")) - extras_logging = get_base_logs(user_agent, user_id) - - # to know if the firearm is dummy or real - extras_logging["bg_dummy_bool"] = res["is_dummy"] - for key in [ - "image_url", - "label", - "confidence", - "confidence_level", - "tutorial_option", - ]: - extras_logging["bg_" + key] = res[key] - - logging.info("Identification dummy", extra=extras_logging) - return - +os.makedirs(PATH_LOGS, exist_ok=True) +logging.config.dictConfig(LOGS_CONFIG) app.include_router(router) diff --git a/backend/model.pt b/backend/src/ml/models/typology.pt similarity index 100% rename from backend/model.pt rename to backend/src/ml/models/typology.pt diff --git a/backend/src/model.py b/backend/src/ml/utils/typology.py similarity index 68% rename from backend/src/model.py rename to backend/src/ml/utils/typology.py index cdc418e9..b90a214d 100644 --- a/backend/src/model.py +++ b/backend/src/ml/utils/typology.py @@ -1,7 +1,6 @@ from io import BytesIO from typing import Union -import numpy as np from PIL import Image from ultralytics import YOLO @@ -20,20 +19,10 @@ "semi_auto_style_militaire_autre", ] +MODEL = YOLO("./src/ml/models/typology.pt") -def load_model_inference(model_path: str): - """Load model structure and weights - Args: - model_path (str): path to model (.pt file) - - Returns: - Model: loaded model ready for prediction and Warm-up - """ - return YOLO(model_path) - - -def predict_image(model, img: bytes) -> Union[str, float]: +def get_typology_from_image(img: bytes) -> Union[str, float]: """Run the model prediction on an image Args: @@ -44,7 +33,7 @@ def predict_image(model, img: bytes) -> Union[str, float]: Union[str, float]: (label, confidence) of best class predicted """ im = Image.open(BytesIO(img)) - results = model(im, verbose=False) + results = MODEL(im, verbose=False) predicted_class = results[0].probs.top5[0] label = CLASSES[predicted_class] confidence = float(results[0].probs.top5conf[0]) diff --git a/backend/src/router.py b/backend/src/router.py new file mode 100644 index 00000000..9728cf52 --- /dev/null +++ b/backend/src/router.py @@ -0,0 +1,144 @@ +import logging +import os +import time +from typing import Union +from uuid import uuid4 + +from fastapi import (APIRouter, BackgroundTasks, Cookie, File, Form, + HTTPException, Request, Response, UploadFile) +from fastapi.responses import PlainTextResponse +from user_agents import parse + +from .config import APP_VERSION, S3_PREFIX, get_base_logs +from .ml.utils.typology import get_typology_from_image +from .utils import upload_image + +router = APIRouter(prefix="/api") + + +@router.get("/", response_class=PlainTextResponse) +def home(): + return "Basegun backend" + + +@router.get("/version", response_class=PlainTextResponse) +def version(): + return APP_VERSION + + +@router.post("/upload") +async def imageupload( + request: Request, + response: Response, + background_tasks: BackgroundTasks, + image: UploadFile = File(...), + date: float = Form(...), + user_id: Union[str, None] = Cookie(None), +): + # prepare content logs + user_agent = parse(request.headers.get("user-agent")) + extras_logging = get_base_logs(user_agent, user_id) + extras_logging["bg_upload_time"] = round(time.time() - date, 2) + + try: + img_key = os.path.join( + S3_PREFIX, str(uuid4()) + os.path.splitext(image.filename)[1].lower() + ) + img_bytes = image.file.read() + + # upload image to OVH Cloud + background_tasks.add_task(upload_image, img_bytes, img_key) + extras_logging["bg_image_url"] = img_key + + # set user id + if not user_id: + user_id = uuid4() + response.set_cookie(key="user_id", value=user_id) + extras_logging["bg_user_id"] = user_id + + # send image to model for prediction + start = time.time() + label, confidence = get_typology_from_image(img_bytes) + extras_logging["bg_label"] = label + extras_logging["bg_confidence"] = confidence + extras_logging["bg_model_time"] = round(time.time() - start, 2) + if confidence < 0.76: + extras_logging["bg_confidence_level"] = "low" + elif confidence < 0.98: + extras_logging["bg_confidence_level"] = "medium" + else: + extras_logging["bg_confidence_level"] = "high" + + logging.info("Identification request", extra=extras_logging) + + return { + "path": img_key, + "label": label, + "confidence": confidence, + "confidence_level": extras_logging["bg_confidence_level"], + } + + except Exception as e: + extras_logging["bg_error_type"] = e.__class__.__name__ + logging.exception(e, extra=extras_logging) + raise HTTPException(status_code=500, detail=str(e)) + + +@router.post("/identification-feedback") +async def log_feedback(request: Request, user_id: Union[str, None] = Cookie(None)): + res = await request.json() + + user_agent = parse(request.headers.get("user-agent")) + extras_logging = get_base_logs(user_agent, user_id) + + extras_logging["bg_feedback_bool"] = res["feedback"] + for key in ["image_url", "label", "confidence", "confidence_level"]: + extras_logging["bg_" + key] = res[key] + + logging.info("Identification feedback", extra=extras_logging) + + +@router.post("/tutorial-feedback") +async def log_tutorial_feedback( + request: Request, user_id: Union[str, None] = Cookie(None) +): + res = await request.json() + + user_agent = parse(request.headers.get("user-agent")) + extras_logging = get_base_logs(user_agent, user_id) + + for key in [ + "image_url", + "label", + "confidence", + "confidence_level", + "tutorial_feedback", + "tutorial_option", + "route_name", + ]: + extras_logging["bg_" + key] = res[key] + + logging.info("Tutorial feedback", extra=extras_logging) + + +@router.post("/identification-dummy") +async def log_identification_dummy( + request: Request, user_id: Union[str, None] = Cookie(None) +): + res = await request.json() + + user_agent = parse(request.headers.get("user-agent")) + extras_logging = get_base_logs(user_agent, user_id) + + # to know if the firearm is dummy or real + extras_logging["bg_dummy_bool"] = res["is_dummy"] + for key in [ + "image_url", + "label", + "confidence", + "confidence_level", + "tutorial_option", + ]: + extras_logging["bg_" + key] = res[key] + + logging.info("Identification dummy", extra=extras_logging) diff --git a/backend/src/utils.py b/backend/src/utils.py new file mode 100644 index 00000000..fb0460fc --- /dev/null +++ b/backend/src/utils.py @@ -0,0 +1,25 @@ +import logging +import time +from datetime import datetime + +from .config import S3, S3_BUCKET_NAME + + +def upload_image(content: bytes, image_key: str): + """Uploads an image to s3 bucket + path uploaded-images/WORKSPACE/img_name + where WORKSPACE is dev, preprod or prod + + Args: + content (bytes): file content + image_key (str): path we want to have + """ + start = time.time() + object = S3.Object(S3_BUCKET_NAME, image_key) + object.put(Body=content) + extras_logging = { + "bg_date": datetime.now().isoformat(), + "bg_upload_time": time.time() - start, + "bg_image_url": image_key, + } + logging.info("Upload successful", extra=extras_logging) diff --git a/backend/tests/test_api.py b/backend/tests/test_api.py index 97ac06c4..28d05b4e 100644 --- a/backend/tests/test_api.py +++ b/backend/tests/test_api.py @@ -7,7 +7,9 @@ import pytest import requests from fastapi.testclient import TestClient -from src.main import S3_BUCKET_NAME, S3_URL_ENDPOINT, app + +from src.config import S3_BUCKET_NAME, S3_URL_ENDPOINT +from src.main import app client = TestClient(app) diff --git a/backend/tests/test_model.py b/backend/tests/test_model.py index c425e29c..243f8406 100644 --- a/backend/tests/test_model.py +++ b/backend/tests/test_model.py @@ -1,13 +1,13 @@ -import os - import pytest -from src.model import CLASSES, load_model_inference, predict_image + from src.main import app +from src.ml.utils.typology import get_typology_from_image + class TestModel: def test_predict_image(self): """Checks the prediction of an image by the model""" with open("./tests/revolver.jpg", "rb") as f: - res = predict_image(app.model, f.read()) + res = get_typology_from_image(f.read()) assert res[0] == "revolver" assert res[1] == pytest.approx(1, 0.1) diff --git a/docker-compose.yml b/docker-compose.yml index 61c43970..99337d1b 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -9,7 +9,6 @@ services: - VERSION=${TAG:-latest} - CACERT_LOCATION context: ./backend - target: ${BUILD_TARGET:-dev} command: uvicorn src.main:app --reload --host 0.0.0.0 --port 5000 --no-server-header container_name: basegun-backend environment: