Skip to content

Add e2e tests #12

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jul 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/code-quality-checks.yml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name: Code Quality Checks
on: [push]
jobs:
run-tests:
run-unit-tests:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what does unit test mean in this context?

Is this mock unit test without the need for databricks account?
or is it integration test e2e which require databricks account

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unit means no databricks account is required.
e2e means databricks account is required.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm updating the CONTRIBUTING doc with this info.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks Jesse.

runs-on: ubuntu-latest
steps:
#----------------------------------------------
Expand Down Expand Up @@ -48,7 +48,7 @@ jobs:
# run test suite
#----------------------------------------------
- name: Run tests
run: poetry run pytest tests/
run: poetry run python -m pytest tests/unit
check-linting:
runs-on: ubuntu-latest
steps:
Expand Down
46 changes: 38 additions & 8 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,34 +9,64 @@ This project uses [Poetry](https://python-poetry.org/) for dependency management
1. Clone this respository
2. Run `poetry install`

### Unit Tests
### Run tests

We use [Pytest](https://docs.pytest.org/en/7.1.x/) as our test runner. Invoke it with `poetry run pytest`, all other arguments are passed directly to `pytest`.
We use [Pytest](https://docs.pytest.org/en/7.1.x/) as our test runner. Invoke it with `poetry run python -m pytest`, all other arguments are passed directly to `pytest`.

#### Unit tests

Unit tests do not require a Databricks account.

#### All tests
```bash
poetry run pytest tests
poetry run python -m pytest tests/unit
```

#### Only a specific test file

```bash
poetry run pytest tests/tests.py
poetry run python -m pytest tests/unit/tests.py
```

#### Only a specific method

```bash
poetry run pytest tests/tests.py::ClientTestSuite::test_closing_connection_closes_commands
poetry run python -m pytest tests/unit/tests.py::ClientTestSuite::test_closing_connection_closes_commands
```

#### e2e Tests

End-to-end tests require a Databricks account. Before you can run them, you must set connection details for a Databricks SQL endpoint in your environment:

```bash
export host=""
export http_path=""
export access_token=""
```

There are several e2e test suites available:
- `PySQLCoreTestSuite`
- `PySQLLargeQueriesSuite`
- `PySQLRetryTestSuite.HTTP503Suite` **[not documented]**
- `PySQLRetryTestSuite.HTTP429Suite` **[not documented]**
- `PySQLUnityCatalogTestSuite` **[not documented]**

To execute the core test suite:

```bash
poetry run python -m pytest tests/e2e/driver_tests.py::PySQLCoreTestSuite
```

The suites marked `[not documented]` require additional configuration which will be documented at a later time.
### Code formatting

This project uses [Black](https://pypi.org/project/black/).

```
poetry run black src
poetry run python3 -m black src --check
```

Remove the `--check` flag to write reformatted files to disk.

To simplify reviews you can format your changes in a separate commit.
## Pull Request Process

1. Update the [CHANGELOG.md](README.md) or similar documentation with details of changes you wish to make, if applicable.
Expand Down
Empty file added tests/__init__.py
Empty file.
Empty file added tests/e2e/common/__init__.py
Empty file.
131 changes: 131 additions & 0 deletions tests/e2e/common/core_tests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
import decimal
import datetime
from collections import namedtuple

TypeFailure = namedtuple(
"TypeFailure", "query,columnType,resultType,resultValue,"
"actualValue,actualType,description,conf")
ResultFailure = namedtuple(
"ResultFailure", "query,columnType,resultType,resultValue,"
"actualValue,actualType,description,conf")
ExecFailure = namedtuple(
"ExecFailure", "query,columnType,resultType,resultValue,"
"actualValue,actualType,description,conf,error")


class SmokeTestMixin:
def test_smoke_test(self):
with self.cursor() as cursor:
cursor.execute("select 0")
rows = cursor.fetchall()
self.assertEqual(len(rows), 1)
self.assertEqual(rows[0][0], 0)


class CoreTestMixin:
"""
This mixin expects to be mixed with a CursorTest-like class with the following extra attributes:
validate_row_value_type: bool
validate_result: bool
"""

# A list of (subquery, column_type, python_type, expected_result)
# To be executed as "SELECT {} FROM RANGE(...)" and "SELECT {}"
range_queries = [
("TRUE", 'boolean', bool, True),
("cast(1 AS TINYINT)", 'byte', int, 1),
("cast(1000 AS SMALLINT)", 'short', int, 1000),
("cast(100000 AS INTEGER)", 'integer', int, 100000),
("cast(10000000000000 AS BIGINT)", 'long', int, 10000000000000),
("cast(100.001 AS DECIMAL(6, 3))", 'decimal', decimal.Decimal, 100.001),
("date '2020-02-20'", 'date', datetime.date, datetime.date(2020, 2, 20)),
("unhex('f000')", 'binary', bytes, b'\xf0\x00'), # pyodbc internal mismatch
("'foo'", 'string', str, 'foo'),
# SPARK-32130: 6.x: "4 weeks 2 days" vs 7.x: "30 days"
# ("interval 30 days", str, str, "interval 4 weeks 2 days"),
# ("interval 3 days", str, str, "interval 3 days"),
("CAST(NULL AS DOUBLE)", 'double', type(None), None),
]

# Full queries, only the first column of the first row is checked
queries = [("NULL UNION (SELECT 1) order by 1", 'integer', type(None), None)]

def run_tests_on_queries(self, default_conf):
failures = []
for (query, columnType, rowValueType, answer) in self.range_queries:
with self.cursor(default_conf) as cursor:
failures.extend(
self.run_query(cursor, query, columnType, rowValueType, answer, default_conf))
failures.extend(
self.run_range_query(cursor, query, columnType, rowValueType, answer,
default_conf))

for (query, columnType, rowValueType, answer) in self.queries:
with self.cursor(default_conf) as cursor:
failures.extend(
self.run_query(cursor, query, columnType, rowValueType, answer, default_conf))

if failures:
self.fail("Failed testing result set with Arrow. "
"Failed queries: {}".format("\n\n".join([str(f) for f in failures])))

def run_query(self, cursor, query, columnType, rowValueType, answer, conf):
full_query = "SELECT {}".format(query)
expected_column_types = self.expected_column_types(columnType)
try:
cursor.execute(full_query)
(result, ) = cursor.fetchone()
if not all(cursor.description[0][1] == type for type in expected_column_types):
return [
TypeFailure(full_query, expected_column_types, rowValueType, answer, result,
type(result), cursor.description, conf)
]
if self.validate_row_value_type and type(result) is not rowValueType:
return [
TypeFailure(full_query, expected_column_types, rowValueType, answer, result,
type(result), cursor.description, conf)
]
if self.validate_result and str(answer) != str(result):
return [
ResultFailure(full_query, query, expected_column_types, rowValueType, answer,
result, type(result), cursor.description, conf)
]
return []
except Exception as e:
return [
ExecFailure(full_query, columnType, rowValueType, None, None, None,
cursor.description, conf, e)
]

def run_range_query(self, cursor, query, columnType, rowValueType, expected, conf):
full_query = "SELECT {}, id FROM RANGE({})".format(query, 5000)
expected_column_types = self.expected_column_types(columnType)
try:
cursor.execute(full_query)
while True:
rows = cursor.fetchmany(1000)
if len(rows) <= 0:
break
for index, (result, id) in enumerate(rows):
if not all(cursor.description[0][1] == type for type in expected_column_types):
return [
TypeFailure(full_query, expected_column_types, rowValueType, expected,
result, type(result), cursor.description, conf)
]
if self.validate_row_value_type and type(result) \
is not rowValueType:
return [
TypeFailure(full_query, expected_column_types, rowValueType, expected,
result, type(result), cursor.description, conf)
]
if self.validate_result and str(expected) != str(result):
return [
ResultFailure(full_query, expected_column_types, rowValueType, expected,
result, type(result), cursor.description, conf)
]
return []
except Exception as e:
return [
ExecFailure(full_query, columnType, rowValueType, None, None, None,
cursor.description, conf, e)
]
48 changes: 48 additions & 0 deletions tests/e2e/common/decimal_tests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from decimal import Decimal

import pyarrow


class DecimalTestsMixin:
decimal_and_expected_results = [
("100.001 AS DECIMAL(6, 3)", Decimal("100.001"), pyarrow.decimal128(6, 3)),
("1000000.0000 AS DECIMAL(11, 4)", Decimal("1000000.0000"), pyarrow.decimal128(11, 4)),
("-10.2343 AS DECIMAL(10, 6)", Decimal("-10.234300"), pyarrow.decimal128(10, 6)),
# TODO(SC-90767): Re-enable this test after we have a way of passing `ansi_mode` = False
#("-13872347.2343 AS DECIMAL(10, 10)", None, pyarrow.decimal128(10, 10)),
("NULL AS DECIMAL(1, 1)", None, pyarrow.decimal128(1, 1)),
("1 AS DECIMAL(1, 0)", Decimal("1"), pyarrow.decimal128(1, 0)),
("0.00000 AS DECIMAL(5, 3)", Decimal("0.000"), pyarrow.decimal128(5, 3)),
("1e-3 AS DECIMAL(38, 3)", Decimal("0.001"), pyarrow.decimal128(38, 3)),
]

multi_decimals_and_expected_results = [
(["1 AS DECIMAL(6, 3)", "100.001 AS DECIMAL(6, 3)", "NULL AS DECIMAL(6, 3)"],
[Decimal("1.00"), Decimal("100.001"), None], pyarrow.decimal128(6, 3)),
(["1 AS DECIMAL(6, 3)", "2 AS DECIMAL(5, 2)"], [Decimal('1.000'),
Decimal('2.000')], pyarrow.decimal128(6,
3)),
]

def test_decimals(self):
with self.cursor({}) as cursor:
for (decimal, expected_value, expected_type) in self.decimal_and_expected_results:
query = "SELECT CAST ({})".format(decimal)
with self.subTest(query=query):
cursor.execute(query)
table = cursor.fetchmany_arrow(1)
self.assertEqual(table.field(0).type, expected_type)
self.assertEqual(table.to_pydict().popitem()[1][0], expected_value)

def test_multi_decimals(self):
with self.cursor({}) as cursor:
for (decimals, expected_values,
expected_type) in self.multi_decimals_and_expected_results:
union_str = " UNION ".join(["(SELECT CAST ({}))".format(dec) for dec in decimals])
query = "SELECT * FROM ({}) ORDER BY 1 NULLS LAST".format(union_str)

with self.subTest(query=query):
cursor.execute(query)
table = cursor.fetchall_arrow()
self.assertEqual(table.field(0).type, expected_type)
self.assertEqual(table.to_pydict().popitem()[1], expected_values)
100 changes: 100 additions & 0 deletions tests/e2e/common/large_queries_mixin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import logging
import math
import time

log = logging.getLogger(__name__)


class LargeQueriesMixin:
"""
This mixin expects to be mixed with a CursorTest-like class
"""

def fetch_rows(self, cursor, row_count, fetchmany_size):
"""
A generator for rows. Fetches until the end or up to 5 minutes.
"""
# TODO: Remove fetchmany_size when we have fixed the performance issues with fetchone
# in the Python client
max_fetch_time = 5 * 60 # Fetch for at most 5 minutes

rows = self.get_some_rows(cursor, fetchmany_size)
start_time = time.time()
n = 0
while rows:
for row in rows:
n += 1
yield row
if time.time() - start_time >= max_fetch_time:
log.warning("Fetching rows timed out")
break
rows = self.get_some_rows(cursor, fetchmany_size)
if not rows:
# Read all the rows, row_count should match
self.assertEqual(n, row_count)

num_fetches = max(math.ceil(n / 10000), 1)
latency_ms = int((time.time() - start_time) * 1000 / num_fetches), 1
print('Fetched {} rows with an avg latency of {} per fetch, '.format(n, latency_ms) +
'assuming 10K fetch size.')

def test_query_with_large_wide_result_set(self):
resultSize = 300 * 1000 * 1000 # 300 MB
width = 8192 # B
rows = resultSize // width
cols = width // 36

# Set the fetchmany_size to get 10MB of data a go
fetchmany_size = 10 * 1024 * 1024 // width
# This is used by PyHive tests to determine the buffer size
self.arraysize = 1000
with self.cursor() as cursor:
uuids = ", ".join(["uuid() uuid{}".format(i) for i in range(cols)])
cursor.execute("SELECT id, {uuids} FROM RANGE({rows})".format(uuids=uuids, rows=rows))
for row_id, row in enumerate(self.fetch_rows(cursor, rows, fetchmany_size)):
self.assertEqual(row[0], row_id) # Verify no rows are dropped in the middle.
self.assertEqual(len(row[1]), 36)

def test_query_with_large_narrow_result_set(self):
resultSize = 300 * 1000 * 1000 # 300 MB
width = 8 # sizeof(long)
rows = resultSize / width

# Set the fetchmany_size to get 10MB of data a go
fetchmany_size = 10 * 1024 * 1024 // width
# This is used by PyHive tests to determine the buffer size
self.arraysize = 10000000
with self.cursor() as cursor:
cursor.execute("SELECT * FROM RANGE({rows})".format(rows=rows))
for row_id, row in enumerate(self.fetch_rows(cursor, rows, fetchmany_size)):
self.assertEqual(row[0], row_id)

def test_long_running_query(self):
""" Incrementally increase query size until it takes at least 5 minutes,
and asserts that the query completes successfully.
"""
minutes = 60
min_duration = 5 * minutes

duration = -1
scale0 = 10000
scale_factor = 1
with self.cursor() as cursor:
while duration < min_duration:
self.assertLess(scale_factor, 512, msg="Detected infinite loop")
start = time.time()

cursor.execute("""SELECT count(*)
FROM RANGE({scale}) x
JOIN RANGE({scale0}) y
ON from_unixtime(x.id * y.id, "yyyy-MM-dd") LIKE "%not%a%date%"
""".format(scale=scale_factor * scale0, scale0=scale0))

n, = cursor.fetchone()
self.assertEqual(n, 0)

duration = time.time() - start
current_fraction = duration / min_duration
print('Took {} s with scale factor={}'.format(duration, scale_factor))
# Extrapolate linearly to reach 5 min and add 50% padding to push over the limit
scale_factor = math.ceil(1.5 * scale_factor / current_fraction)
Loading