diff --git a/.github/workflows/code-quality-checks.yml b/.github/workflows/code-quality-checks.yml index b48a2faf2..a6f44144c 100644 --- a/.github/workflows/code-quality-checks.yml +++ b/.github/workflows/code-quality-checks.yml @@ -1,7 +1,7 @@ name: Code Quality Checks on: [push] jobs: - run-tests: + run-unit-tests: runs-on: ubuntu-latest steps: #---------------------------------------------- @@ -48,7 +48,7 @@ jobs: # run test suite #---------------------------------------------- - name: Run tests - run: poetry run pytest tests/ + run: poetry run python -m pytest tests/unit check-linting: runs-on: ubuntu-latest steps: diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index bd4886df2..cfc34a320 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -9,34 +9,64 @@ This project uses [Poetry](https://python-poetry.org/) for dependency management 1. Clone this respository 2. Run `poetry install` -### Unit Tests +### Run tests -We use [Pytest](https://docs.pytest.org/en/7.1.x/) as our test runner. Invoke it with `poetry run pytest`, all other arguments are passed directly to `pytest`. +We use [Pytest](https://docs.pytest.org/en/7.1.x/) as our test runner. Invoke it with `poetry run python -m pytest`, all other arguments are passed directly to `pytest`. + +#### Unit tests + +Unit tests do not require a Databricks account. -#### All tests ```bash -poetry run pytest tests +poetry run python -m pytest tests/unit ``` - #### Only a specific test file ```bash -poetry run pytest tests/tests.py +poetry run python -m pytest tests/unit/tests.py ``` #### Only a specific method ```bash -poetry run pytest tests/tests.py::ClientTestSuite::test_closing_connection_closes_commands +poetry run python -m pytest tests/unit/tests.py::ClientTestSuite::test_closing_connection_closes_commands +``` + +#### e2e Tests + +End-to-end tests require a Databricks account. Before you can run them, you must set connection details for a Databricks SQL endpoint in your environment: + +```bash +export host="" +export http_path="" +export access_token="" ``` +There are several e2e test suites available: +- `PySQLCoreTestSuite` +- `PySQLLargeQueriesSuite` +- `PySQLRetryTestSuite.HTTP503Suite` **[not documented]** +- `PySQLRetryTestSuite.HTTP429Suite` **[not documented]** +- `PySQLUnityCatalogTestSuite` **[not documented]** + +To execute the core test suite: + +```bash +poetry run python -m pytest tests/e2e/driver_tests.py::PySQLCoreTestSuite +``` + +The suites marked `[not documented]` require additional configuration which will be documented at a later time. ### Code formatting This project uses [Black](https://pypi.org/project/black/). ``` -poetry run black src +poetry run python3 -m black src --check ``` + +Remove the `--check` flag to write reformatted files to disk. + +To simplify reviews you can format your changes in a separate commit. ## Pull Request Process 1. Update the [CHANGELOG.md](README.md) or similar documentation with details of changes you wish to make, if applicable. diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/e2e/common/__init__.py b/tests/e2e/common/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/e2e/common/core_tests.py b/tests/e2e/common/core_tests.py new file mode 100644 index 000000000..cd325e8d0 --- /dev/null +++ b/tests/e2e/common/core_tests.py @@ -0,0 +1,131 @@ +import decimal +import datetime +from collections import namedtuple + +TypeFailure = namedtuple( + "TypeFailure", "query,columnType,resultType,resultValue," + "actualValue,actualType,description,conf") +ResultFailure = namedtuple( + "ResultFailure", "query,columnType,resultType,resultValue," + "actualValue,actualType,description,conf") +ExecFailure = namedtuple( + "ExecFailure", "query,columnType,resultType,resultValue," + "actualValue,actualType,description,conf,error") + + +class SmokeTestMixin: + def test_smoke_test(self): + with self.cursor() as cursor: + cursor.execute("select 0") + rows = cursor.fetchall() + self.assertEqual(len(rows), 1) + self.assertEqual(rows[0][0], 0) + + +class CoreTestMixin: + """ + This mixin expects to be mixed with a CursorTest-like class with the following extra attributes: + validate_row_value_type: bool + validate_result: bool + """ + + # A list of (subquery, column_type, python_type, expected_result) + # To be executed as "SELECT {} FROM RANGE(...)" and "SELECT {}" + range_queries = [ + ("TRUE", 'boolean', bool, True), + ("cast(1 AS TINYINT)", 'byte', int, 1), + ("cast(1000 AS SMALLINT)", 'short', int, 1000), + ("cast(100000 AS INTEGER)", 'integer', int, 100000), + ("cast(10000000000000 AS BIGINT)", 'long', int, 10000000000000), + ("cast(100.001 AS DECIMAL(6, 3))", 'decimal', decimal.Decimal, 100.001), + ("date '2020-02-20'", 'date', datetime.date, datetime.date(2020, 2, 20)), + ("unhex('f000')", 'binary', bytes, b'\xf0\x00'), # pyodbc internal mismatch + ("'foo'", 'string', str, 'foo'), + # SPARK-32130: 6.x: "4 weeks 2 days" vs 7.x: "30 days" + # ("interval 30 days", str, str, "interval 4 weeks 2 days"), + # ("interval 3 days", str, str, "interval 3 days"), + ("CAST(NULL AS DOUBLE)", 'double', type(None), None), + ] + + # Full queries, only the first column of the first row is checked + queries = [("NULL UNION (SELECT 1) order by 1", 'integer', type(None), None)] + + def run_tests_on_queries(self, default_conf): + failures = [] + for (query, columnType, rowValueType, answer) in self.range_queries: + with self.cursor(default_conf) as cursor: + failures.extend( + self.run_query(cursor, query, columnType, rowValueType, answer, default_conf)) + failures.extend( + self.run_range_query(cursor, query, columnType, rowValueType, answer, + default_conf)) + + for (query, columnType, rowValueType, answer) in self.queries: + with self.cursor(default_conf) as cursor: + failures.extend( + self.run_query(cursor, query, columnType, rowValueType, answer, default_conf)) + + if failures: + self.fail("Failed testing result set with Arrow. " + "Failed queries: {}".format("\n\n".join([str(f) for f in failures]))) + + def run_query(self, cursor, query, columnType, rowValueType, answer, conf): + full_query = "SELECT {}".format(query) + expected_column_types = self.expected_column_types(columnType) + try: + cursor.execute(full_query) + (result, ) = cursor.fetchone() + if not all(cursor.description[0][1] == type for type in expected_column_types): + return [ + TypeFailure(full_query, expected_column_types, rowValueType, answer, result, + type(result), cursor.description, conf) + ] + if self.validate_row_value_type and type(result) is not rowValueType: + return [ + TypeFailure(full_query, expected_column_types, rowValueType, answer, result, + type(result), cursor.description, conf) + ] + if self.validate_result and str(answer) != str(result): + return [ + ResultFailure(full_query, query, expected_column_types, rowValueType, answer, + result, type(result), cursor.description, conf) + ] + return [] + except Exception as e: + return [ + ExecFailure(full_query, columnType, rowValueType, None, None, None, + cursor.description, conf, e) + ] + + def run_range_query(self, cursor, query, columnType, rowValueType, expected, conf): + full_query = "SELECT {}, id FROM RANGE({})".format(query, 5000) + expected_column_types = self.expected_column_types(columnType) + try: + cursor.execute(full_query) + while True: + rows = cursor.fetchmany(1000) + if len(rows) <= 0: + break + for index, (result, id) in enumerate(rows): + if not all(cursor.description[0][1] == type for type in expected_column_types): + return [ + TypeFailure(full_query, expected_column_types, rowValueType, expected, + result, type(result), cursor.description, conf) + ] + if self.validate_row_value_type and type(result) \ + is not rowValueType: + return [ + TypeFailure(full_query, expected_column_types, rowValueType, expected, + result, type(result), cursor.description, conf) + ] + if self.validate_result and str(expected) != str(result): + return [ + ResultFailure(full_query, expected_column_types, rowValueType, expected, + result, type(result), cursor.description, conf) + ] + return [] + except Exception as e: + return [ + ExecFailure(full_query, columnType, rowValueType, None, None, None, + cursor.description, conf, e) + ] diff --git a/tests/e2e/common/decimal_tests.py b/tests/e2e/common/decimal_tests.py new file mode 100644 index 000000000..8051d2a18 --- /dev/null +++ b/tests/e2e/common/decimal_tests.py @@ -0,0 +1,48 @@ +from decimal import Decimal + +import pyarrow + + +class DecimalTestsMixin: + decimal_and_expected_results = [ + ("100.001 AS DECIMAL(6, 3)", Decimal("100.001"), pyarrow.decimal128(6, 3)), + ("1000000.0000 AS DECIMAL(11, 4)", Decimal("1000000.0000"), pyarrow.decimal128(11, 4)), + ("-10.2343 AS DECIMAL(10, 6)", Decimal("-10.234300"), pyarrow.decimal128(10, 6)), + # TODO(SC-90767): Re-enable this test after we have a way of passing `ansi_mode` = False + #("-13872347.2343 AS DECIMAL(10, 10)", None, pyarrow.decimal128(10, 10)), + ("NULL AS DECIMAL(1, 1)", None, pyarrow.decimal128(1, 1)), + ("1 AS DECIMAL(1, 0)", Decimal("1"), pyarrow.decimal128(1, 0)), + ("0.00000 AS DECIMAL(5, 3)", Decimal("0.000"), pyarrow.decimal128(5, 3)), + ("1e-3 AS DECIMAL(38, 3)", Decimal("0.001"), pyarrow.decimal128(38, 3)), + ] + + multi_decimals_and_expected_results = [ + (["1 AS DECIMAL(6, 3)", "100.001 AS DECIMAL(6, 3)", "NULL AS DECIMAL(6, 3)"], + [Decimal("1.00"), Decimal("100.001"), None], pyarrow.decimal128(6, 3)), + (["1 AS DECIMAL(6, 3)", "2 AS DECIMAL(5, 2)"], [Decimal('1.000'), + Decimal('2.000')], pyarrow.decimal128(6, + 3)), + ] + + def test_decimals(self): + with self.cursor({}) as cursor: + for (decimal, expected_value, expected_type) in self.decimal_and_expected_results: + query = "SELECT CAST ({})".format(decimal) + with self.subTest(query=query): + cursor.execute(query) + table = cursor.fetchmany_arrow(1) + self.assertEqual(table.field(0).type, expected_type) + self.assertEqual(table.to_pydict().popitem()[1][0], expected_value) + + def test_multi_decimals(self): + with self.cursor({}) as cursor: + for (decimals, expected_values, + expected_type) in self.multi_decimals_and_expected_results: + union_str = " UNION ".join(["(SELECT CAST ({}))".format(dec) for dec in decimals]) + query = "SELECT * FROM ({}) ORDER BY 1 NULLS LAST".format(union_str) + + with self.subTest(query=query): + cursor.execute(query) + table = cursor.fetchall_arrow() + self.assertEqual(table.field(0).type, expected_type) + self.assertEqual(table.to_pydict().popitem()[1], expected_values) diff --git a/tests/e2e/common/large_queries_mixin.py b/tests/e2e/common/large_queries_mixin.py new file mode 100644 index 000000000..d59e0a9fe --- /dev/null +++ b/tests/e2e/common/large_queries_mixin.py @@ -0,0 +1,100 @@ +import logging +import math +import time + +log = logging.getLogger(__name__) + + +class LargeQueriesMixin: + """ + This mixin expects to be mixed with a CursorTest-like class + """ + + def fetch_rows(self, cursor, row_count, fetchmany_size): + """ + A generator for rows. Fetches until the end or up to 5 minutes. + """ + # TODO: Remove fetchmany_size when we have fixed the performance issues with fetchone + # in the Python client + max_fetch_time = 5 * 60 # Fetch for at most 5 minutes + + rows = self.get_some_rows(cursor, fetchmany_size) + start_time = time.time() + n = 0 + while rows: + for row in rows: + n += 1 + yield row + if time.time() - start_time >= max_fetch_time: + log.warning("Fetching rows timed out") + break + rows = self.get_some_rows(cursor, fetchmany_size) + if not rows: + # Read all the rows, row_count should match + self.assertEqual(n, row_count) + + num_fetches = max(math.ceil(n / 10000), 1) + latency_ms = int((time.time() - start_time) * 1000 / num_fetches), 1 + print('Fetched {} rows with an avg latency of {} per fetch, '.format(n, latency_ms) + + 'assuming 10K fetch size.') + + def test_query_with_large_wide_result_set(self): + resultSize = 300 * 1000 * 1000 # 300 MB + width = 8192 # B + rows = resultSize // width + cols = width // 36 + + # Set the fetchmany_size to get 10MB of data a go + fetchmany_size = 10 * 1024 * 1024 // width + # This is used by PyHive tests to determine the buffer size + self.arraysize = 1000 + with self.cursor() as cursor: + uuids = ", ".join(["uuid() uuid{}".format(i) for i in range(cols)]) + cursor.execute("SELECT id, {uuids} FROM RANGE({rows})".format(uuids=uuids, rows=rows)) + for row_id, row in enumerate(self.fetch_rows(cursor, rows, fetchmany_size)): + self.assertEqual(row[0], row_id) # Verify no rows are dropped in the middle. + self.assertEqual(len(row[1]), 36) + + def test_query_with_large_narrow_result_set(self): + resultSize = 300 * 1000 * 1000 # 300 MB + width = 8 # sizeof(long) + rows = resultSize / width + + # Set the fetchmany_size to get 10MB of data a go + fetchmany_size = 10 * 1024 * 1024 // width + # This is used by PyHive tests to determine the buffer size + self.arraysize = 10000000 + with self.cursor() as cursor: + cursor.execute("SELECT * FROM RANGE({rows})".format(rows=rows)) + for row_id, row in enumerate(self.fetch_rows(cursor, rows, fetchmany_size)): + self.assertEqual(row[0], row_id) + + def test_long_running_query(self): + """ Incrementally increase query size until it takes at least 5 minutes, + and asserts that the query completes successfully. + """ + minutes = 60 + min_duration = 5 * minutes + + duration = -1 + scale0 = 10000 + scale_factor = 1 + with self.cursor() as cursor: + while duration < min_duration: + self.assertLess(scale_factor, 512, msg="Detected infinite loop") + start = time.time() + + cursor.execute("""SELECT count(*) + FROM RANGE({scale}) x + JOIN RANGE({scale0}) y + ON from_unixtime(x.id * y.id, "yyyy-MM-dd") LIKE "%not%a%date%" + """.format(scale=scale_factor * scale0, scale0=scale0)) + + n, = cursor.fetchone() + self.assertEqual(n, 0) + + duration = time.time() - start + current_fraction = duration / min_duration + print('Took {} s with scale factor={}'.format(duration, scale_factor)) + # Extrapolate linearly to reach 5 min and add 50% padding to push over the limit + scale_factor = math.ceil(1.5 * scale_factor / current_fraction) diff --git a/tests/e2e/common/predicates.py b/tests/e2e/common/predicates.py new file mode 100644 index 000000000..3450087f5 --- /dev/null +++ b/tests/e2e/common/predicates.py @@ -0,0 +1,101 @@ +import functools +from packaging.version import parse as parse_version +import unittest + +MAJOR_DBR_V_KEY = "major_dbr_version" +MINOR_DBR_V_KEY = "minor_dbr_version" +ENDPOINT_TEST_KEY = "is_endpoint_test" + + +def pysql_supports_arrow(): + """Import databricks.sql and test whether Cursor has fetchall_arrow.""" + from databricks.sql.client import Cursor + return hasattr(Cursor, 'fetchall_arrow') + + +def pysql_has_version(compare, version): + """Import databricks.sql, and return compare_module_version(...). + + Expected use: + from common.predicates import pysql_has_version + from databricks import sql as pysql + ... + @unittest.skipIf(pysql_has_version('<', '2')) + def test_some_pyhive_v1_stuff(): + ... + """ + from databricks import sql + return compare_module_version(sql, compare, version) + + +def is_endpoint_test(cli_args=None): + + # Currently only supporting tests against DBSQL Endpoints + # So we don't read `is_endpoint_test` from the CLI args + return True + + +def compare_dbr_versions(cli_args, compare, major_version, minor_version): + if MAJOR_DBR_V_KEY in cli_args and MINOR_DBR_V_KEY in cli_args: + if cli_args[MINOR_DBR_V_KEY] == "x": + actual_minor_v = float('inf') + else: + actual_minor_v = int(cli_args[MINOR_DBR_V_KEY]) + dbr_version = (int(cli_args[MAJOR_DBR_V_KEY]), actual_minor_v) + req_version = (major_version, minor_version) + return compare_versions(compare, dbr_version, req_version) + + if not is_endpoint_test(): + raise ValueError( + "DBR version not provided for non-endpoint test. Please pass the {} and {} params". + format(MAJOR_DBR_V_KEY, MINOR_DBR_V_KEY)) + + +def is_thrift_v5_plus(cli_args): + return compare_dbr_versions(cli_args, ">=", 10, 2) or is_endpoint_test(cli_args) + + +_compare_fns = { + '<': '__lt__', + '<=': '__le__', + '>': '__gt__', + '>=': '__ge__', + '==': '__eq__', + '!=': '__ne__', +} + + +def compare_versions(compare, v1_tuple, v2_tuple): + compare_fn_name = _compare_fns.get(compare) + assert compare_fn_name, 'Received invalid compare string: ' + compare + return getattr(v1_tuple, compare_fn_name)(v2_tuple) + + +def compare_module_version(module, compare, version): + """Compare `module`'s version as specified, returning True/False. + + @unittest.skipIf(compare_module_version(sql, '<', '2')) + def test_some_pyhive_v1_stuff(): + ... + + `module`: the module whose version will be compared + `compare`: one of '<', '<=', '>', '>=', '==', '!=' + `version`: a version string, of the form 'x[.y[.z]] + + Asserts module and compare to be truthy, and casts version to string. + + NOTE: This comparison leverages packaging.version.parse, and compares _release_ versions, + thus ignoring pre/post release tags (eg -rc1, -dev, etc). + """ + assert module, 'Received invalid module: ' + module + assert getattr(module, '__version__'), 'Received module with no version: ' + module + + def validate_version(version): + v = parse_version(str(version)) + # assert that we get a PEP-440 Version back -- LegacyVersion doesn't have major/minor. + assert hasattr(v, 'major'), 'Module has incompatible "Legacy" version: ' + version + return (v.major, v.minor, v.micro) + + mod_version = validate_version(module.__version__) + req_version = validate_version(version) + return compare_versions(compare, mod_version, req_version) diff --git a/tests/e2e/common/retry_test_mixins.py b/tests/e2e/common/retry_test_mixins.py new file mode 100644 index 000000000..a088ba1e3 --- /dev/null +++ b/tests/e2e/common/retry_test_mixins.py @@ -0,0 +1,38 @@ +class Client429ResponseMixin: + def test_client_should_retry_automatically_when_getting_429(self): + with self.cursor() as cursor: + for _ in range(10): + cursor.execute("SELECT 1") + rows = cursor.fetchall() + self.assertEqual(len(rows), 1) + self.assertEqual(rows[0][0], 1) + + def test_client_should_not_retry_429_if_RateLimitRetry_is_0(self): + with self.assertRaises(self.error_type) as cm: + with self.cursor(self.conf_to_disable_rate_limit_retries) as cursor: + for _ in range(10): + cursor.execute("SELECT 1") + rows = cursor.fetchall() + self.assertEqual(len(rows), 1) + self.assertEqual(rows[0][0], 1) + expected = "Maximum rate of 1 requests per SECOND has been exceeded. " \ + "Please reduce the rate of requests and try again after 1 seconds." + exception_str = str(cm.exception) + + # FIXME (Ali Smesseim, 7-Jul-2020): ODBC driver does not always return the + # X-Thriftserver-Error-Message as-is. Re-enable once Simba resolves this flakiness. + # Simba support ticket: https://magnitudesoftware.force.com/support/5001S000018RlaD + # self.assertIn(expected, exception_str) + + +class Client503ResponseMixin: + def test_wait_cluster_startup(self): + with self.cursor() as cursor: + cursor.execute("SELECT 1") + cursor.fetchall() + + def _test_retry_disabled_with_message(self, error_msg_substring, exception_type): + with self.assertRaises(exception_type) as cm: + with self.connection(self.conf_to_disable_temporarily_unavailable_retries): + pass + self.assertIn(error_msg_substring, str(cm.exception)) diff --git a/tests/e2e/common/timestamp_tests.py b/tests/e2e/common/timestamp_tests.py new file mode 100644 index 000000000..38b14e9e8 --- /dev/null +++ b/tests/e2e/common/timestamp_tests.py @@ -0,0 +1,74 @@ +import datetime + +from .predicates import compare_dbr_versions, is_thrift_v5_plus, pysql_has_version + + +class TimestampTestsMixin: + timestamp_and_expected_results = [ + ('2021-09-30 11:27:35.123+04:00', datetime.datetime(2021, 9, 30, 7, 27, 35, 123000)), + ('2021-09-30 11:27:35+04:00', datetime.datetime(2021, 9, 30, 7, 27, 35)), + ('2021-09-30 11:27:35.123', datetime.datetime(2021, 9, 30, 11, 27, 35, 123000)), + ('2021-09-30 11:27:35', datetime.datetime(2021, 9, 30, 11, 27, 35)), + ('2021-09-30 11:27', datetime.datetime(2021, 9, 30, 11, 27)), + ('2021-09-30 11', datetime.datetime(2021, 9, 30, 11)), + ('2021-09-30', datetime.datetime(2021, 9, 30)), + ('2021-09', datetime.datetime(2021, 9, 1)), + ('2021', datetime.datetime(2021, 1, 1)), + ('9999-12-31T15:59:59', datetime.datetime(9999, 12, 31, 15, 59, 59)), + ('9999-99-31T15:59:59', None), + ] + + date_and_expected_results = [ + ('2021-09-30', datetime.date(2021, 9, 30)), + ('2021-09', datetime.date(2021, 9, 1)), + ('2021', datetime.date(2021, 1, 1)), + ('9999-12-31', datetime.date(9999, 12, 31)), + ('9999-99-31', None), + ] + + def should_add_timezone(self): + return pysql_has_version(">=", 2) and is_thrift_v5_plus(self.arguments) + + def maybe_add_timezone_to_timestamp(self, ts): + """If we're using DBR >= 10.2, then we expect back aware timestamps, so add timezone to `ts` + Otherwise we have naive timestamps, so no change is needed + """ + if ts and self.should_add_timezone(): + return ts.replace(tzinfo=datetime.timezone.utc) + else: + return ts + + def assertTimestampsEqual(self, result, expected): + self.assertEqual(result, self.maybe_add_timezone_to_timestamp(expected)) + + def multi_query(self, n_rows=10): + row_sql = "SELECT " + ", ".join( + ["TIMESTAMP('{}')".format(ts) for (ts, _) in self.timestamp_and_expected_results]) + query = " UNION ALL ".join([row_sql for _ in range(n_rows)]) + expected_matrix = [[dt for (_, dt) in self.timestamp_and_expected_results] + for _ in range(n_rows)] + return query, expected_matrix + + def test_timestamps(self): + with self.cursor({"session_configuration": {"ansi_mode": False}}) as cursor: + for (timestamp, expected) in self.timestamp_and_expected_results: + cursor.execute("SELECT TIMESTAMP('{timestamp}')".format(timestamp=timestamp)) + result = cursor.fetchone()[0] + self.assertTimestampsEqual(result, expected) + + def test_multi_timestamps(self): + with self.cursor({"session_configuration": {"ansi_mode": False}}) as cursor: + query, expected = self.multi_query() + cursor.execute(query) + result = cursor.fetchall() + # We list-ify the rows because PyHive will return a tuple for a row + self.assertEqual([list(r) for r in result], + [[self.maybe_add_timezone_to_timestamp(ts) for ts in r] + for r in expected]) + + def test_dates(self): + with self.cursor({"session_configuration": {"ansi_mode": False}}) as cursor: + for (date, expected) in self.date_and_expected_results: + cursor.execute("SELECT DATE('{date}')".format(date=date)) + result = cursor.fetchone()[0] + self.assertEqual(result, expected) diff --git a/tests/e2e/driver_tests.py b/tests/e2e/driver_tests.py new file mode 100644 index 000000000..358f0b263 --- /dev/null +++ b/tests/e2e/driver_tests.py @@ -0,0 +1,589 @@ +from contextlib import contextmanager +from collections import OrderedDict +import datetime +import io +import logging +import os +import sys +import threading +import time +from unittest import loader, skipIf, skipUnless, TestCase +from uuid import uuid4 + +import numpy as np +import pyarrow +import pytz +import thrift + +import databricks.sql as sql +from databricks.sql import STRING, BINARY, NUMBER, DATETIME, DATE, DatabaseError, Error, OperationalError +from tests.e2e.common.predicates import pysql_has_version, pysql_supports_arrow, compare_dbr_versions, is_thrift_v5_plus +from tests.e2e.common.core_tests import CoreTestMixin, SmokeTestMixin +from tests.e2e.common.large_queries_mixin import LargeQueriesMixin +from tests.e2e.common.timestamp_tests import TimestampTestsMixin +from tests.e2e.common.decimal_tests import DecimalTestsMixin +from tests.e2e.common.retry_test_mixins import Client429ResponseMixin, Client503ResponseMixin + +log = logging.getLogger(__name__) + +# manually decorate DecimalTestsMixin to need arrow support +for name in loader.getTestCaseNames(DecimalTestsMixin, 'test_'): + fn = getattr(DecimalTestsMixin, name) + decorated = skipUnless(pysql_supports_arrow(), 'Decimal tests need arrow support')(fn) + setattr(DecimalTestsMixin, name, decorated) + +get_args_from_env = True + + +class PySQLTestCase(TestCase): + error_type = Error + conf_to_disable_rate_limit_retries = {"_retry_stop_after_attempts_count": 1} + conf_to_disable_temporarily_unavailable_retries = {"_retry_stop_after_attempts_count": 1} + + def __init__(self, method_name): + super().__init__(method_name) + # If running in local mode, just use environment variables for params. + self.arguments = os.environ if get_args_from_env else {} + self.arraysize = 1000 + + def connection_params(self, arguments): + params = { + "server_hostname": arguments["host"], + "http_path": arguments["http_path"], + **self.auth_params(arguments) + } + + return params + + def auth_params(self, arguments): + return { + "_username": arguments.get("rest_username"), + "_password": arguments.get("rest_password"), + "access_token": arguments.get("access_token") + } + + @contextmanager + def connection(self, extra_params=()): + connection_params = dict(self.connection_params(self.arguments), **dict(extra_params)) + + log.info("Connecting with args: {}".format(connection_params)) + conn = sql.connect(**connection_params) + + try: + yield conn + finally: + conn.close() + + @contextmanager + def cursor(self, extra_params=()): + with self.connection(extra_params) as conn: + cursor = conn.cursor(arraysize=self.arraysize) + try: + yield cursor + finally: + cursor.close() + + def assertEqualRowValues(self, actual, expected): + self.assertEqual(len(actual) if actual else 0, len(expected) if expected else 0) + for act, exp in zip(actual, expected): + self.assertSequenceEqual(act, exp) + + +class PySQLLargeQueriesSuite(PySQLTestCase, LargeQueriesMixin): + def get_some_rows(self, cursor, fetchmany_size): + row = cursor.fetchone() + if row: + return [row] + else: + return None + + +# Exclude Retry tests because they require specific setups, and LargeQueries too slow for core +# tests +class PySQLCoreTestSuite(SmokeTestMixin, CoreTestMixin, DecimalTestsMixin, TimestampTestsMixin, + PySQLTestCase): + validate_row_value_type = True + validate_result = True + + # An output column in description evaluates to equal to multiple types + # - type code returned by the client as string. + # - also potentially a PEP-249 object like NUMBER, DATETIME etc. + def expected_column_types(self, type_): + type_mappings = { + 'boolean': ['boolean', NUMBER], + 'byte': ['tinyint', NUMBER], + 'short': ['smallint', NUMBER], + 'integer': ['int', NUMBER], + 'long': ['bigint', NUMBER], + 'decimal': ['decimal', NUMBER], + 'timestamp': ['timestamp', DATETIME], + 'date': ['date', DATE], + 'binary': ['binary', BINARY], + 'string': ['string', STRING], + 'array': ['array'], + 'struct': ['struct'], + 'map': ['map'], + 'double': ['double', NUMBER], + 'null': ['null'] + } + return type_mappings[type_] + + def test_queries(self): + if not self._should_have_native_complex_types(): + array_type = str + array_val = "[1,2,3]" + struct_type = str + struct_val = "{\"a\":1,\"b\":2}" + map_type = str + map_val = "{1:2,3:4}" + else: + array_type = np.ndarray + array_val = np.array([1, 2, 3]) + struct_type = dict + struct_val = {"a": 1, "b": 2} + map_type = list + map_val = [(1, 2), (3, 4)] + + null_type = "null" if float(sql.__version__[0:2]) < 2.0 else "string" + self.range_queries = CoreTestMixin.range_queries + [ + ("NULL", null_type, type(None), None), + ("array(1, 2, 3)", 'array', array_type, array_val), + ("struct(1 as a, 2 as b)", 'struct', struct_type, struct_val), + ("map(1, 2, 3, 4)", 'map', map_type, map_val), + ] + + self.run_tests_on_queries({}) + + @skipIf(pysql_has_version('<', '2'), 'requires pysql v2') + def test_incorrect_query_throws_exception(self): + with self.cursor({}) as cursor: + # Syntax errors should contain the invalid SQL + with self.assertRaises(DatabaseError) as cm: + cursor.execute("^ FOO BAR") + self.assertIn("FOO BAR", str(cm.exception)) + + # Database error should contain the missing database + with self.assertRaises(DatabaseError) as cm: + cursor.execute("USE foo234823498ydfsiusdhf") + self.assertIn("foo234823498ydfsiusdhf", str(cm.exception)) + + # SQL with Extraneous input should send back the extraneous input + with self.assertRaises(DatabaseError) as cm: + cursor.execute("CREATE TABLE IF NOT EXISTS TABLE table_234234234") + self.assertIn("table_234234234", str(cm.exception)) + + def test_create_table_will_return_empty_result_set(self): + with self.cursor({}) as cursor: + table_name = 'table_{uuid}'.format(uuid=str(uuid4()).replace('-', '_')) + try: + cursor.execute( + "CREATE TABLE IF NOT EXISTS {} AS (SELECT 1 AS col_1, '2' AS col_2)".format( + table_name)) + self.assertEqual(cursor.fetchall(), []) + finally: + cursor.execute("DROP TABLE IF EXISTS {}".format(table_name)) + + def test_get_tables(self): + with self.cursor({}) as cursor: + table_name = 'table_{uuid}'.format(uuid=str(uuid4()).replace('-', '_')) + table_names = [table_name + '_1', table_name + '_2'] + + try: + for table in table_names: + cursor.execute( + "CREATE TABLE IF NOT EXISTS {} AS (SELECT 1 AS col_1, '2' AS col_2)".format( + table)) + cursor.tables(schema_name="defa%") + tables = cursor.fetchall() + tables_desc = cursor.description + + for table in table_names: + # Test only schema name and table name. + # From other columns, what is supported depends on DBR version. + self.assertIn(['default', table], [list(table[1:3]) for table in tables]) + self.assertEqual( + tables_desc, + [('TABLE_CAT', 'string', None, None, None, None, None), + ('TABLE_SCHEM', 'string', None, None, None, None, None), + ('TABLE_NAME', 'string', None, None, None, None, None), + ('TABLE_TYPE', 'string', None, None, None, None, None), + ('REMARKS', 'string', None, None, None, None, None), + ('TYPE_CAT', 'string', None, None, None, None, None), + ('TYPE_SCHEM', 'string', None, None, None, None, None), + ('TYPE_NAME', 'string', None, None, None, None, None), + ('SELF_REFERENCING_COL_NAME', 'string', None, None, None, None, None), + ('REF_GENERATION', 'string', None, None, None, None, None)]) + finally: + for table in table_names: + cursor.execute('DROP TABLE IF EXISTS {}'.format(table)) + + def test_get_columns(self): + with self.cursor({}) as cursor: + table_name = 'table_{uuid}'.format(uuid=str(uuid4()).replace('-', '_')) + table_names = [table_name + '_1', table_name + '_2'] + + try: + for table in table_names: + cursor.execute("CREATE TABLE IF NOT EXISTS {} AS (SELECT " + "1 AS col_1, " + "'2' AS col_2, " + "named_struct('name', 'alice', 'age', 28) as col_3, " + "map('items', 45, 'cost', 228) as col_4, " + "array('item1', 'item2', 'item3') as col_5)".format(table)) + + cursor.columns(schema_name="defa%", table_name=table_name + '%') + cols = cursor.fetchall() + cols_desc = cursor.description + + # Catalogue name not consistent across DBR versions, so we skip that + cleaned_response = [list(col[1:6]) for col in cols] + # We also replace ` as DBR changes how it represents struct names + for col in cleaned_response: + col[4] = col[4].replace("`", "") + + self.assertEqual(cleaned_response, [ + ['default', table_name + '_1', 'col_1', 4, 'INT'], + ['default', table_name + '_1', 'col_2', 12, 'STRING'], + ['default', table_name + '_1', 'col_3', 2002, 'STRUCT'], + ['default', table_name + '_1', 'col_4', 2000, 'MAP'], + ['default', table_name + '_1', 'col_5', 2003, 'ARRAY'], + ['default', table_name + '_2', 'col_1', 4, 'INT'], + ['default', table_name + '_2', 'col_2', 12, 'STRING'], + ['default', table_name + '_2', 'col_3', 2002, 'STRUCT'], + ['default', table_name + '_2', 'col_4', 2000, 'MAP'], + [ + 'default', + table_name + '_2', + 'col_5', + 2003, + 'ARRAY', + ] + ]) + + self.assertEqual(cols_desc, + [('TABLE_CAT', 'string', None, None, None, None, None), + ('TABLE_SCHEM', 'string', None, None, None, None, None), + ('TABLE_NAME', 'string', None, None, None, None, None), + ('COLUMN_NAME', 'string', None, None, None, None, None), + ('DATA_TYPE', 'int', None, None, None, None, None), + ('TYPE_NAME', 'string', None, None, None, None, None), + ('COLUMN_SIZE', 'int', None, None, None, None, None), + ('BUFFER_LENGTH', 'tinyint', None, None, None, None, None), + ('DECIMAL_DIGITS', 'int', None, None, None, None, None), + ('NUM_PREC_RADIX', 'int', None, None, None, None, None), + ('NULLABLE', 'int', None, None, None, None, None), + ('REMARKS', 'string', None, None, None, None, None), + ('COLUMN_DEF', 'string', None, None, None, None, None), + ('SQL_DATA_TYPE', 'int', None, None, None, None, None), + ('SQL_DATETIME_SUB', 'int', None, None, None, None, None), + ('CHAR_OCTET_LENGTH', 'int', None, None, None, None, None), + ('ORDINAL_POSITION', 'int', None, None, None, None, None), + ('IS_NULLABLE', 'string', None, None, None, None, None), + ('SCOPE_CATALOG', 'string', None, None, None, None, None), + ('SCOPE_SCHEMA', 'string', None, None, None, None, None), + ('SCOPE_TABLE', 'string', None, None, None, None, None), + ('SOURCE_DATA_TYPE', 'smallint', None, None, None, None, None), + ('IS_AUTO_INCREMENT', 'string', None, None, None, None, None)]) + finally: + for table in table_names: + cursor.execute('DROP TABLE IF EXISTS {}'.format(table)) + + def test_get_schemas(self): + with self.cursor({}) as cursor: + database_name = 'db_{uuid}'.format(uuid=str(uuid4()).replace('-', '_')) + try: + cursor.execute('CREATE DATABASE IF NOT EXISTS {}'.format(database_name)) + cursor.schemas() + schemas = cursor.fetchall() + schemas_desc = cursor.description + # Catalogue name not consistent across DBR versions, so we skip that + self.assertIn(database_name, [schema[0] for schema in schemas]) + self.assertEqual(schemas_desc, + [('TABLE_SCHEM', 'string', None, None, None, None, None), + ('TABLE_CATALOG', 'string', None, None, None, None, None)]) + finally: + cursor.execute('DROP DATABASE IF EXISTS {}'.format(database_name)) + + def test_get_catalogs(self): + with self.cursor({}) as cursor: + cursor.catalogs() + cursor.fetchall() + catalogs_desc = cursor.description + self.assertEqual(catalogs_desc, [('TABLE_CAT', 'string', None, None, None, None, None)]) + + @skipUnless(pysql_supports_arrow(), 'arrow test need arrow support') + def test_get_arrow(self): + # These tests are quite light weight as the arrow fetch methods are used internally + # by everything else + with self.cursor({}) as cursor: + cursor.execute("SELECT * FROM range(10)") + table_1 = cursor.fetchmany_arrow(1).to_pydict() + self.assertEqual(table_1, OrderedDict([("id", [0])])) + + table_2 = cursor.fetchall_arrow().to_pydict() + self.assertEqual(table_2, OrderedDict([("id", [1, 2, 3, 4, 5, 6, 7, 8, 9])])) + + def test_unicode(self): + unicode_str = "数据砖" + with self.cursor({}) as cursor: + cursor.execute("SELECT '{}'".format(unicode_str)) + results = cursor.fetchall() + self.assertTrue(len(results) == 1 and len(results[0]) == 1) + self.assertEqual(results[0][0], unicode_str) + + def test_cancel_during_execute(self): + with self.cursor({}) as cursor: + + def execute_really_long_query(): + cursor.execute("SELECT SUM(A.id - B.id) " + + "FROM range(1000000000) A CROSS JOIN range(100000000) B " + + "GROUP BY (A.id - B.id)") + + exec_thread = threading.Thread(target=execute_really_long_query) + + exec_thread.start() + # Make sure the query has started before cancelling + time.sleep(15) + cursor.cancel() + exec_thread.join(5) + self.assertFalse(exec_thread.is_alive()) + + # Fetching results should throw an exception + with self.assertRaises((Error, thrift.Thrift.TException)): + cursor.fetchall() + with self.assertRaises((Error, thrift.Thrift.TException)): + cursor.fetchone() + with self.assertRaises((Error, thrift.Thrift.TException)): + cursor.fetchmany(10) + + # We should be able to execute a new command on the cursor + cursor.execute("SELECT * FROM range(3)") + self.assertEqual(len(cursor.fetchall()), 3) + + @skipIf(pysql_has_version('<', '2'), 'requires pysql v2') + def test_can_execute_command_after_failure(self): + with self.cursor({}) as cursor: + with self.assertRaises(DatabaseError): + cursor.execute("this is a sytnax error") + + cursor.execute("SELECT 1;") + + res = cursor.fetchall() + self.assertEqualRowValues(res, [[1]]) + + @skipIf(pysql_has_version('<', '2'), 'requires pysql v2') + def test_can_execute_command_after_success(self): + with self.cursor({}) as cursor: + cursor.execute("SELECT 1;") + cursor.execute("SELECT 2;") + + res = cursor.fetchall() + self.assertEqualRowValues(res, [[2]]) + + def generate_multi_row_query(self): + query = "SELECT * FROM range(3);" + return query + + @skipIf(pysql_has_version('<', '2'), 'requires pysql v2') + def test_fetchone(self): + with self.cursor({}) as cursor: + query = self.generate_multi_row_query() + cursor.execute(query) + + self.assertSequenceEqual(cursor.fetchone(), [0]) + self.assertSequenceEqual(cursor.fetchone(), [1]) + self.assertSequenceEqual(cursor.fetchone(), [2]) + + self.assertEqual(cursor.fetchone(), None) + + @skipIf(pysql_has_version('<', '2'), 'requires pysql v2') + def test_fetchall(self): + with self.cursor({}) as cursor: + query = self.generate_multi_row_query() + cursor.execute(query) + + self.assertEqualRowValues(cursor.fetchall(), [[0], [1], [2]]) + + self.assertEqual(cursor.fetchone(), None) + + @skipIf(pysql_has_version('<', '2'), 'requires pysql v2') + def test_fetchmany_when_stride_fits(self): + with self.cursor({}) as cursor: + query = "SELECT * FROM range(4)" + cursor.execute(query) + + self.assertEqualRowValues(cursor.fetchmany(2), [[0], [1]]) + self.assertEqualRowValues(cursor.fetchmany(2), [[2], [3]]) + + @skipIf(pysql_has_version('<', '2'), 'requires pysql v2') + def test_fetchmany_in_excess(self): + with self.cursor({}) as cursor: + query = "SELECT * FROM range(4)" + cursor.execute(query) + + self.assertEqualRowValues(cursor.fetchmany(3), [[0], [1], [2]]) + self.assertEqualRowValues(cursor.fetchmany(3), [[3]]) + + @skipIf(pysql_has_version('<', '2'), 'requires pysql v2') + def test_iterator_api(self): + with self.cursor({}) as cursor: + query = "SELECT * FROM range(4)" + cursor.execute(query) + + expected_results = [[0], [1], [2], [3]] + for (i, row) in enumerate(cursor): + self.assertSequenceEqual(row, expected_results[i]) + + def test_temp_view_fetch(self): + with self.cursor({}) as cursor: + query = "create temporary view f as select * from range(10)" + cursor.execute(query) + # TODO assert on a result + # once what is being returned has stabilised + + @skipIf(pysql_has_version('<', '2'), 'requires pysql v2') + def test_socket_timeout(self): + # We we expect to see a BlockingIO error when the socket is opened + # in non-blocking mode, since no poll is done before the read + with self.assertRaises(OperationalError) as cm: + with self.cursor({"_socket_timeout": 0}): + pass + + self.assertIsInstance(cm.exception.args[1], io.BlockingIOError) + + def test_ssp_passthrough(self): + for enable_ansi in (True, False): + with self.cursor({"session_configuration": {"ansi_mode": enable_ansi}}) as cursor: + cursor.execute("SET ansi_mode") + self.assertEqual(list(cursor.fetchone()), ["ansi_mode", str(enable_ansi)]) + + @skipUnless(pysql_supports_arrow(), 'arrow test needs arrow support') + def test_timestamps_arrow(self): + with self.cursor({"session_configuration": {"ansi_mode": False}}) as cursor: + for (timestamp, expected) in self.timestamp_and_expected_results: + cursor.execute("SELECT TIMESTAMP('{timestamp}')".format(timestamp=timestamp)) + arrow_table = cursor.fetchmany_arrow(1) + if self.should_add_timezone(): + ts_type = pyarrow.timestamp("us", tz="Etc/UTC") + else: + ts_type = pyarrow.timestamp("us") + self.assertEqual(arrow_table.field(0).type, ts_type) + result_value = arrow_table.column(0).combine_chunks()[0].value + # To work consistently across different local timezones, we specify the timezone + # of the expected result to + # be UTC (what it should be by default on the server) + aware_timestamp = expected and expected.replace(tzinfo=datetime.timezone.utc) + self.assertEqual(result_value, aware_timestamp and + aware_timestamp.timestamp() * 1000000, + "timestamp {} did not match {}".format(timestamp, expected)) + + @skipUnless(pysql_supports_arrow(), 'arrow test needs arrow support') + def test_multi_timestamps_arrow(self): + with self.cursor({"session_configuration": {"ansi_mode": False}}) as cursor: + query, expected = self.multi_query() + expected = [[self.maybe_add_timezone_to_timestamp(ts) for ts in row] + for row in expected] + cursor.execute(query) + table = cursor.fetchall_arrow() + # Transpose columnar result to list of rows + list_of_cols = [c.to_pylist() for c in table] + result = [[col[row_index] for col in list_of_cols] + for row_index in range(table.num_rows)] + self.assertEqual(result, expected) + + @skipUnless(pysql_supports_arrow(), 'arrow test needs arrow support') + def test_timezone_with_timestamp(self): + if self.should_add_timezone(): + with self.cursor() as cursor: + cursor.execute("SET TIME ZONE 'Europe/Amsterdam'") + cursor.execute("select CAST('2022-03-02 12:54:56' as TIMESTAMP)") + amsterdam = pytz.timezone("Europe/Amsterdam") + expected = amsterdam.localize(datetime.datetime(2022, 3, 2, 12, 54, 56)) + result = cursor.fetchone()[0] + self.assertEqual(result, expected) + + cursor.execute("select CAST('2022-03-02 12:54:56' as TIMESTAMP)") + arrow_result_table = cursor.fetchmany_arrow(1) + arrow_result_value = arrow_result_table.column(0).combine_chunks()[0].value + ts_type = pyarrow.timestamp("us", tz="Europe/Amsterdam") + + self.assertEqual(arrow_result_table.field(0).type, ts_type) + self.assertEqual(arrow_result_value, expected.timestamp() * 1000000) + + def _should_have_native_complex_types(self): + return pysql_has_version(">=", 2) and is_thrift_v5_plus(self.arguments) + + @skipUnless(pysql_supports_arrow(), 'arrow test needs arrow support') + def test_arrays_are_not_returned_as_strings_arrow(self): + if self._should_have_native_complex_types(): + with self.cursor() as cursor: + cursor.execute("SELECT array(1,2,3,4)") + arrow_df = cursor.fetchall_arrow() + + list_type = arrow_df.field(0).type + self.assertTrue(pyarrow.types.is_list(list_type)) + self.assertTrue(pyarrow.types.is_integer(list_type.value_type)) + + @skipUnless(pysql_supports_arrow(), 'arrow test needs arrow support') + def test_structs_are_not_returned_as_strings_arrow(self): + if self._should_have_native_complex_types(): + with self.cursor() as cursor: + cursor.execute("SELECT named_struct('foo', 42, 'bar', 'baz')") + arrow_df = cursor.fetchall_arrow() + + struct_type = arrow_df.field(0).type + self.assertTrue(pyarrow.types.is_struct(struct_type)) + + @skipUnless(pysql_supports_arrow(), 'arrow test needs arrow support') + def test_decimal_not_returned_as_strings_arrow(self): + if self._should_have_native_complex_types(): + with self.cursor() as cursor: + cursor.execute("SELECT 5E3BD") + arrow_df = cursor.fetchall_arrow() + + decimal_type = arrow_df.field(0).type + self.assertTrue(pyarrow.types.is_decimal(decimal_type)) + + +# use a RetrySuite to encapsulate these tests which we'll typically want to run together; however keep +# the 429/503 subsuites separate since they execute under different circumstances. +class PySQLRetryTestSuite: + class HTTP429Suite(Client429ResponseMixin, PySQLTestCase): + pass # Mixin covers all + + class HTTP503Suite(Client503ResponseMixin, PySQLTestCase): + # 503Response suite gets custom error here vs PyODBC + def test_retry_disabled(self): + self._test_retry_disabled_with_message("TEMPORARILY_UNAVAILABLE", OperationalError) + + +class PySQLUnityCatalogTestSuite(PySQLTestCase): + """Simple namespace tests that should be run against a unity-catalog-enabled cluster""" + + @skipIf(pysql_has_version('<', '2'), 'requires pysql v2') + def test_initial_namespace(self): + table_name = 'table_{uuid}'.format(uuid=str(uuid4()).replace('-', '_')) + with self.cursor() as cursor: + cursor.execute("USE CATALOG {}".format(self.arguments["catA"])) + cursor.execute("CREATE TABLE table_{} (col1 int)".format(table_name)) + with self.connection({ + "catalog": self.arguments["catA"], + "schema": table_name + }) as connection: + cursor = connection.cursor() + cursor.execute("select current_catalog()") + self.assertEqual(cursor.fetchone()[0], self.arguments["catA"]) + cursor.execute("select current_database()") + self.assertEqual(cursor.fetchone()[0], table_name) + + +def main(cli_args): + global get_args_from_env + get_args_from_env = True + print(f"Running tests with version: {sql.__version__}") + logging.getLogger("databricks.sql").setLevel(logging.INFO) + unittest.main(module=__file__, argv=sys.argv[0:1] + cli_args) + + +if __name__ == "__main__": + main(sys.argv[1:]) \ No newline at end of file diff --git a/tests/test_arrow_queue.py b/tests/unit/test_arrow_queue.py similarity index 100% rename from tests/test_arrow_queue.py rename to tests/unit/test_arrow_queue.py diff --git a/tests/test_fetches.py b/tests/unit/test_fetches.py similarity index 100% rename from tests/test_fetches.py rename to tests/unit/test_fetches.py diff --git a/tests/test_fetches_bench.py b/tests/unit/test_fetches_bench.py similarity index 100% rename from tests/test_fetches_bench.py rename to tests/unit/test_fetches_bench.py diff --git a/tests/test_thrift_backend.py b/tests/unit/test_thrift_backend.py similarity index 100% rename from tests/test_thrift_backend.py rename to tests/unit/test_thrift_backend.py diff --git a/tests/tests.py b/tests/unit/tests.py similarity index 100% rename from tests/tests.py rename to tests/unit/tests.py