Skip to content

Commit 9d71bcb

Browse files
bare sqlalchemy session + tests (#3522)
* add bare sqlalchemy session, Closes #3512 * expose sqla_session at module level, add tests, improve typing * fix table name * add model_registry fixture, improve typing * did not meant to push this * add docstring to model_registry * do not expose sqla_session in reflex namespace
1 parent e4c17de commit 9d71bcb

File tree

5 files changed

+221
-22
lines changed

5 files changed

+221
-22
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,4 @@ venv
1212
requirements.txt
1313
.pyi_generator_last_run
1414
.pyi_generator_diff
15+
reflex.db

reflex/model.py

+14-2
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from reflex.utils.compat import sqlmodel
2525

2626

27-
def get_engine(url: str | None = None):
27+
def get_engine(url: str | None = None) -> sqlalchemy.engine.Engine:
2828
"""Get the database engine.
2929
3030
Args:
@@ -396,7 +396,7 @@ def select(cls):
396396

397397

398398
def session(url: str | None = None) -> sqlmodel.Session:
399-
"""Get a session to interact with the database.
399+
"""Get a sqlmodel session to interact with the database.
400400
401401
Args:
402402
url: The database url.
@@ -405,3 +405,15 @@ def session(url: str | None = None) -> sqlmodel.Session:
405405
A database session.
406406
"""
407407
return sqlmodel.Session(get_engine(url))
408+
409+
410+
def sqla_session(url: str | None = None) -> sqlalchemy.orm.Session:
411+
"""Get a bare sqlalchemy session to interact with the database.
412+
413+
Args:
414+
url: The database url.
415+
416+
Returns:
417+
A database session.
418+
"""
419+
return sqlalchemy.orm.Session(get_engine(url))

tests/conftest.py

+13-1
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,14 @@
55
import platform
66
import uuid
77
from pathlib import Path
8-
from typing import Dict, Generator
8+
from typing import Dict, Generator, Type
99
from unittest import mock
1010

1111
import pytest
1212

1313
from reflex.app import App
1414
from reflex.event import EventSpec
15+
from reflex.model import ModelRegistry
1516
from reflex.utils import prerequisites
1617

1718
from .states import (
@@ -247,3 +248,14 @@ def token() -> str:
247248
A fresh/unique token string.
248249
"""
249250
return str(uuid.uuid4())
251+
252+
253+
@pytest.fixture
254+
def model_registry() -> Generator[Type[ModelRegistry], None, None]:
255+
"""Create a model registry.
256+
257+
Yields:
258+
A fresh model registry.
259+
"""
260+
yield ModelRegistry
261+
ModelRegistry._metadata = None

tests/test_model.py

+27-19
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
from typing import Optional
1+
from pathlib import Path
2+
from typing import Optional, Type
23
from unittest import mock
34

45
import pytest
@@ -7,7 +8,7 @@
78

89
import reflex.constants
910
import reflex.model
10-
from reflex.model import Model
11+
from reflex.model import Model, ModelRegistry
1112

1213

1314
@pytest.fixture
@@ -39,7 +40,7 @@ class ChildModel(Model):
3940
return ChildModel(name="name")
4041

4142

42-
def test_default_primary_key(model_default_primary):
43+
def test_default_primary_key(model_default_primary: Model):
4344
"""Test that if a primary key is not defined a default is added.
4445
4546
Args:
@@ -48,7 +49,7 @@ def test_default_primary_key(model_default_primary):
4849
assert "id" in model_default_primary.__class__.__fields__
4950

5051

51-
def test_custom_primary_key(model_custom_primary):
52+
def test_custom_primary_key(model_custom_primary: Model):
5253
"""Test that if a primary key is defined no default key is added.
5354
5455
Args:
@@ -60,12 +61,17 @@ def test_custom_primary_key(model_custom_primary):
6061
@pytest.mark.filterwarnings(
6162
"ignore:This declarative base already contains a class with the same class name",
6263
)
63-
def test_automigration(tmp_working_dir, monkeypatch):
64+
def test_automigration(
65+
tmp_working_dir: Path,
66+
monkeypatch: pytest.MonkeyPatch,
67+
model_registry: Type[ModelRegistry],
68+
):
6469
"""Test alembic automigration with add and drop table and column.
6570
6671
Args:
6772
tmp_working_dir: directory where database and migrations are stored
6873
monkeypatch: pytest fixture to overwrite attributes
74+
model_registry: clean reflex ModelRegistry
6975
"""
7076
alembic_ini = tmp_working_dir / "alembic.ini"
7177
versions = tmp_working_dir / "alembic" / "versions"
@@ -84,8 +90,10 @@ class AlembicThing(Model, table=True): # type: ignore
8490
t1: str
8591

8692
with Model.get_db_engine().connect() as connection:
87-
Model.alembic_autogenerate(connection=connection, message="Initial Revision")
88-
Model.migrate()
93+
assert Model.alembic_autogenerate(
94+
connection=connection, message="Initial Revision"
95+
)
96+
assert Model.migrate()
8997
version_scripts = list(versions.glob("*.py"))
9098
assert len(version_scripts) == 1
9199
assert version_scripts[0].name.endswith("initial_revision.py")
@@ -94,14 +102,14 @@ class AlembicThing(Model, table=True): # type: ignore
94102
session.add(AlembicThing(id=None, t1="foo"))
95103
session.commit()
96104

97-
sqlmodel.SQLModel.metadata.clear()
105+
model_registry.get_metadata().clear()
98106

99107
# Create column t2, mark t1 as optional with default
100108
class AlembicThing(Model, table=True): # type: ignore
101109
t1: Optional[str] = "default"
102110
t2: str = "bar"
103111

104-
Model.migrate(autogenerate=True)
112+
assert Model.migrate(autogenerate=True)
105113
assert len(list(versions.glob("*.py"))) == 2
106114

107115
with reflex.model.session() as session:
@@ -114,13 +122,13 @@ class AlembicThing(Model, table=True): # type: ignore
114122
assert result[1].t1 == "default"
115123
assert result[1].t2 == "baz"
116124

117-
sqlmodel.SQLModel.metadata.clear()
125+
model_registry.get_metadata().clear()
118126

119127
# Drop column t1
120128
class AlembicThing(Model, table=True): # type: ignore
121129
t2: str = "bar"
122130

123-
Model.migrate(autogenerate=True)
131+
assert Model.migrate(autogenerate=True)
124132
assert len(list(versions.glob("*.py"))) == 3
125133

126134
with reflex.model.session() as session:
@@ -134,7 +142,7 @@ class AlembicSecond(Model, table=True): # type: ignore
134142
a: int = 42
135143
b: float = 4.2
136144

137-
Model.migrate(autogenerate=True)
145+
assert Model.migrate(autogenerate=True)
138146
assert len(list(versions.glob("*.py"))) == 4
139147

140148
with reflex.model.session() as session:
@@ -146,16 +154,16 @@ class AlembicSecond(Model, table=True): # type: ignore
146154
assert result[0].b == 4.2
147155

148156
# No-op
149-
Model.migrate(autogenerate=True)
157+
assert Model.migrate(autogenerate=True)
150158
assert len(list(versions.glob("*.py"))) == 4
151159

152160
# drop table (AlembicSecond)
153-
sqlmodel.SQLModel.metadata.clear()
161+
model_registry.get_metadata().clear()
154162

155163
class AlembicThing(Model, table=True): # type: ignore
156164
t2: str = "bar"
157165

158-
Model.migrate(autogenerate=True)
166+
assert Model.migrate(autogenerate=True)
159167
assert len(list(versions.glob("*.py"))) == 5
160168

161169
with reflex.model.session() as session:
@@ -168,18 +176,18 @@ class AlembicThing(Model, table=True): # type: ignore
168176
assert result[0].t2 == "bar"
169177
assert result[1].t2 == "baz"
170178

171-
sqlmodel.SQLModel.metadata.clear()
179+
model_registry.get_metadata().clear()
172180

173181
class AlembicThing(Model, table=True): # type: ignore
174182
# changing column type not supported by default
175183
t2: int = 42
176184

177-
Model.migrate(autogenerate=True)
185+
assert Model.migrate(autogenerate=True)
178186
assert len(list(versions.glob("*.py"))) == 5
179187

180188
# clear all metadata to avoid influencing subsequent tests
181-
sqlmodel.SQLModel.metadata.clear()
189+
model_registry.get_metadata().clear()
182190

183191
# drop remaining tables
184-
Model.migrate(autogenerate=True)
192+
assert Model.migrate(autogenerate=True)
185193
assert len(list(versions.glob("*.py"))) == 6

tests/test_sqlalchemy.py

+166
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
from pathlib import Path
2+
from typing import Optional, Type
3+
from unittest import mock
4+
5+
import pytest
6+
from sqlalchemy import select
7+
from sqlalchemy.exc import OperationalError
8+
from sqlalchemy.orm import (
9+
DeclarativeBase,
10+
Mapped,
11+
MappedAsDataclass,
12+
declared_attr,
13+
mapped_column,
14+
)
15+
16+
import reflex.constants
17+
import reflex.model
18+
from reflex.model import Model, ModelRegistry, sqla_session
19+
20+
21+
@pytest.mark.filterwarnings(
22+
"ignore:This declarative base already contains a class with the same class name",
23+
)
24+
def test_automigration(
25+
tmp_working_dir: Path,
26+
monkeypatch: pytest.MonkeyPatch,
27+
model_registry: Type[ModelRegistry],
28+
):
29+
"""Test alembic automigration with add and drop table and column.
30+
31+
Args:
32+
tmp_working_dir: directory where database and migrations are stored
33+
monkeypatch: pytest fixture to overwrite attributes
34+
model_registry: clean reflex ModelRegistry
35+
"""
36+
alembic_ini = tmp_working_dir / "alembic.ini"
37+
versions = tmp_working_dir / "alembic" / "versions"
38+
monkeypatch.setattr(reflex.constants, "ALEMBIC_CONFIG", str(alembic_ini))
39+
40+
config_mock = mock.Mock()
41+
config_mock.db_url = f"sqlite:///{tmp_working_dir}/reflex.db"
42+
monkeypatch.setattr(reflex.model, "get_config", mock.Mock(return_value=config_mock))
43+
44+
assert alembic_ini.exists() is False
45+
assert versions.exists() is False
46+
Model.alembic_init()
47+
assert alembic_ini.exists()
48+
assert versions.exists()
49+
50+
class Base(DeclarativeBase):
51+
@declared_attr.directive
52+
def __tablename__(cls) -> str:
53+
return cls.__name__.lower()
54+
55+
assert model_registry.register(Base)
56+
57+
class ModelBase(Base, MappedAsDataclass):
58+
__abstract__ = True
59+
id: Mapped[Optional[int]] = mapped_column(primary_key=True, default=None)
60+
61+
# initial table
62+
class AlembicThing(ModelBase): # pyright: ignore[reportGeneralTypeIssues]
63+
t1: Mapped[str] = mapped_column(default="")
64+
65+
with Model.get_db_engine().connect() as connection:
66+
assert Model.alembic_autogenerate(
67+
connection=connection, message="Initial Revision"
68+
)
69+
assert Model.migrate()
70+
version_scripts = list(versions.glob("*.py"))
71+
assert len(version_scripts) == 1
72+
assert version_scripts[0].name.endswith("initial_revision.py")
73+
74+
with sqla_session() as session:
75+
session.add(AlembicThing(t1="foo"))
76+
session.commit()
77+
78+
model_registry.get_metadata().clear()
79+
80+
# Create column t2, mark t1 as optional with default
81+
class AlembicThing(ModelBase): # pyright: ignore[reportGeneralTypeIssues]
82+
t1: Mapped[Optional[str]] = mapped_column(default="default")
83+
t2: Mapped[str] = mapped_column(default="bar")
84+
85+
assert Model.migrate(autogenerate=True)
86+
assert len(list(versions.glob("*.py"))) == 2
87+
88+
with sqla_session() as session:
89+
session.add(AlembicThing(t2="baz"))
90+
session.commit()
91+
result = session.scalars(select(AlembicThing)).all()
92+
assert len(result) == 2
93+
assert result[0].t1 == "foo"
94+
assert result[0].t2 == "bar"
95+
assert result[1].t1 == "default"
96+
assert result[1].t2 == "baz"
97+
98+
model_registry.get_metadata().clear()
99+
100+
# Drop column t1
101+
class AlembicThing(ModelBase): # pyright: ignore[reportGeneralTypeIssues]
102+
t2: Mapped[str] = mapped_column(default="bar")
103+
104+
assert Model.migrate(autogenerate=True)
105+
assert len(list(versions.glob("*.py"))) == 3
106+
107+
with sqla_session() as session:
108+
result = session.scalars(select(AlembicThing)).all()
109+
assert len(result) == 2
110+
assert result[0].t2 == "bar"
111+
assert result[1].t2 == "baz"
112+
113+
# Add table
114+
class AlembicSecond(ModelBase):
115+
a: Mapped[int] = mapped_column(default=42)
116+
b: Mapped[float] = mapped_column(default=4.2)
117+
118+
assert Model.migrate(autogenerate=True)
119+
assert len(list(versions.glob("*.py"))) == 4
120+
121+
with reflex.model.session() as session:
122+
session.add(AlembicSecond(id=None))
123+
session.commit()
124+
result = session.scalars(select(AlembicSecond)).all()
125+
assert len(result) == 1
126+
assert result[0].a == 42
127+
assert result[0].b == 4.2
128+
129+
# No-op
130+
# assert Model.migrate(autogenerate=True)
131+
# assert len(list(versions.glob("*.py"))) == 4
132+
133+
# drop table (AlembicSecond)
134+
model_registry.get_metadata().clear()
135+
136+
class AlembicThing(ModelBase): # pyright: ignore[reportGeneralTypeIssues]
137+
t2: Mapped[str] = mapped_column(default="bar")
138+
139+
assert Model.migrate(autogenerate=True)
140+
assert len(list(versions.glob("*.py"))) == 5
141+
142+
with reflex.model.session() as session:
143+
with pytest.raises(OperationalError) as errctx:
144+
_ = session.scalars(select(AlembicSecond)).all()
145+
assert errctx.match(r"no such table: alembicsecond")
146+
# first table should still exist
147+
result = session.scalars(select(AlembicThing)).all()
148+
assert len(result) == 2
149+
assert result[0].t2 == "bar"
150+
assert result[1].t2 == "baz"
151+
152+
model_registry.get_metadata().clear()
153+
154+
class AlembicThing(ModelBase):
155+
# changing column type not supported by default
156+
t2: Mapped[int] = mapped_column(default=42)
157+
158+
assert Model.migrate(autogenerate=True)
159+
assert len(list(versions.glob("*.py"))) == 5
160+
161+
# clear all metadata to avoid influencing subsequent tests
162+
model_registry.get_metadata().clear()
163+
164+
# drop remaining tables
165+
assert Model.migrate(autogenerate=True)
166+
assert len(list(versions.glob("*.py"))) == 6

0 commit comments

Comments
 (0)