Skip to content

Commit

Permalink
Merge pull request #462 from uhh-lt/improve-search
Browse files Browse the repository at this point in the history
Improve search
  • Loading branch information
bigabig authored Nov 11, 2024
2 parents adcf43f + 25b8ea5 commit 4712931
Show file tree
Hide file tree
Showing 101 changed files with 2,651 additions and 2,439 deletions.
50 changes: 27 additions & 23 deletions backend/src/api/endpoints/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,14 @@

from api.dependencies import get_current_user, get_db_session
from app.core.analysis.analysis_service import AnalysisService
from app.core.analysis.annotated_images import (
AnnotatedImagesColumns,
find_annotated_images,
find_annotated_images_info,
)
from app.core.analysis.annotated_segments import (
AnnotatedSegmentsColumns,
find_annotated_segments,
find_annotated_segments_info,
)
from app.core.analysis.word_frequency import (
WordFrequencyColumns,
from app.core.analysis.word_frequency_analysis.word_frequency import (
word_frequency,
word_frequency_export,
word_frequency_info,
)
from app.core.analysis.word_frequency_analysis.word_frequency_columns import (
WordFrequencyColumns,
)
from app.core.authorization.authz_user import AuthzUser
from app.core.data.doc_type import DocType
from app.core.data.dto.analysis import (
Expand All @@ -32,9 +24,21 @@
SampledSdocsResults,
WordFrequencyResult,
)
from app.core.filters.columns import ColumnInfo
from app.core.filters.filtering import Filter
from app.core.filters.sorting import Sort
from app.core.search.bbox_search.bbox_search import (
find_annotated_images,
find_annotated_images_info,
)
from app.core.search.bbox_search.bbox_search_columns import BBoxColumns
from app.core.search.column_info import ColumnInfo
from app.core.search.filtering import Filter
from app.core.search.sorting import Sort
from app.core.search.span_search.span_search import (
find_annotated_segments,
find_annotated_segments_info,
)
from app.core.search.span_search.span_search_columns import (
SpanColumns,
)

router = APIRouter(
prefix="/analysis", dependencies=[Depends(get_current_user)], tags=["analysis"]
Expand Down Expand Up @@ -101,14 +105,14 @@ def annotation_occurrences(

@router.post(
"/annotated_segments_info",
response_model=List[ColumnInfo[AnnotatedSegmentsColumns]],
response_model=List[ColumnInfo[SpanColumns]],
summary="Returns AnnotationSegments Info.",
)
def annotated_segments_info(
*,
project_id: int,
authz_user: AuthzUser = Depends(),
) -> List[ColumnInfo[AnnotatedSegmentsColumns]]:
) -> List[ColumnInfo[SpanColumns]]:
authz_user.assert_in_project(project_id)
return find_annotated_segments_info(
project_id=project_id,
Expand All @@ -124,10 +128,10 @@ def annotated_segments(
*,
project_id: int,
user_id: int,
filter: Filter[AnnotatedSegmentsColumns],
filter: Filter[SpanColumns],
page: Optional[int] = None,
page_size: Optional[int] = None,
sorts: List[Sort[AnnotatedSegmentsColumns]],
sorts: List[Sort[SpanColumns]],
authz_user: AuthzUser = Depends(),
) -> AnnotatedSegmentResult:
authz_user.assert_in_project(project_id)
Expand All @@ -144,14 +148,14 @@ def annotated_segments(

@router.post(
"/annotated_images_info",
response_model=List[ColumnInfo[AnnotatedImagesColumns]],
response_model=List[ColumnInfo[BBoxColumns]],
summary="Returns AnnotationSegments Info.",
)
def annotated_images_info(
*,
project_id: int,
authz_user: AuthzUser = Depends(),
) -> List[ColumnInfo[AnnotatedImagesColumns]]:
) -> List[ColumnInfo[BBoxColumns]]:
authz_user.assert_in_project(project_id)
return find_annotated_images_info(
project_id=project_id,
Expand All @@ -167,10 +171,10 @@ def annotated_images(
*,
project_id: int,
user_id: int,
filter: Filter[AnnotatedImagesColumns],
filter: Filter[BBoxColumns],
page: Optional[int] = None,
page_size: Optional[int] = None,
sorts: List[Sort[AnnotatedImagesColumns]],
sorts: List[Sort[BBoxColumns]],
authz_user: AuthzUser = Depends(),
) -> AnnotatedImageResult:
authz_user.assert_in_project(project_id)
Expand Down
9 changes: 5 additions & 4 deletions backend/src/api/endpoints/memo.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

from api.dependencies import get_current_user, get_db_session
from api.validation import Validate
from app.core.analysis.memo import MemoColumns, memo_info, memo_search
from app.core.authorization.authz_user import AuthzUser
from app.core.data.crud import Crud
from app.core.data.crud.memo import crud_memo
Expand All @@ -19,9 +18,11 @@
)
from app.core.data.dto.search import PaginatedElasticSearchDocumentHits
from app.core.data.orm.util import get_parent_project_id
from app.core.filters.columns import ColumnInfo
from app.core.filters.filtering import Filter
from app.core.filters.sorting import Sort
from app.core.search.column_info import ColumnInfo
from app.core.search.filtering import Filter
from app.core.search.memo_search.memo_search import memo_info, memo_search
from app.core.search.memo_search.memo_search_columns import MemoColumns
from app.core.search.sorting import Sort

router = APIRouter(
prefix="/memo", dependencies=[Depends(get_current_user)], tags=["memo"]
Expand Down
2 changes: 1 addition & 1 deletion backend/src/api/endpoints/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from app.core.data.dto.project_metadata import ProjectMetadataRead
from app.core.data.dto.user import UserRead
from app.core.data.orm.source_document import SourceDocumentORM
from app.core.search.elasticsearch_service import ElasticSearchService
from app.core.db.elasticsearch_service import ElasticSearchService
from app.preprocessing.preprocessing_service import PreprocessingService

router = APIRouter(
Expand Down
93 changes: 53 additions & 40 deletions backend/src/api/endpoints/search.py
Original file line number Diff line number Diff line change
@@ -1,43 +1,46 @@
from typing import List, Optional
from typing import List, Optional, Union

from fastapi import APIRouter, Depends

import app.core.search.sdoc_search.sdoc_search as sdoc_search
from api.dependencies import get_current_user
from app.core.analysis.statistics import (
compute_code_statistics,
compute_keyword_statistics,
compute_tag_statistics,
)
from app.core.authorization.authz_user import AuthzUser
from app.core.data.crud import Crud
from app.core.data.dto.search import (
PaginatedElasticSearchDocumentHits,
SearchColumns,
SimSearchImageHit,
SimSearchQuery,
SimSearchSentenceHit,
)
from app.core.data.dto.search_stats import KeywordStat, SpanEntityStat, TagStat
from app.core.filters.columns import ColumnInfo
from app.core.filters.filtering import Filter
from app.core.filters.sorting import Sort
from app.core.search.elasticsearch_service import ElasticSearchService
from app.core.search.search_service import SearchService
from app.core.db.elasticsearch_service import ElasticSearchService
from app.core.search.column_info import ColumnInfo
from app.core.search.filtering import Filter
from app.core.search.sdoc_search.sdoc_search_columns import SdocColumns
from app.core.search.sorting import Sort

router = APIRouter(
prefix="/search", dependencies=[Depends(get_current_user)], tags=["search"]
)

ss = SearchService()
es = ElasticSearchService()


@router.post(
"/sdoc_info",
response_model=List[ColumnInfo[SearchColumns]],
response_model=List[ColumnInfo[SdocColumns]],
summary="Returns Search Info.",
)
def search_sdocs_info(
*, project_id: int, authz_user: AuthzUser = Depends()
) -> List[ColumnInfo[SearchColumns]]:
) -> List[ColumnInfo[SdocColumns]]:
authz_user.assert_in_project(project_id)

return SearchService().search_info(project_id=project_id)
return sdoc_search.search_info(project_id=project_id)


@router.post(
Expand All @@ -50,15 +53,15 @@ def search_sdocs(
project_id: int,
search_query: str,
expert_mode: bool,
filter: Filter[SearchColumns],
sorts: List[Sort[SearchColumns]],
filter: Filter[SdocColumns],
sorts: List[Sort[SdocColumns]],
highlight: bool,
page_number: Optional[int] = None,
page_size: Optional[int] = None,
authz_user: AuthzUser = Depends(),
) -> PaginatedElasticSearchDocumentHits:
authz_user.assert_in_project(project_id)
return SearchService().search(
return sdoc_search.search(
search_query=search_query,
expert_mode=expert_mode,
highlight=highlight,
Expand All @@ -85,12 +88,12 @@ def search_code_stats(
project_id: int,
search_query: str,
expert_mode: bool,
filter: Filter[SearchColumns],
sorts: List[Sort[SearchColumns]],
filter: Filter[SdocColumns],
sorts: List[Sort[SdocColumns]],
) -> List[SpanEntityStat]:
# search for relevant sdoc_ids
authz_user.assert_in_project(project_id)
search_result = SearchService().search(
search_result = sdoc_search.search(
project_id=project_id,
search_query=search_query,
expert_mode=expert_mode,
Expand All @@ -104,9 +107,7 @@ def search_code_stats(

# compute code stats
authz_user.assert_in_same_project_as(Crud.CODE, code_id)
code_stats = SearchService().compute_code_statistics(
code_id=code_id, sdoc_ids=set(sdoc_ids)
)
code_stats = compute_code_statistics(code_id=code_id, sdoc_ids=set(sdoc_ids))
if sort_by_global:
code_stats.sort(key=lambda x: x.global_count, reverse=True)
return code_stats
Expand All @@ -131,9 +132,7 @@ def filter_code_stats(

# compute code stats
authz_user.assert_in_same_project_as(Crud.CODE, code_id)
code_stats = SearchService().compute_code_statistics(
code_id=code_id, sdoc_ids=set(sdoc_ids)
)
code_stats = compute_code_statistics(code_id=code_id, sdoc_ids=set(sdoc_ids))
if sort_by_global:
code_stats.sort(key=lambda x: x.global_count, reverse=True)
return code_stats
Expand All @@ -154,12 +153,12 @@ def search_keyword_stats(
# search params
search_query: str,
expert_mode: bool,
filter: Filter[SearchColumns],
sorts: List[Sort[SearchColumns]],
filter: Filter[SdocColumns],
sorts: List[Sort[SdocColumns]],
) -> List[KeywordStat]:
# search for relevant sdoc_ids
authz_user.assert_in_project(project_id)
search_result = SearchService().search(
search_result = sdoc_search.search(
project_id=project_id,
search_query=search_query,
expert_mode=expert_mode,
Expand All @@ -172,7 +171,7 @@ def search_keyword_stats(
return []

# compute keyword stats
keyword_stats = SearchService().compute_keyword_statistics(
keyword_stats = compute_keyword_statistics(
proj_id=project_id, sdoc_ids=set(sdoc_ids), top_k=top_k
)
if sort_by_global:
Expand All @@ -199,7 +198,7 @@ def filter_keyword_stats(
return []

# compute keyword stats
keyword_stats = SearchService().compute_keyword_statistics(
keyword_stats = compute_keyword_statistics(
proj_id=project_id, sdoc_ids=set(sdoc_ids), top_k=top_k
)
if sort_by_global:
Expand All @@ -221,12 +220,12 @@ def search_tag_stats(
project_id: int,
search_query: str,
expert_mode: bool,
filter: Filter[SearchColumns],
sorts: List[Sort[SearchColumns]],
filter: Filter[SdocColumns],
sorts: List[Sort[SdocColumns]],
) -> List[TagStat]:
# search for relevant sdoc_ids
authz_user.assert_in_project(project_id)
search_result = SearchService().search(
search_result = sdoc_search.search(
project_id=project_id,
search_query=search_query,
expert_mode=expert_mode,
Expand All @@ -239,7 +238,7 @@ def search_tag_stats(
return []

# compute tag stats
tag_stats = SearchService().compute_tag_statistics(sdoc_ids=set(sdoc_ids))
tag_stats = compute_tag_statistics(sdoc_ids=set(sdoc_ids))
if sort_by_global:
tag_stats.sort(key=lambda x: x.global_count, reverse=True)
return tag_stats
Expand All @@ -262,7 +261,7 @@ def filter_tag_stats(
return []

# compute tag stats
tag_stats = SearchService().compute_tag_statistics(sdoc_ids=set(sdoc_ids))
tag_stats = compute_tag_statistics(sdoc_ids=set(sdoc_ids))
if sort_by_global:
tag_stats.sort(key=lambda x: x.global_count, reverse=True)
return tag_stats
Expand All @@ -274,11 +273,18 @@ def filter_tag_stats(
summary="Returns similar sentences according to a textual or visual query.",
)
def find_similar_sentences(
query: SimSearchQuery, authz_user: AuthzUser = Depends()
proj_id: int,
query: Union[str, List[str], int],
top_k: int,
threshold: float,
filter: Filter[SdocColumns],
authz_user: AuthzUser = Depends(),
) -> List[SimSearchSentenceHit]:
authz_user.assert_in_project(query.proj_id)
authz_user.assert_in_project(proj_id)

return ss.find_similar_sentences(query=query)
return sdoc_search.find_similar_sentences(
proj_id=proj_id, query=query, top_k=top_k, threshold=threshold, filter=filter
)


@router.post(
Expand All @@ -287,8 +293,15 @@ def find_similar_sentences(
summary="Returns similar images according to a textual or visual query.",
)
def find_similar_images(
query: SimSearchQuery, authz_user: AuthzUser = Depends()
proj_id: int,
query: Union[str, List[str], int],
top_k: int,
threshold: float,
filter: Filter[SdocColumns],
authz_user: AuthzUser = Depends(),
) -> List[SimSearchImageHit]:
authz_user.assert_in_project(query.proj_id)
authz_user.assert_in_project(proj_id)

return ss.find_similar_images(query=query)
return sdoc_search.find_similar_images(
proj_id=proj_id, query=query, top_k=top_k, threshold=threshold, filter=filter
)
10 changes: 6 additions & 4 deletions backend/src/api/endpoints/timeline_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@
from sqlalchemy.orm import Session

from api.dependencies import get_current_user, get_db_session
from app.core.analysis.timeline import (
TimelineAnalysisColumns,
from app.core.analysis.timeline_analysis.timeline import (
timeline_analysis,
timeline_analysis_info,
)
from app.core.analysis.timeline_analysis.timeline_analysis_columns import (
TimelineAnalysisColumns,
)
from app.core.authorization.authz_user import AuthzUser
from app.core.data.crud import Crud
from app.core.data.crud.timeline_analysis import crud_timeline_analysis
Expand All @@ -19,8 +21,8 @@
TimelineAnalysisRead,
TimelineAnalysisUpdate,
)
from app.core.filters.columns import ColumnInfo
from app.core.filters.filtering import Filter
from app.core.search.column_info import ColumnInfo
from app.core.search.filtering import Filter

router = APIRouter(
prefix="/timelineAnalysis",
Expand Down
Loading

0 comments on commit 4712931

Please sign in to comment.