Skip to content

Commit

Permalink
format code with ruff
Browse files Browse the repository at this point in the history
  • Loading branch information
joshdavham committed Dec 7, 2024
1 parent 59e4894 commit a14b3bf
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 50 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
import setuptools

setuptools.setup()
setuptools.setup()
2 changes: 1 addition & 1 deletion src/sm_2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@
The classic SM-2 algorithm for spaced repetition scheduling, implemented as a python package.
"""

from .sm_2 import Scheduler, Card, ReviewLog
from .sm_2 import Scheduler, Card, ReviewLog
106 changes: 64 additions & 42 deletions src/sm_2/sm_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from copy import deepcopy
from math import ceil


class Card:
"""
Represents a flashcard in the SM-2 scheduling system.
Expand All @@ -34,8 +35,15 @@ class Card:
due: datetime
needs_extra_review: bool

def __init__(self, card_id: int | None = None, n: int=0, EF: float=2.5, I: int=0, due: datetime | None = None, needs_extra_review: bool=False) -> None:

def __init__(
self,
card_id: int | None = None,
n: int = 0,
EF: float = 2.5,
I: int = 0,
due: datetime | None = None,
needs_extra_review: bool = False,
) -> None:
if card_id is None:
# epoch miliseconds of when the card was created
card_id = int(datetime.now(timezone.utc).timestamp() * 1000)
Expand All @@ -52,29 +60,34 @@ def __init__(self, card_id: int | None = None, n: int=0, EF: float=2.5, I: int=0
self.needs_extra_review = needs_extra_review

def to_dict(self) -> dict[str, int | float | str | bool]:

return_dict: dict[str, int | float | str | bool] = {
"card_id": self.card_id,
"n": self.n,
"EF": self.EF,
"I": self.I,
"due": self.due.isoformat(),
"needs_extra_review": self.needs_extra_review
"needs_extra_review": self.needs_extra_review,
}

return return_dict

@staticmethod
def from_dict(source_dict: dict[str, Any]) -> "Card":

card_id = int(source_dict['card_id'])
n = int(source_dict['n'])
EF = float(source_dict['EF'])
I = int(source_dict['I'])
due = datetime.fromisoformat(source_dict['due'])
needs_extra_review = bool(source_dict['needs_extra_review'])

return Card(card_id=card_id, n=n, EF=EF, I=I, due=due, needs_extra_review=needs_extra_review)
card_id = int(source_dict["card_id"])
n = int(source_dict["n"])
EF = float(source_dict["EF"])
I = int(source_dict["I"])
due = datetime.fromisoformat(source_dict["due"])
needs_extra_review = bool(source_dict["needs_extra_review"])

return Card(
card_id=card_id,
n=n,
EF=EF,
I=I,
due=due,
needs_extra_review=needs_extra_review,
)


class ReviewLog:
Expand All @@ -93,33 +106,41 @@ class ReviewLog:
review_datetime: datetime
review_duration: int | None

def __init__(self, card: Card, rating: int, review_datetime: datetime, review_duration: int | None = None) -> None:

def __init__(
self,
card: Card,
rating: int,
review_datetime: datetime,
review_duration: int | None = None,
) -> None:
self.card = deepcopy(card)
self.rating = rating
self.review_datetime = review_datetime
self.review_duration = review_duration

def to_dict(self) -> dict[str, dict | int | str | None]:

return_dict: dict[str, dict | int | str | None] = {
"card": self.card.to_dict(),
"rating": self.rating,
"review_datetime": self.review_datetime.isoformat(),
"review_duration": self.review_duration
"review_duration": self.review_duration,
}

return return_dict

@staticmethod
def from_dict(source_dict: dict[str, Any]) -> "ReviewLog":
card = Card.from_dict(source_dict["card"])
rating = int(source_dict["rating"])
review_datetime = datetime.fromisoformat(source_dict["review_datetime"])
review_duration = source_dict["review_duration"]

card = Card.from_dict(source_dict['card'])
rating = int(source_dict['rating'])
review_datetime = datetime.fromisoformat(source_dict['review_datetime'])
review_duration = source_dict['review_duration']

return ReviewLog(card=card, rating=rating, review_datetime=review_datetime, review_duration=review_duration)
return ReviewLog(
card=card,
rating=rating,
review_datetime=review_datetime,
review_duration=review_duration,
)


class Scheduler:
Expand All @@ -133,7 +154,12 @@ class Scheduler:
"""

@staticmethod
def review_card(card: Card, rating: int, review_datetime: datetime | None = None, review_duration: int | None = None) -> tuple[Card, ReviewLog]:
def review_card(
card: Card,
rating: int,
review_datetime: datetime | None = None,
review_duration: int | None = None,
) -> tuple[Card, ReviewLog]:
"""
Reviews a card with a given rating at a specified time.
Expand All @@ -159,51 +185,47 @@ def review_card(card: Card, rating: int, review_datetime: datetime | None = None

if not card_is_due:
raise RuntimeError(f"Card is not due for review until {card.due}.")

review_log = ReviewLog(card=card, rating=rating, review_datetime=review_datetime, review_duration=review_duration)

if card.needs_extra_review:
review_log = ReviewLog(
card=card,
rating=rating,
review_datetime=review_datetime,
review_duration=review_duration,
)

if card.needs_extra_review:
if rating >= 4:
card.needs_extra_review = False
card.due += timedelta(days=card.I)

else:

if rating >= 3: # correct response

if rating >= 3: # correct response
# note: EF increases when rating = 5, stays the same when rating = 4 and decreases when rating = 3
card.EF = card.EF + (0.1-(5-rating)*(0.08+(5-rating)*0.02))
card.EF = card.EF + (0.1 - (5 - rating) * (0.08 + (5 - rating) * 0.02))
card.EF = max(1.3, card.EF)

if card.n == 0:

if card.n == 0:
card.I = 1

elif card.n == 1:

card.I = 6

else:

card.I = ceil(card.I * card.EF)

card.n += 1

if rating >= 4:

card.due += timedelta(days=card.I)

else:

card.needs_extra_review = True
card.due = review_datetime

else: # incorrect response

else: # incorrect response
card.n = 0
card.I = 0
card.due = review_datetime
# EF doesn't change on incorrect reponses

return card, review_log
return card, review_log
14 changes: 8 additions & 6 deletions tests/test_sm_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,9 @@
from copy import deepcopy
from datetime import datetime, timezone

class TestSM2:

class TestSM2:
def test_quickstart(self):

scheduler = Scheduler()

card = Card()
Expand All @@ -28,7 +27,6 @@ def test_quickstart(self):
assert round(time_delta.seconds / 3600) == 24

def test_serialize(self):

scheduler = Scheduler()

card = Card()
Expand All @@ -46,19 +44,23 @@ def test_serialize(self):
# review the card and perform more tests
rating = 5
review_duration = 3000
card, review_log = scheduler.review_card(card=card, rating=rating, review_duration=review_duration)
card, review_log = scheduler.review_card(
card=card, rating=rating, review_duration=review_duration
)

review_log_dict = review_log.to_dict()
copied_review_log = ReviewLog.from_dict(review_log_dict)
assert review_log.to_dict() == copied_review_log.to_dict()
assert copied_review_log.review_duration == review_duration
# can use the review log to recreate the card that was reviewed
assert old_card.to_dict() == Card.from_dict(review_log.to_dict()['card']).to_dict()
assert (
old_card.to_dict() == Card.from_dict(review_log.to_dict()["card"]).to_dict()
)

# the new reviewed card can be serialized and de-serialized while remaining the same
card_dict = card.to_dict()
copied_card = Card.from_dict(card_dict)
assert vars(card) == vars(copied_card)
assert card.to_dict() == copied_card.to_dict()

# TODO: add tests for interval lengths
# TODO: add tests for interval lengths

0 comments on commit a14b3bf

Please sign in to comment.