Skip to content

Commit

Permalink
added scoped session
Browse files Browse the repository at this point in the history
  • Loading branch information
Jason Davis committed Aug 7, 2020
1 parent 4a4a04b commit a838c7d
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 19 deletions.
13 changes: 6 additions & 7 deletions superset/tasks/schedules.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@
from retry.api import retry_call
from selenium.common.exceptions import WebDriverException
from selenium.webdriver import chrome, firefox
from sqlalchemy.orm import Session
from sqlalchemy.exc import NoSuchColumnError, ResourceClosedError
from werkzeug.http import parse_cookie

Expand Down Expand Up @@ -550,7 +549,7 @@ def schedule_alert_query( # pylint: disable=unused-argument
model_cls = get_scheduler_model(report_type)

try:
schedule = db.session.query(model_cls).get(schedule_id)
schedule = db.create_scoped_session().query(model_cls).get(schedule_id)

# The user may have disabled the schedule. If so, ignore this
if not schedule or not schedule.active:
Expand Down Expand Up @@ -584,7 +583,7 @@ class AlertState:


def deliver_alert(alert_id: int, recipients: Optional[str] = None) -> None:
alert = db.session.query(Alert).get(alert_id)
alert = db.create_scoped_session().query(Alert).get(alert_id)

logging.info("Triggering alert: %s", alert)
img_data = None
Expand Down Expand Up @@ -639,9 +638,9 @@ def run_alert_query(
"""
Execute alert.sql and return value if any rows are returned
"""

dbsession = db.create_scoped_session()
logger.info("Processing alert ID: %i", alert_id)
database = db.session.query(Database).get(database_id)
database = dbsession.query(Database).get(database_id)
if not database:
logger.error("Alert database not preset")
return None
Expand Down Expand Up @@ -679,8 +678,8 @@ def run_alert_query(
if not state:
state = AlertState.PASS

db.session.commit()
alert = db.session.query(Alert).get(alert_id)
dbsession.commit()
alert = dbsession.query(Alert).get(alert_id)
if state != AlertState.ERROR:
alert.last_eval_dttm = last_eval_dttm
alert.last_state = state
Expand Down
24 changes: 12 additions & 12 deletions tests/alerts_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,38 +84,38 @@ def setup_database():
@patch("superset.tasks.schedules.deliver_alert")
@patch("superset.tasks.schedules.logging.Logger.error")
def test_run_alert_query(mock_error, mock_deliver_alert, setup_database):
database = setup_database
alert1 = db.session.query(Alert).filter_by(id=1).one()
dbsession = setup_database
alert1 = dbsession.query(Alert).filter_by(id=1).one()
run_alert_query(alert1.id, alert1.database_id, alert1.sql, alert1.label)
alert1 = db.session.query(Alert).filter_by(id=1).one()
alert1 = dbsession.query(Alert).filter_by(id=1).one()
assert mock_deliver_alert.call_count == 0
assert len(alert1.logs) == 1
assert alert1.logs[0].alert_id == 1
assert alert1.logs[0].state == "pass"

alert2 = db.session.query(Alert).filter_by(id=2).one()
alert2 = dbsession.query(Alert).filter_by(id=2).one()
run_alert_query(alert2.id, alert2.database_id, alert2.sql, alert2.label)
alert2 = db.session.query(Alert).filter_by(id=2).one()
alert2 = dbsession.query(Alert).filter_by(id=2).one()
assert mock_deliver_alert.call_count == 1
assert len(alert2.logs) == 1
assert alert2.logs[0].alert_id == 2
assert alert2.logs[0].state == "trigger"

alert3 = db.session.query(Alert).filter_by(id=3).one()
alert3 = dbsession.query(Alert).filter_by(id=3).one()
run_alert_query(alert3.id, alert3.database_id, alert3.sql, alert3.label)
alert3 = db.session.query(Alert).filter_by(id=3).one()
alert3 = dbsession.query(Alert).filter_by(id=3).one()
assert mock_deliver_alert.call_count == 1
assert mock_error.call_count == 2
assert len(alert3.logs) == 1
assert alert3.logs[0].alert_id == 3
assert alert3.logs[0].state == "error"

alert4 = db.session.query(Alert).filter_by(id=4).one()
alert4 = dbsession.query(Alert).filter_by(id=4).one()
run_alert_query(alert4.id, alert4.database_id, alert4.sql, alert4.label)
assert mock_deliver_alert.call_count == 1
assert mock_error.call_count == 3

alert5 = db.session.query(Alert).filter_by(id=5).one()
alert5 = dbsession.query(Alert).filter_by(id=5).one()
run_alert_query(alert5.id, alert5.database_id, alert5.sql, alert5.label)
assert mock_deliver_alert.call_count == 1
assert mock_error.call_count == 4
Expand All @@ -124,9 +124,9 @@ def test_run_alert_query(mock_error, mock_deliver_alert, setup_database):
@patch("superset.tasks.schedules.deliver_alert")
@patch("superset.tasks.schedules.run_alert_query")
def test_schedule_alert_query(mock_run_alert, mock_deliver_alert, setup_database):
database = setup_database
active_alert = database.query(Alert).filter_by(id=1).one()
inactive_alert = database.query(Alert).filter_by(id=3).one()
dbsession = setup_database
active_alert = dbsession.query(Alert).filter_by(id=1).one()
inactive_alert = dbsession.query(Alert).filter_by(id=3).one()

# Test that inactive alerts are no processed
schedule_alert_query(report_type=ScheduleType.alert, schedule_id=inactive_alert.id)
Expand Down

0 comments on commit a838c7d

Please sign in to comment.