Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve search #462

Merged
merged 14 commits into from
Nov 11, 2024
50 changes: 27 additions & 23 deletions backend/src/api/endpoints/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,14 @@

from api.dependencies import get_current_user, get_db_session
from app.core.analysis.analysis_service import AnalysisService
from app.core.analysis.annotated_images import (
AnnotatedImagesColumns,
find_annotated_images,
find_annotated_images_info,
)
from app.core.analysis.annotated_segments import (
AnnotatedSegmentsColumns,
find_annotated_segments,
find_annotated_segments_info,
)
from app.core.analysis.word_frequency import (
WordFrequencyColumns,
from app.core.analysis.word_frequency_analysis.word_frequency import (
word_frequency,
word_frequency_export,
word_frequency_info,
)
from app.core.analysis.word_frequency_analysis.word_frequency_columns import (
WordFrequencyColumns,
)
from app.core.authorization.authz_user import AuthzUser
from app.core.data.doc_type import DocType
from app.core.data.dto.analysis import (
Expand All @@ -32,9 +24,21 @@
SampledSdocsResults,
WordFrequencyResult,
)
from app.core.filters.columns import ColumnInfo
from app.core.filters.filtering import Filter
from app.core.filters.sorting import Sort
from app.core.search.bbox_search.bbox_search import (
find_annotated_images,
find_annotated_images_info,
)
from app.core.search.bbox_search.bbox_search_columns import BBoxColumns
from app.core.search.column_info import ColumnInfo
from app.core.search.filtering import Filter
from app.core.search.sorting import Sort
from app.core.search.span_search.span_search import (
find_annotated_segments,
find_annotated_segments_info,
)
from app.core.search.span_search.span_search_columns import (
SpanColumns,
)

router = APIRouter(
prefix="/analysis", dependencies=[Depends(get_current_user)], tags=["analysis"]
Expand Down Expand Up @@ -101,14 +105,14 @@ def annotation_occurrences(

@router.post(
"/annotated_segments_info",
response_model=List[ColumnInfo[AnnotatedSegmentsColumns]],
response_model=List[ColumnInfo[SpanColumns]],
summary="Returns AnnotationSegments Info.",
)
def annotated_segments_info(
*,
project_id: int,
authz_user: AuthzUser = Depends(),
) -> List[ColumnInfo[AnnotatedSegmentsColumns]]:
) -> List[ColumnInfo[SpanColumns]]:
authz_user.assert_in_project(project_id)
return find_annotated_segments_info(
project_id=project_id,
Expand All @@ -124,10 +128,10 @@ def annotated_segments(
*,
project_id: int,
user_id: int,
filter: Filter[AnnotatedSegmentsColumns],
filter: Filter[SpanColumns],
page: Optional[int] = None,
page_size: Optional[int] = None,
sorts: List[Sort[AnnotatedSegmentsColumns]],
sorts: List[Sort[SpanColumns]],
authz_user: AuthzUser = Depends(),
) -> AnnotatedSegmentResult:
authz_user.assert_in_project(project_id)
Expand All @@ -144,14 +148,14 @@ def annotated_segments(

@router.post(
"/annotated_images_info",
response_model=List[ColumnInfo[AnnotatedImagesColumns]],
response_model=List[ColumnInfo[BBoxColumns]],
summary="Returns AnnotationSegments Info.",
)
def annotated_images_info(
*,
project_id: int,
authz_user: AuthzUser = Depends(),
) -> List[ColumnInfo[AnnotatedImagesColumns]]:
) -> List[ColumnInfo[BBoxColumns]]:
authz_user.assert_in_project(project_id)
return find_annotated_images_info(
project_id=project_id,
Expand All @@ -167,10 +171,10 @@ def annotated_images(
*,
project_id: int,
user_id: int,
filter: Filter[AnnotatedImagesColumns],
filter: Filter[BBoxColumns],
page: Optional[int] = None,
page_size: Optional[int] = None,
sorts: List[Sort[AnnotatedImagesColumns]],
sorts: List[Sort[BBoxColumns]],
authz_user: AuthzUser = Depends(),
) -> AnnotatedImageResult:
authz_user.assert_in_project(project_id)
Expand Down
9 changes: 5 additions & 4 deletions backend/src/api/endpoints/memo.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

from api.dependencies import get_current_user, get_db_session
from api.validation import Validate
from app.core.analysis.memo import MemoColumns, memo_info, memo_search
from app.core.authorization.authz_user import AuthzUser
from app.core.data.crud import Crud
from app.core.data.crud.memo import crud_memo
Expand All @@ -19,9 +18,11 @@
)
from app.core.data.dto.search import PaginatedElasticSearchDocumentHits
from app.core.data.orm.util import get_parent_project_id
from app.core.filters.columns import ColumnInfo
from app.core.filters.filtering import Filter
from app.core.filters.sorting import Sort
from app.core.search.column_info import ColumnInfo
from app.core.search.filtering import Filter
from app.core.search.memo_search.memo_search import memo_info, memo_search
from app.core.search.memo_search.memo_search_columns import MemoColumns
from app.core.search.sorting import Sort

router = APIRouter(
prefix="/memo", dependencies=[Depends(get_current_user)], tags=["memo"]
Expand Down
2 changes: 1 addition & 1 deletion backend/src/api/endpoints/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from app.core.data.dto.project_metadata import ProjectMetadataRead
from app.core.data.dto.user import UserRead
from app.core.data.orm.source_document import SourceDocumentORM
from app.core.search.elasticsearch_service import ElasticSearchService
from app.core.db.elasticsearch_service import ElasticSearchService
from app.preprocessing.preprocessing_service import PreprocessingService

router = APIRouter(
Expand Down
93 changes: 53 additions & 40 deletions backend/src/api/endpoints/search.py
Original file line number Diff line number Diff line change
@@ -1,43 +1,46 @@
from typing import List, Optional
from typing import List, Optional, Union

from fastapi import APIRouter, Depends

import app.core.search.sdoc_search.sdoc_search as sdoc_search
from api.dependencies import get_current_user
from app.core.analysis.statistics import (
compute_code_statistics,
compute_keyword_statistics,
compute_tag_statistics,
)
from app.core.authorization.authz_user import AuthzUser
from app.core.data.crud import Crud
from app.core.data.dto.search import (
PaginatedElasticSearchDocumentHits,
SearchColumns,
SimSearchImageHit,
SimSearchQuery,
SimSearchSentenceHit,
)
from app.core.data.dto.search_stats import KeywordStat, SpanEntityStat, TagStat
from app.core.filters.columns import ColumnInfo
from app.core.filters.filtering import Filter
from app.core.filters.sorting import Sort
from app.core.search.elasticsearch_service import ElasticSearchService
from app.core.search.search_service import SearchService
from app.core.db.elasticsearch_service import ElasticSearchService
from app.core.search.column_info import ColumnInfo
from app.core.search.filtering import Filter
from app.core.search.sdoc_search.sdoc_search_columns import SdocColumns
from app.core.search.sorting import Sort

router = APIRouter(
prefix="/search", dependencies=[Depends(get_current_user)], tags=["search"]
)

ss = SearchService()
es = ElasticSearchService()


@router.post(
"/sdoc_info",
response_model=List[ColumnInfo[SearchColumns]],
response_model=List[ColumnInfo[SdocColumns]],
summary="Returns Search Info.",
)
def search_sdocs_info(
*, project_id: int, authz_user: AuthzUser = Depends()
) -> List[ColumnInfo[SearchColumns]]:
) -> List[ColumnInfo[SdocColumns]]:
authz_user.assert_in_project(project_id)

return SearchService().search_info(project_id=project_id)
return sdoc_search.search_info(project_id=project_id)


@router.post(
Expand All @@ -50,15 +53,15 @@ def search_sdocs(
project_id: int,
search_query: str,
expert_mode: bool,
filter: Filter[SearchColumns],
sorts: List[Sort[SearchColumns]],
filter: Filter[SdocColumns],
sorts: List[Sort[SdocColumns]],
highlight: bool,
page_number: Optional[int] = None,
page_size: Optional[int] = None,
authz_user: AuthzUser = Depends(),
) -> PaginatedElasticSearchDocumentHits:
authz_user.assert_in_project(project_id)
return SearchService().search(
return sdoc_search.search(
search_query=search_query,
expert_mode=expert_mode,
highlight=highlight,
Expand All @@ -85,12 +88,12 @@ def search_code_stats(
project_id: int,
search_query: str,
expert_mode: bool,
filter: Filter[SearchColumns],
sorts: List[Sort[SearchColumns]],
filter: Filter[SdocColumns],
sorts: List[Sort[SdocColumns]],
) -> List[SpanEntityStat]:
# search for relevant sdoc_ids
authz_user.assert_in_project(project_id)
search_result = SearchService().search(
search_result = sdoc_search.search(
project_id=project_id,
search_query=search_query,
expert_mode=expert_mode,
Expand All @@ -104,9 +107,7 @@ def search_code_stats(

# compute code stats
authz_user.assert_in_same_project_as(Crud.CODE, code_id)
code_stats = SearchService().compute_code_statistics(
code_id=code_id, sdoc_ids=set(sdoc_ids)
)
code_stats = compute_code_statistics(code_id=code_id, sdoc_ids=set(sdoc_ids))
if sort_by_global:
code_stats.sort(key=lambda x: x.global_count, reverse=True)
return code_stats
Expand All @@ -131,9 +132,7 @@ def filter_code_stats(

# compute code stats
authz_user.assert_in_same_project_as(Crud.CODE, code_id)
code_stats = SearchService().compute_code_statistics(
code_id=code_id, sdoc_ids=set(sdoc_ids)
)
code_stats = compute_code_statistics(code_id=code_id, sdoc_ids=set(sdoc_ids))
if sort_by_global:
code_stats.sort(key=lambda x: x.global_count, reverse=True)
return code_stats
Expand All @@ -154,12 +153,12 @@ def search_keyword_stats(
# search params
search_query: str,
expert_mode: bool,
filter: Filter[SearchColumns],
sorts: List[Sort[SearchColumns]],
filter: Filter[SdocColumns],
sorts: List[Sort[SdocColumns]],
) -> List[KeywordStat]:
# search for relevant sdoc_ids
authz_user.assert_in_project(project_id)
search_result = SearchService().search(
search_result = sdoc_search.search(
project_id=project_id,
search_query=search_query,
expert_mode=expert_mode,
Expand All @@ -172,7 +171,7 @@ def search_keyword_stats(
return []

# compute keyword stats
keyword_stats = SearchService().compute_keyword_statistics(
keyword_stats = compute_keyword_statistics(
proj_id=project_id, sdoc_ids=set(sdoc_ids), top_k=top_k
)
if sort_by_global:
Expand All @@ -199,7 +198,7 @@ def filter_keyword_stats(
return []

# compute keyword stats
keyword_stats = SearchService().compute_keyword_statistics(
keyword_stats = compute_keyword_statistics(
proj_id=project_id, sdoc_ids=set(sdoc_ids), top_k=top_k
)
if sort_by_global:
Expand All @@ -221,12 +220,12 @@ def search_tag_stats(
project_id: int,
search_query: str,
expert_mode: bool,
filter: Filter[SearchColumns],
sorts: List[Sort[SearchColumns]],
filter: Filter[SdocColumns],
sorts: List[Sort[SdocColumns]],
) -> List[TagStat]:
# search for relevant sdoc_ids
authz_user.assert_in_project(project_id)
search_result = SearchService().search(
search_result = sdoc_search.search(
project_id=project_id,
search_query=search_query,
expert_mode=expert_mode,
Expand All @@ -239,7 +238,7 @@ def search_tag_stats(
return []

# compute tag stats
tag_stats = SearchService().compute_tag_statistics(sdoc_ids=set(sdoc_ids))
tag_stats = compute_tag_statistics(sdoc_ids=set(sdoc_ids))
if sort_by_global:
tag_stats.sort(key=lambda x: x.global_count, reverse=True)
return tag_stats
Expand All @@ -262,7 +261,7 @@ def filter_tag_stats(
return []

# compute tag stats
tag_stats = SearchService().compute_tag_statistics(sdoc_ids=set(sdoc_ids))
tag_stats = compute_tag_statistics(sdoc_ids=set(sdoc_ids))
if sort_by_global:
tag_stats.sort(key=lambda x: x.global_count, reverse=True)
return tag_stats
Expand All @@ -274,11 +273,18 @@ def filter_tag_stats(
summary="Returns similar sentences according to a textual or visual query.",
)
def find_similar_sentences(
query: SimSearchQuery, authz_user: AuthzUser = Depends()
proj_id: int,
query: Union[str, List[str], int],
top_k: int,
threshold: float,
filter: Filter[SdocColumns],
authz_user: AuthzUser = Depends(),
) -> List[SimSearchSentenceHit]:
authz_user.assert_in_project(query.proj_id)
authz_user.assert_in_project(proj_id)

return ss.find_similar_sentences(query=query)
return sdoc_search.find_similar_sentences(
proj_id=proj_id, query=query, top_k=top_k, threshold=threshold, filter=filter
)


@router.post(
Expand All @@ -287,8 +293,15 @@ def find_similar_sentences(
summary="Returns similar images according to a textual or visual query.",
)
def find_similar_images(
query: SimSearchQuery, authz_user: AuthzUser = Depends()
proj_id: int,
query: Union[str, List[str], int],
top_k: int,
threshold: float,
filter: Filter[SdocColumns],
authz_user: AuthzUser = Depends(),
) -> List[SimSearchImageHit]:
authz_user.assert_in_project(query.proj_id)
authz_user.assert_in_project(proj_id)

return ss.find_similar_images(query=query)
return sdoc_search.find_similar_images(
proj_id=proj_id, query=query, top_k=top_k, threshold=threshold, filter=filter
)
10 changes: 6 additions & 4 deletions backend/src/api/endpoints/timeline_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@
from sqlalchemy.orm import Session

from api.dependencies import get_current_user, get_db_session
from app.core.analysis.timeline import (
TimelineAnalysisColumns,
from app.core.analysis.timeline_analysis.timeline import (
timeline_analysis,
timeline_analysis_info,
)
from app.core.analysis.timeline_analysis.timeline_analysis_columns import (
TimelineAnalysisColumns,
)
from app.core.authorization.authz_user import AuthzUser
from app.core.data.crud import Crud
from app.core.data.crud.timeline_analysis import crud_timeline_analysis
Expand All @@ -19,8 +21,8 @@
TimelineAnalysisRead,
TimelineAnalysisUpdate,
)
from app.core.filters.columns import ColumnInfo
from app.core.filters.filtering import Filter
from app.core.search.column_info import ColumnInfo
from app.core.search.filtering import Filter

router = APIRouter(
prefix="/timelineAnalysis",
Expand Down
Loading
Loading