Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CVAT] Adapt exchange/recording oracles for honeypots #2720

Merged
merged 83 commits into from
Nov 14, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
83 commits
Select commit Hold shift + click to select a range
eee13bc
Update exchange oracle
Marishka17 Oct 29, 2024
4fa8d23
Add wait timeout when importing GT annotations
Marishka17 Oct 29, 2024
3ce5d69
Extract a base class for task creation
zhiltsov-max Oct 30, 2024
6795109
Move gt setup into the base class
zhiltsov-max Oct 30, 2024
621af5b
Add draft implementation for points task creation
zhiltsov-max Oct 30, 2024
e8727db
[draft] Update recording oracle
Marishka17 Oct 30, 2024
cb28f92
Upgrade cvat-sdk dep
zhiltsov-max Oct 31, 2024
f6daa82
[Exchnage oracle] apply some comments && small fixes
Marishka17 Oct 31, 2024
8f00430
Refactor some code, fix errors
zhiltsov-max Oct 31, 2024
619ae54
Add quality settings setup
zhiltsov-max Oct 31, 2024
c1e5e93
Fix update quality settings
zhiltsov-max Oct 31, 2024
05b5442
Use inbound bbox circle radius for point validation
zhiltsov-max Oct 31, 2024
d2f3b0b
Merge remote-tracking branch 'upstream/mk/update_cvat_oracles' into z…
zhiltsov-max Oct 31, 2024
0428f03
Fix linter errors
zhiltsov-max Oct 31, 2024
8ef59b4
Merge pull request #2730 from humanprotocol/zm/change_point_validation
zhiltsov-max Nov 1, 2024
e145b7d
Fix quality settings update call
zhiltsov-max Nov 1, 2024
aa3d642
Move common function to the base class
zhiltsov-max Nov 1, 2024
d42f0b6
Expect point group annotations in points annotation task
zhiltsov-max Nov 1, 2024
a4b6505
Fix response check
zhiltsov-max Nov 1, 2024
f0258c4
Improve formatting in the log message
zhiltsov-max Nov 1, 2024
a3fbb84
Use single shape mode for image_points task
zhiltsov-max Nov 1, 2024
19202c7
Refactor recording oracle updates
Marishka17 Nov 1, 2024
4ae5928
[Recording oracle] update deps
Marishka17 Nov 1, 2024
df99be5
Update assignment urls for skeleton tasks
zhiltsov-max Nov 1, 2024
2b9b5fa
Merge pull request #2743 from humanprotocol/zm/change_point_task_stat…
zhiltsov-max Nov 1, 2024
3b70e40
Rename quality parameter
zhiltsov-max Nov 1, 2024
2f62b70
Fix linter error
zhiltsov-max Nov 1, 2024
ebdb341
[Ex oracle] Move gt dataset preparation into separate method for skel…
Marishka17 Nov 3, 2024
79a7c0a
[Ex oracle] update deps
Marishka17 Nov 3, 2024
93284c9
Resolve conflicts
Marishka17 Nov 3, 2024
6e286f5
[Ex oracle] Improve handling oracle mode(dev/prod/test)
Marishka17 Nov 4, 2024
30f84dc
[Ex oracle] Fix test
Marishka17 Nov 4, 2024
24bbc5a
[Recording oracle] Apply comments && small fixes && remove unused code
Marishka17 Nov 4, 2024
95a80c1
[Exchnage oracle] Pass job start/stop frame from Ex oracle to Rec oracle
Marishka17 Nov 6, 2024
916cb0d
Update recording oracle
Marishka17 Nov 6, 2024
c00825c
t
Marishka17 Nov 6, 2024
5bc041f
[Exchange oracle] Fix tests
Marishka17 Nov 7, 2024
6bca715
[Exchange oracle] Add migration
Marishka17 Nov 7, 2024
75db488
Fix test
Marishka17 Nov 7, 2024
41bbcf5
[Exchange oracle] mark job start/stop frame as not nullable
Marishka17 Nov 7, 2024
d70a63b
Fix tests
Marishka17 Nov 7, 2024
a96d1e7
[Recording oracle] Clean up the code
Marishka17 Nov 7, 2024
8853208
Update packages/examples/cvat/exchange-oracle/src/handlers/job_creati…
Marishka17 Nov 7, 2024
1963855
Merge develop
Marishka17 Nov 7, 2024
a933ea6
[Recording oracle] Apply comments
Marishka17 Nov 7, 2024
7189d81
[Exchnage oracle] use_bbox_size_for_points -> point_size_base
Marishka17 Nov 7, 2024
34bbdf3
[Rec oracle] Use MT19937 generator
Marishka17 Nov 8, 2024
3f536a5
[Exchange oracle] Move BoxesFromPointsTaskBuilder::_prepare_gt_roi_da…
Marishka17 Nov 8, 2024
1b6e13a
[Exchange oracle] Fix checking which files should be uploaded to the …
Marishka17 Nov 8, 2024
a5399bd
[Exchange oracle] Update down_revision
Marishka17 Nov 8, 2024
f4c5900
fix typo
Marishka17 Nov 8, 2024
2e71542
[Exchange oracle] Include val_size into chunk_size
Marishka17 Nov 8, 2024
69fc2e3
Fix some errors
zhiltsov-max Nov 8, 2024
2d1af93
Enable empty frame matching
zhiltsov-max Nov 8, 2024
78daf11
Merge remote-tracking branch 'upstream/mk/update_cvat_oracles' into m…
zhiltsov-max Nov 8, 2024
b0603dc
Use the added parameter
zhiltsov-max Nov 8, 2024
8d56fc8
[Recording orcale] Fix get_task_quality_report
Marishka17 Nov 8, 2024
e4ce52d
[Ex oracle] Fix missing segment_size
Marishka17 Nov 11, 2024
fa5cf7a
[Rec oracle] Small fixes
Marishka17 Nov 11, 2024
0fec721
Update packages/examples/cvat/exchange-oracle/src/cvat/api_calls.py
Marishka17 Nov 11, 2024
907235c
[Exchange oracle] Bump cvat-sdk version
Marishka17 Nov 11, 2024
3847298
Fix roi GT dataset in boxes_from_points tasks
zhiltsov-max Nov 11, 2024
eb83293
Add more clever default for sort_images
zhiltsov-max Nov 11, 2024
6a94f45
Fix linter error, remove gt image data callback
zhiltsov-max Nov 11, 2024
a543c24
Fix type annotation
zhiltsov-max Nov 11, 2024
1f43589
Fix quality settings for skeletons_from_boxes
zhiltsov-max Nov 11, 2024
06166f9
Fix GT datasets for points in skeletons_from_boxes
zhiltsov-max Nov 11, 2024
e67625a
Simplify roi info id
zhiltsov-max Nov 11, 2024
481f251
Remove unused field from skeleton roi info
zhiltsov-max Nov 12, 2024
5a08e4d
Allow optional joints in skeleton task manifest
zhiltsov-max Nov 12, 2024
ba9228f
Add points task meta
zhiltsov-max Nov 12, 2024
16c67b0
Basic fix for merged dataset annotations
zhiltsov-max Nov 12, 2024
6e07890
Basic fix for premature escrow validation requests
zhiltsov-max Nov 12, 2024
203ddee
Fix incorrect GT preparation in boxes_from_points
zhiltsov-max Nov 12, 2024
dd43bfa
Refactor some code
zhiltsov-max Nov 12, 2024
a097c3e
Fix linter problem
zhiltsov-max Nov 12, 2024
e2cd19e
Use the original GT for final annotation merging
zhiltsov-max Nov 13, 2024
5f4a616
Resolve conflicts
Marishka17 Nov 13, 2024
2d70afb
Fix dataset merging for the points task
zhiltsov-max Nov 13, 2024
4d7683d
Update comment
zhiltsov-max Nov 13, 2024
7122809
[Exchange Oracle] Move cvat timeout settings to cvat config, update .…
zhiltsov-max Nov 13, 2024
3746126
[Recording Oracle] Update .env template, add some variables
zhiltsov-max Nov 13, 2024
5865287
Update packages/examples/cvat/exchange-oracle/src/core/config.py
Marishka17 Nov 13, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion packages/examples/cvat/exchange-oracle/src/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,14 +167,24 @@ class CronConfig:


class CvatConfig:
# TODO: it looks odd to use cvat_ prefix in class attributes inside CvatConfig
cvat_url = os.environ.get("CVAT_URL", "http://localhost:8080")
cvat_admin = os.environ.get("CVAT_ADMIN", "admin")
cvat_admin_pass = os.environ.get("CVAT_ADMIN_PASS", "admin")
cvat_org_slug = os.environ.get("CVAT_ORG_SLUG", "")

cvat_job_overlap = int(os.environ.get("CVAT_JOB_OVERLAP", 0))
cvat_job_segment_size = int(os.environ.get("CVAT_JOB_SEGMENT_SIZE", 150))
cvat_task_segment_size = int(os.environ.get("CVAT_TASK_SEGMENT_SIZE", 150))
cvat_default_image_quality = int(os.environ.get("CVAT_DEFAULT_IMAGE_QUALITY", 70))
cvat_max_jobs_per_task = int(os.environ.get("CVAT_MAX_JOBS_PER_TASK", 10 * 1000))

# quality control settings
cvat_val_frames_per_job_count = int(os.environ.get("CVAT_VAL_FRAMES_PER_JOB_COUNT", 2))
cvat_max_validation_checks = int(os.environ.get("CVAT_MAX_VALIDATION_CHECKS", 3))
cvat_iou_threshold = float(os.environ.get("CVAT_IOU_THRESHOLD", 0.5))
cvat_low_overlap_threshold = float(os.environ.get("CVAT_LOW_OVERLAP_THRESHOLD", 0.8))
cvat_target_metric_threshold = cvat_low_overlap_threshold
cvat_oks_sigma = float(os.environ.get("CVAT_OKS_SIGMA", 0.1))

cvat_incoming_webhooks_url = os.environ.get("CVAT_INCOMING_WEBHOOKS_URL")
cvat_webhook_secret = os.environ.get("CVAT_WEBHOOK_SECRET", "thisisasamplesecret")
Expand Down Expand Up @@ -223,6 +233,9 @@ class FeaturesConfig:
default_export_timeout = int(os.environ.get("DEFAULT_EXPORT_TIMEOUT", 60))
"Timeout, in seconds, for annotations or dataset export waiting"

default_import_timeout = int(os.environ.get("DEFAULT_IMPORT_TIMEOUT", 60))
"Timeout, in seconds, for waiting on GT annotations import"

request_logging_enabled = to_bool(os.getenv("REQUEST_LOGGING_ENABLED", "0"))
"Allow to log request details for each request"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
@frozen(kw_only=True)
class RoiInfo:
original_image_key: int
original_image_id: str
bbox_id: int
bbox_x: int
bbox_y: int
Expand Down
235 changes: 219 additions & 16 deletions packages/examples/cvat/exchange-oracle/src/cvat/api_calls.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,19 @@
from collections.abc import Generator
from contextlib import contextmanager
from contextvars import ContextVar
from datetime import timedelta
from datetime import datetime, timedelta, timezone
from enum import Enum
from http import HTTPStatus
from io import BytesIO
from pathlib import Path
from time import sleep
from typing import Any

from cvat_sdk import Client, make_client
from cvat_sdk.api_client import ApiClient, Configuration, exceptions, models
from cvat_sdk.api_client.api_client import Endpoint
from cvat_sdk.core.helpers import get_paginated_collection
from cvat_sdk.core.uploading import AnnotationUploader

from src.core.config import Config
from src.utils.enums import BetterEnumMeta
Expand Down Expand Up @@ -122,6 +125,16 @@ def get_api_client() -> ApiClient:
return api_client


def get_sdk_client() -> Client:
client = make_client(
host=Config.cvat_config.cvat_url,
credentials=(Config.cvat_config.cvat_admin, Config.cvat_config.cvat_admin_pass),
)
client.organization_slug = Config.cvat_config.cvat_org_slug

return client


def create_cloudstorage(
provider: str,
bucket_name: str,
Expand Down Expand Up @@ -297,14 +310,19 @@ def create_cvat_webhook(project_id: int) -> models.WebhookRead:
raise


def create_task(project_id: int, name: str) -> models.TaskRead:
def create_task(
project_id: int,
name: str,
*,
segment_size: int = Config.cvat_config.cvat_task_segment_size,
) -> models.TaskRead:
logger = logging.getLogger("app")
with get_api_client() as api_client:
task_write_request = models.TaskWriteRequest(
name=name,
project_id=project_id,
overlap=0,
segment_size=Config.cvat_config.cvat_job_segment_size,
segment_size=segment_size,
)
try:
(task_info, response) = api_client.tasks_api.create(task_write_request)
Expand Down Expand Up @@ -335,8 +353,14 @@ def put_task_data(
*,
filenames: list[str] | None = None,
sort_images: bool = True,
validation_params: dict[str, str | float | list[str]] | None = None,
) -> None:
logger = logging.getLogger("app")
sorting_method = (
models.SortingMethod("lexicographical")
if sort_images
else models.SortingMethod("predefined")
)

with get_api_client() as api_client:
kwargs = {}
Expand All @@ -345,21 +369,42 @@ def put_task_data(
else:
kwargs["filename_pattern"] = "*"

if validation_params:
logger.info(
f"The {sorting_method} is ignored."
'Only "random" sorting can be used when validation parameters passed.'
)
sorting_method = models.SortingMethod("random")

gt_filenames = validation_params["gt_filenames"]
if missed_filenames := set(gt_filenames) - set(filenames):
filenames.extend(missed_filenames)

kwargs["validation_params"] = models.DataRequestValidationParams(
mode=models.ValidationMode("gt_pool"),
frames=gt_filenames,
frame_selection_method=models.FrameSelectionMethod("manual"),
frames_per_job_count=validation_params.get(
"gt_frames_per_job_count",
Config.cvat_config.cvat_val_frames_per_job_count,
),
)

data_request = models.DataRequest(
chunk_size=Config.cvat_config.cvat_job_segment_size,
chunk_size=Config.cvat_config.cvat_task_segment_size,
cloud_storage_id=cloudstorage_id,
image_quality=Config.cvat_config.cvat_default_image_quality,
use_cache=True,
use_zip_chunks=True,
sorting_method="lexicographical" if sort_images else "predefined",
sorting_method=sorting_method,
**kwargs,
)
try:
(_, response) = api_client.tasks_api.create_data(task_id, data_request=data_request)
return

except exceptions.ApiException as e:
logger.exception(f"Exception when calling ProjectsApi.put_task_data: {e}\n")
logger.exception(f"Exception when calling tasks_api.create_data: {e}\n")
raise


Expand Down Expand Up @@ -563,36 +608,194 @@ def clear_job_annotations(job_id: int) -> None:
raise


def update_job_assignee(id: str, assignee_id: int | None):
def setup_gt_job(task_id: int, dataset_path: Path, format_name: str) -> None:
gt_job = get_gt_job(task_id)
upload_gt_annotations(gt_job.id, dataset_path, format_name=format_name)
finish_gt_job(gt_job.id)
settings = get_quality_control_settings(task_id)
update_quality_control_settings(settings.id)


def get_gt_job(task_id: int) -> models.JobRead:
logger = logging.getLogger("app")

with get_api_client() as api_client:
try:
api_client.jobs_api.partial_update(
id=id,
patched_job_write_request=models.PatchedJobWriteRequest(assignee=assignee_id),
(paginated_jobs, _) = api_client.jobs_api.list(task_id=task_id, type="ground_truth")
assert (
len(paginated_jobs["results"]) == 1
), f'CVAT returned {len(paginated_jobs["results"])} GT jobs'
return paginated_jobs["results"][0]
except (exceptions.ApiException, AssertionError) as ex:
logger.exception(f"Exception when calling JobsApi.list(): {ex}\n")
raise


def upload_gt_annotations(
job_id: int,
dataset_path: Path,
*,
format_name: str,
sleep_interval: int = 5,
timeout: int | None = Config.features.default_import_timeout,
) -> None:
# FUTURE-TODO: use job.import_annotations when CVAT will support waiting timeout
start_time = datetime.now(timezone.utc)
logger = logging.getLogger("app")

with get_sdk_client() as client:
uploader = AnnotationUploader(client)
url = client.api_map.make_endpoint_url(
client.api_client.jobs_api.create_annotations_endpoint.path, kwsub={"id": job_id}
)

try:
response = uploader.upload_file(
url,
dataset_path,
query_params={"format": format_name, "filename": dataset_path.name},
meta={"filename": dataset_path.name},
)
except Exception as ex:
logger.exception(f"Exception occurred while importing GT annotations: {ex}\n")
raise

request_id = json.loads(response.data).get("rq_id")
assert request_id, "CVAT server have not returned rq_id in the response."

while True:
try:
(request_details, _) = client.api_client.requests_api.retrieve(request_id)
except exceptions.ApiException as ex:
logger.exception(f"Exception occurred while importing GT annotations: {ex}\n")
raise

if (
request_details.status.value
== models.RequestStatus.allowed_values[("value",)]["FINISHED"]
):
break

if (
request_details.status.value
== models.RequestStatus.allowed_values[("value",)]["FAILED"]
):
raise Exception(
"Annotations upload failed. "
f"Previous status was: {request_details.status.value}."
)

if timeout is not None and timedelta(seconds=timeout) < (utcnow() - start_time):
raise Exception(
"Failed to upload the GT annotations to CVAT within the timeout interval. "
f"Previous status was: {request_details.status.value}. "
f"Timeout: {timeout} seconds."
)

sleep(sleep_interval)

logger.info(f"GT annotations for the job {job_id} have been uploaded to CVAT.")


def get_quality_control_settings(task_id: int) -> models.QualitySettings:
logger = logging.getLogger("app")

with get_api_client() as api_client:
try:
paginated_data, _ = api_client.quality_api.list_settings(task_id=task_id)
assert len(paginated_data["results"]) == 1, (
f'CVAT returned {len(paginated_data["results"])}'
"quality control settings associated with the task"
)
return paginated_data["results"][0]

except (exceptions.ApiException, AssertionError) as e:
logger.exception(f"Exception when calling QualityApi.list_settings(): {e}\n")
raise


def update_quality_control_settings(
settings_id: int,
*,
max_validations_per_job: int = Config.cvat_config.cvat_max_validation_checks,
target_metric: str = "accuracy",
target_metric_threshold: float = Config.cvat_config.cvat_target_metric_threshold,
low_overlap_threshold: float = Config.cvat_config.cvat_low_overlap_threshold,
iou_threshold: float = Config.cvat_config.cvat_iou_threshold,
oks_sigma: float = Config.cvat_config.cvat_oks_sigma,
) -> None:
logger = logging.getLogger("app")

with get_api_client() as api_client:
try:
api_client.quality_api.partial_update_settings(
settings_id,
patched_quality_settings_request=models.PatchedQualitySettingsRequest(
max_validations_per_job=max_validations_per_job,
target_metric=target_metric,
target_metric_threshold=target_metric_threshold,
iou_threshold=iou_threshold,
low_overlap_threshold=low_overlap_threshold,
oks_sigma=oks_sigma,
),
)
except exceptions.ApiException as e:
logger.exception(f"Exception when calling JobsApi.partial_update(): {e}\n")
logger.exception(f"Exception when calling QualityApi.partial_update_settings(): {e}\n")
raise


def restart_job(id: str, *, assignee_id: int | None = None):
def _update_job(
job_id: int,
*,
assignee_id: int | None | object = _NOTSET,
stage: models.JobStage | None = None,
state: models.OperationStatus | None = None,
) -> None:
to_update = {
attr: value
for attr, value in {
"stage": stage,
"state": state,
}.items()
if value
}

if assignee_id is not _NOTSET:
to_update["assignee"] = assignee_id

assert to_update

logger = logging.getLogger("app")

with get_api_client() as api_client:
try:
api_client.jobs_api.partial_update(
id=id,
patched_job_write_request=models.PatchedJobWriteRequest(
stage="annotation", state="new", assignee=assignee_id
),
job_id, patched_job_write_request=models.PatchedJobWriteRequest(**to_update)
)
except exceptions.ApiException as e:
logger.exception(f"Exception when calling JobsApi.partial_update(): {e}\n")
raise


def update_job_assignee(id: int, assignee_id: int | None):
_update_job(id, assignee_id=assignee_id)


def restart_job(id: str, *, assignee_id: int | None = None):
_update_job(
id,
stage=models.JobStage("annotation"),
state=models.OperationStatus("new"),
assignee_id=assignee_id,
)


def finish_gt_job(job_id: int) -> None:
_update_job(
job_id, stage=models.JobStage("acceptance"), state=models.OperationStatus("completed")
)


def get_user_id(user_email: str) -> int:
logger = logging.getLogger("app")

Expand Down
Loading
Loading