Skip to content

Commit 2d9fb0c

Browse files
author
Jesse Whitehouse
committed
Black SQLAlchemy tests
Signed-off-by: Jesse Whitehouse <[email protected]>
1 parent 74c2eb0 commit 2d9fb0c

File tree

1 file changed

+66
-40
lines changed

1 file changed

+66
-40
lines changed

tests/e2e/sqlalchemy/test_basic.py

Lines changed: 66 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
USER_AGENT_TOKEN = "PySQL e2e Tests"
1010

11+
1112
@pytest.fixture
1213
def db_engine():
1314

@@ -19,7 +20,10 @@ def db_engine():
1920

2021
connect_args = {"_user_agent_entry": USER_AGENT_TOKEN}
2122

22-
engine = create_engine(f"databricks://token:{ACCESS_TOKEN}@{HOST}?http_path={HTTP_PATH}&catalog={CATALOG}&schema={SCHEMA}", connect_args=connect_args)
23+
engine = create_engine(
24+
f"databricks://token:{ACCESS_TOKEN}@{HOST}?http_path={HTTP_PATH}&catalog={CATALOG}&schema={SCHEMA}",
25+
connect_args=connect_args,
26+
)
2327
return engine
2428

2529

@@ -32,6 +36,7 @@ def base(db_engine):
3236
def session(db_engine):
3337
return Session(bind=db_engine)
3438

39+
3540
@pytest.fixture()
3641
def metadata_obj(db_engine):
3742
return MetaData(bind=db_engine)
@@ -42,6 +47,7 @@ def test_can_connect(db_engine):
4247
result = db_engine.execute(simple_query).fetchall()
4348
assert len(result) == 1
4449

50+
4551
def test_connect_args(db_engine):
4652
"""Verify that extra connect args passed to sqlalchemy.create_engine are passed to DBAPI
4753
@@ -63,7 +69,14 @@ def test_pandas_upload(db_engine, metadata_obj):
6369
SCHEMA = os.environ.get("schema")
6470
try:
6571
df = pd.read_excel("tests/sqlalchemy/demo_data/MOCK_DATA.xlsx")
66-
df.to_sql("mock_data", db_engine, schema=SCHEMA, index=False, method="multi", if_exists="replace")
72+
df.to_sql(
73+
"mock_data",
74+
db_engine,
75+
schema=SCHEMA,
76+
index=False,
77+
method="multi",
78+
if_exists="replace",
79+
)
6780

6881
df_after = pd.read_sql_table("mock_data", db_engine, schema=SCHEMA)
6982
assert len(df) == len(df_after)
@@ -72,21 +85,24 @@ def test_pandas_upload(db_engine, metadata_obj):
7285
finally:
7386
db_engine.execute("DROP TABLE mock_data")
7487

88+
7589
def test_create_table_not_null(db_engine, metadata_obj):
7690

7791
table_name = "PySQLTest_{}".format(datetime.datetime.utcnow().strftime("%s"))
7892

7993
SampleTable = Table(
80-
table_name,
81-
metadata_obj,
82-
Column("name", String(255)),
83-
Column("episodes", Integer),
84-
Column("some_bool", BOOLEAN, nullable=False)
94+
table_name,
95+
metadata_obj,
96+
Column("name", String(255)),
97+
Column("episodes", Integer),
98+
Column("some_bool", BOOLEAN, nullable=False),
8599
)
86100

87101
metadata_obj.create_all()
88102

89-
columns = db_engine.dialect.get_columns(connection=db_engine.connect(), table_name=table_name)
103+
columns = db_engine.dialect.get_columns(
104+
connection=db_engine.connect(), table_name=table_name
105+
)
90106

91107
name_column_description = columns[0]
92108
some_bool_column_description = columns[2]
@@ -96,6 +112,7 @@ def test_create_table_not_null(db_engine, metadata_obj):
96112

97113
metadata_obj.drop_all()
98114

115+
99116
def test_bulk_insert_with_core(db_engine, metadata_obj, session):
100117

101118
import random
@@ -105,39 +122,41 @@ def test_bulk_insert_with_core(db_engine, metadata_obj, session):
105122
table_name = "PySQLTest_{}".format(datetime.datetime.utcnow().strftime("%s"))
106123

107124
names = ["Bim", "Miki", "Sarah", "Ira"]
108-
125+
109126
SampleTable = Table(
110-
table_name,
111-
metadata_obj,
112-
Column("name", String(255)),
113-
Column("number", Integer)
127+
table_name, metadata_obj, Column("name", String(255)), Column("number", Integer)
114128
)
115129

116-
rows = [{"name": names[i%3], "number": random.choice(range(10000))} for i in range(num_to_insert)]
130+
rows = [
131+
{"name": names[i % 3], "number": random.choice(range(10000))}
132+
for i in range(num_to_insert)
133+
]
117134

118135
metadata_obj.create_all()
119136
db_engine.execute(insert(SampleTable).values(rows))
120137

121138
rows = db_engine.execute(select(SampleTable)).fetchall()
122139

123140
assert len(rows) == num_to_insert
124-
141+
142+
125143
def test_create_insert_drop_table_core(base, db_engine, metadata_obj: MetaData):
126-
"""
127-
"""
144+
""" """
128145

129146
SampleTable = Table(
130-
"PySQLTest_{}".format(datetime.datetime.utcnow().strftime("%s")),
131-
metadata_obj,
132-
Column("name", String(255)),
133-
Column("episodes", Integer),
134-
Column("some_bool", BOOLEAN),
135-
Column("dollars", DECIMAL(10,2))
147+
"PySQLTest_{}".format(datetime.datetime.utcnow().strftime("%s")),
148+
metadata_obj,
149+
Column("name", String(255)),
150+
Column("episodes", Integer),
151+
Column("some_bool", BOOLEAN),
152+
Column("dollars", DECIMAL(10, 2)),
136153
)
137154

138155
metadata_obj.create_all()
139156

140-
insert_stmt = insert(SampleTable).values(name="Bim Adewunmi", episodes=6, some_bool=True, dollars=decimal.Decimal(125))
157+
insert_stmt = insert(SampleTable).values(
158+
name="Bim Adewunmi", episodes=6, some_bool=True, dollars=decimal.Decimal(125)
159+
)
141160

142161
with db_engine.connect() as conn:
143162
conn.execute(insert_stmt)
@@ -156,9 +175,9 @@ def test_create_insert_drop_table_core(base, db_engine, metadata_obj: MetaData):
156175
# https://docs.sqlalchemy.org/en/14/orm/quickstart.html
157176

158177

159-
@skipIf(False, 'Unity catalog must be supported')
178+
@skipIf(False, "Unity catalog must be supported")
160179
def test_create_insert_drop_table_orm(base, session: Session):
161-
"""ORM classes built on the declarative base class must have a primary key.
180+
"""ORM classes built on the declarative base class must have a primary key.
162181
This is restricted to Unity Catalog.
163182
"""
164183

@@ -169,7 +188,7 @@ class SampleObject(base):
169188
name = Column(String(255), primary_key=True)
170189
episodes = Column(Integer)
171190
some_bool = Column(BOOLEAN)
172-
191+
173192
base.metadata.create_all()
174193

175194
sample_object_1 = SampleObject(name="Bim Adewunmi", episodes=6, some_bool=True)
@@ -178,7 +197,9 @@ class SampleObject(base):
178197
session.add(sample_object_2)
179198
session.commit()
180199

181-
stmt = select(SampleObject).where(SampleObject.name.in_(["Bim Adewunmi", "Miki Meek"]))
200+
stmt = select(SampleObject).where(
201+
SampleObject.name.in_(["Bim Adewunmi", "Miki Meek"])
202+
)
182203

183204
output = [i for i in session.scalars(stmt)]
184205
assert len(output) == 2
@@ -187,28 +208,33 @@ class SampleObject(base):
187208

188209

189210
def test_dialect_type_mappings(base, db_engine, metadata_obj: MetaData):
190-
"""Confirms that we get back the same time we declared in a model and inserted using Core
191-
"""
211+
"""Confirms that we get back the same time we declared in a model and inserted using Core"""
192212

193213
SampleTable = Table(
194-
"PySQLTest_{}".format(datetime.datetime.utcnow().strftime("%s")),
195-
metadata_obj,
196-
Column("string_example", String(255)),
197-
Column("integer_example", Integer),
198-
Column("boolean_example", BOOLEAN),
199-
Column("decimal_example", DECIMAL(10,2)),
200-
Column("date_example", Date)
214+
"PySQLTest_{}".format(datetime.datetime.utcnow().strftime("%s")),
215+
metadata_obj,
216+
Column("string_example", String(255)),
217+
Column("integer_example", Integer),
218+
Column("boolean_example", BOOLEAN),
219+
Column("decimal_example", DECIMAL(10, 2)),
220+
Column("date_example", Date),
201221
)
202222

203223
string_example = ""
204224
integer_example = 100
205225
boolean_example = True
206226
decimal_example = decimal.Decimal(125)
207-
date_example = datetime.date(2013,1,1)
227+
date_example = datetime.date(2013, 1, 1)
208228

209229
metadata_obj.create_all()
210230

211-
insert_stmt = insert(SampleTable).values(string_example=string_example, integer_example=integer_example, boolean_example=boolean_example, decimal_example=decimal_example, date_example=date_example)
231+
insert_stmt = insert(SampleTable).values(
232+
string_example=string_example,
233+
integer_example=integer_example,
234+
boolean_example=boolean_example,
235+
decimal_example=decimal_example,
236+
date_example=date_example,
237+
)
212238

213239
with db_engine.connect() as conn:
214240
conn.execute(insert_stmt)
@@ -225,4 +251,4 @@ def test_dialect_type_mappings(base, db_engine, metadata_obj: MetaData):
225251
assert this_row["decimal_example"] == decimal_example
226252
assert this_row["date_example"] == date_example
227253

228-
metadata_obj.drop_all()
254+
metadata_obj.drop_all()

0 commit comments

Comments
 (0)