Skip to content

[PECO-728] Add OAuth support #147

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jun 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 42 additions & 8 deletions lib/DBSQLClient.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import Status from './dto/Status';
import HiveDriverError from './errors/HiveDriverError';
import { buildUserAgentString, definedOrError } from './utils';
import PlainHttpAuthentication from './connection/auth/PlainHttpAuthentication';
import DatabricksOAuth from './connection/auth/DatabricksOAuth';
import IDBSQLLogger, { LogLevel } from './contracts/IDBSQLLogger';
import DBSQLLogger from './DBSQLLogger';

Expand Down Expand Up @@ -61,7 +62,21 @@ export default class DBSQLClient extends EventEmitter implements IDBSQLClient {
}

private getConnectionOptions(options: ConnectionOptions): IConnectionOptions {
const { host, port, path, token, clientId, ...otherOptions } = options;
const {
host,
port,
path,
clientId,
authType,
// @ts-expect-error TS2339: Property 'token' does not exist on type 'ConnectionOptions'
token,
// @ts-expect-error TS2339: Property 'persistence' does not exist on type 'ConnectionOptions'
persistence,
// @ts-expect-error TS2339: Property 'provider' does not exist on type 'ConnectionOptions'
provider,
...otherOptions
} = options;

return {
host,
port: port || 443,
Expand All @@ -76,22 +91,41 @@ export default class DBSQLClient extends EventEmitter implements IDBSQLClient {
};
}

private getAuthProvider(options: ConnectionOptions, authProvider?: IAuthentication): IAuthentication {
if (authProvider) {
return authProvider;
}

switch (options.authType) {
case undefined:
case 'access-token':
return new PlainHttpAuthentication({
username: 'token',
password: options.token,
});
case 'databricks-oauth':
return new DatabricksOAuth({
host: options.host,
logger: this.logger,
persistence: options.persistence,
});
case 'custom':
return options.provider;
// no default
}
}

/**
* Connects DBSQLClient to endpoint
* @public
* @param options - host, path, and token are required
* @param authProvider - Optional custom authentication provider
* @param authProvider - [DEPRECATED - use `authType: 'custom'] Optional custom authentication provider
* @returns Session object that can be used to execute statements
* @example
* const session = client.connect({host, path, token});
*/
public async connect(options: ConnectionOptions, authProvider?: IAuthentication): Promise<IDBSQLClient> {
authProvider =
authProvider ||
new PlainHttpAuthentication({
username: 'token',
password: options.token,
});
authProvider = this.getAuthProvider(options, authProvider);

this.connection = await this.connectionProvider.connect(this.getConnectionOptions(options), authProvider);

Expand Down
180 changes: 180 additions & 0 deletions lib/connection/auth/DatabricksOAuth/AuthorizationCode.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
import http, { IncomingMessage, Server, ServerResponse } from 'http';
import { BaseClient, CallbackParamsType, generators } from 'openid-client';
import open from 'open';
import IDBSQLLogger, { LogLevel } from '../../../contracts/IDBSQLLogger';

export interface AuthorizationCodeOptions {
client: BaseClient;
ports: Array<number>;
logger?: IDBSQLLogger;
}

const scopeDelimiter = ' ';

async function startServer(
host: string,
port: number,
requestHandler: (req: IncomingMessage, res: ServerResponse) => void,
): Promise<Server> {
const server = http.createServer(requestHandler);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could this ever create issues of trying to send an https request from an http server?

Copy link
Contributor Author

@kravets-levko kravets-levko Jun 28, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OAuth app we use is configured to allow http only. But it's not an issue, because we receive only authorization token via callback url, and then use that auth token + verifier string in another request to OAuth endpoint to obtain access and refresh tokens. All OAuth endpoints use https. So even is anyone will intercept auth code - it's basically useles without verifier which is not exposed anywhere


return new Promise((resolve, reject) => {
const errorListener = (error: Error) => {
server.off('error', errorListener);
reject(error);
};

server.on('error', errorListener);
server.listen(port, host, () => {
server.off('error', errorListener);
resolve(server);
});
});
}

async function stopServer(server: Server): Promise<void> {
if (!server.listening) {
return;
}

return new Promise((resolve, reject) => {
const errorListener = (error: Error) => {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This syntax is really strange to me, where did you get this? The errorListener invokes server off with itself?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here (and similarly in startServer) I just "promisify" server methods. Node's http server has event-based API, but since all our code is Promise-based, I wrapped server creation and stopping routines. How it works: first I create an error handler function and attach it to server's error event. Then I invoke a method I need, and if it was successful - I remove that error listener (therefore I store it in variable), and resolve promise. If called method emits an error - my error handler catches it, removes itself and rejects promise. Error handler here is needed only once to handle a single error, therefore I keep a function to be able to unregister it once it's no longer needed

server.off('error', errorListener);
reject(error);
};

server.on('error', errorListener);
server.close(() => {
server.off('error', errorListener);
resolve();
});
});
}

export interface AuthorizationCodeFetchResult {
code: string;
verifier: string;
redirectUri: string;
}

export default class AuthorizationCode {
private readonly client: BaseClient;

private readonly host: string = 'localhost';

private readonly ports: Array<number>;

private readonly logger?: IDBSQLLogger;

constructor(options: AuthorizationCodeOptions) {
this.client = options.client;
this.ports = options.ports;
this.logger = options.logger;
}

private async openUrl(url: string) {
return open(url);
}

public async fetch(scopes: Array<string>): Promise<AuthorizationCodeFetchResult> {
const verifierString = generators.codeVerifier(32);
const challengeString = generators.codeChallenge(verifierString);
const state = generators.state(16);

let receivedParams: CallbackParamsType | undefined;

const server = await this.startServer((req, res) => {
const params = this.client.callbackParams(req);
if (params.state === state) {
receivedParams = params;
res.writeHead(200);
res.end(this.renderCallbackResponse());
server.stop();
} else {
res.writeHead(404);
res.end();
}
});

const redirectUri = `http://${server.host}:${server.port}/`;
const authUrl = this.client.authorizationUrl({
response_type: 'code',
response_mode: 'query',
scope: scopes.join(scopeDelimiter),
code_challenge: challengeString,
code_challenge_method: 'S256',
state,
redirect_uri: redirectUri,
});

await this.openUrl(authUrl);
await server.stopped();

if (!receivedParams || !receivedParams.code) {
if (receivedParams?.error) {
const errorMessage = `OAuth error: ${receivedParams.error} ${receivedParams.error_description}`;
throw new Error(errorMessage);
}
throw new Error(`No path parameters were returned to the callback at ${redirectUri}`);
}

return { code: receivedParams.code, verifier: verifierString, redirectUri };
}

private async startServer(requestHandler: (req: IncomingMessage, res: ServerResponse) => void) {
for (const port of this.ports) {
const host = this.host; // eslint-disable-line prefer-destructuring
try {
const server = await startServer(host, port, requestHandler); // eslint-disable-line no-await-in-loop
this.logger?.log(LogLevel.info, `Listening for OAuth authorization callback at ${host}:${port}`);

let resolveStopped: () => void;
let rejectStopped: (reason?: any) => void;
const stoppedPromise = new Promise<void>((resolve, reject) => {
resolveStopped = resolve;
rejectStopped = reject;
});

return {
host,
port,
server,
stop: () => stopServer(server).then(resolveStopped).catch(rejectStopped),
stopped: () => stoppedPromise,
};
} catch (error) {
// if port already in use - try another one, otherwise re-throw an exception
if (error instanceof Error && 'code' in error && error.code === 'EADDRINUSE') {
this.logger?.log(LogLevel.debug, `Failed to start server at ${host}:${port}: ${error.code}`);
} else {
throw error;
}
}
}

throw new Error('Failed to start server: all ports are in use');
}

private renderCallbackResponse(): string {
const applicationName = 'Databricks Sql Connector';

return `<html lang="en">
<head>
<title>Close this Tab</title>
<style>
body {
font-family: "Barlow", Helvetica, Arial, sans-serif;
padding: 20px;
background-color: #f3f3f3;
}
</style>
</head>
<body>
<h1>Please close this tab.</h1>
<p>
The ${applicationName} received a response. You may close this tab.
</p>
</body>
</html>`;
}
}
102 changes: 102 additions & 0 deletions lib/connection/auth/DatabricksOAuth/OAuthManager.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import { Issuer, BaseClient } from 'openid-client';
import HiveDriverError from '../../../errors/HiveDriverError';
import IDBSQLLogger, { LogLevel } from '../../../contracts/IDBSQLLogger';
import OAuthToken from './OAuthToken';
import AuthorizationCode from './AuthorizationCode';

const oidcConfigPath = 'oidc/.well-known/oauth-authorization-server';

export interface OAuthManagerOptions {
host: string;
callbackPorts: Array<number>;
clientId: string;
logger?: IDBSQLLogger;
}

export default class OAuthManager {
private readonly options: OAuthManagerOptions;

private readonly logger?: IDBSQLLogger;

private issuer?: Issuer;

private client?: BaseClient;

constructor(options: OAuthManagerOptions) {
this.options = options;
this.logger = options.logger;
}

private async getClient(): Promise<BaseClient> {
if (!this.issuer) {
const { host } = this.options;
const schema = host.startsWith('https://') ? '' : 'https://';
const trailingSlash = host.endsWith('/') ? '' : '/';
this.issuer = await Issuer.discover(`${schema}${host}${trailingSlash}${oidcConfigPath}`);
}

if (!this.client) {
this.client = new this.issuer.Client({
client_id: this.options.clientId,
token_endpoint_auth_method: 'none',
});
}

return this.client;
}

public async refreshAccessToken(token: OAuthToken): Promise<OAuthToken> {
try {
if (!token.hasExpired) {
// The access token is fine. Just return it.
return token;
}
} catch (error) {
this.logger?.log(LogLevel.error, `${error}`);
throw error;
}

if (!token.refreshToken) {
const message = `OAuth access token expired on ${token.expirationTime}.`;
this.logger?.log(LogLevel.error, message);
throw new HiveDriverError(message);
}

// Try to refresh using the refresh token
this.logger?.log(
LogLevel.debug,
`Attempting to refresh OAuth access token that expired on ${token.expirationTime}`,
);

const client = await this.getClient();
const { access_token: accessToken, refresh_token: refreshToken } = await client.refresh(token.refreshToken);
if (!accessToken || !refreshToken) {
throw new Error('Failed to refresh token: invalid response');
}
return new OAuthToken(accessToken, refreshToken);
}

public async getToken(scopes: Array<string>): Promise<OAuthToken> {
const client = await this.getClient();
const authCode = new AuthorizationCode({
client,
ports: this.options.callbackPorts,
logger: this.logger,
});

const { code, verifier, redirectUri } = await authCode.fetch(scopes);

const { access_token: accessToken, refresh_token: refreshToken } = await client.grant({
grant_type: 'authorization_code',
code,
code_verifier: verifier,
redirect_uri: redirectUri,
});

if (!accessToken) {
throw new Error('Failed to fetch access token');
}

return new OAuthToken(accessToken, refreshToken);
}
}
7 changes: 7 additions & 0 deletions lib/connection/auth/DatabricksOAuth/OAuthPersistence.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import OAuthToken from './OAuthToken';

export default interface OAuthPersistence {
persist(host: string, token: OAuthToken): Promise<void>;

read(host: string): Promise<OAuthToken | undefined>;
}
Loading