diff --git a/dbsqlcli/dbsqlclirc b/dbsqlcli/dbsqlclirc index 28b735e..4efe0f5 100644 --- a/dbsqlcli/dbsqlclirc +++ b/dbsqlcli/dbsqlclirc @@ -43,6 +43,7 @@ syntax_style = default key_bindings = emacs # DBSQL prompt +# \c - Catalog name # \d - Database name # \h - Hostname # \D - The full current date @@ -51,7 +52,7 @@ key_bindings = emacs # \P - AM/PM # \R - The current time, in 24-hour military time (0–23) # \s - Seconds of the current time -prompt = '\h:\d> ' +prompt = '\h:\c.\d> ' prompt_continuation = '-> ' # enable pager on startup @@ -96,4 +97,4 @@ output.even-row = "" # [credentials] # host_name = "" # http_path = "" -# access_token = "" \ No newline at end of file +# access_token = "" diff --git a/dbsqlcli/main.py b/dbsqlcli/main.py index 510321e..7e67261 100644 --- a/dbsqlcli/main.py +++ b/dbsqlcli/main.py @@ -213,7 +213,8 @@ def change_db(self, arg, **_): None, None, None, - 'You are now connected to database "%s"' % self.sqlexecute.database, + 'You are now connected to database "%s.%s"' + % (self.sqlexecute.catalog, self.sqlexecute.database), ) def change_prompt_format(self, arg, **_): @@ -599,6 +600,7 @@ def get_prompt(self, string): string = string.replace( "\\h", sqlexecute.hostname.replace(".cloud.databricks.com", "") ) + string = string.replace("\\c", sqlexecute.catalog or "(none)") string = string.replace("\\d", sqlexecute.database or "(none)") string = string.replace("\\n", "\n") string = string.replace("\\D", now.strftime("%a %b %d %H:%M:%S %Y")) @@ -702,6 +704,7 @@ def cli( Examples: - dbsqlcli - dbsqlcli my_database + - dbsqlcli my_catalog.my_database """ if (clirc == DBSQLCLIRC) and (not os.path.exists(os.path.expanduser(clirc))): err_msg = ( diff --git a/dbsqlcli/packages/special/dbcommands.py b/dbsqlcli/packages/special/dbcommands.py index f29ca78..9d8e0fd 100644 --- a/dbsqlcli/packages/special/dbcommands.py +++ b/dbsqlcli/packages/special/dbcommands.py @@ -37,9 +37,10 @@ def list_tables(cur, arg=None, arg_type=PARSED_QUERY, verbose=False): "\\l", "\\l", "List databases.", arg_type=RAW_QUERY, case_sensitive=True ) def list_databases(cur, **_): - _databases = cur.schemas().fetchall() - if _databases: - headers = [x[0] for x in _databases] - return [(None, _databases, headers, "")] - else: - return [(None, None, None, "")] + databases = cur.schemas().fetchall() + if databases: + headers = [ + field.title().removeprefix("Table_") for field in databases[0].__fields__ + ] + return [(None, databases, headers, "")] + return [(None, None, None, "")] diff --git a/dbsqlcli/sqlexecute.py b/dbsqlcli/sqlexecute.py index bf11ece..477985a 100644 --- a/dbsqlcli/sqlexecute.py +++ b/dbsqlcli/sqlexecute.py @@ -38,18 +38,32 @@ def read(self, hostname: str) -> Optional[OAuthToken]: class SQLExecute(object): DATABASES_QUERY = "SHOW DATABASES" - def __init__(self, hostname, http_path, access_token, database, auth_type=None): + def __init__( + self, hostname, http_path, access_token, database="default", auth_type=None + ): self.hostname = hostname self.http_path = http_path self.access_token = access_token - self.database = database or "default" self.auth_type = auth_type - - self.connect(database=self.database) + self._set_catalog_database(database) + self.connect() + + def _set_catalog_database(self, database): + """Sets the catalog and database name if a single dot is supplied""" + if database.count(".") == 1: + component = database.split(".") + self.catalog = component[0] + self.database = component[1] + else: + self.catalog = "hive_metastore" + self.database = database def connect(self, database=None): self.close_connection() + if database: + self._set_catalog_database(database) + oauth_params = {} if self.auth_type == AuthType.DATABRICKS_OAUTH.value: oauth_params = { @@ -63,20 +77,14 @@ def connect(self, database=None): server_hostname=self.hostname, http_path=self.http_path, access_token=self.access_token, - schema=database, + catalog=self.catalog, + schema=self.database, _user_agent_entry=USER_AGENT_STRING, **oauth_params, ) - self.database = database or self.database - self.conn = conn - def reconnect(self): - - self.close_connection() - self.connect(database=self.database) - def close_connection(self): """Close any open connection and remove the `conn` attribute""" @@ -138,7 +146,7 @@ def run(self, statement): f"SQL Gateway was timed out. Attempting to reconnect. Attempt {attempts+1}. Error: {e}" ) attempts += 1 - self.reconnect() + self.connect() def get_result(self, cursor): """Get the current result's data from the cursor.""" diff --git a/pyproject.toml b/pyproject.toml index caf9799..04a9bd6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "databricks-sql-cli" -version = "0.1.x" +version = "0.1.5" description = "A DBCLI client for Databricks SQL" authors = ["Databricks SQL CLI Maintainers "] packages = [{include = "dbsqlcli"}]