Skip to content

Commit

Permalink
[FEATURE] SDK - Add support for response status (#4977)
Browse files Browse the repository at this point in the history
<!-- Thanks for your contribution! As part of our Community Growers
initiative 🌱, we're donating Justdiggit bunds in your name to reforest
sub-Saharan Africa. To claim your Community Growers certificate, please
contact David Berenstein in our Slack community or fill in this form
https://tally.so/r/n9XrxK once your PR has been merged. -->

# Description

This PR adds support to provide the status for record responses. This is
quite important to admin users when restoring annotated data from
external sources. Also, the migration script from the legacy dataset can
benefit the change.

When several statuses are found for the same user before sending data to
the server, a warning is shown to the users and the `draft` status is
selected.



**Type of change**

(Please delete options that are not relevant. Remember to title the PR
according to the type of change)

- [ ] New feature (non-breaking change which adds functionality)
- [ ] Refactor (change restructuring the codebase without changing
functionality)
- [ ] Improvement (change adding some improvement to an existing
functionality)

**How Has This Been Tested**

(Please describe the tests that you ran to verify your changes. And
ideally, reference `tests`)

- [ ] Test A
- [ ] Test B

**Checklist**

- [ ] I added relevant documentation
- [ ] I followed the style guidelines of this project
- [ ] I did a self-review of my code
- [ ] I made corresponding changes to the documentation
- [ ] My changes generate no new warnings
- [ ] I have added tests that prove my fix is effective or that my
feature works
- [ ] I filled out [the contributor form](https://tally.so/r/n9XrxK)
(see text above)
- [ ] I have added relevant notes to the `CHANGELOG.md` file (See
https://keepachangelog.com/)
  • Loading branch information
frascuchon authored Jun 12, 2024
1 parent aa2cbe6 commit 279ff4e
Show file tree
Hide file tree
Showing 4 changed files with 162 additions and 35 deletions.
28 changes: 14 additions & 14 deletions argilla-sdk/src/argilla_sdk/records/_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,18 @@ def __init__(self, responses: List[Response], record: Record) -> None:
response.record = self.record
self.__responses_by_question_name[response.question_name].append(response)

def __iter__(self):
return iter(self.__responses)

def __getitem__(self, index: int):
return self.__responses[index]

def __getattr__(self, name) -> List[Response]:
return self.__responses_by_question_name[name]

def __repr__(self) -> str:
return {k: [{"value": v["value"]} for v in values] for k, values in self.to_dict().items()}.__repr__()

def api_models(self) -> List[UserResponseModel]:
"""Returns a list of ResponseModel objects."""

Expand All @@ -305,19 +317,10 @@ def api_models(self) -> List[UserResponseModel]:
responses_by_user_id[response.user_id].append(response)

return [
UserResponse(user_id=user_id, answers=responses, _record=self.record).api_model()
for user_id, responses in responses_by_user_id.items()
UserResponse(answers=responses, _record=self.record).api_model()
for responses in responses_by_user_id.values()
]

def __iter__(self):
return iter(self.__responses)

def __getitem__(self, index: int):
return self.__responses[index]

def __getattr__(self, name) -> List[Response]:
return self.__responses_by_question_name[name]

def to_dict(self) -> Dict[str, List[Dict]]:
"""Converts the responses to a dictionary.
Returns:
Expand All @@ -328,9 +331,6 @@ def to_dict(self) -> Dict[str, List[Dict]]:
response_dict[response.question_name].append({"value": response.value, "user_id": response.user_id})
return response_dict

def __repr__(self) -> str:
return {k: [{"value": v["value"]} for v in values] for k, values in self.to_dict().items()}.__repr__()


class RecordSuggestions(Iterable[Suggestion]):
"""This is a container class for the suggestions of a Record.
Expand Down
63 changes: 46 additions & 17 deletions argilla-sdk/src/argilla_sdk/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,23 +11,32 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, TYPE_CHECKING, List, Dict, Optional, Iterable
import warnings
from enum import Enum
from typing import Any, TYPE_CHECKING, List, Dict, Optional, Iterable, Union
from uuid import UUID

from argilla_sdk._models import UserResponseModel, ResponseStatus
from argilla_sdk._models import UserResponseModel, ResponseStatus as ResponseStatusModel
from argilla_sdk._resource import Resource
from argilla_sdk.settings import RankingQuestion

if TYPE_CHECKING:
from argilla_sdk import Argilla, Dataset, Record

__all__ = ["Response", "UserResponse"]
__all__ = ["Response", "UserResponse", "ResponseStatus"]


class ResponseStatus(str, Enum):
"""Enum for the status of a response"""

draft = "draft"
submitted = "submitted"
discarded = "discarded"


class Response:
"""Class for interacting with Argilla Responses of records. Responses are answers to questions by a user.
Therefore, a recod question can have multiple responses, one for each user that has answered the question.
Therefore, a record question can have multiple responses, one for each user that has answered the question.
A `Response` is typically created by a user in the UI or consumed from a data source as a label,
unlike a `Suggestion` which is typically created by a model prediction.
Expand All @@ -38,6 +47,7 @@ def __init__(
question_name: str,
value: Any,
user_id: UUID,
status: Optional[Union[ResponseStatus, str]] = None,
_record: Optional["Record"] = None,
) -> None:
"""Initializes a `Response` for a `Record` with a user_id and value"""
Expand All @@ -49,10 +59,14 @@ def __init__(
if user_id is None:
raise ValueError("user_id is required")

if isinstance(status, str):
status = ResponseStatus(status)

self.record = _record
self.question_name = question_name
self.value = value
self.user_id = user_id
self.status = status

def serialize(self) -> dict[str, Any]:
"""Serializes the Response to a dictionary. This is principally used for sending the response to the API, \
Expand All @@ -72,6 +86,7 @@ def serialize(self) -> dict[str, Any]:
"question_name": self.question_name,
"value": self.value,
"user_id": self.user_id,
"status": self.status,
}

#####################
Expand All @@ -89,19 +104,17 @@ class UserResponse(Resource):
collected from the server or when creating new records.
Attributes:
status (ResponseStatus): The status of the UserResponse (draft, submitted, etc.)
user_id (UUID): The user_id of the UserResponse (the user who answered the questions)
answers (List[Response]): A list of responses to questions for the user
"""

answers: List[Response]

_model: UserResponseModel

def __init__(
self,
user_id: UUID,
answers: List[Response],
status: ResponseStatus = "draft",
client: Optional["Argilla"] = None,
_record: Optional["Record"] = None,
) -> None:
Expand All @@ -112,8 +125,8 @@ def __init__(
self._record = _record
self._model = UserResponseModel(
values=self.__responses_as_model_values(answers),
status=status,
user_id=user_id,
status=self._compute_status_from_answers(answers),
user_id=self._compute_user_id_from_answers(answers),
)

def __iter__(self) -> Iterable[Response]:
Expand Down Expand Up @@ -148,13 +161,13 @@ def answers(self) -> List[Response]:
def from_model(cls, model: UserResponseModel, dataset: "Dataset") -> "UserResponse":
"""Creates a UserResponse from a ResponseModel"""
answers = cls.__model_as_response_list(model)

for answer in answers:
question = dataset.settings.question_by_name(answer.question_name)
# We need to adapt the ranking question value to the expected format
if isinstance(question, RankingQuestion):
answer.value = cls.__ranking_from_model_value(answer.value) # type: ignore

return cls(user_id=model.user_id, answers=answers, status=model.status)
return cls(answers=answers)

def api_model(self):
"""Returns the model that is used to interact with the API"""
Expand All @@ -171,16 +184,32 @@ def to_dict(self) -> Dict[str, Any]:
"""Returns the UserResponse as a dictionary"""
return self._model.model_dump()

def _compute_status_from_answers(self, answers: List[Response]) -> ResponseStatus:
"""Computes the status of the UserResponse from the responses"""
statuses = set([answer.status for answer in answers if answer.status is not None])
if len(statuses) > 1:
warnings.warn(f"Multiple status found in user answers. Using {ResponseStatus.draft!r} as default.")
elif len(statuses) == 1:
return ResponseStatusModel(next(iter(statuses)))
return ResponseStatusModel.draft

def _compute_user_id_from_answers(self, answers: List[Response]) -> UUID:
user_ids = set([answer.user_id for answer in answers])
if len(user_ids) > 1:
raise ValueError("Multiple user_ids found in user answers.")
return next(iter(user_ids))

@staticmethod
def __responses_as_model_values(answers: List[Response]) -> Dict[str, Dict[str, Any]]:
"""Creates a dictionary of response values from a list of Responses"""
return {answer.question_name: {"value": answer.value} for answer in answers}

@staticmethod
def __model_as_response_list(model: UserResponseModel) -> List[Response]:
"""Creates a list of Responses from a UserResponseModel"""
@classmethod
def __model_as_response_list(cls, model: UserResponseModel) -> List[Response]:
"""Creates a list of Responses from a UserResponseModel without changing the format of the values"""

return [
Response(question_name=question_name, value=value["value"], user_id=model.user_id)
Response(question_name=question_name, value=value["value"], user_id=model.user_id, status=model.status)
for question_name, value in model.values.items()
]

Expand Down
19 changes: 15 additions & 4 deletions argilla-sdk/tests/integration/test_add_records.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import uuid
from datetime import datetime


import argilla_sdk as rg
from argilla_sdk import Argilla

Expand Down Expand Up @@ -584,8 +583,8 @@ def test_add_records_objects_with_responses(client: Argilla):
rg.TextField(name="text"),
],
questions=[
rg.TextQuestion(name="comment", use_markdown=False),
rg.LabelQuestion(name="label", labels=["positive", "negative"]),
rg.TextQuestion(name="comment", use_markdown=False, required=False),
],
)
dataset = rg.Dataset(
Expand All @@ -605,12 +604,17 @@ def test_add_records_objects_with_responses(client: Argilla):
records = [
rg.Record(
fields={"text": "Hello World, how are you?"},
responses=[rg.Response("label", "negative", user_id=user.id)],
responses=[rg.Response("label", "negative", user_id=user.id, status="submitted")],
id=str(uuid.uuid4()),
),
rg.Record(
fields={"text": "Hello World, how are you?"},
responses=[rg.Response("label", "positive", user_id=user.id)],
responses=[rg.Response("label", "positive", user_id=user.id, status="discarded")],
id=str(uuid.uuid4()),
),
rg.Record(
fields={"text": "Hello World, how are you?"},
responses=[rg.Response("comment", "The comment", user_id=user.id, status="draft")],
id=str(uuid.uuid4()),
),
rg.Record(
Expand All @@ -627,9 +631,16 @@ def test_add_records_objects_with_responses(client: Argilla):
assert dataset.name == mock_dataset_name
assert dataset_records[0].id == records[0].id
assert dataset_records[0].responses.label[0].value == "negative"
assert dataset_records[0].responses.label[0].status == "submitted"

assert dataset_records[1].id == records[1].id
assert dataset_records[1].responses.label[0].value == "positive"
assert dataset_records[1].responses.label[0].status == "discarded"

assert dataset_records[2].id == records[2].id
assert dataset_records[2].responses.comment[0].value == "The comment"
assert dataset_records[2].responses.comment[0].status == "draft"

assert dataset_records[3].id == records[3].id
assert dataset_records[3].responses.comment[0].value == "The comment"
assert dataset_records[3].responses.comment[0].status == "draft"
87 changes: 87 additions & 0 deletions argilla-sdk/tests/unit/test_resources/test_responses.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# Copyright 2024-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import uuid

import pytest

from argilla_sdk import UserResponse, Response


class TestResponses:
def test_create_user_response(self):
user_id = uuid.uuid4()
response = UserResponse(
answers=[
Response(question_name="question", value="answer", user_id=user_id),
Response(question_name="other-question", value="answer", user_id=user_id),
],
)

assert response.to_dict() == {
"values": {
"question": {"value": "answer"},
"other-question": {"value": "answer"},
},
"status": "draft",
"user_id": str(user_id),
}

def test_create_submitted_user_responses(self):
user_id = uuid.uuid4()
response = UserResponse(
answers=[
Response(question_name="question", value="answer", user_id=user_id, status="submitted"),
Response(question_name="other-question", value="answer", user_id=user_id, status="submitted"),
],
)

assert response.to_dict() == {
"values": {
"question": {"value": "answer"},
"other-question": {"value": "answer"},
},
"status": "submitted",
"user_id": str(user_id),
}

def test_create_user_response_with_multiple_status(self):
user_id = uuid.uuid4()
response = UserResponse(
answers=[
Response(question_name="question", value="answer", user_id=user_id, status="draft"),
Response(question_name="other-question", value="answer", user_id=user_id, status="submitted"),
],
)

assert response.to_dict() == {
"values": {
"question": {"value": "answer"},
"other-question": {"value": "answer"},
},
"status": "draft",
"user_id": str(user_id),
}

def test_create_user_response_with_multiple_user_id(self):
user_id = uuid.uuid4()
other_user_id = uuid.uuid4()

with pytest.raises(ValueError, match="Multiple user_ids found in user answers"):
UserResponse(
answers=[
Response(question_name="question", value="answer", user_id=user_id),
Response(question_name="other-question", value="answer", user_id=other_user_id),
],
)

0 comments on commit 279ff4e

Please sign in to comment.