From 9bf2327b9ed59015182878b775654892a9306618 Mon Sep 17 00:00:00 2001 From: Levko Kravets Date: Thu, 8 Jun 2023 15:58:25 +0300 Subject: [PATCH 1/8] [PECO-728] Add OAuth support Signed-off-by: Levko Kravets --- .../auth/DatabricksOAuth/AuthorizationCode.ts | 172 ++++++++++++++++++ .../auth/DatabricksOAuth/OAuthManager.ts | 102 +++++++++++ .../auth/DatabricksOAuth/OAuthPersistence.ts | 7 + .../auth/DatabricksOAuth/OAuthToken.ts | 38 ++++ lib/connection/auth/DatabricksOAuth/index.ts | 74 ++++++++ .../auth/PlainHttpAuthentication.ts | 14 +- .../auth/helpers/SaslPackageFactory.ts | 17 -- package-lock.json | 135 ++++++++++++-- package.json | 2 + 9 files changed, 520 insertions(+), 41 deletions(-) create mode 100644 lib/connection/auth/DatabricksOAuth/AuthorizationCode.ts create mode 100644 lib/connection/auth/DatabricksOAuth/OAuthManager.ts create mode 100644 lib/connection/auth/DatabricksOAuth/OAuthPersistence.ts create mode 100644 lib/connection/auth/DatabricksOAuth/OAuthToken.ts create mode 100644 lib/connection/auth/DatabricksOAuth/index.ts delete mode 100644 lib/connection/auth/helpers/SaslPackageFactory.ts diff --git a/lib/connection/auth/DatabricksOAuth/AuthorizationCode.ts b/lib/connection/auth/DatabricksOAuth/AuthorizationCode.ts new file mode 100644 index 00000000..862a2162 --- /dev/null +++ b/lib/connection/auth/DatabricksOAuth/AuthorizationCode.ts @@ -0,0 +1,172 @@ +import http, { Server, IncomingMessage, ServerResponse } from 'http'; +import { BaseClient, generators } from 'openid-client'; +import open from 'open'; +import IDBSQLLogger, { LogLevel } from '../../../contracts/IDBSQLLogger'; + +export interface AuthorizationCodeOptions { + client: BaseClient; + ports: Array; + logger?: IDBSQLLogger; +} + +const scopeDelimiter = ' '; + +async function startServer( + host: string, + port: number, + requestHandler: (req: IncomingMessage, res: ServerResponse) => void, +): Promise { + const server = http.createServer(requestHandler); + + return new Promise((resolve, reject) => { + const errorListener = (error: Error) => { + server.off('error', errorListener); + reject(error); + }; + + server.on('error', errorListener); + server.listen(port, host, () => { + server.off('error', errorListener); + resolve(server); + }); + }); +} + +async function stopServer(server: Server): Promise { + if (!server.listening) { + return; + } + + return new Promise((resolve, reject) => { + const errorListener = (error: Error) => { + server.off('error', errorListener); + reject(error); + }; + + server.on('error', errorListener); + server.close(() => { + server.off('error', errorListener); + resolve(); + }); + }); +} + +export interface AuthorizationCodeFetchResult { + code: string; + verifier: string; + redirectUri: string; +} + +export default class AuthorizationCode { + private readonly client: BaseClient; + + private readonly host: string = 'localhost'; + + private readonly ports: Array; + + private readonly logger?: IDBSQLLogger; + + constructor(options: AuthorizationCodeOptions) { + this.client = options.client; + this.ports = options.ports; + this.logger = options.logger; + } + + public async fetch(scopes: Array): Promise { + const verifierString = generators.codeVerifier(32); + const challengeString = generators.codeChallenge(verifierString); + const state = generators.state(16); + + let code: string | undefined; + + const server = await this.startServer((req, res) => { + const params = this.client.callbackParams(req); + if (params.state === state) { + code = params.code; + res.writeHead(200); + res.end(this.renderCallbackResponse()); + server.stop(); + } else { + res.writeHead(404); + res.end(); + } + }); + + const redirectUri = `http://${server.host}:${server.port}/`; + const authUrl = this.client.authorizationUrl({ + response_type: 'code', + response_mode: 'query', + scope: scopes.join(scopeDelimiter), + code_challenge: challengeString, + code_challenge_method: 'S256', + state, + redirect_uri: redirectUri, + }); + + await open(authUrl); + await server.stopped(); + + if (!code) { + throw new Error(`No path parameters were returned to the callback at ${redirectUri}`); + } + + return { code, verifier: verifierString, redirectUri }; + } + + private async startServer(requestHandler: (req: IncomingMessage, res: ServerResponse) => void) { + for (const port of this.ports) { + const host = this.host; // eslint-disable-line prefer-destructuring + try { + const server = await startServer(host, port, requestHandler); // eslint-disable-line no-await-in-loop + this.logger?.log(LogLevel.info, `Listening for OAuth authorization callback at ${host}:${port}`); + + let resolveStopped: () => void; + let rejectStopped: (reason?: any) => void; + const stoppedPromise = new Promise((resolve, reject) => { + resolveStopped = resolve; + rejectStopped = reject; + }); + + return { + host, + port, + server, + stop: () => stopServer(server).then(resolveStopped).catch(rejectStopped), + stopped: () => stoppedPromise, + }; + } catch (error) { + // if port already in use - try another one, otherwise re-throw an exception + if (error instanceof Error && 'code' in error && error.code === 'EADDRINUSE') { + this.logger?.log(LogLevel.debug, `Failed to start server at ${host}:${port}: ${error.code}`); + } else { + throw error; + } + } + } + + throw new Error('Failed to start server: all ports are in use'); + } + + private renderCallbackResponse(): string { + const applicationName = 'Databricks Sql Connector'; + + return ` + + Close this Tab + + + +

Please close this tab.

+

+ The ${applicationName} received a response. You may close this tab. +

+ +`; + } +} diff --git a/lib/connection/auth/DatabricksOAuth/OAuthManager.ts b/lib/connection/auth/DatabricksOAuth/OAuthManager.ts new file mode 100644 index 00000000..e0b7ff66 --- /dev/null +++ b/lib/connection/auth/DatabricksOAuth/OAuthManager.ts @@ -0,0 +1,102 @@ +import { Issuer, BaseClient } from 'openid-client'; +import HiveDriverError from '../../../errors/HiveDriverError'; +import IDBSQLLogger, { LogLevel } from '../../../contracts/IDBSQLLogger'; +import OAuthToken from './OAuthToken'; +import AuthorizationCode from './AuthorizationCode'; + +const oidcConfigPath = 'oidc/.well-known/oauth-authorization-server'; + +export interface OAuthManagerOptions { + host: string; + callbackPorts: Array; + clientId: string; + logger?: IDBSQLLogger; +} + +export default class OAuthManager { + private readonly options: OAuthManagerOptions; + + private readonly logger?: IDBSQLLogger; + + private issuer?: Issuer; + + private client?: BaseClient; + + constructor(options: OAuthManagerOptions) { + this.options = options; + this.logger = options.logger; + } + + private async getClient(): Promise { + if (!this.issuer) { + const { host } = this.options; + const schema = host.startsWith('https://') ? '' : 'https://'; + const trailingSlash = host.endsWith('/') ? '' : '/'; + this.issuer = await Issuer.discover(`${schema}${host}${trailingSlash}${oidcConfigPath}`); + } + + if (!this.client) { + this.client = new this.issuer.Client({ + client_id: this.options.clientId, + token_endpoint_auth_method: 'none', + }); + } + + return this.client; + } + + public async refreshAccessToken(token: OAuthToken): Promise { + try { + if (!token.hasExpired) { + // The access token is fine. Just return it. + return token; + } + } catch (error) { + this.logger?.log(LogLevel.error, `${error}`); + throw error; + } + + if (!token.refreshToken) { + const message = `OAuth access token expired on ${token.expirationTime}.`; + this.logger?.log(LogLevel.error, message); + throw new HiveDriverError(message); + } + + // Try to refresh using the refresh token + this.logger?.log( + LogLevel.debug, + `Attempting to refresh OAuth access token that expired on ${token.expirationTime}`, + ); + + const client = await this.getClient(); + const { access_token: accessToken, refresh_token: refreshToken } = await client.refresh(token.refreshToken); + if (!accessToken || !refreshToken) { + throw new Error('Failed to refresh token: invalid response'); + } + return new OAuthToken(accessToken, refreshToken); + } + + public async getToken(scopes: Array): Promise { + const client = await this.getClient(); + const authCode = new AuthorizationCode({ + client, + ports: this.options.callbackPorts, + logger: this.logger, + }); + + const { code, verifier, redirectUri } = await authCode.fetch(scopes); + + const { access_token: accessToken, refresh_token: refreshToken } = await client.grant({ + grant_type: 'authorization_code', + code, + code_verifier: verifier, + redirect_uri: redirectUri, + }); + + if (!accessToken) { + throw new Error('Failed to fetch access token'); + } + + return new OAuthToken(accessToken, refreshToken); + } +} diff --git a/lib/connection/auth/DatabricksOAuth/OAuthPersistence.ts b/lib/connection/auth/DatabricksOAuth/OAuthPersistence.ts new file mode 100644 index 00000000..c60a9f2f --- /dev/null +++ b/lib/connection/auth/DatabricksOAuth/OAuthPersistence.ts @@ -0,0 +1,7 @@ +import OAuthToken from './OAuthToken'; + +export default interface OAuthPersistence { + persist(host: string, token: OAuthToken): Promise; + + read(host: string): Promise; +} diff --git a/lib/connection/auth/DatabricksOAuth/OAuthToken.ts b/lib/connection/auth/DatabricksOAuth/OAuthToken.ts new file mode 100644 index 00000000..2c1fee53 --- /dev/null +++ b/lib/connection/auth/DatabricksOAuth/OAuthToken.ts @@ -0,0 +1,38 @@ +export default class OAuthToken { + private readonly _accessToken: string; + + private readonly _refreshToken?: string; + + private _expirationTime?: number; + + constructor(accessToken: string, refreshToken?: string) { + this._accessToken = accessToken; + this._refreshToken = refreshToken; + } + + get accessToken(): string { + return this._accessToken; + } + + get refreshToken(): string | undefined { + return this._refreshToken; + } + + get expirationTime(): number { + if (this._expirationTime === undefined) { + const accessTokenPayload = Buffer.from(this._accessToken.split('.')[1], 'base64').toString('utf8'); + const decoded = JSON.parse(accessTokenPayload); + this._expirationTime = Number(decoded.exp); + } + return this._expirationTime; + } + + get hasExpired(): boolean { + // This token has already been verified, and we are just parsing it. + // If it has been tampered with, it will be rejected on the server side. + // This avoids having to fetch the public key from the issuer and perform + // an unnecessary signature verification. + const now = Math.floor(Date.now() / 1000); // convert it to seconds + return this.expirationTime <= now; + } +} diff --git a/lib/connection/auth/DatabricksOAuth/index.ts b/lib/connection/auth/DatabricksOAuth/index.ts new file mode 100644 index 00000000..c56e6848 --- /dev/null +++ b/lib/connection/auth/DatabricksOAuth/index.ts @@ -0,0 +1,74 @@ +import IAuthentication from '../../contracts/IAuthentication'; +import ITransport from '../../contracts/ITransport'; +import IDBSQLLogger from '../../../contracts/IDBSQLLogger'; +import { AuthOptions } from '../../types/AuthOptions'; +import OAuthPersistence from './OAuthPersistence'; +import OAuthManager from './OAuthManager'; + +interface DatabricksOAuthOptions extends AuthOptions { + host: string; + redirectPorts?: Array; + clientId?: string; + scopes?: Array; + logger?: IDBSQLLogger; + persistence?: OAuthPersistence; + headers?: object; +} + +const defaultOAuthOptions = { + clientId: 'databricks-sql-python', + redirectPorts: [8020, 8021, 8022, 8023, 8024, 8025], + scopes: ['sql', 'offline_access'], +} satisfies Partial; + +export default class DatabricksOAuth implements IAuthentication { + private readonly host: string; + + private readonly redirectPorts: Array; + + private readonly clientId: string; + + private readonly scopes: Array; + + private readonly logger?: IDBSQLLogger; + + private readonly persistence?: OAuthPersistence; + + private readonly headers?: object; + + private readonly manager: OAuthManager; + + constructor(options: DatabricksOAuthOptions) { + this.host = options.host; + this.redirectPorts = options.redirectPorts || defaultOAuthOptions.redirectPorts; + this.clientId = options.clientId || defaultOAuthOptions.clientId; + this.scopes = options.scopes || defaultOAuthOptions.scopes; + this.logger = options.logger; + this.persistence = options.persistence; + this.headers = options.headers; + + this.manager = new OAuthManager({ + host: this.host, + callbackPorts: this.redirectPorts, + clientId: this.clientId, + logger: this.logger, + }); + } + + async authenticate(transport: ITransport): Promise { + let token = await this.persistence?.read(this.host); + if (!token) { + token = await this.manager.getToken(this.scopes); + } + + token = await this.manager.refreshAccessToken(token); + await this.persistence?.persist(this.host, token); + + transport.setOptions('headers', { + ...this.headers, + Authorization: `Bearer ${token.accessToken}`, + }); + + return transport; + } +} diff --git a/lib/connection/auth/PlainHttpAuthentication.ts b/lib/connection/auth/PlainHttpAuthentication.ts index 1b1a1960..fd755cd0 100644 --- a/lib/connection/auth/PlainHttpAuthentication.ts +++ b/lib/connection/auth/PlainHttpAuthentication.ts @@ -2,16 +2,16 @@ import IAuthentication from '../contracts/IAuthentication'; import ITransport from '../contracts/ITransport'; import { AuthOptions } from '../types/AuthOptions'; -type HttpAuthOptions = AuthOptions & { +interface HttpAuthOptions extends AuthOptions { headers?: object; -}; +} export default class PlainHttpAuthentication implements IAuthentication { - private username: string; + private readonly username: string; - private password: string; + private readonly password: string; - private headers: object; + private readonly headers: object; constructor(options: HttpAuthOptions) { this.username = options?.username || 'anonymous'; @@ -19,13 +19,13 @@ export default class PlainHttpAuthentication implements IAuthentication { this.headers = options?.headers || {}; } - authenticate(transport: ITransport): Promise { + async authenticate(transport: ITransport): Promise { transport.setOptions('headers', { ...this.headers, Authorization: this.getToken(this.username, this.password), }); - return Promise.resolve(transport); + return transport; } private getToken(username: string, password: string): string { diff --git a/lib/connection/auth/helpers/SaslPackageFactory.ts b/lib/connection/auth/helpers/SaslPackageFactory.ts deleted file mode 100644 index 4d6c3bc1..00000000 --- a/lib/connection/auth/helpers/SaslPackageFactory.ts +++ /dev/null @@ -1,17 +0,0 @@ -export enum StatusCode { - START = 1, - OK = 2, - BAD = 3, - ERROR = 4, - COMPLETE = 5, -} - -export class SaslPackageFactory { - static create(status: StatusCode, body: Buffer): Buffer { - const bodyLength = Buffer.alloc(4); - - bodyLength.writeUInt32BE(body.length, 0); - - return Buffer.concat([Buffer.from([status]), bodyLength, body]); - } -} diff --git a/package-lock.json b/package-lock.json index 96efb05c..37c437f7 100644 --- a/package-lock.json +++ b/package-lock.json @@ -13,6 +13,8 @@ "apache-arrow": "^10.0.1", "commander": "^9.3.0", "node-int64": "^0.4.0", + "open": "^8.4.2", + "openid-client": "^5.4.2", "patch-package": "^7.0.0", "thrift": "^0.16.0", "uuid": "^9.0.0", @@ -1868,6 +1870,14 @@ "node": ">=8" } }, + "node_modules/define-lazy-prop": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/define-lazy-prop/-/define-lazy-prop-2.0.0.tgz", + "integrity": "sha512-Ds09qNh8yw3khSjiJjiUInaGX9xlqZDY7JVryGxdxV7NPeuqQfplOpQ66yJFZut3jLa5zOwkXw1g9EI2uKh4Og==", + "engines": { + "node": ">=8" + } + }, "node_modules/define-properties": { "version": "1.1.4", "resolved": "https://registry.npmjs.org/define-properties/-/define-properties-1.1.4.tgz", @@ -3534,6 +3544,14 @@ "node": ">=8" } }, + "node_modules/jose": { + "version": "4.14.4", + "resolved": "https://registry.npmjs.org/jose/-/jose-4.14.4.tgz", + "integrity": "sha512-j8GhLiKmUAh+dsFXlX1aJCbt5KMibuKb+d7j1JaOJG6s2UjX1PQlW+OKB/sD4a/5ZYF4RcmYmLSndOoU3Lt/3g==", + "funding": { + "url": "https://github.com/sponsors/panva" + } + }, "node_modules/js-tokens": { "version": "4.0.0", "resolved": "https://registry.npmjs.org/js-tokens/-/js-tokens-4.0.0.tgz", @@ -3758,7 +3776,6 @@ "version": "6.0.0", "resolved": "https://registry.npmjs.org/lru-cache/-/lru-cache-6.0.0.tgz", "integrity": "sha512-Jo6dJ04CmSjuznwJSS3pUeWmd/H0ffTlkXXgwZi+eq1UCmqQwCh+eLsYOYCwY991i2Fah4h1BEMCx4qThGbsiA==", - "dev": true, "dependencies": { "yallist": "^4.0.0" }, @@ -4160,6 +4177,14 @@ "node": ">=0.10.0" } }, + "node_modules/object-hash": { + "version": "2.2.0", + "resolved": "https://registry.npmjs.org/object-hash/-/object-hash-2.2.0.tgz", + "integrity": "sha512-gScRMn0bS5fH+IuwyIFgnh9zBdo4DV+6GhygmWM9HyNJSgS0hScp1f5vjtm7oIIOiT9trXrShAkLFSc2IqKNgw==", + "engines": { + "node": ">= 6" + } + }, "node_modules/object-inspect": { "version": "1.12.2", "resolved": "https://registry.npmjs.org/object-inspect/-/object-inspect-1.12.2.tgz", @@ -4257,6 +4282,14 @@ "url": "https://github.com/sponsors/ljharb" } }, + "node_modules/oidc-token-hash": { + "version": "5.0.3", + "resolved": "https://registry.npmjs.org/oidc-token-hash/-/oidc-token-hash-5.0.3.tgz", + "integrity": "sha512-IF4PcGgzAr6XXSff26Sk/+P4KZFJVuHAJZj3wgO3vX2bMdNVp/QXTP3P7CEm9V1IdG8lDLY3HhiqpsE/nOwpPw==", + "engines": { + "node": "^10.13.0 || >=12.0.0" + } + }, "node_modules/once": { "version": "1.4.0", "resolved": "https://registry.npmjs.org/once/-/once-1.4.0.tgz", @@ -4274,20 +4307,35 @@ } }, "node_modules/open": { - "version": "7.4.2", - "resolved": "https://registry.npmjs.org/open/-/open-7.4.2.tgz", - "integrity": "sha512-MVHddDVweXZF3awtlAS+6pgKLlm/JgxZ90+/NBurBoQctVOOB/zDdVjcyPzQ+0laDGbsWgrRkflI65sQeOgT9Q==", + "version": "8.4.2", + "resolved": "https://registry.npmjs.org/open/-/open-8.4.2.tgz", + "integrity": "sha512-7x81NCL719oNbsq/3mh+hVrAWmFuEYUqrq/Iw3kUzH8ReypT9QQ0BLoJS7/G9k6N81XjW4qHWtjWwe/9eLy1EQ==", "dependencies": { - "is-docker": "^2.0.0", - "is-wsl": "^2.1.1" + "define-lazy-prop": "^2.0.0", + "is-docker": "^2.1.1", + "is-wsl": "^2.2.0" }, "engines": { - "node": ">=8" + "node": ">=12" }, "funding": { "url": "https://github.com/sponsors/sindresorhus" } }, + "node_modules/openid-client": { + "version": "5.4.2", + "resolved": "https://registry.npmjs.org/openid-client/-/openid-client-5.4.2.tgz", + "integrity": "sha512-lIhsdPvJ2RneBm3nGBBhQchpe3Uka//xf7WPHTIglery8gnckvW7Bd9IaQzekzXJvWthCMyi/xVEyGW0RFPytw==", + "dependencies": { + "jose": "^4.14.1", + "lru-cache": "^6.0.0", + "object-hash": "^2.2.0", + "oidc-token-hash": "^5.0.3" + }, + "funding": { + "url": "https://github.com/sponsors/panva" + } + }, "node_modules/optionator": { "version": "0.9.1", "resolved": "https://registry.npmjs.org/optionator/-/optionator-0.9.1.tgz", @@ -4430,6 +4478,21 @@ "npm": ">5" } }, + "node_modules/patch-package/node_modules/open": { + "version": "7.4.2", + "resolved": "https://registry.npmjs.org/open/-/open-7.4.2.tgz", + "integrity": "sha512-MVHddDVweXZF3awtlAS+6pgKLlm/JgxZ90+/NBurBoQctVOOB/zDdVjcyPzQ+0laDGbsWgrRkflI65sQeOgT9Q==", + "dependencies": { + "is-docker": "^2.0.0", + "is-wsl": "^2.1.1" + }, + "engines": { + "node": ">=8" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, "node_modules/patch-package/node_modules/rimraf": { "version": "2.7.1", "resolved": "https://registry.npmjs.org/rimraf/-/rimraf-2.7.1.tgz", @@ -5640,8 +5703,7 @@ "node_modules/yallist": { "version": "4.0.0", "resolved": "https://registry.npmjs.org/yallist/-/yallist-4.0.0.tgz", - "integrity": "sha512-3wdGidZyq5PB084XLES5TpOSRA3wjXAlIWMhum2kRcv/41Sn2emQ0dycQW4uZXLejwKvg6EsvbdlVL+FYEct7A==", - "dev": true + "integrity": "sha512-3wdGidZyq5PB084XLES5TpOSRA3wjXAlIWMhum2kRcv/41Sn2emQ0dycQW4uZXLejwKvg6EsvbdlVL+FYEct7A==" }, "node_modules/yaml": { "version": "2.2.2", @@ -7104,6 +7166,11 @@ "strip-bom": "^4.0.0" } }, + "define-lazy-prop": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/define-lazy-prop/-/define-lazy-prop-2.0.0.tgz", + "integrity": "sha512-Ds09qNh8yw3khSjiJjiUInaGX9xlqZDY7JVryGxdxV7NPeuqQfplOpQ66yJFZut3jLa5zOwkXw1g9EI2uKh4Og==" + }, "define-properties": { "version": "1.1.4", "resolved": "https://registry.npmjs.org/define-properties/-/define-properties-1.1.4.tgz", @@ -8323,6 +8390,11 @@ "istanbul-lib-report": "^3.0.0" } }, + "jose": { + "version": "4.14.4", + "resolved": "https://registry.npmjs.org/jose/-/jose-4.14.4.tgz", + "integrity": "sha512-j8GhLiKmUAh+dsFXlX1aJCbt5KMibuKb+d7j1JaOJG6s2UjX1PQlW+OKB/sD4a/5ZYF4RcmYmLSndOoU3Lt/3g==" + }, "js-tokens": { "version": "4.0.0", "resolved": "https://registry.npmjs.org/js-tokens/-/js-tokens-4.0.0.tgz", @@ -8506,7 +8578,6 @@ "version": "6.0.0", "resolved": "https://registry.npmjs.org/lru-cache/-/lru-cache-6.0.0.tgz", "integrity": "sha512-Jo6dJ04CmSjuznwJSS3pUeWmd/H0ffTlkXXgwZi+eq1UCmqQwCh+eLsYOYCwY991i2Fah4h1BEMCx4qThGbsiA==", - "dev": true, "requires": { "yallist": "^4.0.0" } @@ -8825,6 +8896,11 @@ "integrity": "sha512-rJgTQnkUnH1sFw8yT6VSU3zD3sWmu6sZhIseY8VX+GRu3P6F7Fu+JNDoXfklElbLJSnc3FUQHVe4cU5hj+BcUg==", "dev": true }, + "object-hash": { + "version": "2.2.0", + "resolved": "https://registry.npmjs.org/object-hash/-/object-hash-2.2.0.tgz", + "integrity": "sha512-gScRMn0bS5fH+IuwyIFgnh9zBdo4DV+6GhygmWM9HyNJSgS0hScp1f5vjtm7oIIOiT9trXrShAkLFSc2IqKNgw==" + }, "object-inspect": { "version": "1.12.2", "resolved": "https://registry.npmjs.org/object-inspect/-/object-inspect-1.12.2.tgz", @@ -8892,6 +8968,11 @@ "es-abstract": "^1.19.1" } }, + "oidc-token-hash": { + "version": "5.0.3", + "resolved": "https://registry.npmjs.org/oidc-token-hash/-/oidc-token-hash-5.0.3.tgz", + "integrity": "sha512-IF4PcGgzAr6XXSff26Sk/+P4KZFJVuHAJZj3wgO3vX2bMdNVp/QXTP3P7CEm9V1IdG8lDLY3HhiqpsE/nOwpPw==" + }, "once": { "version": "1.4.0", "resolved": "https://registry.npmjs.org/once/-/once-1.4.0.tgz", @@ -8909,12 +8990,24 @@ } }, "open": { - "version": "7.4.2", - "resolved": "https://registry.npmjs.org/open/-/open-7.4.2.tgz", - "integrity": "sha512-MVHddDVweXZF3awtlAS+6pgKLlm/JgxZ90+/NBurBoQctVOOB/zDdVjcyPzQ+0laDGbsWgrRkflI65sQeOgT9Q==", + "version": "8.4.2", + "resolved": "https://registry.npmjs.org/open/-/open-8.4.2.tgz", + "integrity": "sha512-7x81NCL719oNbsq/3mh+hVrAWmFuEYUqrq/Iw3kUzH8ReypT9QQ0BLoJS7/G9k6N81XjW4qHWtjWwe/9eLy1EQ==", "requires": { - "is-docker": "^2.0.0", - "is-wsl": "^2.1.1" + "define-lazy-prop": "^2.0.0", + "is-docker": "^2.1.1", + "is-wsl": "^2.2.0" + } + }, + "openid-client": { + "version": "5.4.2", + "resolved": "https://registry.npmjs.org/openid-client/-/openid-client-5.4.2.tgz", + "integrity": "sha512-lIhsdPvJ2RneBm3nGBBhQchpe3Uka//xf7WPHTIglery8gnckvW7Bd9IaQzekzXJvWthCMyi/xVEyGW0RFPytw==", + "requires": { + "jose": "^4.14.1", + "lru-cache": "^6.0.0", + "object-hash": "^2.2.0", + "oidc-token-hash": "^5.0.3" } }, "optionator": { @@ -9019,6 +9112,15 @@ "yaml": "^2.2.2" }, "dependencies": { + "open": { + "version": "7.4.2", + "resolved": "https://registry.npmjs.org/open/-/open-7.4.2.tgz", + "integrity": "sha512-MVHddDVweXZF3awtlAS+6pgKLlm/JgxZ90+/NBurBoQctVOOB/zDdVjcyPzQ+0laDGbsWgrRkflI65sQeOgT9Q==", + "requires": { + "is-docker": "^2.0.0", + "is-wsl": "^2.1.1" + } + }, "rimraf": { "version": "2.7.1", "resolved": "https://registry.npmjs.org/rimraf/-/rimraf-2.7.1.tgz", @@ -9899,8 +10001,7 @@ "yallist": { "version": "4.0.0", "resolved": "https://registry.npmjs.org/yallist/-/yallist-4.0.0.tgz", - "integrity": "sha512-3wdGidZyq5PB084XLES5TpOSRA3wjXAlIWMhum2kRcv/41Sn2emQ0dycQW4uZXLejwKvg6EsvbdlVL+FYEct7A==", - "dev": true + "integrity": "sha512-3wdGidZyq5PB084XLES5TpOSRA3wjXAlIWMhum2kRcv/41Sn2emQ0dycQW4uZXLejwKvg6EsvbdlVL+FYEct7A==" }, "yaml": { "version": "2.2.2", diff --git a/package.json b/package.json index 862c872f..92ac273f 100644 --- a/package.json +++ b/package.json @@ -74,6 +74,8 @@ "apache-arrow": "^10.0.1", "commander": "^9.3.0", "node-int64": "^0.4.0", + "open": "^8.4.2", + "openid-client": "^5.4.2", "patch-package": "^7.0.0", "thrift": "^0.16.0", "uuid": "^9.0.0", From f96e636a928a94e7696729816cef4d3d33625b93 Mon Sep 17 00:00:00 2001 From: Levko Kravets Date: Thu, 8 Jun 2023 17:50:52 +0300 Subject: [PATCH 2/8] Cleanup DBSQLClient code; remove redundant and no longer needed NoSaslAuthentication Signed-off-by: Levko Kravets --- lib/DBSQLClient.ts | 9 +++------ lib/connection/auth/NoSaslAuthentication.ts | 11 ----------- lib/index.ts | 2 -- .../auth/NoSaslAuthentication.test.js | 18 ------------------ 4 files changed, 3 insertions(+), 37 deletions(-) delete mode 100644 lib/connection/auth/NoSaslAuthentication.ts delete mode 100644 tests/unit/connection/auth/NoSaslAuthentication.test.js diff --git a/lib/DBSQLClient.ts b/lib/DBSQLClient.ts index 04b9d1a8..8ca62b07 100644 --- a/lib/DBSQLClient.ts +++ b/lib/DBSQLClient.ts @@ -11,13 +11,13 @@ import IDBSQLSession from './contracts/IDBSQLSession'; import IThriftConnection from './connection/contracts/IThriftConnection'; import IConnectionProvider from './connection/contracts/IConnectionProvider'; import IAuthentication from './connection/contracts/IAuthentication'; -import NoSaslAuthentication from './connection/auth/NoSaslAuthentication'; import HttpConnection from './connection/connections/HttpConnection'; import IConnectionOptions from './connection/contracts/IConnectionOptions'; import Status from './dto/Status'; import HiveDriverError from './errors/HiveDriverError'; import { buildUserAgentString, definedOrError } from './utils'; import PlainHttpAuthentication from './connection/auth/PlainHttpAuthentication'; +import DatabricksOAuth from './connection/auth/DatabricksOAuth'; import IDBSQLLogger, { LogLevel } from './contracts/IDBSQLLogger'; import DBSQLLogger from './DBSQLLogger'; @@ -48,8 +48,6 @@ export default class DBSQLClient extends EventEmitter implements IDBSQLClient { private connectionProvider: IConnectionProvider; - private authProvider: IAuthentication; - private readonly logger: IDBSQLLogger; private readonly thrift = thrift; @@ -57,7 +55,6 @@ export default class DBSQLClient extends EventEmitter implements IDBSQLClient { constructor(options?: ClientOptions) { super(); this.connectionProvider = new HttpConnection(); - this.authProvider = new NoSaslAuthentication(); this.logger = options?.logger || new DBSQLLogger(); this.client = null; this.connection = null; @@ -87,7 +84,7 @@ export default class DBSQLClient extends EventEmitter implements IDBSQLClient { * const session = client.connect({host, path, token}); */ public async connect(options: ConnectionOptions, authProvider?: IAuthentication): Promise { - this.authProvider = + authProvider = authProvider || new PlainHttpAuthentication({ username: 'token', @@ -97,7 +94,7 @@ export default class DBSQLClient extends EventEmitter implements IDBSQLClient { }, }); - this.connection = await this.connectionProvider.connect(this.getConnectionOptions(options), this.authProvider); + this.connection = await this.connectionProvider.connect(this.getConnectionOptions(options), authProvider); this.client = this.thrift.createClient(TCLIService, this.connection.getConnection()); diff --git a/lib/connection/auth/NoSaslAuthentication.ts b/lib/connection/auth/NoSaslAuthentication.ts deleted file mode 100644 index c8d3795d..00000000 --- a/lib/connection/auth/NoSaslAuthentication.ts +++ /dev/null @@ -1,11 +0,0 @@ -import thrift from 'thrift'; -import IAuthentication from '../contracts/IAuthentication'; -import ITransport from '../contracts/ITransport'; - -export default class NoSaslAuthentication implements IAuthentication { - authenticate(transport: ITransport): Promise { - transport.setOptions('transport', thrift.TBufferedTransport); - - return Promise.resolve(transport); - } -} diff --git a/lib/index.ts b/lib/index.ts index d24345f8..bc12120e 100644 --- a/lib/index.ts +++ b/lib/index.ts @@ -4,14 +4,12 @@ import TCLIService_types from '../thrift/TCLIService_types'; import DBSQLClient from './DBSQLClient'; import DBSQLSession from './DBSQLSession'; import DBSQLLogger from './DBSQLLogger'; -import NoSaslAuthentication from './connection/auth/NoSaslAuthentication'; import PlainHttpAuthentication from './connection/auth/PlainHttpAuthentication'; import HttpConnection from './connection/connections/HttpConnection'; import { formatProgress } from './utils'; import { LogLevel } from './contracts/IDBSQLLogger'; export const auth = { - NoSaslAuthentication, PlainHttpAuthentication, }; diff --git a/tests/unit/connection/auth/NoSaslAuthentication.test.js b/tests/unit/connection/auth/NoSaslAuthentication.test.js deleted file mode 100644 index 631a62ca..00000000 --- a/tests/unit/connection/auth/NoSaslAuthentication.test.js +++ /dev/null @@ -1,18 +0,0 @@ -const { expect } = require('chai'); -const thrift = require('thrift'); -const NoSaslAuthentication = require('../../../../dist/connection/auth/NoSaslAuthentication').default; - -describe('NoSaslAuthentication', () => { - it('auth token must be set to header', () => { - const auth = new NoSaslAuthentication(); - const transportMock = { - setOptions(name, value) { - expect(name).to.be.equal('transport'); - expect(value).to.be.equal(thrift.TBufferedTransport); - }, - }; - return auth.authenticate(transportMock).then((transport) => { - expect(transport).to.be.eq(transportMock); - }); - }); -}); From 7c1574218d44ac4e429017b5102e5e38ce4cf0d5 Mon Sep 17 00:00:00 2001 From: Levko Kravets Date: Tue, 20 Jun 2023 14:06:52 +0300 Subject: [PATCH 3/8] DBSQLClient: options for auth types Signed-off-by: Levko Kravets --- lib/DBSQLClient.ts | 51 ++++++++++++++++++++++++++++------- lib/contracts/IDBSQLClient.ts | 21 ++++++++++++--- 2 files changed, 60 insertions(+), 12 deletions(-) diff --git a/lib/DBSQLClient.ts b/lib/DBSQLClient.ts index a10592d7..e6ad8c16 100644 --- a/lib/DBSQLClient.ts +++ b/lib/DBSQLClient.ts @@ -17,7 +17,7 @@ import Status from './dto/Status'; import HiveDriverError from './errors/HiveDriverError'; import { buildUserAgentString, definedOrError } from './utils'; import PlainHttpAuthentication from './connection/auth/PlainHttpAuthentication'; -// import DatabricksOAuth from './connection/auth/DatabricksOAuth'; +import DatabricksOAuth from './connection/auth/DatabricksOAuth'; import IDBSQLLogger, { LogLevel } from './contracts/IDBSQLLogger'; import DBSQLLogger from './DBSQLLogger'; @@ -62,7 +62,21 @@ export default class DBSQLClient extends EventEmitter implements IDBSQLClient { } private getConnectionOptions(options: ConnectionOptions): IConnectionOptions { - const { host, port, path, token, clientId, ...otherOptions } = options; + const { + host, + port, + path, + clientId, + authType, + // @ts-expect-error TS2339: Property 'token' does not exist on type 'ConnectionOptions' + token, + // @ts-expect-error TS2339: Property 'persistence' does not exist on type 'ConnectionOptions' + persistence, + // @ts-expect-error TS2339: Property 'provider' does not exist on type 'ConnectionOptions' + provider, + ...otherOptions + } = options; + return { host, port: port || 443, @@ -77,22 +91,41 @@ export default class DBSQLClient extends EventEmitter implements IDBSQLClient { }; } + private getAuthProvider(options: ConnectionOptions, authProvider?: IAuthentication): IAuthentication { + if (authProvider) { + return authProvider; + } + + switch (options.authType) { + case undefined: + case 'access-token': + return new PlainHttpAuthentication({ + username: 'token', + password: options.token, + }); + case 'databricks-oauth': + return new DatabricksOAuth({ + host: options.host, + logger: this.logger, + persistence: options.persistence, + }); + case 'custom': + return options.provider; + // no default + } + } + /** * Connects DBSQLClient to endpoint * @public * @param options - host, path, and token are required - * @param authProvider - Optional custom authentication provider + * @param authProvider - [DEPRECATED - use `authType: 'custom'] Optional custom authentication provider * @returns Session object that can be used to execute statements * @example * const session = client.connect({host, path, token}); */ public async connect(options: ConnectionOptions, authProvider?: IAuthentication): Promise { - authProvider = - authProvider || - new PlainHttpAuthentication({ - username: 'token', - password: options.token, - }); + authProvider = this.getAuthProvider(options, authProvider); this.connection = await this.connectionProvider.connect(this.getConnectionOptions(options), authProvider); diff --git a/lib/contracts/IDBSQLClient.ts b/lib/contracts/IDBSQLClient.ts index b486130c..26e7062a 100644 --- a/lib/contracts/IDBSQLClient.ts +++ b/lib/contracts/IDBSQLClient.ts @@ -1,18 +1,33 @@ import IDBSQLLogger from './IDBSQLLogger'; import IDBSQLSession from './IDBSQLSession'; +import IAuthentication from '../connection/contracts/IAuthentication'; +import OAuthPersistence from '../connection/auth/DatabricksOAuth/OAuthPersistence'; export interface ClientOptions { logger?: IDBSQLLogger; } -export interface ConnectionOptions { +type AuthOptions = + | { + authType?: 'access-token'; + token: string; + } + | { + authType: 'databricks-oauth'; + persistence?: OAuthPersistence; + } + | { + authType: 'custom'; + provider: IAuthentication; + }; + +export type ConnectionOptions = { host: string; port?: number; path: string; - token: string; clientId?: string; socketTimeout?: number; -} +} & AuthOptions; export interface OpenSessionRequest { initialCatalog?: string; From bf1639c93f7f2588f7fdddf2d8ef60bd532fa077 Mon Sep 17 00:00:00 2001 From: Levko Kravets Date: Thu, 22 Jun 2023 10:19:41 +0300 Subject: [PATCH 4/8] Tests Signed-off-by: Levko Kravets --- .../auth/DatabricksOAuth/AuthorizationCode.ts | 6 +- .../DatabricksOAuth/AuthorizationCode.test.js | 247 ++++++++++++++++++ .../auth/DatabricksOAuth/OAuthManager.test.js | 233 +++++++++++++++++ .../auth/DatabricksOAuth/OAuthToken.test.js | 64 +++++ .../auth/DatabricksOAuth/index.test.js | 135 ++++++++++ .../connection/auth/DatabricksOAuth/utils.js | 20 ++ 6 files changed, 704 insertions(+), 1 deletion(-) create mode 100644 tests/unit/connection/auth/DatabricksOAuth/AuthorizationCode.test.js create mode 100644 tests/unit/connection/auth/DatabricksOAuth/OAuthManager.test.js create mode 100644 tests/unit/connection/auth/DatabricksOAuth/OAuthToken.test.js create mode 100644 tests/unit/connection/auth/DatabricksOAuth/index.test.js create mode 100644 tests/unit/connection/auth/DatabricksOAuth/utils.js diff --git a/lib/connection/auth/DatabricksOAuth/AuthorizationCode.ts b/lib/connection/auth/DatabricksOAuth/AuthorizationCode.ts index 862a2162..5129337a 100644 --- a/lib/connection/auth/DatabricksOAuth/AuthorizationCode.ts +++ b/lib/connection/auth/DatabricksOAuth/AuthorizationCode.ts @@ -72,6 +72,10 @@ export default class AuthorizationCode { this.logger = options.logger; } + private async openUrl(url: string) { + return open(url); + } + public async fetch(scopes: Array): Promise { const verifierString = generators.codeVerifier(32); const challengeString = generators.codeChallenge(verifierString); @@ -103,7 +107,7 @@ export default class AuthorizationCode { redirect_uri: redirectUri, }); - await open(authUrl); + await this.openUrl(authUrl); await server.stopped(); if (!code) { diff --git a/tests/unit/connection/auth/DatabricksOAuth/AuthorizationCode.test.js b/tests/unit/connection/auth/DatabricksOAuth/AuthorizationCode.test.js new file mode 100644 index 00000000..c165b3e0 --- /dev/null +++ b/tests/unit/connection/auth/DatabricksOAuth/AuthorizationCode.test.js @@ -0,0 +1,247 @@ +const { expect, AssertionError } = require('chai'); +const { EventEmitter } = require('events'); +const sinon = require('sinon'); +const http = require('http'); +const AuthorizationCode = require('../../../../../dist/connection/auth/DatabricksOAuth/AuthorizationCode').default; + +class HttpServerMock extends EventEmitter { + constructor() { + super(); + this.requestHandler = () => {}; + this.listening = false; + this.listenError = undefined; // error to emit on listen + this.closeError = undefined; // error to emit on close + } + + listen(port, host, callback) { + if (this.listenError) { + this.emit('error', this.listenError); + this.listenError = undefined; + } else if (port < 1000) { + const error = new Error(`Address ${host}:${port} is already in use`); + error.code = 'EADDRINUSE'; + this.emit('error', error); + } else { + this.listening = true; + callback(); + } + } + + close(callback) { + this.requestHandler = () => {}; + this.listening = false; + if (this.closeError) { + this.emit('error', this.closeError); + this.closeError = undefined; + } else { + callback(); + } + } +} + +class OAuthClientMock { + constructor() { + this.code = 'test_authorization_code'; + this.redirectUri = undefined; + } + + authorizationUrl(params) { + this.redirectUri = params.redirect_uri; + return JSON.stringify({ + state: params.state, + code: this.code, + }); + } + + callbackParams(req) { + return req.params; + } +} + +function prepareTestInstances(options) { + const httpServer = new HttpServerMock(); + + const oauthClient = new OAuthClientMock(); + + const authCode = new AuthorizationCode({ + client: oauthClient, + ...options, + }); + + sinon.stub(http, 'createServer').callsFake((requestHandler) => { + httpServer.requestHandler = requestHandler; + return httpServer; + }); + + sinon.stub(authCode, 'openUrl').callsFake((url) => { + const params = JSON.parse(url); + httpServer.requestHandler( + { params }, + { + writeHead: () => {}, + end: () => {}, + }, + ); + }); + + function reloadUrl() { + setTimeout(() => { + const args = authCode.openUrl.firstCall.args; + authCode.openUrl(...args); + }, 10); + } + + return { httpServer, oauthClient, authCode, reloadUrl }; +} + +describe('AuthorizationCode', () => { + afterEach(() => { + http.createServer.restore?.(); + }); + + it('should fetch authorization code', async () => { + const { authCode, oauthClient } = prepareTestInstances({ + ports: [80, 8000], + logger: { log: () => {} }, + }); + + const result = await authCode.fetch([]); + expect(http.createServer.callCount).to.be.equal(2); + expect(authCode.openUrl.callCount).to.be.equal(1); + + expect(result.code).to.be.equal(oauthClient.code); + expect(result.verifier).to.not.be.empty; + expect(result.redirectUri).to.be.equal(oauthClient.redirectUri); + }); + + it('should throw error if cannot start server on any port', async () => { + const { authCode } = prepareTestInstances({ + ports: [80, 443], + }); + + try { + await authCode.fetch([]); + expect.fail('It should throw an error'); + } catch (error) { + if (error instanceof AssertionError) { + throw error; + } + expect(http.createServer.callCount).to.be.equal(2); + expect(authCode.openUrl.callCount).to.be.equal(0); + + expect(error.message).to.contain('all ports are in use'); + } + }); + + it('should re-throw unhandled server start errors', async () => { + const { authCode, httpServer } = prepareTestInstances({ + ports: [80], + }); + + const testError = new Error('Test'); + httpServer.listenError = testError; + + try { + await authCode.fetch([]); + expect.fail('It should throw an error'); + } catch (error) { + if (error instanceof AssertionError) { + throw error; + } + expect(http.createServer.callCount).to.be.equal(1); + expect(authCode.openUrl.callCount).to.be.equal(0); + + expect(error).to.be.equal(testError); + } + }); + + it('should re-throw unhandled server stop errors', async () => { + const { authCode, httpServer } = prepareTestInstances({ + ports: [8000], + }); + + const testError = new Error('Test'); + httpServer.closeError = testError; + + try { + await authCode.fetch([]); + expect.fail('It should throw an error'); + } catch (error) { + if (error instanceof AssertionError) { + throw error; + } + expect(http.createServer.callCount).to.be.equal(1); + expect(authCode.openUrl.callCount).to.be.equal(1); + + expect(error).to.be.equal(testError); + } + }); + + it('should throw an error if no code was returned', async () => { + const { authCode, oauthClient } = prepareTestInstances({ + ports: [8000], + }); + + sinon.stub(oauthClient, 'callbackParams').callsFake((req) => { + // Omit authorization code from params + const { code, ...otherParams } = req.params; + return otherParams; + }); + + try { + await authCode.fetch([]); + expect.fail('It should throw an error'); + } catch (error) { + if (error instanceof AssertionError) { + throw error; + } + expect(http.createServer.callCount).to.be.equal(1); + expect(authCode.openUrl.callCount).to.be.equal(1); + + expect(error.message).to.contain('No path parameters were returned to the callback'); + } + }); + + it('should serve 404 for unrecognized requests', async () => { + const { authCode, oauthClient, reloadUrl } = prepareTestInstances({ + ports: [8000], + }); + + sinon + .stub(oauthClient, 'callbackParams') + .onFirstCall() + .callsFake(() => { + // Repeat the same request after currently processed one. + // We won't modify response on subsequent requests so OAuth routine can complete + reloadUrl(); + // Return no params so request cannot be recognized + return {}; + }) + .callThrough(); + + await authCode.fetch([]); + + expect(http.createServer.callCount).to.be.equal(1); + expect(authCode.openUrl.callCount).to.be.equal(2); + }); + + it('should not attempt to stop server if not running', async () => { + const { authCode, oauthClient, httpServer } = prepareTestInstances({ + ports: [8000], + logger: { log: () => {} }, + }); + + const promise = authCode.fetch([]); + + httpServer.listening = false; + httpServer.closeError = new Error('Test'); + + const result = await promise; + // We set up server to throw an error on close. If nothing happened - it means + // that `authCode` never tried to stop it + expect(result.code).to.be.equal(oauthClient.code); + + expect(http.createServer.callCount).to.be.equal(1); + expect(authCode.openUrl.callCount).to.be.equal(1); + }); +}); diff --git a/tests/unit/connection/auth/DatabricksOAuth/OAuthManager.test.js b/tests/unit/connection/auth/DatabricksOAuth/OAuthManager.test.js new file mode 100644 index 00000000..5630a61b --- /dev/null +++ b/tests/unit/connection/auth/DatabricksOAuth/OAuthManager.test.js @@ -0,0 +1,233 @@ +const { expect, AssertionError } = require('chai'); +const sinon = require('sinon'); +const OAuthManager = require('../../../../../dist/connection/auth/DatabricksOAuth/OAuthManager').default; +const OAuthToken = require('../../../../../dist/connection/auth/DatabricksOAuth/OAuthToken').default; +const AuthorizationCodeModule = require('../../../../../dist/connection/auth/DatabricksOAuth/AuthorizationCode'); + +const { createValidAccessToken, createExpiredAccessToken } = require('./utils'); + +class AuthorizationCodeMock { + constructor() { + this.fetchResult = undefined; + } + + async fetch() { + return this.fetchResult; + } +} + +AuthorizationCodeMock.validCode = { + code: 'auth_code', + verifier: 'verifier_string', + redirectUri: 'http://localhost:8000', +}; + +class OAuthClientMock { + constructor() { + this.grantError = undefined; + this.refreshError = undefined; + + this.accessToken = undefined; + this.refreshToken = undefined; + this.recreateTokens(); + } + + recreateTokens() { + const suffix = Math.random().toString(36).substring(2); + this.accessToken = `${createValidAccessToken()}.${suffix}`; + this.refreshToken = `refresh.${suffix}`; + } + + async grant(params) { + if (this.grantError) { + const error = this.grantError; + this.grantError = undefined; + throw error; + } + + expect(params.grant_type).to.be.equal('authorization_code'); + expect(params.code).to.be.equal(AuthorizationCodeMock.validCode.code); + expect(params.code_verifier).to.be.equal(AuthorizationCodeMock.validCode.verifier); + expect(params.redirect_uri).to.be.equal(AuthorizationCodeMock.validCode.redirectUri); + + return { + access_token: this.accessToken, + refresh_token: this.refreshToken, + }; + } + + async refresh(refreshToken) { + if (this.refreshError) { + const error = this.refreshError; + this.refreshError = undefined; + throw error; + } + + expect(refreshToken).to.be.equal(this.refreshToken); + + this.recreateTokens(); + return { + access_token: this.accessToken, + refresh_token: this.refreshToken, + }; + } +} + +function prepareTestInstances(options) { + const oauthClient = new OAuthClientMock(); + sinon.stub(oauthClient, 'grant').callThrough(); + sinon.stub(oauthClient, 'refresh').callThrough(); + + const oauthManager = new OAuthManager({ ...options }); + sinon.stub(oauthManager, 'getClient').returns(Promise.resolve(oauthClient)); + + const authCode = new AuthorizationCodeMock(); + authCode.fetchResult = { ...AuthorizationCodeMock.validCode }; + + sinon.stub(AuthorizationCodeModule, 'default').returns(authCode); + + return { oauthClient, oauthManager, authCode }; +} + +describe('OAuthManager', () => { + afterEach(() => { + AuthorizationCodeModule.default.restore?.(); + }); + + it('should get access token', async () => { + const { oauthManager, oauthClient } = prepareTestInstances({ + logger: { log: () => {} }, + }); + + const token = await oauthManager.getToken([]); + expect(oauthClient.grant.called).to.be.true; + expect(token).to.be.instanceOf(OAuthToken); + expect(token.accessToken).to.be.equal(oauthClient.accessToken); + expect(token.refreshToken).to.be.equal(oauthClient.refreshToken); + }); + + it('should throw an error if cannot get access token', async () => { + const { oauthManager, oauthClient } = prepareTestInstances(); + + // Make it return empty tokens + oauthClient.accessToken = undefined; + oauthClient.refreshToken = undefined; + + try { + await oauthManager.getToken([]); + expect.fail('It should throw an error'); + } catch (error) { + if (error instanceof AssertionError) { + throw error; + } + expect(oauthClient.grant.called).to.be.true; + expect(error.message).to.contain('Failed to fetch access token'); + } + }); + + it('should re-throw unhandled errors when getting access token', async () => { + const { oauthManager, oauthClient } = prepareTestInstances(); + + const testError = new Error('Test'); + oauthClient.grantError = testError; + + try { + await oauthManager.getToken([]); + expect.fail('It should throw an error'); + } catch (error) { + if (error instanceof AssertionError) { + throw error; + } + expect(oauthClient.grant.called).to.be.true; + expect(error).to.be.equal(testError); + } + }); + + it('should not refresh valid token', async () => { + const { oauthManager, oauthClient } = prepareTestInstances(); + + const token = new OAuthToken(createValidAccessToken(), oauthClient.refreshToken); + expect(token.hasExpired).to.be.false; + + const newToken = await oauthManager.refreshAccessToken(token); + expect(oauthClient.refresh.called).to.be.false; + expect(newToken).to.be.instanceOf(OAuthToken); + expect(newToken.accessToken).to.be.equal(token.accessToken); + expect(newToken.hasExpired).to.be.false; + }); + + it('should throw an error if no refresh token is available', async () => { + const { oauthManager, oauthClient } = prepareTestInstances({ + logger: { log: () => {} }, + }); + + try { + const token = new OAuthToken(createExpiredAccessToken()); + expect(token.hasExpired).to.be.true; + + await oauthManager.refreshAccessToken(token); + expect.fail('It should throw an error'); + } catch (error) { + if (error instanceof AssertionError) { + throw error; + } + expect(oauthClient.refresh.called).to.be.false; + expect(error.message).to.contain('token expired'); + } + }); + + it('should throw an error on invalid response', async () => { + const { oauthManager, oauthClient } = prepareTestInstances({ + logger: { log: () => {} }, + }); + + oauthClient.refresh.restore(); + sinon.stub(oauthClient, 'refresh').returns({}); + + try { + const token = new OAuthToken(createExpiredAccessToken(), oauthClient.refreshToken); + expect(token.hasExpired).to.be.true; + + await oauthManager.refreshAccessToken(token); + expect.fail('It should throw an error'); + } catch (error) { + if (error instanceof AssertionError) { + throw error; + } + expect(oauthClient.refresh.called).to.be.true; + expect(error.message).to.contain('invalid response'); + } + }); + + it('should throw an error for invalid token', async () => { + const { oauthManager, oauthClient } = prepareTestInstances({ + logger: { log: () => {} }, + }); + + try { + const token = new OAuthToken('invalid_access_token', 'invalid_refresh_token'); + await oauthManager.refreshAccessToken(token); + expect.fail('It should throw an error'); + } catch (error) { + if (error instanceof AssertionError) { + throw error; + } + expect(oauthClient.refresh.called).to.be.false; + // Random malformed string passed as access token will cause JSON parse errors + expect(error).to.be.instanceof(TypeError); + } + }); + + it('should refresh expired token', async () => { + const { oauthManager, oauthClient } = prepareTestInstances(); + + const token = new OAuthToken(createExpiredAccessToken(), oauthClient.refreshToken); + expect(token.hasExpired).to.be.true; + + const newToken = await oauthManager.refreshAccessToken(token); + expect(oauthClient.refresh.called).to.be.true; + expect(newToken).to.be.instanceOf(OAuthToken); + expect(newToken.accessToken).to.be.not.equal(token.accessToken); + expect(newToken.hasExpired).to.be.false; + }); +}); diff --git a/tests/unit/connection/auth/DatabricksOAuth/OAuthToken.test.js b/tests/unit/connection/auth/DatabricksOAuth/OAuthToken.test.js new file mode 100644 index 00000000..1ab4f486 --- /dev/null +++ b/tests/unit/connection/auth/DatabricksOAuth/OAuthToken.test.js @@ -0,0 +1,64 @@ +const { expect } = require('chai'); +const OAuthToken = require('../../../../../dist/connection/auth/DatabricksOAuth/OAuthToken').default; + +const { createAccessToken } = require('./utils'); + +describe('OAuthToken', () => { + it('should be properly initialized', () => { + const accessToken = 'access'; + const refreshToken = 'refresh'; + + const token1 = new OAuthToken(accessToken); + expect(token1.accessToken).to.be.equal(accessToken); + + const token2 = new OAuthToken(accessToken, refreshToken); + expect(token2.accessToken).to.be.equal(accessToken); + expect(token2.refreshToken).to.be.equal(refreshToken); + }); + + it('should return valid expiration time', () => { + const expirationTime = Math.trunc(Date.now() / 1000); + const accessToken = createAccessToken(expirationTime); + + const token = new OAuthToken(accessToken); + expect(token.expirationTime).to.be.equal(expirationTime); + // second attempt - to make sure it returns the same value + expect(token.expirationTime).to.be.equal(expirationTime); + }); + + it('should throw error if cannot get expiration time', () => { + expect(() => { + const token = new OAuthToken('without_payload'); + expect(token.expirationTime).to.be.equal(undefined); + }).to.throw(); + + expect(() => { + const token = new OAuthToken('invalid.payload'); + expect(token.expirationTime).to.be.equal(undefined); + }).to.throw(); + + expect(() => { + const payload = Buffer.from('qwerty', 'utf8').toString('base64'); + const token = new OAuthToken(`malformed.${payload}`); + expect(token.expirationTime).to.be.equal(undefined); + }).to.throw(); + }); + + it('should test for expired token', () => { + const expirationTime = Math.trunc(Date.now() / 1000) - 1; + const accessToken = createAccessToken(expirationTime); + + const token = new OAuthToken(accessToken); + expect(token.expirationTime).to.be.equal(expirationTime); + expect(token.hasExpired).to.be.true; + }); + + it('should test for valid token', () => { + const expirationTime = Math.trunc(Date.now() / 1000) + 1; + const accessToken = createAccessToken(expirationTime); + + const token = new OAuthToken(accessToken); + expect(token.expirationTime).to.be.equal(expirationTime); + expect(token.hasExpired).to.be.false; + }); +}); diff --git a/tests/unit/connection/auth/DatabricksOAuth/index.test.js b/tests/unit/connection/auth/DatabricksOAuth/index.test.js new file mode 100644 index 00000000..bdeccc9b --- /dev/null +++ b/tests/unit/connection/auth/DatabricksOAuth/index.test.js @@ -0,0 +1,135 @@ +const { expect, AssertionError } = require('chai'); +const sinon = require('sinon'); +const DatabricksOAuth = require('../../../../../dist/connection/auth/DatabricksOAuth/index').default; +const OAuthToken = require('../../../../../dist/connection/auth/DatabricksOAuth/OAuthToken').default; +const OAuthManagerModule = require('../../../../../dist/connection/auth/DatabricksOAuth/OAuthManager'); + +const { createValidAccessToken, createExpiredAccessToken } = require('./utils'); + +class OAuthManagerMock { + constructor() { + this.getTokenResult = new OAuthToken(createValidAccessToken()); + this.refreshTokenResult = new OAuthToken(createValidAccessToken()); + } + + async refreshAccessToken(token) { + return token.hasExpired ? this.refreshTokenResult : token; + } + + async getToken() { + return this.getTokenResult; + } +} + +class TransportMock { + constructor() { + this.headers = {}; + } + + updateHeaders(newHeaders) { + this.headers = { + ...this.headers, + ...newHeaders, + }; + } +} + +class OAuthPersistenceMock { + constructor() { + this.token = undefined; + + sinon.stub(this, 'persist').callThrough(); + sinon.stub(this, 'read').callThrough(); + } + + async persist(host, token) { + this.token = token; + } + + async read() { + return this.token; + } +} + +function prepareTestInstances(options) { + const oauthManager = new OAuthManagerMock(); + + sinon.stub(oauthManager, 'refreshAccessToken').callThrough(); + sinon.stub(oauthManager, 'getToken').callThrough(); + + sinon.stub(OAuthManagerModule, 'default').returns(oauthManager); + + const provider = new DatabricksOAuth({ ...options }); + + const transport = new TransportMock(); + sinon.stub(transport, 'updateHeaders').callThrough(); + + return { oauthManager, provider, transport }; +} + +describe('DatabricksOAuth', () => { + afterEach(() => { + OAuthManagerModule.default.restore?.(); + }); + + it('should get persisted token if available', async () => { + const persistence = new OAuthPersistenceMock(); + persistence.token = new OAuthToken(createValidAccessToken()); + + const { provider, transport } = prepareTestInstances({ persistence }); + + await provider.authenticate(transport); + expect(persistence.read.called).to.be.true; + }); + + it('should get new token if storage not available', async () => { + const { oauthManager, provider, transport } = prepareTestInstances(); + + await provider.authenticate(transport); + expect(oauthManager.getToken.called).to.be.true; + }); + + it('should get new token if persisted token not available, and store valid token', async () => { + const persistence = new OAuthPersistenceMock(); + persistence.token = undefined; + const { oauthManager, provider, transport } = prepareTestInstances({ persistence }); + + await provider.authenticate(transport); + expect(oauthManager.getToken.called).to.be.true; + expect(persistence.persist.called).to.be.true; + expect(persistence.token).to.be.equal(oauthManager.getTokenResult); + }); + + it('should refresh expired token and store new token', async () => { + const persistence = new OAuthPersistenceMock(); + persistence.token = undefined; + + const { oauthManager, provider, transport } = prepareTestInstances({ persistence }); + oauthManager.getTokenResult = new OAuthToken(createExpiredAccessToken()); + oauthManager.refreshTokenResult = new OAuthToken(createValidAccessToken()); + + await provider.authenticate(transport); + expect(oauthManager.getToken.called).to.be.true; + expect(oauthManager.refreshAccessToken.called).to.be.true; + expect(oauthManager.refreshAccessToken.firstCall.firstArg).to.be.equal(oauthManager.getTokenResult); + expect(persistence.token).to.be.equal(oauthManager.refreshTokenResult); + expect(persistence.persist.called).to.be.true; + expect(persistence.token).to.be.equal(oauthManager.refreshTokenResult); + }); + + it('should configure transport using valid token', async () => { + const { oauthManager, provider, transport } = prepareTestInstances(); + + const initialHeaders = { + x: 'x', + y: 'y', + }; + + transport.headers = initialHeaders; + + await provider.authenticate(transport); + expect(oauthManager.getToken.called).to.be.true; + expect(transport.updateHeaders.called).to.be.true; + expect(Object.keys(transport.headers)).to.deep.equal([...Object.keys(initialHeaders), 'Authorization']); + }); +}); diff --git a/tests/unit/connection/auth/DatabricksOAuth/utils.js b/tests/unit/connection/auth/DatabricksOAuth/utils.js new file mode 100644 index 00000000..edffe7b1 --- /dev/null +++ b/tests/unit/connection/auth/DatabricksOAuth/utils.js @@ -0,0 +1,20 @@ +function createAccessToken(expirationTime) { + const payload = Buffer.from(JSON.stringify({ exp: expirationTime }), 'utf8').toString('base64'); + return `access.${payload}`; +} + +function createValidAccessToken() { + const expirationTime = Math.trunc(Date.now() / 1000) + 20000; + return createAccessToken(expirationTime); +} + +function createExpiredAccessToken() { + const expirationTime = Math.trunc(Date.now() / 1000) - 1000; + return createAccessToken(expirationTime); +} + +module.exports = { + createAccessToken, + createValidAccessToken, + createExpiredAccessToken, +}; From b894a0b2e70bd4564388e7219b237e64a0c0d487 Mon Sep 17 00:00:00 2001 From: Levko Kravets Date: Thu, 22 Jun 2023 11:33:19 +0300 Subject: [PATCH 5/8] Tests Signed-off-by: Levko Kravets --- tests/unit/DBSQLClient.test.js | 73 ++++++++++++++++++++++++++++++++-- 1 file changed, 69 insertions(+), 4 deletions(-) diff --git a/tests/unit/DBSQLClient.test.js b/tests/unit/DBSQLClient.test.js index 6722405e..5ffedd47 100644 --- a/tests/unit/DBSQLClient.test.js +++ b/tests/unit/DBSQLClient.test.js @@ -2,10 +2,10 @@ const { expect } = require('chai'); const sinon = require('sinon'); const DBSQLClient = require('../../dist/DBSQLClient').default; const DBSQLSession = require('../../dist/DBSQLSession').default; -const { - auth: { PlainHttpAuthentication }, - connections: { HttpConnection }, -} = require('../../'); + +const PlainHttpAuthentication = require('../../dist/connection/auth/PlainHttpAuthentication').default; +const DatabricksOAuth = require('../../dist/connection/auth/DatabricksOAuth').default; +const HttpConnection = require('../../dist/connection/connections/HttpConnection').default; const ConnectionProviderMock = (connection) => ({ connect(options, auth) { @@ -227,3 +227,68 @@ describe('DBSQLClient.close', () => { // No additional asserts needed - it should just reach this point }); }); + +describe('DBSQLClient.getAuthProvider', () => { + it('should use access token auth method', () => { + const client = new DBSQLClient(); + + const testAccessToken = 'token'; + const provider = client.getAuthProvider({ + authType: 'access-token', + token: testAccessToken, + }); + + expect(provider).to.be.instanceOf(PlainHttpAuthentication); + expect(provider.password).to.be.equal(testAccessToken); + }); + + it('should use access token auth method by default (compatibility)', () => { + const client = new DBSQLClient(); + + const testAccessToken = 'token'; + const provider = client.getAuthProvider({ + // note: no `authType` provided + token: testAccessToken, + }); + + expect(provider).to.be.instanceOf(PlainHttpAuthentication); + expect(provider.password).to.be.equal(testAccessToken); + }); + + it('should use Databricks OAuth method', () => { + const client = new DBSQLClient(); + + const provider = client.getAuthProvider({ + authType: 'databricks-oauth', + }); + + expect(provider).to.be.instanceOf(DatabricksOAuth); + }); + + it('should use custom auth method', () => { + const client = new DBSQLClient(); + + const customProvider = {}; + + const provider = client.getAuthProvider({ + authType: 'custom', + provider: customProvider, + }); + + expect(provider).to.be.equal(customProvider); + }); + + it('should use custom auth method (legacy way)', () => { + const client = new DBSQLClient(); + + const customProvider = {}; + + const provider = client.getAuthProvider( + // custom provider from second arg should be used no matter what's specified in config + { authType: 'access-token', token: 'token' }, + customProvider, + ); + + expect(provider).to.be.equal(customProvider); + }); +}); From 3fe269c0eb34125dd4f7477ec641db1b36bd244a Mon Sep 17 00:00:00 2001 From: Levko Kravets Date: Thu, 22 Jun 2023 11:39:01 +0300 Subject: [PATCH 6/8] Fix: move comment to appropriate place Signed-off-by: Levko Kravets --- lib/connection/auth/DatabricksOAuth/OAuthToken.ts | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/lib/connection/auth/DatabricksOAuth/OAuthToken.ts b/lib/connection/auth/DatabricksOAuth/OAuthToken.ts index 2c1fee53..e48b0d05 100644 --- a/lib/connection/auth/DatabricksOAuth/OAuthToken.ts +++ b/lib/connection/auth/DatabricksOAuth/OAuthToken.ts @@ -19,6 +19,10 @@ export default class OAuthToken { } get expirationTime(): number { + // This token has already been verified, and we are just parsing it. + // If it has been tampered with, it will be rejected on the server side. + // This avoids having to fetch the public key from the issuer and perform + // an unnecessary signature verification. if (this._expirationTime === undefined) { const accessTokenPayload = Buffer.from(this._accessToken.split('.')[1], 'base64').toString('utf8'); const decoded = JSON.parse(accessTokenPayload); @@ -28,10 +32,6 @@ export default class OAuthToken { } get hasExpired(): boolean { - // This token has already been verified, and we are just parsing it. - // If it has been tampered with, it will be rejected on the server side. - // This avoids having to fetch the public key from the issuer and perform - // an unnecessary signature verification. const now = Math.floor(Date.now() / 1000); // convert it to seconds return this.expirationTime <= now; } From f6eae7c5da3114ee6bb4197d61a6c2258c3df4be Mon Sep 17 00:00:00 2001 From: Levko Kravets Date: Thu, 22 Jun 2023 11:54:29 +0300 Subject: [PATCH 7/8] Improve tests Signed-off-by: Levko Kravets --- .../auth/DatabricksOAuth/OAuthManager.test.js | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/tests/unit/connection/auth/DatabricksOAuth/OAuthManager.test.js b/tests/unit/connection/auth/DatabricksOAuth/OAuthManager.test.js index 5630a61b..d7209b83 100644 --- a/tests/unit/connection/auth/DatabricksOAuth/OAuthManager.test.js +++ b/tests/unit/connection/auth/DatabricksOAuth/OAuthManager.test.js @@ -1,5 +1,6 @@ const { expect, AssertionError } = require('chai'); const sinon = require('sinon'); +const { Issuer } = require('openid-client'); const OAuthManager = require('../../../../../dist/connection/auth/DatabricksOAuth/OAuthManager').default; const OAuthToken = require('../../../../../dist/connection/auth/DatabricksOAuth/OAuthToken').default; const AuthorizationCodeModule = require('../../../../../dist/connection/auth/DatabricksOAuth/AuthorizationCode'); @@ -78,8 +79,18 @@ function prepareTestInstances(options) { sinon.stub(oauthClient, 'grant').callThrough(); sinon.stub(oauthClient, 'refresh').callThrough(); - const oauthManager = new OAuthManager({ ...options }); - sinon.stub(oauthManager, 'getClient').returns(Promise.resolve(oauthClient)); + sinon.stub(Issuer, 'discover').returns( + Promise.resolve({ + Client: function () { + return oauthClient; + }, + }), + ); + + const oauthManager = new OAuthManager({ + host: 'https://example.com', + ...options, + }); const authCode = new AuthorizationCodeMock(); authCode.fetchResult = { ...AuthorizationCodeMock.validCode }; @@ -92,6 +103,7 @@ function prepareTestInstances(options) { describe('OAuthManager', () => { afterEach(() => { AuthorizationCodeModule.default.restore?.(); + Issuer.discover.restore?.(); }); it('should get access token', async () => { @@ -221,7 +233,8 @@ describe('OAuthManager', () => { it('should refresh expired token', async () => { const { oauthManager, oauthClient } = prepareTestInstances(); - const token = new OAuthToken(createExpiredAccessToken(), oauthClient.refreshToken); + oauthClient.accessToken = createExpiredAccessToken(); + const token = await oauthManager.getToken([]); expect(token.hasExpired).to.be.true; const newToken = await oauthManager.refreshAccessToken(token); From 2edbb6600c498359e7906fd93a456e3a9847c532 Mon Sep 17 00:00:00 2001 From: Levko Kravets Date: Thu, 22 Jun 2023 22:15:51 +0300 Subject: [PATCH 8/8] Use proper client ID; improve callback handling Signed-off-by: Levko Kravets --- .../auth/DatabricksOAuth/AuthorizationCode.ts | 16 ++++++---- lib/connection/auth/DatabricksOAuth/index.ts | 4 +-- .../DatabricksOAuth/AuthorizationCode.test.js | 29 +++++++++++++++++++ 3 files changed, 41 insertions(+), 8 deletions(-) diff --git a/lib/connection/auth/DatabricksOAuth/AuthorizationCode.ts b/lib/connection/auth/DatabricksOAuth/AuthorizationCode.ts index 5129337a..490f51d3 100644 --- a/lib/connection/auth/DatabricksOAuth/AuthorizationCode.ts +++ b/lib/connection/auth/DatabricksOAuth/AuthorizationCode.ts @@ -1,5 +1,5 @@ -import http, { Server, IncomingMessage, ServerResponse } from 'http'; -import { BaseClient, generators } from 'openid-client'; +import http, { IncomingMessage, Server, ServerResponse } from 'http'; +import { BaseClient, CallbackParamsType, generators } from 'openid-client'; import open from 'open'; import IDBSQLLogger, { LogLevel } from '../../../contracts/IDBSQLLogger'; @@ -81,12 +81,12 @@ export default class AuthorizationCode { const challengeString = generators.codeChallenge(verifierString); const state = generators.state(16); - let code: string | undefined; + let receivedParams: CallbackParamsType | undefined; const server = await this.startServer((req, res) => { const params = this.client.callbackParams(req); if (params.state === state) { - code = params.code; + receivedParams = params; res.writeHead(200); res.end(this.renderCallbackResponse()); server.stop(); @@ -110,11 +110,15 @@ export default class AuthorizationCode { await this.openUrl(authUrl); await server.stopped(); - if (!code) { + if (!receivedParams || !receivedParams.code) { + if (receivedParams?.error) { + const errorMessage = `OAuth error: ${receivedParams.error} ${receivedParams.error_description}`; + throw new Error(errorMessage); + } throw new Error(`No path parameters were returned to the callback at ${redirectUri}`); } - return { code, verifier: verifierString, redirectUri }; + return { code: receivedParams.code, verifier: verifierString, redirectUri }; } private async startServer(requestHandler: (req: IncomingMessage, res: ServerResponse) => void) { diff --git a/lib/connection/auth/DatabricksOAuth/index.ts b/lib/connection/auth/DatabricksOAuth/index.ts index 5e2cf9ee..955bbfc6 100644 --- a/lib/connection/auth/DatabricksOAuth/index.ts +++ b/lib/connection/auth/DatabricksOAuth/index.ts @@ -16,8 +16,8 @@ interface DatabricksOAuthOptions { } const defaultOAuthOptions = { - clientId: 'databricks-sql-python', - redirectPorts: [8020, 8021, 8022, 8023, 8024, 8025], + clientId: 'databricks-sql-connector', + redirectPorts: [8030], scopes: ['sql', 'offline_access'], } satisfies Partial; diff --git a/tests/unit/connection/auth/DatabricksOAuth/AuthorizationCode.test.js b/tests/unit/connection/auth/DatabricksOAuth/AuthorizationCode.test.js index c165b3e0..bc83e559 100644 --- a/tests/unit/connection/auth/DatabricksOAuth/AuthorizationCode.test.js +++ b/tests/unit/connection/auth/DatabricksOAuth/AuthorizationCode.test.js @@ -202,6 +202,35 @@ describe('AuthorizationCode', () => { } }); + it('should use error details from callback params', async () => { + const { authCode, oauthClient } = prepareTestInstances({ + ports: [8000], + }); + + sinon.stub(oauthClient, 'callbackParams').callsFake((req) => { + // Omit authorization code from params + const { code, ...otherParams } = req.params; + return { + ...otherParams, + error: 'test_error', + error_description: 'Test error', + }; + }); + + try { + await authCode.fetch([]); + expect.fail('It should throw an error'); + } catch (error) { + if (error instanceof AssertionError) { + throw error; + } + expect(http.createServer.callCount).to.be.equal(1); + expect(authCode.openUrl.callCount).to.be.equal(1); + + expect(error.message).to.contain('Test error'); + } + }); + it('should serve 404 for unrecognized requests', async () => { const { authCode, oauthClient, reloadUrl } = prepareTestInstances({ ports: [8000],