diff --git a/common/models/voucher.py b/common/models/voucher.py index bff755df..a3485fd2 100644 --- a/common/models/voucher.py +++ b/common/models/voucher.py @@ -1,16 +1,19 @@ import uuid from sqlalchemy import Column, Text, Integer, ForeignKey, UUID, LargeBinary -from sqlalchemy.orm import relationship, backref +from sqlalchemy.orm import relationship, backref, Mapped from common.models.base import Base, TimeStampMixin, AutoIdMixin +from common.models.product import Product +from common.models.receipt import Receipt from common.utils.receipt import PlanetID class VoucherRequest(AutoIdMixin, TimeStampMixin, Base): __tablename__ = "voucher_request" receipt_id = Column(Integer, ForeignKey("receipt.id"), nullable=False) - receipt = relationship("Receipt", foreign_keys=[receipt_id], uselist=False, backref=backref("voucher_request")) + receipt: Mapped["Receipt"] = relationship("Receipt", foreign_keys=[receipt_id], uselist=False, + backref=backref("voucher_request")) # Copy all required data to view all info solely with this table uuid = Column(UUID(as_uuid=True), nullable=False, index=True, default=uuid.uuid4, @@ -20,7 +23,7 @@ class VoucherRequest(AutoIdMixin, TimeStampMixin, Base): planet_id = Column(LargeBinary(length=12), nullable=False, default=PlanetID.ODIN.value, doc="An identifier of planets") product_id = Column(Integer, ForeignKey("product.id"), nullable=False) - product = relationship("Product", foreign_keys=[product_id], uselist=False) + product: Mapped["Product"] = relationship("Product", foreign_keys=[product_id], uselist=False) product_name = Column(Text, nullable=False) # Voucher request result diff --git a/iap/api/purchase.py b/iap/api/purchase.py index 17ca9a7b..1ddc627d 100644 --- a/iap/api/purchase.py +++ b/iap/api/purchase.py @@ -167,6 +167,7 @@ def request_product(receipt_data: ReceiptSchema, sess=Depends(session)): MessageBody=json.dumps({ "id": receipt.id, "uuid": str(receipt.uuid), + "receipt_id": receipt.id, "product_id": receipt.product_id, "product_name": receipt.product.name, "agent_addr": receipt.agent_addr, diff --git a/worker/worker/voucher.py b/worker/worker/voucher.py index 1913da8a..c1822028 100644 --- a/worker/worker/voucher.py +++ b/worker/worker/voucher.py @@ -9,6 +9,7 @@ from common import logger from common.models.voucher import VoucherRequest from common.utils.aws import fetch_secrets, fetch_parameter +from common.utils.receipt import PlanetID from schemas.aws import SQSMessage DB_URI = os.environ.get("DB_URI") @@ -66,7 +67,8 @@ def handle(event, context): message = SQSMessage(Records=event.get("Records", {})) logger.info(f"SQS Message: {message}") - with scoped_session(sessionmaker(bind=engine)) as sess: + sess = scoped_session(sessionmaker(bind=engine)) + try: uuid_list = [x.body.get("uuid") for x in message.Records if x.body.get("uuid")] voucher_list = sess.scalars(select(VoucherRequest.uuid).where(VoucherRequest.uuid.in_(uuid_list))).fetchall() target_message_list = [x.body for x in message.Records if @@ -74,7 +76,11 @@ def handle(event, context): for msg in target_message_list: voucher = VoucherRequest(**msg) + voucher.planet_id = PlanetID(voucher.planet_id.encode()) sess.add(voucher) sess.commit() sess.refresh(voucher) request(sess, voucher) + finally: + if sess is not None: + sess.close()