diff --git a/lib/DBSQLClient.ts b/lib/DBSQLClient.ts index 4e0a4e6d..e6ad8c16 100644 --- a/lib/DBSQLClient.ts +++ b/lib/DBSQLClient.ts @@ -17,6 +17,7 @@ import Status from './dto/Status'; import HiveDriverError from './errors/HiveDriverError'; import { buildUserAgentString, definedOrError } from './utils'; import PlainHttpAuthentication from './connection/auth/PlainHttpAuthentication'; +import DatabricksOAuth from './connection/auth/DatabricksOAuth'; import IDBSQLLogger, { LogLevel } from './contracts/IDBSQLLogger'; import DBSQLLogger from './DBSQLLogger'; @@ -61,7 +62,21 @@ export default class DBSQLClient extends EventEmitter implements IDBSQLClient { } private getConnectionOptions(options: ConnectionOptions): IConnectionOptions { - const { host, port, path, token, clientId, ...otherOptions } = options; + const { + host, + port, + path, + clientId, + authType, + // @ts-expect-error TS2339: Property 'token' does not exist on type 'ConnectionOptions' + token, + // @ts-expect-error TS2339: Property 'persistence' does not exist on type 'ConnectionOptions' + persistence, + // @ts-expect-error TS2339: Property 'provider' does not exist on type 'ConnectionOptions' + provider, + ...otherOptions + } = options; + return { host, port: port || 443, @@ -76,22 +91,41 @@ export default class DBSQLClient extends EventEmitter implements IDBSQLClient { }; } + private getAuthProvider(options: ConnectionOptions, authProvider?: IAuthentication): IAuthentication { + if (authProvider) { + return authProvider; + } + + switch (options.authType) { + case undefined: + case 'access-token': + return new PlainHttpAuthentication({ + username: 'token', + password: options.token, + }); + case 'databricks-oauth': + return new DatabricksOAuth({ + host: options.host, + logger: this.logger, + persistence: options.persistence, + }); + case 'custom': + return options.provider; + // no default + } + } + /** * Connects DBSQLClient to endpoint * @public * @param options - host, path, and token are required - * @param authProvider - Optional custom authentication provider + * @param authProvider - [DEPRECATED - use `authType: 'custom'] Optional custom authentication provider * @returns Session object that can be used to execute statements * @example * const session = client.connect({host, path, token}); */ public async connect(options: ConnectionOptions, authProvider?: IAuthentication): Promise { - authProvider = - authProvider || - new PlainHttpAuthentication({ - username: 'token', - password: options.token, - }); + authProvider = this.getAuthProvider(options, authProvider); this.connection = await this.connectionProvider.connect(this.getConnectionOptions(options), authProvider); diff --git a/lib/connection/auth/DatabricksOAuth/AuthorizationCode.ts b/lib/connection/auth/DatabricksOAuth/AuthorizationCode.ts new file mode 100644 index 00000000..490f51d3 --- /dev/null +++ b/lib/connection/auth/DatabricksOAuth/AuthorizationCode.ts @@ -0,0 +1,180 @@ +import http, { IncomingMessage, Server, ServerResponse } from 'http'; +import { BaseClient, CallbackParamsType, generators } from 'openid-client'; +import open from 'open'; +import IDBSQLLogger, { LogLevel } from '../../../contracts/IDBSQLLogger'; + +export interface AuthorizationCodeOptions { + client: BaseClient; + ports: Array; + logger?: IDBSQLLogger; +} + +const scopeDelimiter = ' '; + +async function startServer( + host: string, + port: number, + requestHandler: (req: IncomingMessage, res: ServerResponse) => void, +): Promise { + const server = http.createServer(requestHandler); + + return new Promise((resolve, reject) => { + const errorListener = (error: Error) => { + server.off('error', errorListener); + reject(error); + }; + + server.on('error', errorListener); + server.listen(port, host, () => { + server.off('error', errorListener); + resolve(server); + }); + }); +} + +async function stopServer(server: Server): Promise { + if (!server.listening) { + return; + } + + return new Promise((resolve, reject) => { + const errorListener = (error: Error) => { + server.off('error', errorListener); + reject(error); + }; + + server.on('error', errorListener); + server.close(() => { + server.off('error', errorListener); + resolve(); + }); + }); +} + +export interface AuthorizationCodeFetchResult { + code: string; + verifier: string; + redirectUri: string; +} + +export default class AuthorizationCode { + private readonly client: BaseClient; + + private readonly host: string = 'localhost'; + + private readonly ports: Array; + + private readonly logger?: IDBSQLLogger; + + constructor(options: AuthorizationCodeOptions) { + this.client = options.client; + this.ports = options.ports; + this.logger = options.logger; + } + + private async openUrl(url: string) { + return open(url); + } + + public async fetch(scopes: Array): Promise { + const verifierString = generators.codeVerifier(32); + const challengeString = generators.codeChallenge(verifierString); + const state = generators.state(16); + + let receivedParams: CallbackParamsType | undefined; + + const server = await this.startServer((req, res) => { + const params = this.client.callbackParams(req); + if (params.state === state) { + receivedParams = params; + res.writeHead(200); + res.end(this.renderCallbackResponse()); + server.stop(); + } else { + res.writeHead(404); + res.end(); + } + }); + + const redirectUri = `http://${server.host}:${server.port}/`; + const authUrl = this.client.authorizationUrl({ + response_type: 'code', + response_mode: 'query', + scope: scopes.join(scopeDelimiter), + code_challenge: challengeString, + code_challenge_method: 'S256', + state, + redirect_uri: redirectUri, + }); + + await this.openUrl(authUrl); + await server.stopped(); + + if (!receivedParams || !receivedParams.code) { + if (receivedParams?.error) { + const errorMessage = `OAuth error: ${receivedParams.error} ${receivedParams.error_description}`; + throw new Error(errorMessage); + } + throw new Error(`No path parameters were returned to the callback at ${redirectUri}`); + } + + return { code: receivedParams.code, verifier: verifierString, redirectUri }; + } + + private async startServer(requestHandler: (req: IncomingMessage, res: ServerResponse) => void) { + for (const port of this.ports) { + const host = this.host; // eslint-disable-line prefer-destructuring + try { + const server = await startServer(host, port, requestHandler); // eslint-disable-line no-await-in-loop + this.logger?.log(LogLevel.info, `Listening for OAuth authorization callback at ${host}:${port}`); + + let resolveStopped: () => void; + let rejectStopped: (reason?: any) => void; + const stoppedPromise = new Promise((resolve, reject) => { + resolveStopped = resolve; + rejectStopped = reject; + }); + + return { + host, + port, + server, + stop: () => stopServer(server).then(resolveStopped).catch(rejectStopped), + stopped: () => stoppedPromise, + }; + } catch (error) { + // if port already in use - try another one, otherwise re-throw an exception + if (error instanceof Error && 'code' in error && error.code === 'EADDRINUSE') { + this.logger?.log(LogLevel.debug, `Failed to start server at ${host}:${port}: ${error.code}`); + } else { + throw error; + } + } + } + + throw new Error('Failed to start server: all ports are in use'); + } + + private renderCallbackResponse(): string { + const applicationName = 'Databricks Sql Connector'; + + return ` + + Close this Tab + + + +

Please close this tab.

+

+ The ${applicationName} received a response. You may close this tab. +

+ +`; + } +} diff --git a/lib/connection/auth/DatabricksOAuth/OAuthManager.ts b/lib/connection/auth/DatabricksOAuth/OAuthManager.ts new file mode 100644 index 00000000..e0b7ff66 --- /dev/null +++ b/lib/connection/auth/DatabricksOAuth/OAuthManager.ts @@ -0,0 +1,102 @@ +import { Issuer, BaseClient } from 'openid-client'; +import HiveDriverError from '../../../errors/HiveDriverError'; +import IDBSQLLogger, { LogLevel } from '../../../contracts/IDBSQLLogger'; +import OAuthToken from './OAuthToken'; +import AuthorizationCode from './AuthorizationCode'; + +const oidcConfigPath = 'oidc/.well-known/oauth-authorization-server'; + +export interface OAuthManagerOptions { + host: string; + callbackPorts: Array; + clientId: string; + logger?: IDBSQLLogger; +} + +export default class OAuthManager { + private readonly options: OAuthManagerOptions; + + private readonly logger?: IDBSQLLogger; + + private issuer?: Issuer; + + private client?: BaseClient; + + constructor(options: OAuthManagerOptions) { + this.options = options; + this.logger = options.logger; + } + + private async getClient(): Promise { + if (!this.issuer) { + const { host } = this.options; + const schema = host.startsWith('https://') ? '' : 'https://'; + const trailingSlash = host.endsWith('/') ? '' : '/'; + this.issuer = await Issuer.discover(`${schema}${host}${trailingSlash}${oidcConfigPath}`); + } + + if (!this.client) { + this.client = new this.issuer.Client({ + client_id: this.options.clientId, + token_endpoint_auth_method: 'none', + }); + } + + return this.client; + } + + public async refreshAccessToken(token: OAuthToken): Promise { + try { + if (!token.hasExpired) { + // The access token is fine. Just return it. + return token; + } + } catch (error) { + this.logger?.log(LogLevel.error, `${error}`); + throw error; + } + + if (!token.refreshToken) { + const message = `OAuth access token expired on ${token.expirationTime}.`; + this.logger?.log(LogLevel.error, message); + throw new HiveDriverError(message); + } + + // Try to refresh using the refresh token + this.logger?.log( + LogLevel.debug, + `Attempting to refresh OAuth access token that expired on ${token.expirationTime}`, + ); + + const client = await this.getClient(); + const { access_token: accessToken, refresh_token: refreshToken } = await client.refresh(token.refreshToken); + if (!accessToken || !refreshToken) { + throw new Error('Failed to refresh token: invalid response'); + } + return new OAuthToken(accessToken, refreshToken); + } + + public async getToken(scopes: Array): Promise { + const client = await this.getClient(); + const authCode = new AuthorizationCode({ + client, + ports: this.options.callbackPorts, + logger: this.logger, + }); + + const { code, verifier, redirectUri } = await authCode.fetch(scopes); + + const { access_token: accessToken, refresh_token: refreshToken } = await client.grant({ + grant_type: 'authorization_code', + code, + code_verifier: verifier, + redirect_uri: redirectUri, + }); + + if (!accessToken) { + throw new Error('Failed to fetch access token'); + } + + return new OAuthToken(accessToken, refreshToken); + } +} diff --git a/lib/connection/auth/DatabricksOAuth/OAuthPersistence.ts b/lib/connection/auth/DatabricksOAuth/OAuthPersistence.ts new file mode 100644 index 00000000..c60a9f2f --- /dev/null +++ b/lib/connection/auth/DatabricksOAuth/OAuthPersistence.ts @@ -0,0 +1,7 @@ +import OAuthToken from './OAuthToken'; + +export default interface OAuthPersistence { + persist(host: string, token: OAuthToken): Promise; + + read(host: string): Promise; +} diff --git a/lib/connection/auth/DatabricksOAuth/OAuthToken.ts b/lib/connection/auth/DatabricksOAuth/OAuthToken.ts new file mode 100644 index 00000000..e48b0d05 --- /dev/null +++ b/lib/connection/auth/DatabricksOAuth/OAuthToken.ts @@ -0,0 +1,38 @@ +export default class OAuthToken { + private readonly _accessToken: string; + + private readonly _refreshToken?: string; + + private _expirationTime?: number; + + constructor(accessToken: string, refreshToken?: string) { + this._accessToken = accessToken; + this._refreshToken = refreshToken; + } + + get accessToken(): string { + return this._accessToken; + } + + get refreshToken(): string | undefined { + return this._refreshToken; + } + + get expirationTime(): number { + // This token has already been verified, and we are just parsing it. + // If it has been tampered with, it will be rejected on the server side. + // This avoids having to fetch the public key from the issuer and perform + // an unnecessary signature verification. + if (this._expirationTime === undefined) { + const accessTokenPayload = Buffer.from(this._accessToken.split('.')[1], 'base64').toString('utf8'); + const decoded = JSON.parse(accessTokenPayload); + this._expirationTime = Number(decoded.exp); + } + return this._expirationTime; + } + + get hasExpired(): boolean { + const now = Math.floor(Date.now() / 1000); // convert it to seconds + return this.expirationTime <= now; + } +} diff --git a/lib/connection/auth/DatabricksOAuth/index.ts b/lib/connection/auth/DatabricksOAuth/index.ts new file mode 100644 index 00000000..955bbfc6 --- /dev/null +++ b/lib/connection/auth/DatabricksOAuth/index.ts @@ -0,0 +1,72 @@ +import { HttpHeaders } from 'thrift'; +import IAuthentication from '../../contracts/IAuthentication'; +import HttpTransport from '../../transports/HttpTransport'; +import IDBSQLLogger from '../../../contracts/IDBSQLLogger'; +import OAuthPersistence from './OAuthPersistence'; +import OAuthManager from './OAuthManager'; + +interface DatabricksOAuthOptions { + host: string; + redirectPorts?: Array; + clientId?: string; + scopes?: Array; + logger?: IDBSQLLogger; + persistence?: OAuthPersistence; + headers?: HttpHeaders; +} + +const defaultOAuthOptions = { + clientId: 'databricks-sql-connector', + redirectPorts: [8030], + scopes: ['sql', 'offline_access'], +} satisfies Partial; + +export default class DatabricksOAuth implements IAuthentication { + private readonly host: string; + + private readonly redirectPorts: Array; + + private readonly clientId: string; + + private readonly scopes: Array; + + private readonly logger?: IDBSQLLogger; + + private readonly persistence?: OAuthPersistence; + + private readonly headers?: HttpHeaders; + + private readonly manager: OAuthManager; + + constructor(options: DatabricksOAuthOptions) { + this.host = options.host; + this.redirectPorts = options.redirectPorts || defaultOAuthOptions.redirectPorts; + this.clientId = options.clientId || defaultOAuthOptions.clientId; + this.scopes = options.scopes || defaultOAuthOptions.scopes; + this.logger = options.logger; + this.persistence = options.persistence; + this.headers = options.headers; + + this.manager = new OAuthManager({ + host: this.host, + callbackPorts: this.redirectPorts, + clientId: this.clientId, + logger: this.logger, + }); + } + + public async authenticate(transport: HttpTransport): Promise { + let token = await this.persistence?.read(this.host); + if (!token) { + token = await this.manager.getToken(this.scopes); + } + + token = await this.manager.refreshAccessToken(token); + await this.persistence?.persist(this.host, token); + + transport.updateHeaders({ + ...this.headers, + Authorization: `Bearer ${token.accessToken}`, + }); + } +} diff --git a/lib/contracts/IDBSQLClient.ts b/lib/contracts/IDBSQLClient.ts index b486130c..26e7062a 100644 --- a/lib/contracts/IDBSQLClient.ts +++ b/lib/contracts/IDBSQLClient.ts @@ -1,18 +1,33 @@ import IDBSQLLogger from './IDBSQLLogger'; import IDBSQLSession from './IDBSQLSession'; +import IAuthentication from '../connection/contracts/IAuthentication'; +import OAuthPersistence from '../connection/auth/DatabricksOAuth/OAuthPersistence'; export interface ClientOptions { logger?: IDBSQLLogger; } -export interface ConnectionOptions { +type AuthOptions = + | { + authType?: 'access-token'; + token: string; + } + | { + authType: 'databricks-oauth'; + persistence?: OAuthPersistence; + } + | { + authType: 'custom'; + provider: IAuthentication; + }; + +export type ConnectionOptions = { host: string; port?: number; path: string; - token: string; clientId?: string; socketTimeout?: number; -} +} & AuthOptions; export interface OpenSessionRequest { initialCatalog?: string; diff --git a/package-lock.json b/package-lock.json index 96efb05c..37c437f7 100644 --- a/package-lock.json +++ b/package-lock.json @@ -13,6 +13,8 @@ "apache-arrow": "^10.0.1", "commander": "^9.3.0", "node-int64": "^0.4.0", + "open": "^8.4.2", + "openid-client": "^5.4.2", "patch-package": "^7.0.0", "thrift": "^0.16.0", "uuid": "^9.0.0", @@ -1868,6 +1870,14 @@ "node": ">=8" } }, + "node_modules/define-lazy-prop": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/define-lazy-prop/-/define-lazy-prop-2.0.0.tgz", + "integrity": "sha512-Ds09qNh8yw3khSjiJjiUInaGX9xlqZDY7JVryGxdxV7NPeuqQfplOpQ66yJFZut3jLa5zOwkXw1g9EI2uKh4Og==", + "engines": { + "node": ">=8" + } + }, "node_modules/define-properties": { "version": "1.1.4", "resolved": "https://registry.npmjs.org/define-properties/-/define-properties-1.1.4.tgz", @@ -3534,6 +3544,14 @@ "node": ">=8" } }, + "node_modules/jose": { + "version": "4.14.4", + "resolved": "https://registry.npmjs.org/jose/-/jose-4.14.4.tgz", + "integrity": "sha512-j8GhLiKmUAh+dsFXlX1aJCbt5KMibuKb+d7j1JaOJG6s2UjX1PQlW+OKB/sD4a/5ZYF4RcmYmLSndOoU3Lt/3g==", + "funding": { + "url": "https://github.com/sponsors/panva" + } + }, "node_modules/js-tokens": { "version": "4.0.0", "resolved": "https://registry.npmjs.org/js-tokens/-/js-tokens-4.0.0.tgz", @@ -3758,7 +3776,6 @@ "version": "6.0.0", "resolved": "https://registry.npmjs.org/lru-cache/-/lru-cache-6.0.0.tgz", "integrity": "sha512-Jo6dJ04CmSjuznwJSS3pUeWmd/H0ffTlkXXgwZi+eq1UCmqQwCh+eLsYOYCwY991i2Fah4h1BEMCx4qThGbsiA==", - "dev": true, "dependencies": { "yallist": "^4.0.0" }, @@ -4160,6 +4177,14 @@ "node": ">=0.10.0" } }, + "node_modules/object-hash": { + "version": "2.2.0", + "resolved": "https://registry.npmjs.org/object-hash/-/object-hash-2.2.0.tgz", + "integrity": "sha512-gScRMn0bS5fH+IuwyIFgnh9zBdo4DV+6GhygmWM9HyNJSgS0hScp1f5vjtm7oIIOiT9trXrShAkLFSc2IqKNgw==", + "engines": { + "node": ">= 6" + } + }, "node_modules/object-inspect": { "version": "1.12.2", "resolved": "https://registry.npmjs.org/object-inspect/-/object-inspect-1.12.2.tgz", @@ -4257,6 +4282,14 @@ "url": "https://github.com/sponsors/ljharb" } }, + "node_modules/oidc-token-hash": { + "version": "5.0.3", + "resolved": "https://registry.npmjs.org/oidc-token-hash/-/oidc-token-hash-5.0.3.tgz", + "integrity": "sha512-IF4PcGgzAr6XXSff26Sk/+P4KZFJVuHAJZj3wgO3vX2bMdNVp/QXTP3P7CEm9V1IdG8lDLY3HhiqpsE/nOwpPw==", + "engines": { + "node": "^10.13.0 || >=12.0.0" + } + }, "node_modules/once": { "version": "1.4.0", "resolved": "https://registry.npmjs.org/once/-/once-1.4.0.tgz", @@ -4274,20 +4307,35 @@ } }, "node_modules/open": { - "version": "7.4.2", - "resolved": "https://registry.npmjs.org/open/-/open-7.4.2.tgz", - "integrity": "sha512-MVHddDVweXZF3awtlAS+6pgKLlm/JgxZ90+/NBurBoQctVOOB/zDdVjcyPzQ+0laDGbsWgrRkflI65sQeOgT9Q==", + "version": "8.4.2", + "resolved": "https://registry.npmjs.org/open/-/open-8.4.2.tgz", + "integrity": "sha512-7x81NCL719oNbsq/3mh+hVrAWmFuEYUqrq/Iw3kUzH8ReypT9QQ0BLoJS7/G9k6N81XjW4qHWtjWwe/9eLy1EQ==", "dependencies": { - "is-docker": "^2.0.0", - "is-wsl": "^2.1.1" + "define-lazy-prop": "^2.0.0", + "is-docker": "^2.1.1", + "is-wsl": "^2.2.0" }, "engines": { - "node": ">=8" + "node": ">=12" }, "funding": { "url": "https://github.com/sponsors/sindresorhus" } }, + "node_modules/openid-client": { + "version": "5.4.2", + "resolved": "https://registry.npmjs.org/openid-client/-/openid-client-5.4.2.tgz", + "integrity": "sha512-lIhsdPvJ2RneBm3nGBBhQchpe3Uka//xf7WPHTIglery8gnckvW7Bd9IaQzekzXJvWthCMyi/xVEyGW0RFPytw==", + "dependencies": { + "jose": "^4.14.1", + "lru-cache": "^6.0.0", + "object-hash": "^2.2.0", + "oidc-token-hash": "^5.0.3" + }, + "funding": { + "url": "https://github.com/sponsors/panva" + } + }, "node_modules/optionator": { "version": "0.9.1", "resolved": "https://registry.npmjs.org/optionator/-/optionator-0.9.1.tgz", @@ -4430,6 +4478,21 @@ "npm": ">5" } }, + "node_modules/patch-package/node_modules/open": { + "version": "7.4.2", + "resolved": "https://registry.npmjs.org/open/-/open-7.4.2.tgz", + "integrity": "sha512-MVHddDVweXZF3awtlAS+6pgKLlm/JgxZ90+/NBurBoQctVOOB/zDdVjcyPzQ+0laDGbsWgrRkflI65sQeOgT9Q==", + "dependencies": { + "is-docker": "^2.0.0", + "is-wsl": "^2.1.1" + }, + "engines": { + "node": ">=8" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, "node_modules/patch-package/node_modules/rimraf": { "version": "2.7.1", "resolved": "https://registry.npmjs.org/rimraf/-/rimraf-2.7.1.tgz", @@ -5640,8 +5703,7 @@ "node_modules/yallist": { "version": "4.0.0", "resolved": "https://registry.npmjs.org/yallist/-/yallist-4.0.0.tgz", - "integrity": "sha512-3wdGidZyq5PB084XLES5TpOSRA3wjXAlIWMhum2kRcv/41Sn2emQ0dycQW4uZXLejwKvg6EsvbdlVL+FYEct7A==", - "dev": true + "integrity": "sha512-3wdGidZyq5PB084XLES5TpOSRA3wjXAlIWMhum2kRcv/41Sn2emQ0dycQW4uZXLejwKvg6EsvbdlVL+FYEct7A==" }, "node_modules/yaml": { "version": "2.2.2", @@ -7104,6 +7166,11 @@ "strip-bom": "^4.0.0" } }, + "define-lazy-prop": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/define-lazy-prop/-/define-lazy-prop-2.0.0.tgz", + "integrity": "sha512-Ds09qNh8yw3khSjiJjiUInaGX9xlqZDY7JVryGxdxV7NPeuqQfplOpQ66yJFZut3jLa5zOwkXw1g9EI2uKh4Og==" + }, "define-properties": { "version": "1.1.4", "resolved": "https://registry.npmjs.org/define-properties/-/define-properties-1.1.4.tgz", @@ -8323,6 +8390,11 @@ "istanbul-lib-report": "^3.0.0" } }, + "jose": { + "version": "4.14.4", + "resolved": "https://registry.npmjs.org/jose/-/jose-4.14.4.tgz", + "integrity": "sha512-j8GhLiKmUAh+dsFXlX1aJCbt5KMibuKb+d7j1JaOJG6s2UjX1PQlW+OKB/sD4a/5ZYF4RcmYmLSndOoU3Lt/3g==" + }, "js-tokens": { "version": "4.0.0", "resolved": "https://registry.npmjs.org/js-tokens/-/js-tokens-4.0.0.tgz", @@ -8506,7 +8578,6 @@ "version": "6.0.0", "resolved": "https://registry.npmjs.org/lru-cache/-/lru-cache-6.0.0.tgz", "integrity": "sha512-Jo6dJ04CmSjuznwJSS3pUeWmd/H0ffTlkXXgwZi+eq1UCmqQwCh+eLsYOYCwY991i2Fah4h1BEMCx4qThGbsiA==", - "dev": true, "requires": { "yallist": "^4.0.0" } @@ -8825,6 +8896,11 @@ "integrity": "sha512-rJgTQnkUnH1sFw8yT6VSU3zD3sWmu6sZhIseY8VX+GRu3P6F7Fu+JNDoXfklElbLJSnc3FUQHVe4cU5hj+BcUg==", "dev": true }, + "object-hash": { + "version": "2.2.0", + "resolved": "https://registry.npmjs.org/object-hash/-/object-hash-2.2.0.tgz", + "integrity": "sha512-gScRMn0bS5fH+IuwyIFgnh9zBdo4DV+6GhygmWM9HyNJSgS0hScp1f5vjtm7oIIOiT9trXrShAkLFSc2IqKNgw==" + }, "object-inspect": { "version": "1.12.2", "resolved": "https://registry.npmjs.org/object-inspect/-/object-inspect-1.12.2.tgz", @@ -8892,6 +8968,11 @@ "es-abstract": "^1.19.1" } }, + "oidc-token-hash": { + "version": "5.0.3", + "resolved": "https://registry.npmjs.org/oidc-token-hash/-/oidc-token-hash-5.0.3.tgz", + "integrity": "sha512-IF4PcGgzAr6XXSff26Sk/+P4KZFJVuHAJZj3wgO3vX2bMdNVp/QXTP3P7CEm9V1IdG8lDLY3HhiqpsE/nOwpPw==" + }, "once": { "version": "1.4.0", "resolved": "https://registry.npmjs.org/once/-/once-1.4.0.tgz", @@ -8909,12 +8990,24 @@ } }, "open": { - "version": "7.4.2", - "resolved": "https://registry.npmjs.org/open/-/open-7.4.2.tgz", - "integrity": "sha512-MVHddDVweXZF3awtlAS+6pgKLlm/JgxZ90+/NBurBoQctVOOB/zDdVjcyPzQ+0laDGbsWgrRkflI65sQeOgT9Q==", + "version": "8.4.2", + "resolved": "https://registry.npmjs.org/open/-/open-8.4.2.tgz", + "integrity": "sha512-7x81NCL719oNbsq/3mh+hVrAWmFuEYUqrq/Iw3kUzH8ReypT9QQ0BLoJS7/G9k6N81XjW4qHWtjWwe/9eLy1EQ==", "requires": { - "is-docker": "^2.0.0", - "is-wsl": "^2.1.1" + "define-lazy-prop": "^2.0.0", + "is-docker": "^2.1.1", + "is-wsl": "^2.2.0" + } + }, + "openid-client": { + "version": "5.4.2", + "resolved": "https://registry.npmjs.org/openid-client/-/openid-client-5.4.2.tgz", + "integrity": "sha512-lIhsdPvJ2RneBm3nGBBhQchpe3Uka//xf7WPHTIglery8gnckvW7Bd9IaQzekzXJvWthCMyi/xVEyGW0RFPytw==", + "requires": { + "jose": "^4.14.1", + "lru-cache": "^6.0.0", + "object-hash": "^2.2.0", + "oidc-token-hash": "^5.0.3" } }, "optionator": { @@ -9019,6 +9112,15 @@ "yaml": "^2.2.2" }, "dependencies": { + "open": { + "version": "7.4.2", + "resolved": "https://registry.npmjs.org/open/-/open-7.4.2.tgz", + "integrity": "sha512-MVHddDVweXZF3awtlAS+6pgKLlm/JgxZ90+/NBurBoQctVOOB/zDdVjcyPzQ+0laDGbsWgrRkflI65sQeOgT9Q==", + "requires": { + "is-docker": "^2.0.0", + "is-wsl": "^2.1.1" + } + }, "rimraf": { "version": "2.7.1", "resolved": "https://registry.npmjs.org/rimraf/-/rimraf-2.7.1.tgz", @@ -9899,8 +10001,7 @@ "yallist": { "version": "4.0.0", "resolved": "https://registry.npmjs.org/yallist/-/yallist-4.0.0.tgz", - "integrity": "sha512-3wdGidZyq5PB084XLES5TpOSRA3wjXAlIWMhum2kRcv/41Sn2emQ0dycQW4uZXLejwKvg6EsvbdlVL+FYEct7A==", - "dev": true + "integrity": "sha512-3wdGidZyq5PB084XLES5TpOSRA3wjXAlIWMhum2kRcv/41Sn2emQ0dycQW4uZXLejwKvg6EsvbdlVL+FYEct7A==" }, "yaml": { "version": "2.2.2", diff --git a/package.json b/package.json index 862c872f..92ac273f 100644 --- a/package.json +++ b/package.json @@ -74,6 +74,8 @@ "apache-arrow": "^10.0.1", "commander": "^9.3.0", "node-int64": "^0.4.0", + "open": "^8.4.2", + "openid-client": "^5.4.2", "patch-package": "^7.0.0", "thrift": "^0.16.0", "uuid": "^9.0.0", diff --git a/tests/unit/DBSQLClient.test.js b/tests/unit/DBSQLClient.test.js index 6722405e..5ffedd47 100644 --- a/tests/unit/DBSQLClient.test.js +++ b/tests/unit/DBSQLClient.test.js @@ -2,10 +2,10 @@ const { expect } = require('chai'); const sinon = require('sinon'); const DBSQLClient = require('../../dist/DBSQLClient').default; const DBSQLSession = require('../../dist/DBSQLSession').default; -const { - auth: { PlainHttpAuthentication }, - connections: { HttpConnection }, -} = require('../../'); + +const PlainHttpAuthentication = require('../../dist/connection/auth/PlainHttpAuthentication').default; +const DatabricksOAuth = require('../../dist/connection/auth/DatabricksOAuth').default; +const HttpConnection = require('../../dist/connection/connections/HttpConnection').default; const ConnectionProviderMock = (connection) => ({ connect(options, auth) { @@ -227,3 +227,68 @@ describe('DBSQLClient.close', () => { // No additional asserts needed - it should just reach this point }); }); + +describe('DBSQLClient.getAuthProvider', () => { + it('should use access token auth method', () => { + const client = new DBSQLClient(); + + const testAccessToken = 'token'; + const provider = client.getAuthProvider({ + authType: 'access-token', + token: testAccessToken, + }); + + expect(provider).to.be.instanceOf(PlainHttpAuthentication); + expect(provider.password).to.be.equal(testAccessToken); + }); + + it('should use access token auth method by default (compatibility)', () => { + const client = new DBSQLClient(); + + const testAccessToken = 'token'; + const provider = client.getAuthProvider({ + // note: no `authType` provided + token: testAccessToken, + }); + + expect(provider).to.be.instanceOf(PlainHttpAuthentication); + expect(provider.password).to.be.equal(testAccessToken); + }); + + it('should use Databricks OAuth method', () => { + const client = new DBSQLClient(); + + const provider = client.getAuthProvider({ + authType: 'databricks-oauth', + }); + + expect(provider).to.be.instanceOf(DatabricksOAuth); + }); + + it('should use custom auth method', () => { + const client = new DBSQLClient(); + + const customProvider = {}; + + const provider = client.getAuthProvider({ + authType: 'custom', + provider: customProvider, + }); + + expect(provider).to.be.equal(customProvider); + }); + + it('should use custom auth method (legacy way)', () => { + const client = new DBSQLClient(); + + const customProvider = {}; + + const provider = client.getAuthProvider( + // custom provider from second arg should be used no matter what's specified in config + { authType: 'access-token', token: 'token' }, + customProvider, + ); + + expect(provider).to.be.equal(customProvider); + }); +}); diff --git a/tests/unit/connection/auth/DatabricksOAuth/AuthorizationCode.test.js b/tests/unit/connection/auth/DatabricksOAuth/AuthorizationCode.test.js new file mode 100644 index 00000000..bc83e559 --- /dev/null +++ b/tests/unit/connection/auth/DatabricksOAuth/AuthorizationCode.test.js @@ -0,0 +1,276 @@ +const { expect, AssertionError } = require('chai'); +const { EventEmitter } = require('events'); +const sinon = require('sinon'); +const http = require('http'); +const AuthorizationCode = require('../../../../../dist/connection/auth/DatabricksOAuth/AuthorizationCode').default; + +class HttpServerMock extends EventEmitter { + constructor() { + super(); + this.requestHandler = () => {}; + this.listening = false; + this.listenError = undefined; // error to emit on listen + this.closeError = undefined; // error to emit on close + } + + listen(port, host, callback) { + if (this.listenError) { + this.emit('error', this.listenError); + this.listenError = undefined; + } else if (port < 1000) { + const error = new Error(`Address ${host}:${port} is already in use`); + error.code = 'EADDRINUSE'; + this.emit('error', error); + } else { + this.listening = true; + callback(); + } + } + + close(callback) { + this.requestHandler = () => {}; + this.listening = false; + if (this.closeError) { + this.emit('error', this.closeError); + this.closeError = undefined; + } else { + callback(); + } + } +} + +class OAuthClientMock { + constructor() { + this.code = 'test_authorization_code'; + this.redirectUri = undefined; + } + + authorizationUrl(params) { + this.redirectUri = params.redirect_uri; + return JSON.stringify({ + state: params.state, + code: this.code, + }); + } + + callbackParams(req) { + return req.params; + } +} + +function prepareTestInstances(options) { + const httpServer = new HttpServerMock(); + + const oauthClient = new OAuthClientMock(); + + const authCode = new AuthorizationCode({ + client: oauthClient, + ...options, + }); + + sinon.stub(http, 'createServer').callsFake((requestHandler) => { + httpServer.requestHandler = requestHandler; + return httpServer; + }); + + sinon.stub(authCode, 'openUrl').callsFake((url) => { + const params = JSON.parse(url); + httpServer.requestHandler( + { params }, + { + writeHead: () => {}, + end: () => {}, + }, + ); + }); + + function reloadUrl() { + setTimeout(() => { + const args = authCode.openUrl.firstCall.args; + authCode.openUrl(...args); + }, 10); + } + + return { httpServer, oauthClient, authCode, reloadUrl }; +} + +describe('AuthorizationCode', () => { + afterEach(() => { + http.createServer.restore?.(); + }); + + it('should fetch authorization code', async () => { + const { authCode, oauthClient } = prepareTestInstances({ + ports: [80, 8000], + logger: { log: () => {} }, + }); + + const result = await authCode.fetch([]); + expect(http.createServer.callCount).to.be.equal(2); + expect(authCode.openUrl.callCount).to.be.equal(1); + + expect(result.code).to.be.equal(oauthClient.code); + expect(result.verifier).to.not.be.empty; + expect(result.redirectUri).to.be.equal(oauthClient.redirectUri); + }); + + it('should throw error if cannot start server on any port', async () => { + const { authCode } = prepareTestInstances({ + ports: [80, 443], + }); + + try { + await authCode.fetch([]); + expect.fail('It should throw an error'); + } catch (error) { + if (error instanceof AssertionError) { + throw error; + } + expect(http.createServer.callCount).to.be.equal(2); + expect(authCode.openUrl.callCount).to.be.equal(0); + + expect(error.message).to.contain('all ports are in use'); + } + }); + + it('should re-throw unhandled server start errors', async () => { + const { authCode, httpServer } = prepareTestInstances({ + ports: [80], + }); + + const testError = new Error('Test'); + httpServer.listenError = testError; + + try { + await authCode.fetch([]); + expect.fail('It should throw an error'); + } catch (error) { + if (error instanceof AssertionError) { + throw error; + } + expect(http.createServer.callCount).to.be.equal(1); + expect(authCode.openUrl.callCount).to.be.equal(0); + + expect(error).to.be.equal(testError); + } + }); + + it('should re-throw unhandled server stop errors', async () => { + const { authCode, httpServer } = prepareTestInstances({ + ports: [8000], + }); + + const testError = new Error('Test'); + httpServer.closeError = testError; + + try { + await authCode.fetch([]); + expect.fail('It should throw an error'); + } catch (error) { + if (error instanceof AssertionError) { + throw error; + } + expect(http.createServer.callCount).to.be.equal(1); + expect(authCode.openUrl.callCount).to.be.equal(1); + + expect(error).to.be.equal(testError); + } + }); + + it('should throw an error if no code was returned', async () => { + const { authCode, oauthClient } = prepareTestInstances({ + ports: [8000], + }); + + sinon.stub(oauthClient, 'callbackParams').callsFake((req) => { + // Omit authorization code from params + const { code, ...otherParams } = req.params; + return otherParams; + }); + + try { + await authCode.fetch([]); + expect.fail('It should throw an error'); + } catch (error) { + if (error instanceof AssertionError) { + throw error; + } + expect(http.createServer.callCount).to.be.equal(1); + expect(authCode.openUrl.callCount).to.be.equal(1); + + expect(error.message).to.contain('No path parameters were returned to the callback'); + } + }); + + it('should use error details from callback params', async () => { + const { authCode, oauthClient } = prepareTestInstances({ + ports: [8000], + }); + + sinon.stub(oauthClient, 'callbackParams').callsFake((req) => { + // Omit authorization code from params + const { code, ...otherParams } = req.params; + return { + ...otherParams, + error: 'test_error', + error_description: 'Test error', + }; + }); + + try { + await authCode.fetch([]); + expect.fail('It should throw an error'); + } catch (error) { + if (error instanceof AssertionError) { + throw error; + } + expect(http.createServer.callCount).to.be.equal(1); + expect(authCode.openUrl.callCount).to.be.equal(1); + + expect(error.message).to.contain('Test error'); + } + }); + + it('should serve 404 for unrecognized requests', async () => { + const { authCode, oauthClient, reloadUrl } = prepareTestInstances({ + ports: [8000], + }); + + sinon + .stub(oauthClient, 'callbackParams') + .onFirstCall() + .callsFake(() => { + // Repeat the same request after currently processed one. + // We won't modify response on subsequent requests so OAuth routine can complete + reloadUrl(); + // Return no params so request cannot be recognized + return {}; + }) + .callThrough(); + + await authCode.fetch([]); + + expect(http.createServer.callCount).to.be.equal(1); + expect(authCode.openUrl.callCount).to.be.equal(2); + }); + + it('should not attempt to stop server if not running', async () => { + const { authCode, oauthClient, httpServer } = prepareTestInstances({ + ports: [8000], + logger: { log: () => {} }, + }); + + const promise = authCode.fetch([]); + + httpServer.listening = false; + httpServer.closeError = new Error('Test'); + + const result = await promise; + // We set up server to throw an error on close. If nothing happened - it means + // that `authCode` never tried to stop it + expect(result.code).to.be.equal(oauthClient.code); + + expect(http.createServer.callCount).to.be.equal(1); + expect(authCode.openUrl.callCount).to.be.equal(1); + }); +}); diff --git a/tests/unit/connection/auth/DatabricksOAuth/OAuthManager.test.js b/tests/unit/connection/auth/DatabricksOAuth/OAuthManager.test.js new file mode 100644 index 00000000..d7209b83 --- /dev/null +++ b/tests/unit/connection/auth/DatabricksOAuth/OAuthManager.test.js @@ -0,0 +1,246 @@ +const { expect, AssertionError } = require('chai'); +const sinon = require('sinon'); +const { Issuer } = require('openid-client'); +const OAuthManager = require('../../../../../dist/connection/auth/DatabricksOAuth/OAuthManager').default; +const OAuthToken = require('../../../../../dist/connection/auth/DatabricksOAuth/OAuthToken').default; +const AuthorizationCodeModule = require('../../../../../dist/connection/auth/DatabricksOAuth/AuthorizationCode'); + +const { createValidAccessToken, createExpiredAccessToken } = require('./utils'); + +class AuthorizationCodeMock { + constructor() { + this.fetchResult = undefined; + } + + async fetch() { + return this.fetchResult; + } +} + +AuthorizationCodeMock.validCode = { + code: 'auth_code', + verifier: 'verifier_string', + redirectUri: 'http://localhost:8000', +}; + +class OAuthClientMock { + constructor() { + this.grantError = undefined; + this.refreshError = undefined; + + this.accessToken = undefined; + this.refreshToken = undefined; + this.recreateTokens(); + } + + recreateTokens() { + const suffix = Math.random().toString(36).substring(2); + this.accessToken = `${createValidAccessToken()}.${suffix}`; + this.refreshToken = `refresh.${suffix}`; + } + + async grant(params) { + if (this.grantError) { + const error = this.grantError; + this.grantError = undefined; + throw error; + } + + expect(params.grant_type).to.be.equal('authorization_code'); + expect(params.code).to.be.equal(AuthorizationCodeMock.validCode.code); + expect(params.code_verifier).to.be.equal(AuthorizationCodeMock.validCode.verifier); + expect(params.redirect_uri).to.be.equal(AuthorizationCodeMock.validCode.redirectUri); + + return { + access_token: this.accessToken, + refresh_token: this.refreshToken, + }; + } + + async refresh(refreshToken) { + if (this.refreshError) { + const error = this.refreshError; + this.refreshError = undefined; + throw error; + } + + expect(refreshToken).to.be.equal(this.refreshToken); + + this.recreateTokens(); + return { + access_token: this.accessToken, + refresh_token: this.refreshToken, + }; + } +} + +function prepareTestInstances(options) { + const oauthClient = new OAuthClientMock(); + sinon.stub(oauthClient, 'grant').callThrough(); + sinon.stub(oauthClient, 'refresh').callThrough(); + + sinon.stub(Issuer, 'discover').returns( + Promise.resolve({ + Client: function () { + return oauthClient; + }, + }), + ); + + const oauthManager = new OAuthManager({ + host: 'https://example.com', + ...options, + }); + + const authCode = new AuthorizationCodeMock(); + authCode.fetchResult = { ...AuthorizationCodeMock.validCode }; + + sinon.stub(AuthorizationCodeModule, 'default').returns(authCode); + + return { oauthClient, oauthManager, authCode }; +} + +describe('OAuthManager', () => { + afterEach(() => { + AuthorizationCodeModule.default.restore?.(); + Issuer.discover.restore?.(); + }); + + it('should get access token', async () => { + const { oauthManager, oauthClient } = prepareTestInstances({ + logger: { log: () => {} }, + }); + + const token = await oauthManager.getToken([]); + expect(oauthClient.grant.called).to.be.true; + expect(token).to.be.instanceOf(OAuthToken); + expect(token.accessToken).to.be.equal(oauthClient.accessToken); + expect(token.refreshToken).to.be.equal(oauthClient.refreshToken); + }); + + it('should throw an error if cannot get access token', async () => { + const { oauthManager, oauthClient } = prepareTestInstances(); + + // Make it return empty tokens + oauthClient.accessToken = undefined; + oauthClient.refreshToken = undefined; + + try { + await oauthManager.getToken([]); + expect.fail('It should throw an error'); + } catch (error) { + if (error instanceof AssertionError) { + throw error; + } + expect(oauthClient.grant.called).to.be.true; + expect(error.message).to.contain('Failed to fetch access token'); + } + }); + + it('should re-throw unhandled errors when getting access token', async () => { + const { oauthManager, oauthClient } = prepareTestInstances(); + + const testError = new Error('Test'); + oauthClient.grantError = testError; + + try { + await oauthManager.getToken([]); + expect.fail('It should throw an error'); + } catch (error) { + if (error instanceof AssertionError) { + throw error; + } + expect(oauthClient.grant.called).to.be.true; + expect(error).to.be.equal(testError); + } + }); + + it('should not refresh valid token', async () => { + const { oauthManager, oauthClient } = prepareTestInstances(); + + const token = new OAuthToken(createValidAccessToken(), oauthClient.refreshToken); + expect(token.hasExpired).to.be.false; + + const newToken = await oauthManager.refreshAccessToken(token); + expect(oauthClient.refresh.called).to.be.false; + expect(newToken).to.be.instanceOf(OAuthToken); + expect(newToken.accessToken).to.be.equal(token.accessToken); + expect(newToken.hasExpired).to.be.false; + }); + + it('should throw an error if no refresh token is available', async () => { + const { oauthManager, oauthClient } = prepareTestInstances({ + logger: { log: () => {} }, + }); + + try { + const token = new OAuthToken(createExpiredAccessToken()); + expect(token.hasExpired).to.be.true; + + await oauthManager.refreshAccessToken(token); + expect.fail('It should throw an error'); + } catch (error) { + if (error instanceof AssertionError) { + throw error; + } + expect(oauthClient.refresh.called).to.be.false; + expect(error.message).to.contain('token expired'); + } + }); + + it('should throw an error on invalid response', async () => { + const { oauthManager, oauthClient } = prepareTestInstances({ + logger: { log: () => {} }, + }); + + oauthClient.refresh.restore(); + sinon.stub(oauthClient, 'refresh').returns({}); + + try { + const token = new OAuthToken(createExpiredAccessToken(), oauthClient.refreshToken); + expect(token.hasExpired).to.be.true; + + await oauthManager.refreshAccessToken(token); + expect.fail('It should throw an error'); + } catch (error) { + if (error instanceof AssertionError) { + throw error; + } + expect(oauthClient.refresh.called).to.be.true; + expect(error.message).to.contain('invalid response'); + } + }); + + it('should throw an error for invalid token', async () => { + const { oauthManager, oauthClient } = prepareTestInstances({ + logger: { log: () => {} }, + }); + + try { + const token = new OAuthToken('invalid_access_token', 'invalid_refresh_token'); + await oauthManager.refreshAccessToken(token); + expect.fail('It should throw an error'); + } catch (error) { + if (error instanceof AssertionError) { + throw error; + } + expect(oauthClient.refresh.called).to.be.false; + // Random malformed string passed as access token will cause JSON parse errors + expect(error).to.be.instanceof(TypeError); + } + }); + + it('should refresh expired token', async () => { + const { oauthManager, oauthClient } = prepareTestInstances(); + + oauthClient.accessToken = createExpiredAccessToken(); + const token = await oauthManager.getToken([]); + expect(token.hasExpired).to.be.true; + + const newToken = await oauthManager.refreshAccessToken(token); + expect(oauthClient.refresh.called).to.be.true; + expect(newToken).to.be.instanceOf(OAuthToken); + expect(newToken.accessToken).to.be.not.equal(token.accessToken); + expect(newToken.hasExpired).to.be.false; + }); +}); diff --git a/tests/unit/connection/auth/DatabricksOAuth/OAuthToken.test.js b/tests/unit/connection/auth/DatabricksOAuth/OAuthToken.test.js new file mode 100644 index 00000000..1ab4f486 --- /dev/null +++ b/tests/unit/connection/auth/DatabricksOAuth/OAuthToken.test.js @@ -0,0 +1,64 @@ +const { expect } = require('chai'); +const OAuthToken = require('../../../../../dist/connection/auth/DatabricksOAuth/OAuthToken').default; + +const { createAccessToken } = require('./utils'); + +describe('OAuthToken', () => { + it('should be properly initialized', () => { + const accessToken = 'access'; + const refreshToken = 'refresh'; + + const token1 = new OAuthToken(accessToken); + expect(token1.accessToken).to.be.equal(accessToken); + + const token2 = new OAuthToken(accessToken, refreshToken); + expect(token2.accessToken).to.be.equal(accessToken); + expect(token2.refreshToken).to.be.equal(refreshToken); + }); + + it('should return valid expiration time', () => { + const expirationTime = Math.trunc(Date.now() / 1000); + const accessToken = createAccessToken(expirationTime); + + const token = new OAuthToken(accessToken); + expect(token.expirationTime).to.be.equal(expirationTime); + // second attempt - to make sure it returns the same value + expect(token.expirationTime).to.be.equal(expirationTime); + }); + + it('should throw error if cannot get expiration time', () => { + expect(() => { + const token = new OAuthToken('without_payload'); + expect(token.expirationTime).to.be.equal(undefined); + }).to.throw(); + + expect(() => { + const token = new OAuthToken('invalid.payload'); + expect(token.expirationTime).to.be.equal(undefined); + }).to.throw(); + + expect(() => { + const payload = Buffer.from('qwerty', 'utf8').toString('base64'); + const token = new OAuthToken(`malformed.${payload}`); + expect(token.expirationTime).to.be.equal(undefined); + }).to.throw(); + }); + + it('should test for expired token', () => { + const expirationTime = Math.trunc(Date.now() / 1000) - 1; + const accessToken = createAccessToken(expirationTime); + + const token = new OAuthToken(accessToken); + expect(token.expirationTime).to.be.equal(expirationTime); + expect(token.hasExpired).to.be.true; + }); + + it('should test for valid token', () => { + const expirationTime = Math.trunc(Date.now() / 1000) + 1; + const accessToken = createAccessToken(expirationTime); + + const token = new OAuthToken(accessToken); + expect(token.expirationTime).to.be.equal(expirationTime); + expect(token.hasExpired).to.be.false; + }); +}); diff --git a/tests/unit/connection/auth/DatabricksOAuth/index.test.js b/tests/unit/connection/auth/DatabricksOAuth/index.test.js new file mode 100644 index 00000000..bdeccc9b --- /dev/null +++ b/tests/unit/connection/auth/DatabricksOAuth/index.test.js @@ -0,0 +1,135 @@ +const { expect, AssertionError } = require('chai'); +const sinon = require('sinon'); +const DatabricksOAuth = require('../../../../../dist/connection/auth/DatabricksOAuth/index').default; +const OAuthToken = require('../../../../../dist/connection/auth/DatabricksOAuth/OAuthToken').default; +const OAuthManagerModule = require('../../../../../dist/connection/auth/DatabricksOAuth/OAuthManager'); + +const { createValidAccessToken, createExpiredAccessToken } = require('./utils'); + +class OAuthManagerMock { + constructor() { + this.getTokenResult = new OAuthToken(createValidAccessToken()); + this.refreshTokenResult = new OAuthToken(createValidAccessToken()); + } + + async refreshAccessToken(token) { + return token.hasExpired ? this.refreshTokenResult : token; + } + + async getToken() { + return this.getTokenResult; + } +} + +class TransportMock { + constructor() { + this.headers = {}; + } + + updateHeaders(newHeaders) { + this.headers = { + ...this.headers, + ...newHeaders, + }; + } +} + +class OAuthPersistenceMock { + constructor() { + this.token = undefined; + + sinon.stub(this, 'persist').callThrough(); + sinon.stub(this, 'read').callThrough(); + } + + async persist(host, token) { + this.token = token; + } + + async read() { + return this.token; + } +} + +function prepareTestInstances(options) { + const oauthManager = new OAuthManagerMock(); + + sinon.stub(oauthManager, 'refreshAccessToken').callThrough(); + sinon.stub(oauthManager, 'getToken').callThrough(); + + sinon.stub(OAuthManagerModule, 'default').returns(oauthManager); + + const provider = new DatabricksOAuth({ ...options }); + + const transport = new TransportMock(); + sinon.stub(transport, 'updateHeaders').callThrough(); + + return { oauthManager, provider, transport }; +} + +describe('DatabricksOAuth', () => { + afterEach(() => { + OAuthManagerModule.default.restore?.(); + }); + + it('should get persisted token if available', async () => { + const persistence = new OAuthPersistenceMock(); + persistence.token = new OAuthToken(createValidAccessToken()); + + const { provider, transport } = prepareTestInstances({ persistence }); + + await provider.authenticate(transport); + expect(persistence.read.called).to.be.true; + }); + + it('should get new token if storage not available', async () => { + const { oauthManager, provider, transport } = prepareTestInstances(); + + await provider.authenticate(transport); + expect(oauthManager.getToken.called).to.be.true; + }); + + it('should get new token if persisted token not available, and store valid token', async () => { + const persistence = new OAuthPersistenceMock(); + persistence.token = undefined; + const { oauthManager, provider, transport } = prepareTestInstances({ persistence }); + + await provider.authenticate(transport); + expect(oauthManager.getToken.called).to.be.true; + expect(persistence.persist.called).to.be.true; + expect(persistence.token).to.be.equal(oauthManager.getTokenResult); + }); + + it('should refresh expired token and store new token', async () => { + const persistence = new OAuthPersistenceMock(); + persistence.token = undefined; + + const { oauthManager, provider, transport } = prepareTestInstances({ persistence }); + oauthManager.getTokenResult = new OAuthToken(createExpiredAccessToken()); + oauthManager.refreshTokenResult = new OAuthToken(createValidAccessToken()); + + await provider.authenticate(transport); + expect(oauthManager.getToken.called).to.be.true; + expect(oauthManager.refreshAccessToken.called).to.be.true; + expect(oauthManager.refreshAccessToken.firstCall.firstArg).to.be.equal(oauthManager.getTokenResult); + expect(persistence.token).to.be.equal(oauthManager.refreshTokenResult); + expect(persistence.persist.called).to.be.true; + expect(persistence.token).to.be.equal(oauthManager.refreshTokenResult); + }); + + it('should configure transport using valid token', async () => { + const { oauthManager, provider, transport } = prepareTestInstances(); + + const initialHeaders = { + x: 'x', + y: 'y', + }; + + transport.headers = initialHeaders; + + await provider.authenticate(transport); + expect(oauthManager.getToken.called).to.be.true; + expect(transport.updateHeaders.called).to.be.true; + expect(Object.keys(transport.headers)).to.deep.equal([...Object.keys(initialHeaders), 'Authorization']); + }); +}); diff --git a/tests/unit/connection/auth/DatabricksOAuth/utils.js b/tests/unit/connection/auth/DatabricksOAuth/utils.js new file mode 100644 index 00000000..edffe7b1 --- /dev/null +++ b/tests/unit/connection/auth/DatabricksOAuth/utils.js @@ -0,0 +1,20 @@ +function createAccessToken(expirationTime) { + const payload = Buffer.from(JSON.stringify({ exp: expirationTime }), 'utf8').toString('base64'); + return `access.${payload}`; +} + +function createValidAccessToken() { + const expirationTime = Math.trunc(Date.now() / 1000) + 20000; + return createAccessToken(expirationTime); +} + +function createExpiredAccessToken() { + const expirationTime = Math.trunc(Date.now() / 1000) - 1000; + return createAccessToken(expirationTime); +} + +module.exports = { + createAccessToken, + createValidAccessToken, + createExpiredAccessToken, +};