diff --git a/.github/workflows/integration_tests.yml b/.github/workflows/integration_tests.yml index 497c2eb7a..7875229bb 100644 --- a/.github/workflows/integration_tests.yml +++ b/.github/workflows/integration_tests.yml @@ -22,14 +22,13 @@ permissions: contents: read # This is required for actions/checkout jobs: - run-integration-tests: - name: Run Integration Tests + run-integration-tests-default: + name: Run Integration Tests (Default) runs-on: ubuntu-latest strategy: fail-fast: false matrix: - versions: [ "default", "latest" ] - dbEngine: ["aurora-mysql", "aurora-postgres" ] + dbEngine: ["aurora-mysql", "aurora-postgres"] steps: - name: Clone repository @@ -64,8 +63,75 @@ jobs: AWS_ACCESS_KEY_ID: ${{ steps.creds.outputs.aws-access-key-id }} AWS_SECRET_ACCESS_KEY: ${{ steps.creds.outputs.aws-secret-access-key }} AWS_SESSION_TOKEN: ${{ steps.creds.outputs.aws-session-token }} - AURORA_MYSQL_DB_ENGINE_VERSION: ${{ matrix.dbEngine }} - AURORA_PG_DB_ENGINE_VERSION: ${{ matrix.versions }} + AURORA_MYSQL_DB_ENGINE_VERSION: default + AURORA_PG_DB_ENGINE_VERSION: default + + - name: "Get Github Action IP" + if: always() + id: ip + uses: haythem/public-ip@v1.3 + + - name: "Remove Github Action IP" + if: always() + run: | + aws ec2 revoke-security-group-ingress \ + --group-name default \ + --protocol -1 \ + --port -1 \ + --cidr ${{ steps.ip.outputs.ipv4 }}/32 \ + 2>&1 > /dev/null; + + - name: Archive results + if: always() + uses: actions/upload-artifact@v4 + with: + name: integration-report-default-${{ matrix.dbEngine }} + path: ./tests/integration/container/reports + retention-days: 5 + + run-integration-tests-latest: + name: Run Integration Tests (Latest) + runs-on: ubuntu-latest + needs: run-integration-tests-default + strategy: + fail-fast: false + matrix: + dbEngine: ["aurora-mysql", "aurora-postgres" ] + + steps: + - name: Clone repository + uses: actions/checkout@v4 + - name: "Set up JDK 8" + uses: actions/setup-java@v3 + with: + distribution: "corretto" + java-version: 8 + - name: Set up Node.js + uses: actions/setup-node@v4 + with: + node-version: "20.x" + - name: Install dependencies + run: npm install --no-save + + - name: Configure AWS Credentials + id: creds + uses: aws-actions/configure-aws-credentials@v4 + with: + role-to-assume: arn:aws:iam::${{ secrets.AWS_ACCOUNT_ID }}:role/${{ secrets.AWS_DEPLOY_ROLE }} + role-session-name: nodejs_int_latest_tests + aws-region: ${{ secrets.AWS_DEFAULT_REGION }} + output-credentials: true + + - name: Run Integration Tests + run: | + ./gradlew --no-parallel --no-daemon test-${{ matrix.dbEngine }} --info + env: + RDS_DB_REGION: ${{ secrets.AWS_DEFAULT_REGION }} + AWS_ACCESS_KEY_ID: ${{ steps.creds.outputs.aws-access-key-id }} + AWS_SECRET_ACCESS_KEY: ${{ steps.creds.outputs.aws-secret-access-key }} + AWS_SESSION_TOKEN: ${{ steps.creds.outputs.aws-session-token }} + AURORA_MYSQL_DB_ENGINE_VERSION: latest + AURORA_PG_DB_ENGINE_VERSION: latest - name: "Get Github Action IP" if: always() @@ -86,6 +152,6 @@ jobs: if: always() uses: actions/upload-artifact@v4 with: - name: integration-report-default-${{ matrix.dbEngine }}-${{ matrix.versions}} + name: integration-report-latest-${{ matrix.dbEngine }} path: ./tests/integration/container/reports retention-days: 5 diff --git a/.prettierrc b/.prettierrc index 60a1bb379..938cb431b 100644 --- a/.prettierrc +++ b/.prettierrc @@ -2,5 +2,5 @@ "semi": true, "trailingComma": "none", "printWidth": 150, - "endOfLine": "lf" + "endOfLine": "auto" } diff --git a/common/lib/authentication/aws_secrets_manager_plugin.ts b/common/lib/authentication/aws_secrets_manager_plugin.ts index c7f2e6213..631b9f068 100644 --- a/common/lib/authentication/aws_secrets_manager_plugin.ts +++ b/common/lib/authentication/aws_secrets_manager_plugin.ts @@ -129,20 +129,18 @@ export class AwsSecretsManagerPlugin extends AbstractConnectionPlugin implements this.pluginService.updateConfigWithProperties(props); return await connectFunc(); } catch (error) { - if (error instanceof Error) { - if ((error.message.includes("password authentication failed") || error.message.includes("Access denied")) && !secretWasFetched) { - // Login unsuccessful with cached credentials - // Try to re-fetch credentials and try again - - secretWasFetched = await this.updateSecret(true); - if (secretWasFetched) { - WrapperProperties.USER.set(props, this.secret?.username ?? ""); - WrapperProperties.PASSWORD.set(props, this.secret?.password ?? ""); - return await connectFunc(); - } + if ((error.message.includes("password authentication failed") || error.message.includes("Access denied")) && !secretWasFetched) { + // Login unsuccessful with cached credentials + // Try to re-fetch credentials and try again + + secretWasFetched = await this.updateSecret(true); + if (secretWasFetched) { + WrapperProperties.USER.set(props, this.secret?.username ?? ""); + WrapperProperties.PASSWORD.set(props, this.secret?.password ?? ""); + return await connectFunc(); } - logger.debug(Messages.get("AwsSecretsManagerConnectionPlugin.unhandledError", error.name, error.message)); } + logger.debug(Messages.get("AwsSecretsManagerConnectionPlugin.unhandledError", error.name, error.message)); throw error; } } diff --git a/common/lib/authentication/aws_secrets_manager_plugin_factory.ts b/common/lib/authentication/aws_secrets_manager_plugin_factory.ts index eb0ad30d3..6b3242cd1 100644 --- a/common/lib/authentication/aws_secrets_manager_plugin_factory.ts +++ b/common/lib/authentication/aws_secrets_manager_plugin_factory.ts @@ -15,20 +15,20 @@ */ import { ConnectionPluginFactory } from "../plugin_factory"; -import { PluginService } from "../plugin_service"; import { ConnectionPlugin } from "../connection_plugin"; import { AwsWrapperError } from "../utils/errors"; import { Messages } from "../utils/messages"; +import { FullServicesContainer } from "../utils/full_services_container"; export class AwsSecretsManagerPluginFactory extends ConnectionPluginFactory { private static awsSecretsManagerPlugin: any; - async getInstance(pluginService: PluginService, properties: Map): Promise { + async getInstance(servicesContainer: FullServicesContainer, properties: Map): Promise { try { if (!AwsSecretsManagerPluginFactory.awsSecretsManagerPlugin) { AwsSecretsManagerPluginFactory.awsSecretsManagerPlugin = await import("./aws_secrets_manager_plugin"); } - return new AwsSecretsManagerPluginFactory.awsSecretsManagerPlugin.AwsSecretsManagerPlugin(pluginService, new Map(properties)); + return new AwsSecretsManagerPluginFactory.awsSecretsManagerPlugin.AwsSecretsManagerPlugin(servicesContainer.pluginService, new Map(properties)); } catch (error: any) { if (error.code === "MODULE_NOT_FOUND") { throw new AwsWrapperError(Messages.get("ConnectionPluginChainBuilder.errorImportingPlugin", error.message, "AwsSecretsManagerPlugin")); diff --git a/common/lib/authentication/iam_authentication_plugin_factory.ts b/common/lib/authentication/iam_authentication_plugin_factory.ts index da5e0a32b..6ececd67b 100644 --- a/common/lib/authentication/iam_authentication_plugin_factory.ts +++ b/common/lib/authentication/iam_authentication_plugin_factory.ts @@ -15,20 +15,20 @@ */ import { ConnectionPluginFactory } from "../plugin_factory"; -import { PluginService } from "../plugin_service"; import { ConnectionPlugin } from "../connection_plugin"; import { AwsWrapperError } from "../utils/errors"; import { Messages } from "../utils/messages"; +import { FullServicesContainer } from "../utils/full_services_container"; export class IamAuthenticationPluginFactory extends ConnectionPluginFactory { private static iamAuthenticationPlugin: any; - async getInstance(pluginService: PluginService, properties: object): Promise { + async getInstance(servicesContainer: FullServicesContainer, properties: object): Promise { try { if (!IamAuthenticationPluginFactory.iamAuthenticationPlugin) { IamAuthenticationPluginFactory.iamAuthenticationPlugin = await import("./iam_authentication_plugin"); } - return new IamAuthenticationPluginFactory.iamAuthenticationPlugin.IamAuthenticationPlugin(pluginService); + return new IamAuthenticationPluginFactory.iamAuthenticationPlugin.IamAuthenticationPlugin(servicesContainer.pluginService); } catch (error: any) { throw new AwsWrapperError(Messages.get("ConnectionPluginChainBuilder.errorImportingPlugin", error.message, "IamAuthenticationPlugin")); } diff --git a/common/lib/aws_client.ts b/common/lib/aws_client.ts index 0140c847b..9f772be13 100644 --- a/common/lib/aws_client.ts +++ b/common/lib/aws_client.ts @@ -14,8 +14,7 @@ limitations under the License. */ -import { PluginServiceManagerContainer } from "./plugin_service_manager_container"; -import { PluginService, PluginServiceImpl } from "./plugin_service"; +import { PluginService } from "./plugin_service"; import { DatabaseDialect, DatabaseType } from "./database_dialect/database_dialect"; import { ConnectionUrlParser } from "./utils/connection_url_parser"; import { HostListProvider } from "./host_list_provider/host_list_provider"; @@ -23,26 +22,31 @@ import { PluginManager } from "./plugin_manager"; import pkgStream from "stream"; import { ClientWrapper } from "./client_wrapper"; -import { ConnectionProviderManager } from "./connection_provider_manager"; import { DefaultTelemetryFactory } from "./utils/telemetry/default_telemetry_factory"; import { TelemetryFactory } from "./utils/telemetry/telemetry_factory"; import { DriverDialect } from "./driver_dialect/driver_dialect"; import { WrapperProperties } from "./wrapper_property"; import { DriverConfigurationProfiles } from "./profile/driver_configuration_profiles"; import { ConfigurationProfile } from "./profile/configuration_profile"; -import { AwsWrapperError, TransactionIsolationLevel, ConnectionProvider } from "./"; +import { AwsWrapperError, ConnectionProvider, TransactionIsolationLevel } from "./"; import { Messages } from "./utils/messages"; import { HostListProviderService } from "./host_list_provider_service"; import { SessionStateClient } from "./session_state_client"; -import { DriverConnectionProvider } from "./driver_connection_provider"; +import { ServiceUtils } from "./utils/service_utils"; import { StorageService } from "./utils/storage/storage_service"; +import { MonitorService } from "./utils/monitoring/monitor_service"; import { CoreServicesContainer } from "./utils/core_services_container"; +import { FullServicesContainer } from "./utils/full_services_container"; +import { EventPublisher } from "./utils/events/event"; const { EventEmitter } = pkgStream; export abstract class AwsClient extends EventEmitter implements SessionStateClient { private _defaultPort: number = -1; + private readonly fullServiceContainer: FullServicesContainer; private readonly storageService: StorageService; + private readonly monitorService: MonitorService; + private readonly eventPublisher: EventPublisher; protected telemetryFactory: TelemetryFactory; protected pluginManager: PluginManager; protected pluginService: PluginService; @@ -67,7 +71,7 @@ export abstract class AwsClient extends EventEmitter implements SessionStateClie this.properties = new Map(Object.entries(config)); - this.storageService = CoreServicesContainer.getInstance().getStorageService(); + this.storageService = CoreServicesContainer.getInstance().storageService; const profileName = WrapperProperties.PROFILE_NAME.get(this.properties); if (profileName && profileName.length > 0) { @@ -103,22 +107,27 @@ export abstract class AwsClient extends EventEmitter implements SessionStateClie } } + const coreServicesContainer: CoreServicesContainer = CoreServicesContainer.getInstance(); + this.storageService = coreServicesContainer.storageService; + this.monitorService = coreServicesContainer.monitorService; + this.eventPublisher = coreServicesContainer.eventPublisher; this.telemetryFactory = new DefaultTelemetryFactory(this.properties); - const container = new PluginServiceManagerContainer(); - this.pluginService = new PluginServiceImpl( - container, + + this.fullServiceContainer = ServiceUtils.instance.createStandardServiceContainer( + this.storageService, + this.monitorService, + this.eventPublisher, this, + this.properties, dbType, knownDialectsByCode, - this.properties, - this._configurationProfile?.getDriverDialect() ?? driverDialect - ); - this.pluginManager = new PluginManager( - container, - this.properties, - new ConnectionProviderManager(connectionProvider ?? new DriverConnectionProvider(), WrapperProperties.CONNECTION_PROVIDER.get(this.properties)), - this.telemetryFactory + this._configurationProfile?.getDriverDialect() ?? driverDialect, + this.telemetryFactory, + connectionProvider ); + + this.pluginService = this.fullServiceContainer.pluginService; + this.pluginManager = this.fullServiceContainer.pluginManager; } private async setup() { @@ -130,12 +139,12 @@ export abstract class AwsClient extends EventEmitter implements SessionStateClie await this.setup(); const hostListProvider: HostListProvider = this.pluginService .getDialect() - .getHostListProvider(this.properties, this.properties.get("host"), (this.pluginService)); + .getHostListProvider(this.properties, this.properties.get("host"), this.fullServiceContainer); this.pluginService.setHostListProvider(hostListProvider); await this.pluginService.refreshHostList(); const initialHostInfo = this.pluginService.getInitialConnectionHostInfo(); if (initialHostInfo != null) { - await this.pluginManager.initHostProvider(initialHostInfo, this.properties, (this.pluginService)); + await this.pluginManager.initHostProvider(initialHostInfo, this.properties, this.fullServiceContainer.hostListProviderService); await this.pluginService.refreshHostList(); } } @@ -159,23 +168,23 @@ export abstract class AwsClient extends EventEmitter implements SessionStateClie abstract setReadOnly(readOnly: boolean): Promise; - abstract isReadOnly(): boolean; + abstract isReadOnly(): boolean | undefined; abstract setAutoCommit(autoCommit: boolean): Promise; - abstract getAutoCommit(): boolean; + abstract getAutoCommit(): boolean | undefined; abstract setTransactionIsolation(level: TransactionIsolationLevel): Promise; - abstract getTransactionIsolation(): TransactionIsolationLevel; + abstract getTransactionIsolation(): TransactionIsolationLevel | undefined; abstract setSchema(schema: any): Promise; - abstract getSchema(): string; + abstract getSchema(): string | undefined; abstract setCatalog(catalog: string): Promise; - abstract getCatalog(): string; + abstract getCatalog(): string | undefined; abstract end(): Promise; diff --git a/common/lib/connection_info.ts b/common/lib/connection_info.ts index 2d1951202..77b8e6507 100644 --- a/common/lib/connection_info.ts +++ b/common/lib/connection_info.ts @@ -1,18 +1,18 @@ /* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ + Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"). + You may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ import { ClientWrapper } from "./client_wrapper"; diff --git a/common/lib/connection_plugin_chain_builder.ts b/common/lib/connection_plugin_chain_builder.ts index 9c399dd7c..1cad14c72 100644 --- a/common/lib/connection_plugin_chain_builder.ts +++ b/common/lib/connection_plugin_chain_builder.ts @@ -43,6 +43,8 @@ import { CustomEndpointPluginFactory } from "./plugins/custom_endpoint/custom_en import { ConfigurationProfile } from "./profile/configuration_profile"; import { HostMonitoring2PluginFactory } from "./plugins/efm2/host_monitoring2_plugin_factory"; import { BlueGreenPluginFactory } from "./plugins/bluegreen/blue_green_plugin_factory"; +import { GlobalDbFailoverPluginFactory } from "./plugins/gdb_failover/global_db_failover_plugin_factory"; +import { FullServicesContainer } from "./utils/full_services_container"; /* Type alias used for plugin factory sorting. It holds a reference to a plugin @@ -65,6 +67,7 @@ export class ConnectionPluginChainBuilder { ["readWriteSplitting", { factory: ReadWriteSplittingPluginFactory, weight: 600 }], ["failover", { factory: FailoverPluginFactory, weight: 700 }], ["failover2", { factory: Failover2PluginFactory, weight: 710 }], + ["gdbFailover", { factory: GlobalDbFailoverPluginFactory, weight: 720 }], ["efm", { factory: HostMonitoringPluginFactory, weight: 800 }], ["efm2", { factory: HostMonitoring2PluginFactory, weight: 810 }], ["fastestResponseStrategy", { factory: FastestResponseStrategyPluginFactory, weight: 900 }], @@ -86,6 +89,7 @@ export class ConnectionPluginChainBuilder { [ReadWriteSplittingPluginFactory, 600], [FailoverPluginFactory, 700], [Failover2PluginFactory, 710], + [GlobalDbFailoverPluginFactory, 720], [HostMonitoringPluginFactory, 800], [HostMonitoring2PluginFactory, 810], [LimitlessConnectionPluginFactory, 950], @@ -99,7 +103,7 @@ export class ConnectionPluginChainBuilder { ]); static async getPlugins( - pluginService: PluginService, + servicesContainer: FullServicesContainer, props: Map, connectionProviderManager: ConnectionProviderManager, configurationProfile: ConfigurationProfile | null @@ -162,10 +166,10 @@ export class ConnectionPluginChainBuilder { for (const pluginFactoryInfo of pluginFactoryInfoList) { const factoryObj = new pluginFactoryInfo.factory(); - plugins.push(await factoryObj.getInstance(pluginService, props)); + plugins.push(await factoryObj.getInstance(servicesContainer, props)); } - plugins.push(new DefaultPlugin(pluginService, connectionProviderManager)); + plugins.push(new DefaultPlugin(servicesContainer, connectionProviderManager)); return plugins; } diff --git a/common/lib/database_dialect/database_dialect.ts b/common/lib/database_dialect/database_dialect.ts index 22f2cf28a..2077934af 100644 --- a/common/lib/database_dialect/database_dialect.ts +++ b/common/lib/database_dialect/database_dialect.ts @@ -21,6 +21,7 @@ import { FailoverRestriction } from "../plugins/failover/failover_restriction"; import { ErrorHandler } from "../error_handler"; import { TransactionIsolationLevel } from "../utils/transaction_isolation_level"; import { HostRole } from "../host_role"; +import { FullServicesContainer } from "../utils/full_services_container"; export enum DatabaseType { MYSQL, @@ -41,7 +42,7 @@ export interface DatabaseDialect { getErrorHandler(): ErrorHandler; getHostRole(targetClient: ClientWrapper): Promise; isDialect(targetClient: ClientWrapper): Promise; - getHostListProvider(props: Map, originalUrl: string, hostListProviderService: HostListProviderService): HostListProvider; + getHostListProvider(props: Map, originalUrl: string, servicesContainer: FullServicesContainer): HostListProvider; isClientValid(targetClient: ClientWrapper): Promise; getDatabaseType(): DatabaseType; getDialectName(): string; diff --git a/common/lib/database_dialect/database_dialect_codes.ts b/common/lib/database_dialect/database_dialect_codes.ts index 4815f48ec..4351a6cb5 100644 --- a/common/lib/database_dialect/database_dialect_codes.ts +++ b/common/lib/database_dialect/database_dialect_codes.ts @@ -15,11 +15,13 @@ */ export class DatabaseDialectCodes { + static readonly GLOBAL_AURORA_MYSQL: string = "global-aurora-mysql"; static readonly AURORA_MYSQL: string = "aurora-mysql"; static readonly RDS_MYSQL: string = "rds-mysql"; static readonly MYSQL: string = "mysql"; // https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/multi-az-db-clusters-concepts.html static readonly RDS_MULTI_AZ_MYSQL: string = "rds-multi-az-mysql"; + static readonly GLOBAL_AURORA_PG: string = "global-aurora-pg"; static readonly AURORA_PG: string = "aurora-pg"; static readonly RDS_PG: string = "rds-pg"; // https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/multi-az-db-clusters-concepts.html diff --git a/common/lib/database_dialect/database_dialect_manager.ts b/common/lib/database_dialect/database_dialect_manager.ts index e8127f2a0..b462e60bd 100644 --- a/common/lib/database_dialect/database_dialect_manager.ts +++ b/common/lib/database_dialect/database_dialect_manager.ts @@ -95,6 +95,14 @@ export class DatabaseDialectManager implements DatabaseDialectProvider { if (this.dbType === DatabaseType.MYSQL) { const type = this.rdsHelper.identifyRdsType(host); + if (type == RdsUrlType.RDS_GLOBAL_WRITER_CLUSTER) { + this.canUpdate = false; + this.dialectCode = DatabaseDialectCodes.GLOBAL_AURORA_MYSQL; + this.dialect = this.knownDialectsByCode.get(DatabaseDialectCodes.GLOBAL_AURORA_MYSQL); + this.logCurrentDialect(); + return this.dialect; + } + if (type.isRdsCluster) { this.canUpdate = true; this.dialectCode = DatabaseDialectCodes.AURORA_MYSQL; @@ -128,6 +136,14 @@ export class DatabaseDialectManager implements DatabaseDialectProvider { return this.dialect; } + if (type == RdsUrlType.RDS_GLOBAL_WRITER_CLUSTER) { + this.canUpdate = false; + this.dialectCode = DatabaseDialectCodes.GLOBAL_AURORA_PG; + this.dialect = this.knownDialectsByCode.get(DatabaseDialectCodes.GLOBAL_AURORA_PG); + this.logCurrentDialect(); + return this.dialect; + } + if (type.isRdsCluster) { this.canUpdate = true; this.dialectCode = DatabaseDialectCodes.AURORA_PG; diff --git a/common/lib/database_dialect/topology_aware_database_dialect.ts b/common/lib/database_dialect/topology_aware_database_dialect.ts index cdd1929cf..df0553c8c 100644 --- a/common/lib/database_dialect/topology_aware_database_dialect.ts +++ b/common/lib/database_dialect/topology_aware_database_dialect.ts @@ -1,18 +1,18 @@ /* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ + Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"). + You may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ import { HostRole } from "../host_role"; import { ClientWrapper } from "../client_wrapper"; @@ -34,3 +34,7 @@ export interface TopologyAwareDatabaseDialect { export interface GlobalAuroraTopologyDialect extends TopologyAwareDatabaseDialect { getRegionByInstanceId(targetClient: ClientWrapper, instanceId: string): Promise; } + +export function isDialectTopologyAware(dialect: any): dialect is TopologyAwareDatabaseDialect { + return dialect; +} diff --git a/common/lib/driver_dialect/driver_dialect.ts b/common/lib/driver_dialect/driver_dialect.ts index 347805448..05bf05911 100644 --- a/common/lib/driver_dialect/driver_dialect.ts +++ b/common/lib/driver_dialect/driver_dialect.ts @@ -30,7 +30,7 @@ export interface DriverDialect { setConnectTimeout(props: Map, wrapperConnectTimeout?: any): void; - setQueryTimeout(props: Map, sql?: any, wrapperConnectTimeout?: any): void; + setQueryTimeout(props: Map, sql?: any, wrapperQueryTimeout?: any): void; setKeepAliveProperties(props: Map, keepAliveProps: any): void; diff --git a/common/lib/highest_weight_host_selector.ts b/common/lib/highest_weight_host_selector.ts index 370d68773..75d42d48c 100644 --- a/common/lib/highest_weight_host_selector.ts +++ b/common/lib/highest_weight_host_selector.ts @@ -26,7 +26,7 @@ export class HighestWeightHostSelector implements HostSelector { getHost(hosts: HostInfo[], role: HostRole, props?: Map): HostInfo { const eligibleHosts: HostInfo[] = hosts - .filter((host: HostInfo) => host.role === role && host.availability === HostAvailability.AVAILABLE) + .filter((host: HostInfo) => (role === null || host.role === role) && host.availability === HostAvailability.AVAILABLE) .sort((hostA: HostInfo, hostB: HostInfo) => (hostA.weight > hostB.weight ? -1 : hostA.weight < hostB.weight ? 1 : 0)); if (eligibleHosts.length === 0) { diff --git a/common/lib/plugin_service_manager_container.ts b/common/lib/host_availability/host_availability_cache_item.ts similarity index 50% rename from common/lib/plugin_service_manager_container.ts rename to common/lib/host_availability/host_availability_cache_item.ts index e82728227..c068f2de3 100644 --- a/common/lib/plugin_service_manager_container.ts +++ b/common/lib/host_availability/host_availability_cache_item.ts @@ -14,26 +14,12 @@ limitations under the License. */ -import { PluginService } from "./plugin_service"; -import { PluginManager } from "./plugin_manager"; +import { HostAvailability } from "./host_availability"; -export class PluginServiceManagerContainer { - private _pluginService?: PluginService | null; - private _pluginManager?: PluginManager | null; +export class HostAvailabilityCacheItem { + readonly availability: HostAvailability; - get pluginService(): PluginService | null { - return this._pluginService ?? null; - } - - set pluginService(service: PluginService | null) { - this._pluginService = service; - } - - get pluginManager(): PluginManager | null { - return this._pluginManager ?? null; - } - - set pluginManager(service: PluginManager | null) { - this._pluginManager = service; + constructor(availability: HostAvailability) { + this.availability = availability; } } diff --git a/common/lib/host_list_provider/aurora_topology_utils.ts b/common/lib/host_list_provider/aurora_topology_utils.ts new file mode 100644 index 000000000..5a16ac359 --- /dev/null +++ b/common/lib/host_list_provider/aurora_topology_utils.ts @@ -0,0 +1,68 @@ +/* + Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"). + You may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +import { TopologyQueryResult, TopologyUtils } from "./topology_utils"; +import { ClientWrapper } from "../client_wrapper"; +import { DatabaseDialect } from "../database_dialect/database_dialect"; +import { HostInfo } from "../host_info"; +import { isDialectTopologyAware } from "../database_dialect/topology_aware_database_dialect"; +import { Messages } from "../utils/messages"; + +/** + * TopologyUtils implementation for Aurora clusters using a single HostInfo template. + */ +export class AuroraTopologyUtils extends TopologyUtils { + async queryForTopology( + targetClient: ClientWrapper, + dialect: DatabaseDialect, + initialHost: HostInfo, + clusterInstanceTemplate: HostInfo + ): Promise { + if (!isDialectTopologyAware(dialect)) { + throw new TypeError(Messages.get("RdsHostListProvider.incorrectDialect")); + } + + return await dialect + .queryForTopology(targetClient) + .then((res: TopologyQueryResult[]) => this.verifyWriter(this.createHosts(res, initialHost, clusterInstanceTemplate))); + } + + public createHosts(topologyQueryResults: TopologyQueryResult[], initialHost: HostInfo, clusterInstanceTemplate: HostInfo): HostInfo[] { + const hostsMap = new Map(); + topologyQueryResults.forEach((row) => { + const lastUpdateTime = row.lastUpdateTime ?? Date.now(); + + const host = this.createHost( + row.id, + row.host, + row.isWriter, + row.weight, + lastUpdateTime, + initialHost, + clusterInstanceTemplate, + row.endpoint, + row.port + ); + + const existing = hostsMap.get(host.host); + if (!existing || existing.lastUpdateTime < host.lastUpdateTime) { + hostsMap.set(host.host, host); + } + }); + + return Array.from(hostsMap.values()); + } +} diff --git a/common/lib/host_list_provider/connection_string_host_list_provider.ts b/common/lib/host_list_provider/connection_string_host_list_provider.ts index 0885d9b3c..a7973111a 100644 --- a/common/lib/host_list_provider/connection_string_host_list_provider.ts +++ b/common/lib/host_list_provider/connection_string_host_list_provider.ts @@ -65,17 +65,12 @@ export class ConnectionStringHostListProvider implements StaticHostListProvider this.isInitialized = true; } - refresh(): Promise; - refresh(client: ClientWrapper): Promise; - refresh(client?: ClientWrapper): Promise; - refresh(client?: ClientWrapper | undefined): Promise { + refresh(): Promise { this.init(); return Promise.resolve(this.hostList); } - forceRefresh(): Promise; - forceRefresh(client: ClientWrapper): Promise; - forceRefresh(client?: ClientWrapper): Promise { + forceRefresh(): Promise { this.init(); return Promise.resolve(this.hostList); } @@ -89,7 +84,7 @@ export class ConnectionStringHostListProvider implements StaticHostListProvider return null; } const instance = await this.hostListProviderService.getDialect().getHostAliasAndParseResults(client.client); - const topology = await this.refresh(client.client); + const topology = await this.refresh(); if (!topology || topology.length == 0) { return null; } @@ -97,19 +92,6 @@ export class ConnectionStringHostListProvider implements StaticHostListProvider return topology.filter((hostInfo) => instance === hostInfo.hostId)[0]; } - createHost(host: string, isWriter: boolean, weight: number, lastUpdateTime: number, port?: number): HostInfo { - return this.hostListProviderService - .getHostInfoBuilder() - .withHost(host ?? "") - .withPort(port ?? this.initialPort) - .withRole(isWriter ? HostRole.WRITER : HostRole.READER) - .withAvailability(HostAvailability.AVAILABLE) - .withWeight(weight) - .withLastUpdateTime(lastUpdateTime) - .withHostId(host) - .build(); - } - getHostProviderType(): string { return this.constructor.name; } diff --git a/common/lib/host_list_provider/global_aurora_host_list_provider.ts b/common/lib/host_list_provider/global_aurora_host_list_provider.ts new file mode 100644 index 000000000..7b17e3a5f --- /dev/null +++ b/common/lib/host_list_provider/global_aurora_host_list_provider.ts @@ -0,0 +1,71 @@ +/* + Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"). + You may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +import { RdsHostListProvider } from "./rds_host_list_provider"; +import { FullServicesContainer } from "../utils/full_services_container"; +import { HostInfo } from "../host_info"; +import { WrapperProperties } from "../wrapper_property"; +import { ClusterTopologyMonitor, ClusterTopologyMonitorImpl } from "./monitoring/cluster_topology_monitor"; +import { GlobalAuroraTopologyMonitor } from "./monitoring/global_aurora_topology_monitor"; +import { MonitorInitializer } from "../utils/monitoring/monitor"; +import { ClientWrapper } from "../client_wrapper"; +import { DatabaseDialect } from "../database_dialect/database_dialect"; +import { parseInstanceTemplates } from "../utils/utils"; + +export class GlobalAuroraHostListProvider extends RdsHostListProvider { + protected instanceTemplatesByRegion: Map; + protected override initSettings(): void { + super.initSettings(); + + const instanceTemplates = WrapperProperties.GLOBAL_CLUSTER_INSTANCE_HOST_PATTERNS.get(this.properties); + this.instanceTemplatesByRegion = parseInstanceTemplates( + instanceTemplates, + (hostPattern: string) => this.validateHostPatternSetting(hostPattern), + () => this.hostListProviderService.getHostInfoBuilder() + ); + } + + protected override async getOrCreateMonitor(): Promise { + const initializer: MonitorInitializer = { + createMonitor: (servicesContainer: FullServicesContainer): ClusterTopologyMonitor => { + return new GlobalAuroraTopologyMonitor( + servicesContainer, + this.topologyUtils, + this.clusterId, + this.initialHost, + this.properties, + this.clusterInstanceTemplate, + this.refreshRateNano, + this.highRefreshRateNano, + this.instanceTemplatesByRegion + ); + } + }; + + return await this.servicesContainers.monitorService.runIfAbsent( + ClusterTopologyMonitorImpl, + this.clusterId, + this.servicesContainers, + this.properties, + initializer + ); + } + + override async getCurrentTopology(targetClient: ClientWrapper, dialect: DatabaseDialect): Promise { + this.init(); + return await this.topologyUtils.queryForTopology(targetClient, dialect, this.initialHost, this.instanceTemplatesByRegion); + } +} diff --git a/common/lib/host_list_provider/global_topology_utils.ts b/common/lib/host_list_provider/global_topology_utils.ts index 7e9131db8..7b197dc15 100644 --- a/common/lib/host_list_provider/global_topology_utils.ts +++ b/common/lib/host_list_provider/global_topology_utils.ts @@ -1,45 +1,40 @@ /* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ + Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"). + You may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ import { TopologyQueryResult, TopologyUtils } from "./topology_utils"; import { ClientWrapper } from "../client_wrapper"; import { DatabaseDialect } from "../database_dialect/database_dialect"; import { HostInfo } from "../host_info"; -import { isDialectTopologyAware } from "../utils/utils"; +import { isDialectTopologyAware } from "../database_dialect/topology_aware_database_dialect"; import { Messages } from "../utils/messages"; import { AwsWrapperError } from "../utils/errors"; -export class GlobalTopologyUtils extends TopologyUtils { - async queryForTopology( - targetClient: ClientWrapper, - dialect: DatabaseDialect, - initialHost: HostInfo, - clusterInstanceTemplate: HostInfo - ): Promise { - throw new AwsWrapperError("Not implemented"); - } +export interface GdbTopologyUtils { + getRegion(instanceId: string, targetClient: ClientWrapper, dialect: DatabaseDialect): Promise; +} - async queryForTopologyWithRegion( +export class GlobalTopologyUtils extends TopologyUtils implements GdbTopologyUtils { + async queryForTopology( targetClient: ClientWrapper, dialect: DatabaseDialect, initialHost: HostInfo, instanceTemplateByRegion: Map ): Promise { if (!isDialectTopologyAware(dialect)) { - throw new TypeError(Messages.get("RdsHostListProvider.incorrectDialect")); + throw new AwsWrapperError(Messages.get("RdsHostListProvider.incorrectDialect")); } return await dialect @@ -47,6 +42,16 @@ export class GlobalTopologyUtils extends TopologyUtils { .then((res: TopologyQueryResult[]) => this.verifyWriter(this.createHostsWithTemplateMap(res, initialHost, instanceTemplateByRegion))); } + async getRegion(instanceId: string, targetClient: ClientWrapper, dialect: DatabaseDialect): Promise { + if (!isDialectTopologyAware(dialect)) { + throw new AwsWrapperError(Messages.get("RdsHostListProvider.incorrectDialect")); + } + + const results = await dialect.queryForTopology(targetClient); + const match = results.find((row) => row.id === instanceId); + return match?.awsRegion ?? null; + } + private createHostsWithTemplateMap( topologyQueryResults: TopologyQueryResult[], initialHost: HostInfo, diff --git a/common/lib/host_list_provider/host_list_provider.ts b/common/lib/host_list_provider/host_list_provider.ts index a8ea22f49..b7e1005a2 100644 --- a/common/lib/host_list_provider/host_list_provider.ts +++ b/common/lib/host_list_provider/host_list_provider.ts @@ -19,25 +19,13 @@ import { HostRole } from "../host_role"; import { DatabaseDialect } from "../database_dialect/database_dialect"; import { ClientWrapper } from "../client_wrapper"; -export type DynamicHostListProvider = HostListProvider; - export type StaticHostListProvider = HostListProvider; -export interface BlockingHostListProvider extends HostListProvider { - forceMonitoringRefresh(shouldVerifyWriter: boolean, timeoutMs: number): Promise; - - clearAll(): Promise; -} - export interface HostListProvider { refresh(): Promise; - refresh(client: ClientWrapper): Promise; - forceRefresh(): Promise; - forceRefresh(client: ClientWrapper): Promise; - getHostRole(client: ClientWrapper, dialect: DatabaseDialect): Promise; identifyConnection(targetClient: ClientWrapper): Promise; @@ -46,3 +34,7 @@ export interface HostListProvider { getClusterId(): string; } + +export interface DynamicHostListProvider extends HostListProvider { + forceMonitoringRefresh(shouldVerifyWriter: boolean, timeoutMs: number): Promise; +} diff --git a/common/lib/host_list_provider/monitoring/cluster_topology_monitor.ts b/common/lib/host_list_provider/monitoring/cluster_topology_monitor.ts index 2808c55a1..e9b4dcfd3 100644 --- a/common/lib/host_list_provider/monitoring/cluster_topology_monitor.ts +++ b/common/lib/host_list_provider/monitoring/cluster_topology_monitor.ts @@ -17,50 +17,76 @@ import { HostInfo } from "../../host_info"; import { PluginService } from "../../plugin_service"; import { HostAvailability } from "../../host_availability/host_availability"; -import { logTopology, sleep } from "../../utils/utils"; +import { convertMsToNanos, convertNanosToMs, getTimeInNanos, logTopology, sleep } from "../../utils/utils"; import { logger } from "../../../logutils"; import { HostRole } from "../../host_role"; import { ClientWrapper } from "../../client_wrapper"; -import { AwsWrapperError } from "../../utils/errors"; -import { MonitoringRdsHostListProvider } from "./monitoring_host_list_provider"; +import { AwsTimeoutError, AwsWrapperError } from "../../utils/errors"; import { Messages } from "../../utils/messages"; -import { CoreServicesContainer } from "../../utils/core_services_container"; import { Topology } from "../topology"; import { StorageService } from "../../utils/storage/storage_service"; import { TopologyUtils } from "../topology_utils"; import { RdsUtils } from "../../utils/rds_utils"; - -export interface ClusterTopologyMonitor { +import { AbstractMonitor, Monitor } from "../../utils/monitoring/monitor"; +import { FullServicesContainer } from "../../utils/full_services_container"; +import { HostListProviderService } from "../../host_list_provider_service"; +import { Event, EventSubscriber } from "../../utils/events/event"; +import { MonitorResetEvent } from "../../utils/events/monitor_reset_event"; +import { ServiceUtils } from "../../utils/service_utils"; +import { WrapperProperties } from "../../wrapper_property"; + +export interface ClusterTopologyMonitor extends Monitor, EventSubscriber { forceRefresh(client: ClientWrapper, timeoutMs: number): Promise; close(): Promise; - forceMonitoringRefresh(shouldVerifyWriter: boolean, timeoutMs: number): Promise; + /** + * Initiates a topology update. + * + * @param verifyTopology defines whether extra measures should be taken to verify the topology. If false, the + * method will return as soon as topology is successfully retrieved from any instance. If + * true, extra steps are taken to verify the topology is accurate. + * @param timeoutMs timeout in msec to wait until the topology gets refreshed (if verifyWriter has a value of + * false) or verified (if verifyTopology has a value of true). + * @return true if successful, false if unsuccessful or the timeout is reached + * @throws AwsWrapperError if wrapper timed out while fetching the topology. + */ + forceMonitoringRefresh(verifyTopology: boolean, timeoutMs: number): Promise; + + canDispose(): boolean; } -export class ClusterTopologyMonitorImpl implements ClusterTopologyMonitor { - private static readonly TOPOLOGY_CACHE_EXPIRATION_NANOS: number = 5 * 60 * 1_000_000_000; // 5 minutes. - static readonly MONITORING_PROPERTY_PREFIX: string = "topology_monitoring_"; +export class ClusterTopologyMonitorImpl extends AbstractMonitor implements ClusterTopologyMonitor { + private static readonly MONITOR_TERMINATION_TIMEOUT_SEC: number = 30; + private static readonly STABLE_TOPOLOGIES_DURATION_NS: bigint = convertMsToNanos(15000); // 15 seconds. + protected static readonly DEFAULT_CONNECTION_TIMEOUT_MS: number = 5000; + protected static readonly DEFAULT_QUERY_TIMEOUT_MS: number = 5000; private readonly clusterId: string; - private readonly initialHostInfo: HostInfo; + protected readonly initialHostInfo: HostInfo; + private readonly servicesContainer: FullServicesContainer; private readonly _monitoringProperties: Map; private readonly _pluginService: PluginService; - private readonly _hostListProvider: MonitoringRdsHostListProvider; - private readonly refreshRateMs: number; - private readonly highRefreshRateMs: number; + protected readonly hostListProviderService: HostListProviderService; + private readonly refreshRateNs: number; + private readonly highRefreshRateNs: number; private readonly storageService: StorageService; - private readonly topologyUtils: TopologyUtils; private readonly rdsUtils: RdsUtils = new RdsUtils(); - private readonly instanceTemplate: HostInfo; + protected readonly instanceTemplate: HostInfo; - private writerHostInfo: HostInfo = null; + private writerHostInfo: HostInfo | null = null; private isVerifiedWriterConnection: boolean = false; - private monitoringClient: ClientWrapper = null; - private highRefreshRateEndTimeMs: number = -1; - private highRefreshPeriodAfterPanicMs: number = 30000; // 30 seconds. - private ignoreNewTopologyRequestsEndTimeMs: number = -1; - private ignoreTopologyRequestMs: number = 10000; // 10 seconds. + private monitoringClient: ClientWrapper | null = null; + private highRefreshRateEndTimeNs: bigint = BigInt(0); + + public readonly topologyUtils: TopologyUtils; + public readonly readerTopologiesById: Map = new Map(); + public readonly completedOneCycle: Map = new Map(); + // When comparing topologies, we don't want to check HostInfo.weight, which is used in HostInfo#equals. + // We use this function to compare the other fields. + protected readonly hostInfoExtractor = (host: HostInfo): string => { + return `${host.host}:${host.port}:${host.availability}:${host.role}`; + }; // Tracking of the host monitors. private hostMonitors: Map = new Map(); @@ -70,46 +96,50 @@ export class ClusterTopologyMonitorImpl implements ClusterTopologyMonitor { public hostMonitorsLatestTopology: HostInfo[] = []; // Controls for stopping asynchronous monitoring tasks. - private stopMonitoring: boolean = false; public hostMonitorsStop: boolean = false; - private untrackedPromises: Promise[] = []; // Signals to other methods that asynchronous tasks have completed/should be completed. private requestToUpdateTopology: boolean = false; + private submittedHosts: Map> = new Map(); + private stableTopologiesStartNs: bigint; constructor( + servicesContainer: FullServicesContainer, topologyUtils: TopologyUtils, clusterId: string, initialHostInfo: HostInfo, props: Map, instanceTemplate: HostInfo, - pluginService: PluginService, - hostListProvider: MonitoringRdsHostListProvider, - refreshRateMs: number, - highRefreshRateMs: number + refreshRateNs: number, + highRefreshRateNs: number ) { + super(ClusterTopologyMonitorImpl.MONITOR_TERMINATION_TIMEOUT_SEC); this.topologyUtils = topologyUtils; this.clusterId = clusterId; - this.storageService = CoreServicesContainer.getInstance().getStorageService(); // TODO: store serviceContainer instead this.initialHostInfo = initialHostInfo; this.instanceTemplate = instanceTemplate; - this._pluginService = pluginService; - this._hostListProvider = hostListProvider; - this.refreshRateMs = refreshRateMs; - this.highRefreshRateMs = highRefreshRateMs; + this.servicesContainer = servicesContainer; + this.storageService = this.servicesContainer.storageService; + this._pluginService = this.servicesContainer.pluginService; + this.hostListProviderService = this.servicesContainer.hostListProviderService; + this.refreshRateNs = refreshRateNs; + this.highRefreshRateNs = highRefreshRateNs; this._monitoringProperties = new Map(props); for (const [key, val] of props) { - if (key.startsWith(ClusterTopologyMonitorImpl.MONITORING_PROPERTY_PREFIX)) { - this._monitoringProperties.set(key.substring(ClusterTopologyMonitorImpl.MONITORING_PROPERTY_PREFIX.length), val); + if (key.startsWith(WrapperProperties.TOPOLOGY_MONITORING_PROPERTY_PREFIX)) { + this._monitoringProperties.set(key.substring(WrapperProperties.TOPOLOGY_MONITORING_PROPERTY_PREFIX.length), val); this._monitoringProperties.delete(key); } } - this.untrackedPromises.push(this.run()); - } - get hostListProvider(): MonitoringRdsHostListProvider { - return this._hostListProvider; + const connectTimeout = + this._monitoringProperties.get(WrapperProperties.WRAPPER_CONNECT_TIMEOUT.name) ?? ClusterTopologyMonitorImpl.DEFAULT_CONNECTION_TIMEOUT_MS; + const queryTimeout = + this._monitoringProperties.get(WrapperProperties.WRAPPER_QUERY_TIMEOUT.name) ?? ClusterTopologyMonitorImpl.DEFAULT_QUERY_TIMEOUT_MS; + const driverDialect = this._pluginService.getDriverDialect(); + driverDialect.setConnectTimeout(this._monitoringProperties, connectTimeout); + driverDialect.setQueryTimeout(this._monitoringProperties, undefined, queryTimeout); } get pluginService(): PluginService { @@ -121,27 +151,35 @@ export class ClusterTopologyMonitorImpl implements ClusterTopologyMonitor { } async close(): Promise { - this.stopMonitoring = true; this.hostMonitorsStop = true; this.requestToUpdateTopology = true; - await Promise.all(this.untrackedPromises); - await this.closeConnection(this.monitoringClient); - await this.closeConnection(this.hostMonitorsWriterClient); - await this.closeConnection(this.hostMonitorsReaderClient); - this.untrackedPromises = []; + await Promise.all(this.submittedHosts.values()); + + const monitoringClientToClose = this.monitoringClient; + const hostMonitorsWriterClientToClose = this.hostMonitorsWriterClient; + const hostMonitorsReaderClientToClose = this.hostMonitorsReaderClient; + + this.monitoringClient = null; + this.hostMonitorsWriterClient = null; + this.hostMonitorsReaderClient = null; + + await this.closeConnection(monitoringClientToClose); + if (hostMonitorsWriterClientToClose && hostMonitorsWriterClientToClose !== monitoringClientToClose) { + await this.closeConnection(hostMonitorsWriterClientToClose); + } + if ( + hostMonitorsReaderClientToClose && + hostMonitorsReaderClientToClose !== monitoringClientToClose && + hostMonitorsReaderClientToClose !== hostMonitorsWriterClientToClose + ) { + await this.closeConnection(hostMonitorsReaderClientToClose); + } + + this.submittedHosts.clear(); this.hostMonitors.clear(); } async forceMonitoringRefresh(shouldVerifyWriter: boolean, timeoutMs: number): Promise { - if (Date.now() < this.ignoreNewTopologyRequestsEndTimeMs) { - // Previous failover has just completed, use results without triggering new update. - const currentHosts = this.getStoredHosts(); - if (currentHosts !== null) { - logger.info(Messages.get("ClusterTopologyMonitoring.ignoringNewTopologyRequest")); - return currentHosts; - } - } - if (shouldVerifyWriter) { this.isVerifiedWriterConnection = false; if (this.monitoringClient) { @@ -157,15 +195,16 @@ export class ClusterTopologyMonitorImpl implements ClusterTopologyMonitor { async forceRefresh(client: ClientWrapper, timeoutMs: number): Promise { if (this.isVerifiedWriterConnection) { + // Get the monitoring task to refresh the topology using a verified connection. return await this.waitTillTopologyGetsUpdated(timeoutMs); } - // Otherwise use provided unverified connection to update topology. + // Otherwise, use the provided unverified connection to update the topology. return await this.fetchTopologyAndUpdateCache(client); } async waitTillTopologyGetsUpdated(timeoutMs: number): Promise { - // Signal to any monitor that might be in delay, that topology should be updated. + // Notify the monitoring task, which may be sleeping, that topology should be refreshed immediately. this.requestToUpdateTopology = true; const currentHosts: HostInfo[] = this.getStoredHosts(); @@ -183,7 +222,7 @@ export class ClusterTopologyMonitorImpl implements ClusterTopologyMonitor { } if (Date.now() >= endTime) { - throw new AwsWrapperError(Messages.get("ClusterTopologyMonitor.timeoutError", timeoutMs.toString())); + throw new AwsTimeoutError(Messages.get("ClusterTopologyMonitor.timeoutError", timeoutMs.toString())); } return latestHosts; } @@ -194,7 +233,7 @@ export class ClusterTopologyMonitorImpl implements ClusterTopologyMonitor { } try { - const hosts: HostInfo[] = await this._hostListProvider.sqlQueryForTopology(client); + const hosts: HostInfo[] = await this.queryForTopology(client); if (hosts) { this.updateTopologyCache(hosts); } @@ -206,12 +245,11 @@ export class ClusterTopologyMonitorImpl implements ClusterTopologyMonitor { } private async openAnyClientAndUpdateTopology(): Promise { - let writerVerifiedByThisTask = false; if (!this.monitoringClient) { let client: ClientWrapper; try { - client = await this._pluginService.forceConnect(this.initialHostInfo, this._monitoringProperties); - } catch { + client = await this.servicesContainer.pluginService.forceConnect(this.initialHostInfo, this._monitoringProperties); + } catch (connectError) { // Unable to connect to host; return null; } @@ -226,7 +264,6 @@ export class ClusterTopologyMonitorImpl implements ClusterTopologyMonitor { if (this.rdsUtils.isRdsInstance(this.initialHostInfo.host)) { this.writerHostInfo = this.initialHostInfo; logger.info(Messages.get("ClusterTopologyMonitor.writerMonitoringConnection", this.writerHostInfo.host)); - writerVerifiedByThisTask = true; } else { const pair: [string, string] = await this.topologyUtils.getInstanceId(this.monitoringClient); const instanceTemplate: HostInfo = await this.getInstanceTemplate(pair[1], this.monitoringClient); @@ -238,20 +275,13 @@ export class ClusterTopologyMonitorImpl implements ClusterTopologyMonitor { // Do nothing. logger.error(Messages.get("ClusterTopologyMonitor.invalidWriterQuery", error?.message)); } - } else { + } else if (client) { // Monitoring connection already set by another task, close the new connection. await this.closeConnection(client); } } const hosts: HostInfo[] = await this.fetchTopologyAndUpdateCache(this.monitoringClient); - if (writerVerifiedByThisTask) { - if (this.ignoreNewTopologyRequestsEndTimeMs === -1) { - this.ignoreNewTopologyRequestsEndTimeMs = 0; - } else { - this.ignoreNewTopologyRequestsEndTimeMs = Date.now() + this.ignoreTopologyRequestMs; - } - } if (hosts === null) { this.isVerifiedWriterConnection = false; @@ -264,294 +294,530 @@ export class ClusterTopologyMonitorImpl implements ClusterTopologyMonitor { return Promise.resolve(this.instanceTemplate); } + queryForTopology(client: ClientWrapper): Promise { + return this.topologyUtils.queryForTopology(client, this.pluginService.getDialect(), this.initialHostInfo, this.instanceTemplate); + } + + updateHostsAvailability(hosts: HostInfo[]): void { + if (!hosts) { + return; + } + + hosts.forEach((host) => { + host.setAvailability(this.readerTopologiesById.has(host.hostId) ? HostAvailability.AVAILABLE : HostAvailability.NOT_AVAILABLE); + }); + } + updateTopologyCache(hosts: HostInfo[]): void { this.storageService.set(this.clusterId, new Topology(hosts)); this.requestToUpdateTopology = false; } - async getWriterHostIdIfConnected(client: ClientWrapper, hostId: string): Promise { - const writerHost: string = await this.hostListProvider.getWriterId(client); - // Returns the hostId of the writer if client is connected to that writer, otherwise returns null. - return writerHost === hostId ? writerHost : null; + protected clearTopologyCache(): void { + this.servicesContainer.storageService.remove(Topology, this.clusterId); } - async closeConnection(client: ClientWrapper): Promise { - if (client !== null) { - await client.abort(); - client = null; - } + async closeConnection(client: ClientWrapper | null): Promise { + await client?.abort(); } async updateMonitoringClient(newClient: ClientWrapper | null): Promise { const clientToClose = this.monitoringClient; this.monitoringClient = newClient; - if (clientToClose) { - await clientToClose.abort(); - } + await clientToClose?.abort(); } - private isInPanicMode(): boolean { - return !this.monitoringClient || !this.isVerifiedWriterConnection; + async stop(): Promise { + this._stop = true; + this.hostMonitorsStop = true; + + await Promise.all(this.submittedHosts.values()); + + await this.closeHostMonitors(); + + const hostMonitorsWriterClientToClose = this.hostMonitorsWriterClient; + const hostMonitorsReaderClientToClose = this.hostMonitorsReaderClient; + const monitoringClientToClose = this.monitoringClient; + + this.hostMonitorsWriterClient = null; + this.hostMonitorsReaderClient = null; + this.monitoringClient = null; + + await this.closeConnection(hostMonitorsWriterClientToClose); + if (hostMonitorsReaderClientToClose && hostMonitorsReaderClientToClose !== hostMonitorsWriterClientToClose) { + await this.closeConnection(hostMonitorsReaderClientToClose); + } + if ( + monitoringClientToClose && + monitoringClientToClose !== hostMonitorsWriterClientToClose && + monitoringClientToClose !== hostMonitorsReaderClientToClose + ) { + await this.closeConnection(monitoringClientToClose); + } + + this.submittedHosts.clear(); + + return super.stop(); } - async run(): Promise { - logger.debug(Messages.get("ClusterTopologyMonitor.startMonitoring")); + async monitor(): Promise { try { - while (!this.stopMonitoring) { + logger.debug(Messages.get("ClusterTopologyMonitor.startMonitoring", this.clusterId, this.initialHostInfo.host)); + this.servicesContainer.eventPublisher.subscribe(this, new Set([MonitorResetEvent])); + + while (!this._stop) { + this.lastActivityTimestampNanos = getTimeInNanos(); + if (this.isInPanicMode()) { - // Panic Mode: high refresh rate in effect. + if (this.submittedHosts.size === 0) { + logger.debug(Messages.get("ClusterTopologyMonitor.startingHostMonitoringTasks")); - if (this.hostMonitors.size === 0) { - // Initialize host tasks. - logger.debug(Messages.get("ClusterTopologyMonitor.startingHostMonitors")); + // Start host monitoring tasks. this.hostMonitorsStop = false; - if (this.hostMonitorsReaderClient !== null) { - await this.closeConnection(this.hostMonitorsReaderClient); - } - if (this.hostMonitorsWriterClient !== null) { - await this.closeConnection(this.hostMonitorsWriterClient); - } - this.hostMonitorsWriterClient = null; - this.hostMonitorsReaderClient = null; + await this.hostMonitorClientCleanUp(); this.hostMonitorsWriterInfo = null; this.hostMonitorsLatestTopology = []; - // Use any client to gather topology information. let hosts: HostInfo[] = this.getStoredHosts(); - if (!hosts) { + if (hosts === null) { + // Use any available connection to get the topology. hosts = await this.openAnyClientAndUpdateTopology(); } - // Set up host monitors. - if (hosts && !this.isVerifiedWriterConnection) { - for (const hostInfo of hosts) { - if (!this.hostMonitors.has(hostInfo.host)) { - const hostMonitor = new HostMonitor(this, hostInfo, this.writerHostInfo); - const hostRun = hostMonitor.run(); - this.hostMonitors.set(hostInfo.host, hostMonitor); - this.untrackedPromises.push(hostRun); - } + await this.closeHostMonitors(); + + if (!(hosts !== null && !this.isVerifiedWriterConnection)) { + await this.delay(true); + continue; + } + + for (const hostInfo of hosts) { + if (!this.submittedHosts.get(hostInfo.host)) { + const minimalServiceContainer = ServiceUtils.instance.createMinimalServiceContainerFrom( + this.servicesContainer, + this._monitoringProperties + ); + await minimalServiceContainer.pluginManager.init(); + const hostMonitor = new HostMonitor(minimalServiceContainer, this, hostInfo, this.writerHostInfo); + const promise = hostMonitor.run(); + this.submittedHosts.set(hostInfo.host, promise); } } - // If topology is not correctly updated, will try on the next round. + + // We will try again in the next iteration. } else { - // Host monitors already running, check if a writer has been detected. - const writerClient = this.hostMonitorsWriterClient; - const writerHostInfo = this.hostMonitorsWriterInfo; - if (writerClient && writerHostInfo && writerHostInfo !== this.writerHostInfo) { - // Writer detected, update monitoringClient. - logger.info(Messages.get("ClusterTopologyMonitor.writerPickedUpFromHostMonitors", writerHostInfo.hostId)); - await this.updateMonitoringClient(writerClient); - this.writerHostInfo = writerHostInfo; + // The host monitors are running, so we check if the writer has been detected. + const writerClient: ClientWrapper | null = this.hostMonitorsWriterClient; + const writerClientHostInfo: HostInfo | null = this.hostMonitorsWriterInfo; + + if (writerClient && writerClientHostInfo) { + logger.debug(Messages.get("ClusterTopologyMonitor.writerPickedUpFromHostMonitors", writerClientHostInfo.toString())); + + this.monitoringClient = writerClient; + this.writerHostInfo = writerClientHostInfo; this.isVerifiedWriterConnection = true; - if (this.ignoreNewTopologyRequestsEndTimeMs === -1) { - this.ignoreNewTopologyRequestsEndTimeMs = 0; - } else { - this.ignoreNewTopologyRequestsEndTimeMs = Date.now() + this.ignoreTopologyRequestMs; - } - if (this.highRefreshRateEndTimeMs === -1) { - this.highRefreshRateEndTimeMs = 0; - } else { - this.highRefreshRateEndTimeMs = Date.now() + this.highRefreshPeriodAfterPanicMs; - } + this.highRefreshRateEndTimeNs = getTimeInNanos() + BigInt(this.highRefreshRateNs); - // Stop monitoring of each host, writer detected. this.hostMonitorsStop = true; - this.hostMonitors.clear(); + await this.closeHostMonitors(); + this.submittedHosts.clear(); + this.stableTopologiesStartNs = BigInt(0); + this.readerTopologiesById.clear(); + this.completedOneCycle.clear(); + + await this.delay(true); continue; } else { - // No writer detected, update host monitors with any new hosts in the topology. - const hosts: HostInfo[] = this.hostMonitorsLatestTopology; - if (hosts !== null && !this.hostMonitorsStop) { + // Update host monitors with the new instances in the topology. + const hosts: HostInfo[] | null = this.hostMonitorsLatestTopology; + if (hosts && !this.hostMonitorsStop) { for (const hostInfo of hosts) { - if (!this.hostMonitors.has(hostInfo.host)) { - const hostMonitor = new HostMonitor(this, hostInfo, this.writerHostInfo); - const hostRun = hostMonitor.run(); - this.hostMonitors.set(hostInfo.host, hostMonitor); - this.untrackedPromises.push(hostRun); + if (!this.submittedHosts.get(hostInfo.host)) { + const minimalServiceContainer = ServiceUtils.instance.createMinimalServiceContainerFrom( + this.servicesContainer, + this._monitoringProperties + ); + await minimalServiceContainer.pluginManager.init(); + // Intentionally not calling await on hostMonitor.run(). + const hostMonitor = new HostMonitor(minimalServiceContainer, this, hostInfo, this.writerHostInfo); + const promise = hostMonitor.run(); + this.submittedHosts.set(hostInfo.host, promise); } } } } } - // Trigger a delay before retrying. + + this.checkForStableReaderTopologies(); await this.delay(true); } else { - // Regular mode: lower refresh rate than panic mode. - if (this.hostMonitors.size !== 0) { - // Stop host monitors. - this.hostMonitorsStop = true; - this.hostMonitors.clear(); + // We are in regular mode. + if (this.submittedHosts.size !== 0) { + await this.closeHostMonitors(); + this.submittedHosts.clear(); + this.stableTopologiesStartNs = BigInt(0); + this.readerTopologiesById.clear(); + this.completedOneCycle.clear(); } - const hosts = await this.fetchTopologyAndUpdateCache(this.monitoringClient); + + const hosts: HostInfo[] = await this.fetchTopologyAndUpdateCache(this.monitoringClient); if (hosts === null) { - // Unable to gather topology, switch to panic mode. + // Attempt to fetch topology failed, so we switch to panic mode. + const clientToClose = this.monitoringClient; + this.monitoringClient = null; + await this.closeConnection(clientToClose); this.isVerifiedWriterConnection = false; - await this.updateMonitoringClient(null); + this.writerHostInfo = null; + await this.delay(false); continue; } - if (this.highRefreshRateEndTimeMs > 0 && Date.now() > this.highRefreshRateEndTimeMs) { - this.highRefreshRateEndTimeMs = 0; + + if (this.highRefreshRateEndTimeNs > 0 && getTimeInNanos() > this.highRefreshRateEndTimeNs) { + this.highRefreshRateEndTimeNs = BigInt(0); } - if (this.highRefreshRateEndTimeMs < 0) { - // Log topology when not in high refresh rate. - this.logTopology(`[clusterTopologyMonitor] `); + + // We avoid logging the topology while using the high refresh rate because it is too noisy. + if (this.highRefreshRateEndTimeNs === BigInt(0)) { + logger.debug(logTopology(this.getStoredHosts(), "")); } - // Set an easily interruptible delay between topology refreshes. + await this.delay(false); } - if (this.ignoreNewTopologyRequestsEndTimeMs > 0 && Date.now() > this.ignoreNewTopologyRequestsEndTimeMs) { - this.ignoreNewTopologyRequestsEndTimeMs = 0; - } } - } catch (error) { - logger.error(Messages.get("ClusterTopologyMonitor.errorDuringMonitoring", error?.message)); } finally { - this.stopMonitoring = true; - await this.updateMonitoringClient(null); - logger.debug(Messages.get("ClusterTopologyMonitor.endMonitoring")); + this._stop = true; + await this.closeHostMonitors(); + await this.hostMonitorClientCleanUp(); + + this.servicesContainer.eventPublisher.unsubscribe(this, new Set([MonitorResetEvent])); + + logger.debug(Messages.get("ClusterTopologyMonitor.stopHostMonitoringTask", this.initialHostInfo.host)); } + + return Promise.resolve(); + } + + protected checkForStableReaderTopologies(): void { + const latestHosts: HostInfo[] = this.getStoredHosts(); + if (!latestHosts || latestHosts.length === 0) { + this.stableTopologiesStartNs = BigInt(0); + return; + } + + const readerIds: string[] = latestHosts.map((host) => host.hostId); + for (const id of readerIds) { + const completedCycle = this.completedOneCycle.get(id) ?? false; + if (!completedCycle) { + // Not all reader monitors have completed a cycle. We shouldn't conclude that reader topologies are stable until + // each reader monitor has made at least one attempt to fetch topology information, even if unsuccessful. + this.stableTopologiesStartNs = BigInt(0); + return; + } + } + + const readerTopologyValues = Array.from(this.readerTopologiesById.values()); + const readerTopology: HostInfo[] | undefined = readerTopologyValues.length > 0 ? readerTopologyValues[0] : undefined; + if (!readerTopology) { + // readerTopologiesById has been cleared since checking its size. + this.stableTopologiesStartNs = BigInt(0); + return; + } + + // Check whether the topologies match. HostInfos are compared using their host, port, role, and availability fields. + // Using the first HostInfo in the topology as the reference. + // Note that monitors that encounter errors will remove their entry from the map, so only entries from + // successful monitors are checked. + const reference = JSON.stringify(readerTopology.map(this.hostInfoExtractor).sort()); + const allTopologiesMatch = readerTopologyValues.every((hosts) => JSON.stringify(hosts.map(this.hostInfoExtractor).sort()) === reference); + + if (!allTopologiesMatch) { + // The topologies detected by each reader do not match. + this.stableTopologiesStartNs = BigInt(0); + return; + } + + // All reader topologies match. + if (this.stableTopologiesStartNs === BigInt(0)) { + this.stableTopologiesStartNs = getTimeInNanos(); + } + + if (getTimeInNanos() > this.stableTopologiesStartNs + ClusterTopologyMonitorImpl.STABLE_TOPOLOGIES_DURATION_NS) { + // Reader topologies have been consistent for STABLE_TOPOLOGIES_DURATION_NS, so the topology should be accurate. + this.stableTopologiesStartNs = BigInt(0); + this.updateHostsAvailability(readerTopology); + logger.debug( + logTopology( + readerTopology, + Messages.get( + "ClusterTopologyMonitor.matchingReaderTopologies", + String(convertNanosToMs(ClusterTopologyMonitorImpl.STABLE_TOPOLOGIES_DURATION_NS)) + ) + ) + ); + this.updateTopologyCache(readerTopology); + } + } + + protected async reset(): Promise { + logger.debug(Messages.get("ClusterTopologyMonitor.reset", this.clusterId, this.initialHostInfo.host)); + + this.hostMonitorsStop = true; + await this.closeHostMonitors(); + await this.hostMonitorClientCleanUp(); + this.hostMonitorsStop = false; + this.submittedHosts.clear(); + this.stableTopologiesStartNs = BigInt(0); + this.readerTopologiesById.clear(); + this.completedOneCycle.clear(); + + this.hostMonitorsWriterInfo = null; + this.hostMonitorsLatestTopology = []; + + await this.updateMonitoringClient(null); + this.isVerifiedWriterConnection = false; + this.writerHostInfo = null; + this.highRefreshRateEndTimeNs = BigInt(0); + this.requestToUpdateTopology = false; + this.clearTopologyCache(); + + // This breaks any waiting/sleeping cycles in the monitoring task. + this.requestToUpdateTopology = true; + } + + async processEvent(event: Event): Promise { + if (event instanceof MonitorResetEvent) { + logger.debug(Messages.get("ClusterTopologyMonitor.resetEventReceived")); + const resetEvent = event as MonitorResetEvent; + if (resetEvent.clusterId === this.clusterId) { + await this.reset(); + } + } + } + + protected async hostMonitorClientCleanUp(): Promise { + const writerClientToClose = this.hostMonitorsWriterClient; + const readerClientToClose = this.hostMonitorsReaderClient; + + this.hostMonitorsWriterClient = null; + this.hostMonitorsReaderClient = null; + + if (writerClientToClose && this.monitoringClient !== writerClientToClose) { + try { + await this.closeConnection(writerClientToClose); + } catch (e: any) { + // Ignore + } + } + + if (readerClientToClose && this.monitoringClient !== readerClientToClose && writerClientToClose !== readerClientToClose) { + try { + await this.closeConnection(readerClientToClose); + } catch (e: any) { + // Ignore + } + } + } + + protected async closeHostMonitors(): Promise { + await Promise.all(this.submittedHosts.values()); + this.submittedHosts.clear(); + await this.hostMonitorClientCleanUp(); + } + + private isInPanicMode(): boolean { + return !this.monitoringClient || !this.isVerifiedWriterConnection; } private getStoredHosts(): HostInfo[] | null { - const topology = this.storageService.get(Topology, this.clusterId); - return topology == null ? null : topology.hosts; + return this.storageService.get(Topology, this.clusterId)?.hosts ?? null; } - private async delay(useHighRefreshRate: boolean) { - if (Date.now() < this.highRefreshRateEndTimeMs) { + private async delay(useHighRefreshRate: boolean): Promise { + if (getTimeInNanos() < this.highRefreshRateEndTimeNs) { useHighRefreshRate = true; } - const endTime = Date.now() + (useHighRefreshRate ? this.highRefreshRateMs : this.refreshRateMs); + const delayNs = useHighRefreshRate ? this.highRefreshRateNs : this.refreshRateNs; + const endTime: bigint = getTimeInNanos() + BigInt(delayNs); await sleep(50); - while (Date.now() < endTime && !this.requestToUpdateTopology) { + while (getTimeInNanos() < endTime && !this.requestToUpdateTopology && !this._stop) { await sleep(50); } } - - logTopology(msgPrefix: string) { - const hosts: HostInfo[] = this.getStoredHosts(); - if (hosts && hosts.length !== 0) { - logger.debug(logTopology(hosts, msgPrefix)); - } - } } export class HostMonitor { + private static readonly INITIAL_BACKOFF_MS = 100; + private static readonly MAX_BACKOFF_MS = 10000; + + protected readonly servicesContainer: FullServicesContainer; protected readonly monitor: ClusterTopologyMonitorImpl; protected readonly hostInfo: HostInfo; - protected readonly writerHostInfo: HostInfo; + protected readonly writerHostInfo: HostInfo | null; protected writerChanged: boolean = false; + protected connectionAttempts: number = 0; + protected client: ClientWrapper | null = null; - constructor(monitor: ClusterTopologyMonitorImpl, hostInfo: HostInfo, writerHostInfo: HostInfo) { + constructor(servicesContainer: FullServicesContainer, monitor: ClusterTopologyMonitorImpl, hostInfo: HostInfo, writerHostInfo: HostInfo | null) { + this.servicesContainer = servicesContainer; this.monitor = monitor; this.hostInfo = hostInfo; this.writerHostInfo = writerHostInfo; } async run() { - let client: ClientWrapper | null = null; let updateTopology: boolean = false; const startTime: number = Date.now(); logger.debug(Messages.get("HostMonitor.startMonitoring", this.hostInfo.hostId)); + const pluginService = this.servicesContainer.pluginService; try { while (!this.monitor.hostMonitorsStop) { - if (!client) { + if (!this.client) { try { - client = await this.monitor.pluginService.forceConnect(this.hostInfo, this.monitor.monitoringProperties); - this.monitor.pluginService.setAvailability(this.hostInfo.allAliases, HostAvailability.AVAILABLE); + this.client = await pluginService.forceConnect(this.hostInfo, this.monitor.monitoringProperties); + this.connectionAttempts = 0; } catch (error) { - this.monitor.pluginService.setAvailability(this.hostInfo.allAliases, HostAvailability.NOT_AVAILABLE); + // A problem occurred while connecting. + if (pluginService.isNetworkError(error)) { + // It's a network issue that's expected during a cluster failover. + // We will try again on the next iteration. + await sleep(100); + this.monitor.completedOneCycle.set(this.hostInfo.hostId, true); + this.monitor.readerTopologiesById.delete(this.hostInfo.hostId); + continue; + } else if (pluginService.isLoginError(error)) { + throw new AwsWrapperError(Messages.get("HostMonitor.loginErrorDuringMonitoring"), error); + } else { + // It might be some transient error. Let's try again. + // If the error repeats, we will try again after a longer delay. + const backoff = this.calculateBackoffWithJitter(this.connectionAttempts++); + await sleep(backoff); + this.monitor.completedOneCycle.set(this.hostInfo.hostId, true); + this.monitor.readerTopologiesById.delete(this.hostInfo.hostId); + continue; + } } } - if (client) { - let writerId = null; + if (this.client) { + let isWriter: boolean = false; try { - writerId = await this.monitor.getWriterHostIdIfConnected(client, this.hostInfo.hostId); + isWriter = await this.monitor.topologyUtils.isWriterInstance(this.client); } catch (error) { logger.error(Messages.get("ClusterTopologyMonitor.invalidWriterQuery", error?.message)); - await this.monitor.closeConnection(client); - client = null; + await this.monitor.closeConnection(this.client); + this.client = null; } - if (writerId) { - // First connection after failover may be stale. - if ((await this.monitor.pluginService.getHostRole(client)) !== HostRole.WRITER) { - logger.debug(Messages.get("HostMonitor.writerIsStale", writerId)); - writerId = null; + if (isWriter) { + try { + // First connection after failover may be stale. + const hostRole = await this.monitor.pluginService.getHostRole(this.client); + if (hostRole !== HostRole.WRITER) { + isWriter = false; + } + } catch (error: any) { + // Invalid connection, retry. + this.monitor.completedOneCycle.set(this.hostInfo.hostId, true); + this.monitor.readerTopologiesById.delete(this.hostInfo.hostId); + continue; } } - if (writerId) { + if (isWriter) { + // This prevents us from closing the connection in the finally block. if (this.monitor.hostMonitorsWriterClient) { - await this.monitor.closeConnection(client); + // The writer connection is already set up, probably by another host monitor. + await this.monitor.closeConnection(this.client); } else { - logger.debug(Messages.get("HostMonitor.detectedWriter", writerId, this.hostInfo.host)); - const updatedHosts: HostInfo[] = await this.monitor.fetchTopologyAndUpdateCache(client); - if (updatedHosts && this.monitor.hostMonitorsWriterClient === null) { - this.monitor.hostMonitorsWriterClient = client; - this.monitor.hostMonitorsWriterInfo = this.hostInfo; - this.monitor.hostMonitorsStop = true; - this.monitor.logTopology(`[hostMonitor ${this.hostInfo.hostId}] `); - } else { - await this.monitor.closeConnection(client); - } + // Successfully updated the host monitor writer connection. + logger.debug(Messages.get("HostMonitor.detectedWriter", this.hostInfo.hostId, this.hostInfo.url)); + + this.servicesContainer.importantEventService.registerEvent(() => + Messages.get("HostMonitor.detectedWriter", this.hostInfo.hostId, this.hostInfo.url) + ); + + await this.monitor.fetchTopologyAndUpdateCache(this.client); + this.hostInfo.setAvailability(HostAvailability.AVAILABLE); + this.monitor.hostMonitorsWriterClient = this.client; + this.monitor.hostMonitorsWriterInfo = this.hostInfo; + // Connection is already assigned to this.monitor.hostMonitorsWriterClient + // so we need to reset client without closing it. + this.client = null; + this.monitor.hostMonitorsStop = true; + logger.debug(logTopology(this.monitor.hostMonitorsLatestTopology, `[hostMonitor ${this.hostInfo.hostId}] `)); } - client = null; return; - } else if (client) { + } else if (this.client) { // Client is a reader. if (!this.monitor.hostMonitorsWriterClient) { - // While the writer hasn't been identified, reader client can update topology. + // We can use this reader connection to update the topology while we wait for the writer connection to + // be established. if (updateTopology) { - await this.readerTaskFetchTopology(client, this.writerHostInfo); - } else if (this.monitor.hostMonitorsReaderClient === null) { - this.monitor.hostMonitorsReaderClient = client; + await this.readerTaskFetchTopology(this.client, this.writerHostInfo); + } else if (!this.monitor.hostMonitorsReaderClient) { + this.monitor.hostMonitorsReaderClient = this.client; updateTopology = true; - await this.readerTaskFetchTopology(client, this.writerHostInfo); + await this.readerTaskFetchTopology(this.client, this.writerHostInfo); + } else { + await this.readerTaskFetchTopology(this.client, this.writerHostInfo); } } } } + + this.monitor.completedOneCycle.set(this.hostInfo.hostId, true); await sleep(100); } } catch (error) { // Close the monitor. } finally { - await this.monitor.closeConnection(client); + this.monitor.completedOneCycle.set(this.hostInfo.hostId, true); + this.monitor.readerTopologiesById.delete(this.hostInfo.hostId); + + await this.monitor.closeConnection(this.client); logger.debug(Messages.get("HostMonitor.endMonitoring", this.hostInfo.hostId, (Date.now() - startTime).toString())); } } - private async readerTaskFetchTopology(client: any, writerHostInfo: HostInfo) { + private async readerTaskFetchTopology(client: ClientWrapper, writerHostInfo: HostInfo | null) { if (!client) { return; } - let hosts: HostInfo[]; + let hosts: HostInfo[] | null; try { - hosts = await this.monitor.hostListProvider.sqlQueryForTopology(client); - if (hosts === null) { + hosts = await this.monitor.queryForTopology(client); + if (!hosts) { return; } - this.monitor.hostMonitorsLatestTopology = hosts; } catch (error) { return; } + // Share this topology so that the main monitoring task can adjust the node monitoring tasks. + this.monitor.hostMonitorsLatestTopology = hosts; + this.monitor.readerTopologiesById.set(this.hostInfo.hostId, hosts); + if (this.writerChanged) { + this.monitor.updateHostsAvailability(hosts); this.monitor.updateTopologyCache(hosts); logger.debug(logTopology(hosts, `[hostMonitor ${this.hostInfo.hostId}] `)); return; } - const latestWriterHostInfo: HostInfo = hosts.find((x) => x.role === HostRole.WRITER); + const latestWriterHostInfo = hosts.find((x) => x.role === HostRole.WRITER); if (latestWriterHostInfo && writerHostInfo && latestWriterHostInfo.hostAndPort !== writerHostInfo.hostAndPort) { this.writerChanged = true; logger.debug(Messages.get("HostMonitor.writerHostChanged", writerHostInfo.hostAndPort, latestWriterHostInfo.hostAndPort)); + this.monitor.updateHostsAvailability(hosts); this.monitor.updateTopologyCache(hosts); logger.debug(logTopology(hosts, `[hostMonitor ${this.hostInfo.hostId}] `)); } } + + private calculateBackoffWithJitter(attempt: number): number { + let backoff = HostMonitor.INITIAL_BACKOFF_MS * Math.round(Math.pow(2, Math.min(attempt, 6))); + backoff = Math.min(backoff, HostMonitor.MAX_BACKOFF_MS); + return Math.round(backoff * (0.5 + Math.random() * 0.5)); + } } diff --git a/common/lib/host_list_provider/monitoring/global_aurora_topology_monitor.ts b/common/lib/host_list_provider/monitoring/global_aurora_topology_monitor.ts new file mode 100644 index 000000000..9582c5222 --- /dev/null +++ b/common/lib/host_list_provider/monitoring/global_aurora_topology_monitor.ts @@ -0,0 +1,69 @@ +/* + Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"). + You may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +import { ClusterTopologyMonitorImpl } from "./cluster_topology_monitor"; +import { GdbTopologyUtils, GlobalTopologyUtils } from "../global_topology_utils"; +import { FullServicesContainer } from "../../utils/full_services_container"; +import { HostInfo } from "../../host_info"; +import { ClientWrapper } from "../../client_wrapper"; +import { AwsWrapperError } from "../../utils/errors"; +import { Messages } from "../../utils/messages"; +import { TopologyUtils } from "../topology_utils"; + +function isGdbTopologyUtils(utils: TopologyUtils): utils is TopologyUtils & GdbTopologyUtils { + return "getRegion" in utils && typeof (utils as unknown as GdbTopologyUtils).getRegion === "function"; +} + +export class GlobalAuroraTopologyMonitor extends ClusterTopologyMonitorImpl { + protected readonly instanceTemplatesByRegion: Map; + declare public readonly topologyUtils: TopologyUtils; + + constructor( + servicesContainer: FullServicesContainer, + topologyUtils: TopologyUtils, + clusterId: string, + initialHostInfo: HostInfo, + properties: Map, + instanceTemplate: HostInfo, + refreshRateNano: number, + highRefreshRateNano: number, + instanceTemplatesByRegion: Map + ) { + super(servicesContainer, topologyUtils, clusterId, initialHostInfo, properties, instanceTemplate, refreshRateNano, highRefreshRateNano); + + this.instanceTemplatesByRegion = instanceTemplatesByRegion; + this.topologyUtils = topologyUtils; + } + + protected override async getInstanceTemplate(hostId: string, targetClient: ClientWrapper): Promise { + if (!isGdbTopologyUtils(this.topologyUtils)) { + throw new AwsWrapperError(Messages.get("GlobalAuroraTopologyMonitor.invalidTopologyUtils")); + } + + const dialect = this.hostListProviderService.getDialect(); + const region = await this.topologyUtils.getRegion(hostId, targetClient, dialect); + + if (region) { + const instanceTemplate = this.instanceTemplatesByRegion.get(region); + if (!instanceTemplate) { + throw new AwsWrapperError(Messages.get("GlobalAuroraTopologyMonitor.cannotFindRegionTemplate", region)); + } + return instanceTemplate; + } + + return this.instanceTemplate; + } +} diff --git a/common/lib/host_list_provider/monitoring/monitoring_host_list_provider.ts b/common/lib/host_list_provider/monitoring/monitoring_host_list_provider.ts deleted file mode 100644 index c3f9b87d8..000000000 --- a/common/lib/host_list_provider/monitoring/monitoring_host_list_provider.ts +++ /dev/null @@ -1,112 +0,0 @@ -/* - Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - - Licensed under the Apache License, Version 2.0 (the "License"). - You may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ - -import { RdsHostListProvider } from "../rds_host_list_provider"; -import { HostInfo, AwsWrapperError } from "../../"; -import { PluginService } from "../../plugin_service"; -import { ClusterTopologyMonitor, ClusterTopologyMonitorImpl } from "./cluster_topology_monitor"; -import { HostListProviderService } from "../../host_list_provider_service"; -import { ClientWrapper } from "../../client_wrapper"; -import { DatabaseDialect } from "../../database_dialect/database_dialect"; -import { Messages } from "../../utils/messages"; -import { WrapperProperties } from "../../wrapper_property"; -import { BlockingHostListProvider } from "../host_list_provider"; -import { logger } from "../../../logutils"; -import { SlidingExpirationCacheWithCleanupTask } from "../../utils/sliding_expiration_cache_with_cleanup_task"; -import { isDialectTopologyAware } from "../../utils/utils"; -import { TopologyUtils } from "../topology_utils"; - -export class MonitoringRdsHostListProvider extends RdsHostListProvider implements BlockingHostListProvider { - static readonly CACHE_CLEANUP_NANOS: bigint = BigInt(60_000_000_000); // 1 minute. - static readonly MONITOR_EXPIRATION_NANOS: bigint = BigInt(15 * 60_000_000_000); // 15 minutes. - static readonly DEFAULT_TOPOLOGY_QUERY_TIMEOUT_MS = 5000; // 5 seconds. - - private static monitors: SlidingExpirationCacheWithCleanupTask = new SlidingExpirationCacheWithCleanupTask( - MonitoringRdsHostListProvider.CACHE_CLEANUP_NANOS, - () => true, - async (item: ClusterTopologyMonitor) => { - try { - await item.close(); - } catch { - // Ignore. - } - }, - "MonitoringRdsHostListProvider.monitors" - ); - - private readonly pluginService: PluginService; - - constructor( - properties: Map, - originalUrl: string, - topologyUtils: TopologyUtils, - hostListProviderService: HostListProviderService, - pluginService: PluginService - ) { - super(properties, originalUrl, topologyUtils, hostListProviderService); - this.pluginService = pluginService; - } - - async clearAll(): Promise { - RdsHostListProvider.clearAll(); - await MonitoringRdsHostListProvider.monitors.clear(); - } - - async getCurrentTopology(targetClient: ClientWrapper, dialect: DatabaseDialect): Promise { - const monitor: ClusterTopologyMonitor = this.initMonitor(); - - try { - return await monitor.forceRefresh(targetClient, MonitoringRdsHostListProvider.DEFAULT_TOPOLOGY_QUERY_TIMEOUT_MS); - } catch (error) { - logger.info(Messages.get("MonitoringHostListProvider.errorForceRefresh", error.message)); - return null; - } - } - - async sqlQueryForTopology(targetClient: ClientWrapper): Promise { - return await this.topologyUtils.queryForTopology(targetClient, this.pluginService.getDialect(), this.initialHost, this.clusterInstanceTemplate); - } - - async forceMonitoringRefresh(shouldVerifyWriter: boolean, timeoutMs: number): Promise { - const monitor: ClusterTopologyMonitor = this.initMonitor(); - - return await monitor.forceMonitoringRefresh(shouldVerifyWriter, timeoutMs); - } - - protected initMonitor(): ClusterTopologyMonitor { - const monitor: ClusterTopologyMonitor = MonitoringRdsHostListProvider.monitors.computeIfAbsent( - this.clusterId, - () => - new ClusterTopologyMonitorImpl( - this.topologyUtils, - this.clusterId, - this.initialHost, - this.properties, - this.clusterInstanceTemplate, - this.pluginService, - this, - WrapperProperties.CLUSTER_TOPOLOGY_REFRESH_RATE_MS.get(this.properties), - WrapperProperties.CLUSTER_TOPOLOGY_HIGH_REFRESH_RATE_MS.get(this.properties) - ), - MonitoringRdsHostListProvider.MONITOR_EXPIRATION_NANOS - ); - - if (monitor === null) { - throw new AwsWrapperError(Messages.get("MonitoringHostListProvider.requiresMonitor")); - } - return monitor; - } -} diff --git a/common/lib/host_list_provider/rds_host_list_provider.ts b/common/lib/host_list_provider/rds_host_list_provider.ts index aae69e4f7..9558e1aa4 100644 --- a/common/lib/host_list_provider/rds_host_list_provider.ts +++ b/common/lib/host_list_provider/rds_host_list_provider.ts @@ -21,29 +21,36 @@ import { RdsUrlType } from "../utils/rds_url_type"; import { RdsUtils } from "../utils/rds_utils"; import { HostListProviderService } from "../host_list_provider_service"; import { ConnectionUrlParser } from "../utils/connection_url_parser"; -import { AwsWrapperError } from "../utils/errors"; +import { AwsTimeoutError, AwsWrapperError } from "../utils/errors"; import { Messages } from "../utils/messages"; import { WrapperProperties } from "../wrapper_property"; import { logger } from "../../logutils"; -import { isDialectTopologyAware, logTopology } from "../utils/utils"; +import { isDialectTopologyAware } from "../database_dialect/topology_aware_database_dialect"; import { DatabaseDialect } from "../database_dialect/database_dialect"; import { ClientWrapper } from "../client_wrapper"; import { CoreServicesContainer } from "../utils/core_services_container"; import { StorageService } from "../utils/storage/storage_service"; import { Topology } from "./topology"; import { TopologyUtils } from "./topology_utils"; +import { FullServicesContainer } from "../utils/full_services_container"; +import { PluginService } from "../plugin_service"; +import { ClusterTopologyMonitor, ClusterTopologyMonitorImpl } from "./monitoring/cluster_topology_monitor"; +import { MonitorInitializer } from "../utils/monitoring/monitor"; export class RdsHostListProvider implements DynamicHostListProvider { + private static readonly DEFAULT_TOPOLOGY_QUERY_TIMEOUT_MS: number = 5000; private readonly originalUrl: string; - private readonly rdsHelper: RdsUtils; + protected readonly rdsHelper: RdsUtils; + protected readonly servicesContainers: FullServicesContainer; + private readonly pluginService: PluginService; private readonly storageService: StorageService; protected readonly topologyUtils: TopologyUtils; protected readonly properties: Map; private rdsUrlType: RdsUrlType; private initialHostList: HostInfo[]; protected initialHost: HostInfo; - private refreshRateNano: number; - private hostList?: HostInfo[]; + protected refreshRateNano: number; + protected highRefreshRateNano: number; protected readonly connectionUrlParser: ConnectionUrlParser; protected readonly hostListProviderService: HostListProviderService; @@ -51,18 +58,34 @@ export class RdsHostListProvider implements DynamicHostListProvider { public isInitialized: boolean = false; public clusterInstanceTemplate?: HostInfo; - constructor(properties: Map, originalUrl: string, topologyUtils: TopologyUtils, hostListProviderService: HostListProviderService) { + constructor(properties: Map, originalUrl: string, topologyUtils: TopologyUtils, servicesContainers: FullServicesContainer) { this.rdsHelper = new RdsUtils(); this.topologyUtils = topologyUtils; - this.hostListProviderService = hostListProviderService; - this.connectionUrlParser = hostListProviderService.getConnectionUrlParser(); + this.servicesContainers = servicesContainers; + this.pluginService = this.servicesContainers.pluginService; + this.storageService = this.servicesContainers.storageService; + this.hostListProviderService = this.servicesContainers.hostListProviderService; + this.connectionUrlParser = this.hostListProviderService.getConnectionUrlParser(); this.originalUrl = originalUrl; this.properties = properties; - this.storageService = CoreServicesContainer.getInstance().getStorageService(); // TODO: store the service container instead. + this.refreshRateNano = WrapperProperties.CLUSTER_TOPOLOGY_REFRESH_RATE_MS.get(this.properties) * 1000000; + this.highRefreshRateNano = WrapperProperties.CLUSTER_TOPOLOGY_HIGH_REFRESH_RATE_MS.get(this.properties) * 1000000; + } + + init(): void { + if (this.isInitialized) { + return; + } + + this.initSettings(); - let port = WrapperProperties.PORT.get(properties); + this.isInitialized = true; + } + + protected initSettings(): void { + let port = WrapperProperties.PORT.get(this.properties); if (port == null) { - port = hostListProviderService.getDialect().getDefaultPort(); + port = this.hostListProviderService.getDialect().getDefaultPort(); } this.initialHostList = this.connectionUrlParser.getHostsFromConnectionUrl(this.originalUrl, false, port, () => @@ -74,15 +97,8 @@ export class RdsHostListProvider implements DynamicHostListProvider { this.initialHost = this.initialHostList[0]; this.hostListProviderService.setInitialConnectionHostInfo(this.initialHost); - this.refreshRateNano = WrapperProperties.CLUSTER_TOPOLOGY_REFRESH_RATE_MS.get(this.properties) * 1000000; - this.rdsUrlType = this.rdsHelper.identifyRdsType(this.initialHost.host); - } - - init(): void { - if (this.isInitialized) { - return; - } + this.clusterId = WrapperProperties.CLUSTER_ID.get(this.properties); const hostInfoBuilder = this.hostListProviderService.getHostInfoBuilder(); this.clusterInstanceTemplate = hostInfoBuilder @@ -91,27 +107,51 @@ export class RdsHostListProvider implements DynamicHostListProvider { .build(); this.validateHostPatternSetting(this.clusterInstanceTemplate.host); + this.rdsUrlType = this.rdsHelper.identifyRdsType(this.initialHost.host); + } - this.clusterId = WrapperProperties.CLUSTER_ID.get(this.properties); + protected async getOrCreateMonitor(): Promise { + const initializer: MonitorInitializer = { + createMonitor: (servicesContainer: FullServicesContainer): ClusterTopologyMonitor => { + return new ClusterTopologyMonitorImpl( + servicesContainer, + this.topologyUtils, + this.clusterId, + this.initialHost, + this.properties, + this.clusterInstanceTemplate, + this.refreshRateNano, + this.highRefreshRateNano + ); + } + }; + + return await this.servicesContainers.monitorService.runIfAbsent( + ClusterTopologyMonitorImpl, + this.clusterId, + this.servicesContainers, + this.properties, + initializer + ); + } - this.isInitialized = true; + async forceRefresh(): Promise { + return this.forceMonitoringRefresh(false, RdsHostListProvider.DEFAULT_TOPOLOGY_QUERY_TIMEOUT_MS); } - async forceRefresh(): Promise; - async forceRefresh(targetClient: ClientWrapper): Promise; - async forceRefresh(targetClient?: ClientWrapper): Promise { + async forceMonitoringRefresh(verifyTopology: boolean, timeoutMs: number): Promise { this.init(); - const currentClient = targetClient ?? this.hostListProviderService.getCurrentClient().targetClient; - if (currentClient) { - const results: FetchTopologyResult = await this.getTopology(currentClient, true); - this.hostList = results.hosts; - return Array.from(this.hostList); + if (!this.pluginService.isDialectConfirmed()) { + // We need to confirm the dialect before creating a topology monitor so that it uses the correct SQL queries. + // We will return the original hosts parsed from the connections string until the dialect has been confirmed. + return this.initialHostList; } - throw new AwsWrapperError("Could not retrieve targetClient."); + + return await this.forceRefreshMonitor(verifyTopology, timeoutMs); } - async getHostRole(client: ClientWrapper, dialect: DatabaseDialect): Promise { + async getHostRole(client: ClientWrapper, _dialect: DatabaseDialect): Promise { return this.topologyUtils.getHostRole(client); } @@ -134,7 +174,7 @@ export class RdsHostListProvider implements DynamicHostListProvider { return null; } - let topology = await this.refresh(targetClient); + let topology = await this.refresh(); let isForcedRefresh = false; if (!topology) { @@ -161,47 +201,37 @@ export class RdsHostListProvider implements DynamicHostListProvider { return matches.length === 0 ? null : matches[0]; } - async refresh(): Promise; - async refresh(targetClient: ClientWrapper): Promise; - async refresh(targetClient?: ClientWrapper): Promise { + async refresh(): Promise { this.init(); - const currentClient = targetClient ?? this.hostListProviderService.getCurrentClient().targetClient; - const results: FetchTopologyResult = await this.getTopology(currentClient, false); - logger.debug(logTopology(results.hosts, results.isCachedData ? "[From cache] " : "")); - this.hostList = results.hosts; - return this.hostList; + const results: FetchTopologyResult = await this.getTopology(); + return results.hosts; } - async getTopology(targetClient: ClientWrapper | undefined, forceUpdate: boolean): Promise { + async getTopology(): Promise { this.init(); - if (!this.clusterId) { - throw new AwsWrapperError(Messages.get("RdsHostListProvider.noClusterId")); - } - - const cachedHosts: HostInfo[] | null = this.getStoredTopology(); + const storedTopology: HostInfo[] | null = this.getStoredTopology(); - // This clusterId is a primary one and is about to create a new entry in the cache. - // When a primary entry is created it needs to be suggested for other (non-primary) entries. - // Remember a flag to do suggestion after cache is updated. - if (!cachedHosts || forceUpdate) { + if (!storedTopology) { // need to re-fetch the topology. - if (!targetClient || !(await this.hostListProviderService.isClientValid(targetClient))) { + + if (!this.pluginService.isDialectConfirmed()) { + // We need to confirm the dialect before creating a topology monitor so that it uses the correct SQL queries. + // We will return the original hosts parsed from the connections string until the dialect has been confirmed. return new FetchTopologyResult(false, this.initialHostList); } - const hosts = await this.getCurrentTopology(targetClient, this.hostListProviderService.getDialect()); + const hosts = await this.forceRefreshMonitor(false, RdsHostListProvider.DEFAULT_TOPOLOGY_QUERY_TIMEOUT_MS); if (hosts && hosts.length > 0) { - this.storageService.set(this.clusterId, new Topology(hosts)); return new FetchTopologyResult(false, hosts); } } - if (!cachedHosts) { + if (!storedTopology) { return new FetchTopologyResult(false, this.initialHostList); } else { - return new FetchTopologyResult(true, cachedHosts); + return new FetchTopologyResult(true, storedTopology); } } @@ -209,12 +239,16 @@ export class RdsHostListProvider implements DynamicHostListProvider { return await this.topologyUtils.queryForTopology(targetClient, dialect, this.initialHost, this.clusterInstanceTemplate); } - private getHostEndpoint(hostName: string): string | null { - if (!this.clusterInstanceTemplate || !this.clusterInstanceTemplate.host) { - return null; + protected async forceRefreshMonitor(verifyTopology: boolean, timeoutMs: number): Promise { + const monitor = await this.getOrCreateMonitor(); + try { + return await monitor.forceMonitoringRefresh(verifyTopology, timeoutMs); + } catch (error) { + if (error instanceof AwsTimeoutError) { + return null; + } + throw error; } - const host = this.clusterInstanceTemplate.host; - return host.replace("?", hostName); } getStoredTopology(): HostInfo[] | null { @@ -227,18 +261,13 @@ export class RdsHostListProvider implements DynamicHostListProvider { return topology == null ? null : topology.hosts; } - static clearAll(): void { - // No-op - // TODO: remove if still not used after full service container refactoring - } - clear(): void { if (this.clusterId) { - CoreServicesContainer.getInstance().getStorageService().remove(Topology, this.clusterId); + this.servicesContainers.storageService.remove(Topology, this.clusterId); } } - private validateHostPatternSetting(hostPattern: string) { + protected validateHostPatternSetting(hostPattern: string) { if (!this.rdsHelper.isDnsPatternValid(hostPattern)) { const message: string = Messages.get("RdsHostListProvider.invalidPattern.suggestedClusterId"); logger.error(message); @@ -246,7 +275,7 @@ export class RdsHostListProvider implements DynamicHostListProvider { } const rdsUrlType: RdsUrlType = this.rdsHelper.identifyRdsType(hostPattern); - if (rdsUrlType == RdsUrlType.RDS_PROXY) { + if (rdsUrlType == RdsUrlType.RDS_PROXY || rdsUrlType == RdsUrlType.RDS_PROXY_ENDPOINT) { const message: string = Messages.get("RdsHostListProvider.clusterInstanceHostPatternNotSupportedForRDSProxy"); logger.error(message); throw new AwsWrapperError(message); diff --git a/common/lib/host_list_provider/topology.ts b/common/lib/host_list_provider/topology.ts index 35a534606..b7d4e344b 100644 --- a/common/lib/host_list_provider/topology.ts +++ b/common/lib/host_list_provider/topology.ts @@ -1,18 +1,18 @@ /* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ + Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"). + You may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ import { HostInfo } from "../host_info"; diff --git a/common/lib/host_list_provider/topology_utils.ts b/common/lib/host_list_provider/topology_utils.ts index 2542c63f5..500f3046d 100644 --- a/common/lib/host_list_provider/topology_utils.ts +++ b/common/lib/host_list_provider/topology_utils.ts @@ -1,23 +1,22 @@ /* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ + Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"). + You may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ import { ClientWrapper } from "../client_wrapper"; import { DatabaseDialect } from "../database_dialect/database_dialect"; import { HostInfo } from "../host_info"; -import { isDialectTopologyAware } from "../utils/utils"; import { Messages } from "../utils/messages"; import { HostRole } from "../host_role"; import { HostAvailability } from "../host_availability/host_availability"; @@ -25,6 +24,11 @@ import { HostInfoBuilder } from "../host_info_builder"; import { AwsWrapperError } from "../utils/errors"; import { TopologyAwareDatabaseDialect } from "../database_dialect/topology_aware_database_dialect"; +/** + * Type representing an instance template - either a single HostInfo or a Map of region to HostInfo. + */ +export type InstanceTemplate = HostInfo | Map; + /** * Options for creating a TopologyQueryResult instance. */ @@ -66,11 +70,11 @@ export class TopologyQueryResult { } /** - * A class defining utility methods that can be used to retrieve and process a variety of database topology + * An abstract class defining utility methods that can be used to retrieve and process a variety of database topology * information. This class can be overridden to define logic specific to various database engine deployments * (e.g. Aurora, Multi-AZ, Global Aurora etc.). */ -export class TopologyUtils { +export abstract class TopologyUtils { protected readonly dialect: TopologyAwareDatabaseDialect; protected readonly hostInfoBuilder: HostInfoBuilder; @@ -84,25 +88,17 @@ export class TopologyUtils { * * @param targetClient the client wrapper to use to query the database. * @param dialect the database dialect to use for the topology query. - * @param clusterInstanceTemplate the template {@link HostInfo} to use when constructing new {@link HostInfo} objects from - * the data returned by the topology query. + * @param initialHost the initial host info. + * @param instanceTemplate the template for constructing host info objects. * @returns a list of {@link HostInfo} objects representing the results of the topology query. * @throws TypeError if the dialect is not topology-aware. */ - async queryForTopology( + abstract queryForTopology( targetClient: ClientWrapper, dialect: DatabaseDialect, initialHost: HostInfo, - clusterInstanceTemplate: HostInfo - ): Promise { - if (!isDialectTopologyAware(dialect)) { - throw new TypeError(Messages.get("RdsHostListProvider.incorrectDialect")); - } - - return await dialect - .queryForTopology(targetClient) - .then((res: TopologyQueryResult[]) => this.verifyWriter(this.createHosts(res, initialHost, clusterInstanceTemplate))); - } + instanceTemplate: InstanceTemplate + ): Promise; public createHost( instanceId: string | undefined, @@ -123,7 +119,6 @@ export class TopologyUtils { } const finalEndpoint = endpoint ?? this.getHostEndpoint(hostname, instanceTemplate) ?? ""; - const finalPort = port ?? (instanceTemplate?.isPortSpecified() ? instanceTemplate?.port : initialHost?.port); const host: HostInfo = this.hostInfoBuilder @@ -139,46 +134,8 @@ export class TopologyUtils { return host; } - /** - * Creates {@link HostInfo} objects from the given topology query results. - * - * @param topologyQueryResults the result set returned by the topology query describing the cluster topology - * @param initialHost the {@link HostInfo} describing the initial connection. - * @param clusterInstanceTemplate the template used to construct the new {@link HostInfo} objects. - * @returns a list of {@link HostInfo} objects representing the topology. - */ - public createHosts(topologyQueryResults: TopologyQueryResult[], initialHost: HostInfo, clusterInstanceTemplate: HostInfo): HostInfo[] { - const hostsMap = new Map(); - topologyQueryResults.forEach((row) => { - const lastUpdateTime = row.lastUpdateTime ?? Date.now(); - - const host = this.createHost( - row.id, - row.host, - row.isWriter, - row.weight, - lastUpdateTime, - initialHost, - clusterInstanceTemplate, - row.endpoint, - row.port - ); - - const existing = hostsMap.get(host.host); - if (!existing || existing.lastUpdateTime < host.lastUpdateTime) { - hostsMap.set(host.host, host); - } - }); - - return Array.from(hostsMap.values()); - } - /** * Gets the host endpoint by replacing the placeholder in the cluster instance template. - * - * @param hostName the host name to use in the endpoint. - * @param clusterInstanceTemplate the template containing the endpoint pattern. - * @returns the constructed endpoint, or null if the template is invalid. */ protected getHostEndpoint(hostName: string, clusterInstanceTemplate: HostInfo): string | null { if (!clusterInstanceTemplate || !clusterInstanceTemplate.host) { @@ -191,9 +148,6 @@ export class TopologyUtils { /** * Verifies that the topology contains exactly one writer instance. * If multiple writers are found, selects the most recently updated one. - * - * @param allHosts the list of all hosts from the topology query. - * @returns the verified list of hosts with exactly one writer, or null if no writer is found. */ protected async verifyWriter(allHosts: HostInfo[]): Promise { if (allHosts === null || allHosts.length === 0) { diff --git a/common/lib/host_list_provider_service.ts b/common/lib/host_list_provider_service.ts index 094750d92..155377723 100644 --- a/common/lib/host_list_provider_service.ts +++ b/common/lib/host_list_provider_service.ts @@ -14,7 +14,7 @@ limitations under the License. */ -import { BlockingHostListProvider, HostListProvider } from "./host_list_provider/host_list_provider"; +import { HostListProvider } from "./host_list_provider/host_list_provider"; import { HostInfo } from "./host_info"; import { AwsClient } from "./aws_client"; import { DatabaseDialect } from "./database_dialect/database_dialect"; @@ -28,7 +28,7 @@ export interface HostListProviderService { setHostListProvider(hostListProvider: HostListProvider): void; - isStaticHostListProvider(): boolean; + isDynamicHostListProvider(): boolean; setInitialConnectionHostInfo(initialConnectionHostInfo: HostInfo): void; @@ -53,6 +53,4 @@ export interface HostListProviderService { getTelemetryFactory(): TelemetryFactory; setAllowedAndBlockedHosts(allowedAndBlockedHosts: AllowedAndBlockedHosts): void; - - isBlockingHostListProvider(arg: any): arg is BlockingHostListProvider; } diff --git a/common/lib/host_selector.ts b/common/lib/host_selector.ts index ac995cf37..e32fbb934 100644 --- a/common/lib/host_selector.ts +++ b/common/lib/host_selector.ts @@ -18,5 +18,5 @@ import { HostInfo } from "./host_info"; import { HostRole } from "./host_role"; export interface HostSelector { - getHost(hosts: HostInfo[], role: HostRole, props?: Map): HostInfo; + getHost(hosts: HostInfo[], role: HostRole | null, props?: Map): HostInfo; } diff --git a/common/lib/internal_pooled_connection_provider.ts b/common/lib/internal_pooled_connection_provider.ts index aa9846aa2..a104bfe28 100644 --- a/common/lib/internal_pooled_connection_provider.ts +++ b/common/lib/internal_pooled_connection_provider.ts @@ -27,7 +27,7 @@ import { lookup, LookupAddress } from "dns"; import { promisify } from "util"; import { HostInfoBuilder } from "./host_info_builder"; import { RdsUrlType } from "./utils/rds_url_type"; -import { AwsWrapperError } from "./index"; +import { AwsWrapperError } from "./utils/errors"; import { Messages } from "./utils/messages"; import { HostSelector } from "./host_selector"; import { RandomHostSelector } from "./random_host_selector"; diff --git a/common/lib/least_connections_host_selector.ts b/common/lib/least_connections_host_selector.ts index 58a6ed867..f02f355e9 100644 --- a/common/lib/least_connections_host_selector.ts +++ b/common/lib/least_connections_host_selector.ts @@ -32,7 +32,7 @@ export class LeastConnectionsHostSelector implements HostSelector { getHost(hosts: HostInfo[], role: HostRole, props?: Map): HostInfo { const eligibleHosts: HostInfo[] = hosts - .filter((host: HostInfo) => host.role === role && host.availability === HostAvailability.AVAILABLE) + .filter((host: HostInfo) => (role === null || host.role === role) && host.availability === HostAvailability.AVAILABLE) .sort((hostA: HostInfo, hostB: HostInfo) => { const hostACount = this.getNumConnections(hostA, LeastConnectionsHostSelector.databasePools); const hostBCount = this.getNumConnections(hostB, LeastConnectionsHostSelector.databasePools); diff --git a/common/lib/mysql_client_wrapper.ts b/common/lib/mysql_client_wrapper.ts index eb438d533..52367053e 100644 --- a/common/lib/mysql_client_wrapper.ts +++ b/common/lib/mysql_client_wrapper.ts @@ -66,7 +66,7 @@ export class MySQLClientWrapper implements ClientWrapper { async abort(): Promise { try { - return await ClientUtils.queryWithTimeout(this.client?.destroy(), this.properties); + this.client?.destroy(); } catch (error: any) { // ignore } diff --git a/common/lib/partial_plugin_service.ts b/common/lib/partial_plugin_service.ts new file mode 100644 index 000000000..2097d093e --- /dev/null +++ b/common/lib/partial_plugin_service.ts @@ -0,0 +1,550 @@ +/* + Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"). + You may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +import { PluginService } from "./plugin_service"; +import { HostInfo } from "./host_info"; +import { AwsClient } from "./aws_client"; +import { DynamicHostListProvider, HostListProvider } from "./host_list_provider/host_list_provider"; +import { ConnectionUrlParser } from "./utils/connection_url_parser"; +import { DatabaseDialect } from "./database_dialect/database_dialect"; +import { HostInfoBuilder } from "./host_info_builder"; +import { AwsTimeoutError, AwsWrapperError, UnsupportedMethodError } from "./"; +import { HostAvailability } from "./host_availability/host_availability"; +import { HostAvailabilityCacheItem } from "./host_availability/host_availability_cache_item"; +import { HostChangeOptions } from "./host_change_options"; +import { HostRole } from "./host_role"; +import { SessionStateService } from "./session_state_service"; +import { HostAvailabilityStrategyFactory } from "./host_availability/host_availability_strategy_factory"; +import { ClientWrapper } from "./client_wrapper"; +import { logger } from "../logutils"; +import { Messages } from "./utils/messages"; +import { getWriter, logTopology } from "./utils/utils"; +import { TelemetryFactory } from "./utils/telemetry/telemetry_factory"; +import { DriverDialect } from "./driver_dialect/driver_dialect"; +import { AllowedAndBlockedHosts } from "./allowed_and_blocked_hosts"; +import { ConnectionPlugin } from "./connection_plugin"; +import { FullServicesContainer } from "./utils/full_services_container"; +import { HostListProviderService } from "./host_list_provider_service"; +import { StorageService } from "./utils/storage/storage_service"; +import { CoreServicesContainer } from "./utils/core_services_container"; + +/** + * A PluginService containing some methods that are not intended to be called. This class is intended to be used + * by monitors, which require a PluginService, but are not expected to need or use some of the methods defined + * by the PluginService interface. The methods that are not expected to be called will throw an + * UnsupportedOperationException when called. + */ +export class PartialPluginService implements PluginService, HostListProviderService { + private static readonly DEFAULT_TOPOLOGY_QUERY_TIMEOUT_MS = 5000; // 5 seconds + + protected readonly servicesContainer: FullServicesContainer; + protected readonly storageService: StorageService; + protected readonly props: Map; + protected hostListProvider: HostListProvider | null = null; + protected hosts: HostInfo[] = []; + protected currentHostInfo: HostInfo | null = null; + protected initialConnectionHostInfo: HostInfo | null = null; + protected isInTransactionFlag: boolean = false; + protected readonly dialect: DatabaseDialect; + protected readonly driverDialect: DriverDialect; + protected allowedAndBlockedHosts: AllowedAndBlockedHosts | null = null; + private _isPooledClient: boolean = false; + private connectionUrlParser: ConnectionUrlParser; + + constructor( + servicesContainer: FullServicesContainer, + props: Map, + dialect: DatabaseDialect, + driverDialect: DriverDialect, + connectionUrlParser: ConnectionUrlParser + ) { + this.servicesContainer = servicesContainer; + this.storageService = servicesContainer.storageService; + this.servicesContainer.hostListProviderService = this; + this.servicesContainer.pluginService = this; + + this.props = props; + this.dialect = dialect; + this.driverDialect = driverDialect; + this.connectionUrlParser = connectionUrlParser; + + this.hostListProvider = this.dialect.getHostListProvider(this.props, this.props.get("host"), this.servicesContainer); + } + + getCurrentClient(): AwsClient { + throw new AwsWrapperError(Messages.get("PartialPluginService.unexpectedMethodCall", "getCurrentClient")); + } + + getCurrentHostInfo(): HostInfo | null { + if (!this.currentHostInfo) { + this.currentHostInfo = this.initialConnectionHostInfo; + + if (!this.currentHostInfo) { + if (this.getAllHosts().length === 0) { + throw new AwsWrapperError(Messages.get("PluginService.hostListEmpty")); + } + + const writerHost = getWriter(this.getAllHosts()); + if (writerHost) { + this.currentHostInfo = writerHost; + const allowedHosts = this.getHosts(); + if (!allowedHosts.some((hostInfo: HostInfo) => hostInfo.host === writerHost.host && hostInfo.port === writerHost.port)) { + throw new AwsWrapperError( + Messages.get( + "PluginService.currentHostNotAllowed", + this.currentHostInfo ? this.currentHostInfo.host : "", + logTopology(allowedHosts, "") + ) + ); + } + } + + if (!this.currentHostInfo) { + const hosts = this.getHosts(); + if (hosts.length > 0) { + this.currentHostInfo = hosts[0]; + } + } + } + + if (!this.currentHostInfo) { + throw new AwsWrapperError(Messages.get("PluginService.currentHostNotDefined")); + } + + logger.debug(`Set current host to: ${this.currentHostInfo.host}`); + } + + return this.currentHostInfo; + } + + setCurrentHostInfo(value: HostInfo): void { + this.currentHostInfo = value; + } + + setInitialConnectionHostInfo(initialConnectionHostInfo: HostInfo): void { + this.initialConnectionHostInfo = initialConnectionHostInfo; + } + + getInitialConnectionHostInfo(): HostInfo | null { + return this.initialConnectionHostInfo; + } + + acceptsStrategy(role: HostRole, strategy: string): boolean { + throw new AwsWrapperError(Messages.get("PartialPluginService.unexpectedMethodCall", "acceptsStrategy")); + } + + getHostInfoByStrategy(role: HostRole, strategy: string, hosts?: HostInfo[]): HostInfo | undefined { + throw new AwsWrapperError(Messages.get("PartialPluginService.unexpectedMethodCall", "getHostInfoByStrategy")); + } + + getHostRole(client: any): Promise | undefined { + return this.dialect.getHostRole(client); + } + + getDriverDialect(): DriverDialect { + return this.driverDialect; + } + + getConnectionUrlParser(): ConnectionUrlParser { + return this.connectionUrlParser; + } + + setCurrentClient(newClient: ClientWrapper, hostInfo: HostInfo): Promise> { + throw new AwsWrapperError(Messages.get("PartialPluginService.unexpectedMethodCall", "setCurrentClient")); + } + + protected compare(hostInfoA: HostInfo, hostInfoB: HostInfo): Set { + const changes: Set = new Set(); + + if (hostInfoA.host !== hostInfoB.host || hostInfoA.port !== hostInfoB.port) { + changes.add(HostChangeOptions.HOSTNAME); + } + + if (hostInfoA.role !== hostInfoB.role) { + if (hostInfoB.role === HostRole.WRITER) { + changes.add(HostChangeOptions.PROMOTED_TO_WRITER); + } else if (hostInfoB.role === HostRole.READER) { + changes.add(HostChangeOptions.PROMOTED_TO_READER); + } + } + + if (hostInfoA.availability !== hostInfoB.availability) { + if (hostInfoB.availability === HostAvailability.AVAILABLE) { + changes.add(HostChangeOptions.WENT_UP); + } else if (hostInfoB.availability === HostAvailability.NOT_AVAILABLE) { + changes.add(HostChangeOptions.WENT_DOWN); + } + } + + if (changes.size > 0) { + changes.add(HostChangeOptions.HOST_CHANGED); + } + + return changes; + } + + getAllHosts(): HostInfo[] { + return this.hosts; + } + + getHosts(): HostInfo[] { + const hostPermissions = this.allowedAndBlockedHosts; + if (!hostPermissions) { + return this.hosts; + } + + let hosts = this.hosts; + const allowedHostIds = hostPermissions.getAllowedHostIds(); + const blockedHostIds = hostPermissions.getBlockedHostIds(); + + if (allowedHostIds && allowedHostIds.size > 0) { + hosts = hosts.filter((host: HostInfo) => allowedHostIds.has(host.hostId)); + } + + if (blockedHostIds && blockedHostIds.size > 0) { + hosts = hosts.filter((host: HostInfo) => !blockedHostIds.has(host.hostId)); + } + + return hosts; + } + + setAvailability(hostAliases: Set, availability: HostAvailability): void { + if (hostAliases.size === 0) { + return; + } + + const hostsToChange = [ + ...new Set( + this.getAllHosts().filter( + (host: HostInfo) => hostAliases.has(host.asAlias) || [...host.aliases].some((hostAlias: string) => hostAliases.has(hostAlias)) + ) + ) + ]; + + if (hostsToChange.length === 0) { + return; + } + + const changes = new Map>(); + for (const host of hostsToChange) { + const currentAvailability = host.getAvailability(); + host.availability = availability; + this.storageService.set(host.url, new HostAvailabilityCacheItem(availability)); + if (currentAvailability !== availability) { + let hostChanges: Set; + if (availability === HostAvailability.AVAILABLE) { + hostChanges = new Set([HostChangeOptions.WENT_UP, HostChangeOptions.HOST_CHANGED]); + } else { + hostChanges = new Set([HostChangeOptions.WENT_DOWN, HostChangeOptions.HOST_CHANGED]); + } + changes.set(host.url, hostChanges); + } + } + + if (changes.size > 0) { + this.servicesContainer.pluginManager?.notifyHostListChanged(changes); + } + } + + isInTransaction(): boolean { + throw new AwsWrapperError(Messages.get("PartialPluginService.unexpectedMethodCall", "isInTransaction")); + } + + isDialectConfirmed(): boolean { + throw new AwsWrapperError(Messages.get("PartialPluginService.unexpectedMethodCall", "isDialectConfirmed")); + } + + setInTransaction(inTransaction: boolean): void { + this.isInTransactionFlag = inTransaction; + } + + getHostListProvider(): HostListProvider | null { + return this.hostListProvider; + } + + async refreshHostList(): Promise { + const updatedHostList = await this.getHostListProvider()?.refresh(); + if (updatedHostList && updatedHostList !== this.hosts) { + this.updateHostAvailability(updatedHostList); + this.setHostList(this.hosts, updatedHostList); + } + } + + async forceRefreshHostList(): Promise { + await this.forceMonitoringRefresh(false, PartialPluginService.DEFAULT_TOPOLOGY_QUERY_TIMEOUT_MS); + } + + async forceMonitoringRefresh(shouldVerifyWriter: boolean, timeoutMs: number): Promise { + const hostListProvider = this.getHostListProvider(); + + if (!this.isDynamicHostListProvider()) { + const providerName = hostListProvider?.constructor.name ?? "null"; + throw new UnsupportedMethodError(Messages.get("PluginService.requiredDynamicHostListProvider", providerName)); + } + + try { + const updatedHostList = await (hostListProvider as DynamicHostListProvider).forceMonitoringRefresh(shouldVerifyWriter, timeoutMs); + if (updatedHostList) { + this.updateHostAvailability(updatedHostList); + this.setHostList(this.hosts, updatedHostList); + return true; + } + } catch (err) { + if (err instanceof AwsTimeoutError) { + logger.debug(Messages.get("PluginService.forceMonitoringRefreshTimeout", timeoutMs.toString())); + } + } + + return false; + } + + protected setHostList(oldHosts: HostInfo[] | null, newHosts: HostInfo[] | null): void { + const oldHostMap: Map = oldHosts ? new Map(oldHosts.map((e) => [e.url, e])) : new Map(); + + const newHostMap: Map = newHosts ? new Map(newHosts.map((e) => [e.url, e])) : new Map(); + + const changes: Map> = new Map(); + + oldHostMap.forEach((value, key) => { + const correspondingNewHost = newHostMap.get(key); + if (!correspondingNewHost) { + changes.set(key, new Set([HostChangeOptions.HOST_DELETED])); + } else { + const hostChanges = this.compare(value, correspondingNewHost); + if (hostChanges.size > 0) { + changes.set(key, hostChanges); + } + } + }); + + newHostMap.forEach((value, key) => { + if (!oldHostMap.has(key)) { + changes.set(key, new Set([HostChangeOptions.HOST_ADDED])); + } + }); + + if (changes.size > 0) { + this.hosts = newHosts ? newHosts : []; + this.servicesContainer.pluginManager?.notifyHostListChanged(changes); + } + } + + isDynamicHostListProvider(): boolean { + const provider = this.getHostListProvider(); + return provider !== null && typeof (provider as any).forceMonitoringRefresh === "function"; + } + + setHostListProvider(hostListProvider: HostListProvider): void { + this.hostListProvider = hostListProvider; + } + + connect(hostInfo: HostInfo, props: Map): Promise; + connect(hostInfo: HostInfo, props: Map, pluginToSkip: ConnectionPlugin | null): Promise; + connect(hostInfo: HostInfo, props: Map, pluginToSkip?: ConnectionPlugin | null): Promise { + throw new AwsWrapperError(Messages.get("PartialPluginService.unexpectedMethodCall", "connect")); + } + + forceConnect(hostInfo: HostInfo, props: Map): Promise; + forceConnect(hostInfo: HostInfo, props: Map, pluginToSkip: ConnectionPlugin | null): Promise; + forceConnect(hostInfo: HostInfo, props: Map, pluginToSkip?: ConnectionPlugin | null): Promise { + const pluginManager = this.servicesContainer.pluginManager; + if (!pluginManager) { + throw new AwsWrapperError(Messages.get("PartialPluginService.unexpectedMethodCall", "forceConnect")); + } + return pluginManager.forceConnect(hostInfo, props, true, pluginToSkip ?? null); + } + + protected updateHostAvailability(hosts: HostInfo[]): void { + hosts.forEach((host) => { + const cacheItem = this.storageService.get(HostAvailabilityCacheItem, host.url); + if (cacheItem != null) { + host.availability = cacheItem.availability; + } + }); + } + + // Error handler methods + isLoginError(e: Error): boolean { + return this.dialect.getErrorHandler().isLoginError(e); + } + + isNetworkError(e: Error): boolean { + return this.dialect.getErrorHandler().isNetworkError(e); + } + + isSyntaxError(e: Error): boolean { + return this.dialect.getErrorHandler().isSyntaxError(e); + } + + hasLoginError(): boolean { + return this.dialect.getErrorHandler().hasLoginError(); + } + + hasNetworkError(): boolean { + return this.dialect.getErrorHandler().hasNetworkError(); + } + + getUnexpectedError(): Error | null { + return this.dialect.getErrorHandler().getUnexpectedError(); + } + + attachErrorListener(clientWrapper: ClientWrapper | undefined): void { + this.dialect.getErrorHandler().attachErrorListener(clientWrapper); + } + + attachNoOpErrorListener(clientWrapper: ClientWrapper | undefined): void { + this.dialect.getErrorHandler().attachNoOpErrorListener(clientWrapper); + } + + removeErrorListener(clientWrapper: ClientWrapper | undefined): void { + this.dialect.getErrorHandler().removeErrorListener(clientWrapper); + } + + getDialect(): DatabaseDialect { + return this.dialect; + } + + async updateDialect(targetClient: ClientWrapper): Promise { + // Do nothing. This method is called after connecting in DefaultConnectionPlugin but the dialect passed to the + // constructor should already be updated and verified. + } + + async identifyConnection(targetClient: ClientWrapper): Promise { + const provider = this.getHostListProvider(); + if (!provider) { + return Promise.reject(new AwsWrapperError(Messages.get("PluginService.errorIdentifyConnection"))); + } + return provider.identifyConnection(targetClient); + } + + async fillAliases(targetClient: ClientWrapper, hostInfo: HostInfo): Promise { + if (!hostInfo) { + return; + } + + if (hostInfo.aliases.size > 0) { + logger.debug(Messages.get("PluginService.nonEmptyAliases", [...hostInfo.aliases].join(", "))); + return; + } + + hostInfo.addAlias(hostInfo.asAlias); + + try { + const res = await this.dialect.getHostAliasAndParseResults(targetClient); + if (res) { + hostInfo.addAlias(res); + } + } catch (error) { + logger.debug(Messages.get("PluginService.failedToRetrieveHostPort")); + } + + try { + const host = await this.identifyConnection(targetClient); + if (host && host.allAliases) { + hostInfo.addAlias(...host.allAliases); + } + } catch (error) { + // Ignore errors from identifyConnection + logger.debug(Messages.get("PluginService.failedToRetrieveHostPort")); + } + } + + getHostInfoBuilder(): HostInfoBuilder { + return new HostInfoBuilder({ hostAvailabilityStrategy: new HostAvailabilityStrategyFactory().create(this.props) }); + } + + getProperties(): Map { + return this.props; + } + + getTelemetryFactory(): TelemetryFactory { + const pluginManager = this.servicesContainer.pluginManager; + if (!pluginManager) { + throw new AwsWrapperError(Messages.get("PartialPluginService.unexpectedMethodCall", "getTelemetryFactory")); + } + return pluginManager.getTelemetryFactory(); + } + + getSessionStateService(): SessionStateService { + throw new AwsWrapperError(Messages.get("PartialPluginService.unexpectedMethodCall", "getSessionStateService")); + } + + async updateState(sql: string): Promise { + throw new AwsWrapperError(Messages.get("PartialPluginService.unexpectedMethodCall", "updateState")); + } + + updateInTransaction(sql: string): void { + throw new AwsWrapperError(Messages.get("PartialPluginService.unexpectedMethodCall", "updateInTransaction")); + } + + async isClientValid(targetClient: ClientWrapper): Promise { + return await this.getDialect().isClientValid(targetClient); + } + + async abortCurrentClient(): Promise { + throw new AwsWrapperError(Messages.get("PartialPluginService.unexpectedMethodCall", "abortCurrentClient")); + } + + async abortTargetClient(targetClient: ClientWrapper | undefined | null): Promise { + if (targetClient) { + await targetClient.abort(); + } + } + + updateConfigWithProperties(props: Map): void { + throw new AwsWrapperError(Messages.get("PartialPluginService.unexpectedMethodCall", "updateConfigWithProperties")); + } + + setAllowedAndBlockedHosts(allowedAndBlockedHosts: AllowedAndBlockedHosts): void { + this.allowedAndBlockedHosts = allowedAndBlockedHosts; + } + + setStatus(clazz: any, status: T | null, clusterBound: boolean): void; + setStatus(clazz: any, status: T | null, key: string): void; + setStatus(clazz: any, status: T | null, clusterBoundOrKey: boolean | string): void { + throw new AwsWrapperError(Messages.get("PartialPluginService.unexpectedMethodCall", "setStatus")); + } + + getStatus(clazz: any, clusterBound: boolean): T; + getStatus(clazz: any, key: string): T; + getStatus(clazz: any, clusterBoundOrKey: boolean | string): T { + throw new AwsWrapperError(Messages.get("PartialPluginService.unexpectedMethodCall", "getStatus")); + } + + isPluginInUse(plugin: any): boolean { + try { + return this.servicesContainer.pluginManager?.isPluginInUse(plugin) ?? false; + } catch (e) { + return false; + } + } + + getPlugin(pluginClazz: new (...args: any[]) => T): T | null { + return this.servicesContainer.pluginManager?.unwrapPlugin(pluginClazz) ?? null; + } + + static clearCache(): void { + CoreServicesContainer.getInstance().storageService.clear(HostAvailabilityCacheItem); + } + + isPooledClient(): boolean { + return this._isPooledClient; + } + + setIsPooledClient(isPooledClient: boolean): void { + this._isPooledClient = isPooledClient; + } +} diff --git a/common/lib/plugin_factory.ts b/common/lib/plugin_factory.ts index 7360ba1dc..1e69580b6 100644 --- a/common/lib/plugin_factory.ts +++ b/common/lib/plugin_factory.ts @@ -16,9 +16,10 @@ import { PluginService } from "./plugin_service"; import { ConnectionPlugin } from "./connection_plugin"; +import { FullServicesContainer } from "./utils/full_services_container"; export class ConnectionPluginFactory { - getInstance(pluginService: PluginService, properties: Map): Promise { + getInstance(servicesContainer: FullServicesContainer, properties: Map): Promise { return; } } diff --git a/common/lib/plugin_manager.ts b/common/lib/plugin_manager.ts index 410b4cf99..1edffdf54 100644 --- a/common/lib/plugin_manager.ts +++ b/common/lib/plugin_manager.ts @@ -19,7 +19,6 @@ import { HostInfo } from "./host_info"; import { ConnectionPluginChainBuilder } from "./connection_plugin_chain_builder"; import { AwsWrapperError } from "./utils/errors"; import { Messages } from "./utils/messages"; -import { PluginServiceManagerContainer } from "./plugin_service_manager_container"; import { HostListProviderService } from "./host_list_provider_service"; import { HostChangeOptions } from "./host_change_options"; import { OldConnectionSuggestionAction } from "./old_connection_suggestion_action"; @@ -32,6 +31,7 @@ import { TelemetryTraceLevel } from "./utils/telemetry/telemetry_trace_level"; import { ConnectionProvider } from "./connection_provider"; import { ConnectionPluginFactory } from "./plugin_factory"; import { ConfigurationProfile } from "./profile/configuration_profile"; +import { FullServicesContainer } from "./utils/full_services_container"; import { CoreServicesContainer } from "./utils/core_services_container"; type PluginFunc = (plugin: ConnectionPlugin, targetFunc: () => Promise) => Promise; @@ -80,17 +80,16 @@ export class PluginManager { private readonly props: Map; private _plugins: ConnectionPlugin[] = []; private readonly connectionProviderManager: ConnectionProviderManager; - private pluginServiceManagerContainer: PluginServiceManagerContainer; + private fullServicesContainer: FullServicesContainer; protected telemetryFactory: TelemetryFactory; constructor( - pluginServiceManagerContainer: PluginServiceManagerContainer, + fullServicesContainer: FullServicesContainer, props: Map, connectionProviderManager: ConnectionProviderManager, telemetryFactory: TelemetryFactory ) { - this.pluginServiceManagerContainer = pluginServiceManagerContainer; - this.pluginServiceManagerContainer.pluginManager = this; + this.fullServicesContainer = fullServicesContainer; this.connectionProviderManager = connectionProviderManager; this.props = props; this.telemetryFactory = telemetryFactory; @@ -99,17 +98,15 @@ export class PluginManager { async init(configurationProfile?: ConfigurationProfile | null): Promise; async init(configurationProfile: ConfigurationProfile | null, plugins: ConnectionPlugin[]): Promise; async init(configurationProfile: ConfigurationProfile | null, plugins?: ConnectionPlugin[]) { - if (this.pluginServiceManagerContainer.pluginService != null) { - if (plugins) { - this._plugins = plugins; - } else { - this._plugins = await ConnectionPluginChainBuilder.getPlugins( - this.pluginServiceManagerContainer.pluginService, - this.props, - this.connectionProviderManager, - configurationProfile - ); - } + if (plugins) { + this._plugins = plugins; + } else { + this._plugins = await ConnectionPluginChainBuilder.getPlugins( + this.fullServicesContainer, + this.props, + this.connectionProviderManager, + configurationProfile + ); } for (const plugin of this._plugins) { PluginManager.PLUGINS.add(plugin); @@ -129,8 +126,8 @@ export class PluginManager { } const telemetryContext = this.telemetryFactory.openTelemetryContext(methodName, TelemetryTraceLevel.NESTED); - const currentClient: ClientWrapper = this.pluginServiceManagerContainer.pluginService.getCurrentClient().targetClient; - this.pluginServiceManagerContainer.pluginService.attachNoOpErrorListener(currentClient); + const currentClient: ClientWrapper = this.fullServicesContainer.pluginService.getCurrentClient().targetClient; + this.fullServicesContainer.pluginService.attachNoOpErrorListener(currentClient); try { return await telemetryContext.start(() => { return this.executeWithSubscribedPlugins( @@ -143,7 +140,7 @@ export class PluginManager { ); }); } finally { - this.pluginServiceManagerContainer.pluginService.attachErrorListener(currentClient); + this.fullServicesContainer.pluginService.attachErrorListener(currentClient); } } @@ -396,7 +393,7 @@ export class PluginManager { } PluginManager.STRATEGY_PLUGIN_CHAIN_CACHE.clear(); - CoreServicesContainer.releaseResources(); + await CoreServicesContainer.releaseResources(); PluginManager.PLUGINS = new Set(); } @@ -445,7 +442,7 @@ export class PluginManager { for (const p of this._plugins) { if (p instanceof iface) { - return p as any; + return p as T; } } return null; diff --git a/common/lib/plugin_service.ts b/common/lib/plugin_service.ts index 18144455c..68f69c979 100644 --- a/common/lib/plugin_service.ts +++ b/common/lib/plugin_service.ts @@ -14,18 +14,18 @@ limitations under the License. */ -import { PluginServiceManagerContainer } from "./plugin_service_manager_container"; import { ErrorHandler } from "./error_handler"; import { HostInfo } from "./host_info"; import { AwsClient } from "./aws_client"; import { HostListProviderService } from "./host_list_provider_service"; -import { BlockingHostListProvider, HostListProvider } from "./host_list_provider/host_list_provider"; +import { DynamicHostListProvider, HostListProvider } from "./host_list_provider/host_list_provider"; import { ConnectionUrlParser } from "./utils/connection_url_parser"; import { DatabaseDialect, DatabaseType } from "./database_dialect/database_dialect"; import { HostInfoBuilder } from "./host_info_builder"; -import { AwsWrapperError } from "./"; +import { AwsTimeoutError, AwsWrapperError, UnsupportedMethodError } from "./utils/errors"; import { HostAvailability } from "./host_availability/host_availability"; -import { CacheMap } from "./utils/cache_map"; +import { HostAvailabilityCacheItem } from "./host_availability/host_availability_cache_item"; +import { StatusCacheItem } from "./utils/status_cache_item"; import { HostChangeOptions } from "./host_change_options"; import { HostRole } from "./host_role"; import { WrapperProperties } from "./wrapper_property"; @@ -45,6 +45,9 @@ import { TelemetryFactory } from "./utils/telemetry/telemetry_factory"; import { DriverDialect } from "./driver_dialect/driver_dialect"; import { AllowedAndBlockedHosts } from "./allowed_and_blocked_hosts"; import { ConnectionPlugin } from "./connection_plugin"; +import { FullServicesContainer } from "./utils/full_services_container"; +import { StorageService } from "./utils/storage/storage_service"; +import { CoreServicesContainer } from "./utils/core_services_container"; export interface PluginService extends ErrorHandler { isInTransaction(): boolean; @@ -73,28 +76,22 @@ export interface PluginService extends ErrorHandler { getDialect(): DatabaseDialect; + isDialectConfirmed(): boolean; + getDriverDialect(): DriverDialect; getHostInfoBuilder(): HostInfoBuilder; - isStaticHostListProvider(): boolean; + isDynamicHostListProvider(): boolean; acceptsStrategy(role: HostRole, strategy: string): boolean; forceRefreshHostList(): Promise; - forceRefreshHostList(targetClient: ClientWrapper): Promise; - - forceRefreshHostList(targetClient?: ClientWrapper): Promise; - forceMonitoringRefresh(shouldVerifyWriter: boolean, timeoutMs: number): Promise; refreshHostList(): Promise; - refreshHostList(targetClient: ClientWrapper): Promise; - - refreshHostList(targetClient?: ClientWrapper): Promise; - getAllHosts(): HostInfo[]; getHosts(): HostInfo[]; @@ -157,29 +154,28 @@ export interface PluginService extends ErrorHandler { } export class PluginServiceImpl implements PluginService, HostListProviderService { - private static readonly DEFAULT_HOST_AVAILABILITY_CACHE_EXPIRE_NANO = 5 * 60_000_000_000; // 5 minutes + private static readonly DEFAULT_TOPOLOGY_QUERY_TIMEOUT_MS = 5000; // 5 seconds. private readonly _currentClient: AwsClient; private _currentHostInfo?: HostInfo; private _hostListProvider?: HostListProvider; private _initialConnectionHostInfo?: HostInfo; private _isInTransaction: boolean = false; - private pluginServiceManagerContainer: PluginServiceManagerContainer; + private servicesContainer: FullServicesContainer; protected hosts: HostInfo[] = []; private dbDialectProvider: DatabaseDialectProvider; private readonly initialHost: string; private dialect: DatabaseDialect; + private _isDialectConfirmed: boolean = false; private readonly driverDialect: DriverDialect; protected readonly sessionStateService: SessionStateService; - protected static readonly hostAvailabilityExpiringCache: CacheMap = new CacheMap(); + protected storageService: StorageService; readonly props: Map; private allowedAndBlockedHosts: AllowedAndBlockedHosts | null = null; - protected static readonly statusesExpiringCache: CacheMap = new CacheMap(); - protected static readonly DEFAULT_STATUS_CACHE_EXPIRE_NANO: number = 3_600_000_000_000; // 60 minutes protected _isPooledClient: boolean = false; constructor( - container: PluginServiceManagerContainer, + container: FullServicesContainer, client: AwsClient, dbType: DatabaseType, knownDialectsByCode: Map, @@ -187,12 +183,12 @@ export class PluginServiceImpl implements PluginService, HostListProviderService driverDialect: DriverDialect ) { this._currentClient = client; - this.pluginServiceManagerContainer = container; + this.servicesContainer = container; + this.storageService = container.storageService; this.props = props; this.dbDialectProvider = new DatabaseDialectManager(knownDialectsByCode, dbType, this.props); this.driverDialect = driverDialect; this.initialHost = props.get(WrapperProperties.HOST.name); - container.pluginService = this; this.dialect = WrapperProperties.CUSTOM_DATABASE_DIALECT.get(this.props) ?? this.dbDialectProvider.getDialect(this.props); this.sessionStateService = new SessionStateServiceImpl(this, this.props); @@ -223,7 +219,7 @@ export class PluginServiceImpl implements PluginService, HostListProviderService } getHostInfoByStrategy(role: HostRole, strategy: string, hosts?: HostInfo[]): HostInfo | undefined { - const pluginManager = this.pluginServiceManagerContainer.pluginManager; + const pluginManager = this.servicesContainer.pluginManager; return pluginManager?.getHostInfoByStrategy(role, strategy, hosts); } @@ -285,6 +281,10 @@ export class PluginServiceImpl implements PluginService, HostListProviderService return this.dialect; } + isDialectConfirmed(): boolean { + return this._isDialectConfirmed; + } + getDriverDialect(): DriverDialect { return this.driverDialect; } @@ -293,58 +293,46 @@ export class PluginServiceImpl implements PluginService, HostListProviderService return new HostInfoBuilder({ hostAvailabilityStrategy: new HostAvailabilityStrategyFactory().create(this.props) }); } - isStaticHostListProvider(): boolean { - return false; + isDynamicHostListProvider(): boolean { + const provider = this.getHostListProvider(); + return provider !== null && typeof (provider as any).forceMonitoringRefresh === "function"; } acceptsStrategy(role: HostRole, strategy: string): boolean { - return this.pluginServiceManagerContainer.pluginManager?.acceptsStrategy(role, strategy) ?? false; + return this.servicesContainer.pluginManager?.acceptsStrategy(role, strategy) ?? false; } - async forceRefreshHostList(): Promise; - async forceRefreshHostList(targetClient: ClientWrapper): Promise; - async forceRefreshHostList(targetClient?: ClientWrapper): Promise { - const updatedHostList = targetClient - ? await this.getHostListProvider()?.forceRefresh(targetClient) - : await this.getHostListProvider()?.forceRefresh(); - if (updatedHostList && updatedHostList !== this.hosts) { - this.updateHostAvailability(updatedHostList); - await this.setHostList(this.hosts, updatedHostList); - } + async forceRefreshHostList(): Promise { + await this.forceMonitoringRefresh(false, PluginServiceImpl.DEFAULT_TOPOLOGY_QUERY_TIMEOUT_MS); } async forceMonitoringRefresh(shouldVerifyWriter: boolean, timeoutMs: number): Promise { const hostListProvider: HostListProvider = this.getHostListProvider(); - if (!this.isBlockingHostListProvider(hostListProvider)) { - logger.info(Messages.get("PluginService.requiredBlockingHostListProvider", typeof hostListProvider)); - throw new AwsWrapperError(Messages.get("PluginService.requiredBlockingHostListProvider", typeof hostListProvider)); + + if (!this.isDynamicHostListProvider()) { + const providerName = hostListProvider?.constructor.name ?? "null"; + throw new UnsupportedMethodError(Messages.get("PluginService.requiredDynamicHostListProvider", providerName)); } try { - const updatedHostList: HostInfo[] = await hostListProvider.forceMonitoringRefresh(shouldVerifyWriter, timeoutMs); + const updatedHostList: HostInfo[] = await (hostListProvider as DynamicHostListProvider).forceMonitoringRefresh(shouldVerifyWriter, timeoutMs); if (updatedHostList) { - if (updatedHostList !== this.hosts) { - this.updateHostAvailability(updatedHostList); - await this.setHostList(this.hosts, updatedHostList); - } + this.updateHostAvailability(updatedHostList); + await this.setHostList(this.hosts, updatedHostList); return true; } } catch (err) { - // Do nothing. - logger.info(Messages.get("PluginService.forceMonitoringRefreshTimeout", timeoutMs.toString())); + if (err instanceof AwsTimeoutError) { + // Do nothing. + logger.info(Messages.get("PluginService.forceMonitoringRefreshTimeout", timeoutMs.toString())); + } } return false; } - isBlockingHostListProvider(arg: any): arg is BlockingHostListProvider { - return arg != null && typeof arg.clearAll === "function" && typeof arg.forceMonitoringRefresh === "function"; - } - - async refreshHostList(): Promise; - async refreshHostList(targetClient: ClientWrapper): Promise; - async refreshHostList(targetClient?: ClientWrapper): Promise { - const updatedHostList = targetClient ? await this.getHostListProvider()?.refresh(targetClient) : await this.getHostListProvider()?.refresh(); + async refreshHostList(): Promise { + const updatedHostList = await this.getHostListProvider()?.refresh(); if (updatedHostList && updatedHostList !== this.hosts) { this.updateHostAvailability(updatedHostList); await this.setHostList(this.hosts, updatedHostList); @@ -353,9 +341,9 @@ export class PluginServiceImpl implements PluginService, HostListProviderService private updateHostAvailability(hosts: HostInfo[]) { hosts.forEach((host) => { - const availability = PluginServiceImpl.hostAvailabilityExpiringCache.get(host.url); - if (availability != null) { - host.availability = availability; + const cacheItem = this.storageService.get(HostAvailabilityCacheItem, host.url); + if (cacheItem != null) { + host.availability = cacheItem.availability; } }); } @@ -421,7 +409,7 @@ export class PluginServiceImpl implements PluginService, HostListProviderService if (changes.size > 0) { this.hosts = newHosts ? newHosts : []; - await this.pluginServiceManagerContainer.pluginManager!.notifyHostListChanged(changes); + await this.servicesContainer.pluginManager!.notifyHostListChanged(changes); } } @@ -464,14 +452,13 @@ export class PluginServiceImpl implements PluginService, HostListProviderService ]; if (hostsToChange.length === 0) { - logger.debug(Messages.get("PluginService.hostsChangeListEmpty")); return; } const changes = new Map>(); for (const host of hostsToChange) { const currentAvailability = host.getAvailability(); - PluginServiceImpl.hostAvailabilityExpiringCache.put(host.url, availability, PluginServiceImpl.DEFAULT_HOST_AVAILABILITY_CACHE_EXPIRE_NANO); + this.storageService.set(host.url, new HostAvailabilityCacheItem(availability)); if (currentAvailability !== availability) { let hostChanges = new Set(); if (availability === HostAvailability.AVAILABLE) { @@ -525,13 +512,13 @@ export class PluginServiceImpl implements PluginService, HostListProviderService connect(hostInfo: HostInfo, props: Map): Promise; connect(hostInfo: HostInfo, props: Map, pluginToSkip: ConnectionPlugin): Promise; connect(hostInfo: HostInfo, props: Map, pluginToSkip?: ConnectionPlugin): Promise { - return this.pluginServiceManagerContainer.pluginManager!.connect(hostInfo, props, false, pluginToSkip); + return this.servicesContainer.pluginManager!.connect(hostInfo, props, false, pluginToSkip); } forceConnect(hostInfo: HostInfo, props: Map): Promise; forceConnect(hostInfo: HostInfo, props: Map, pluginToSkip: ConnectionPlugin): Promise; forceConnect(hostInfo: HostInfo, props: Map, pluginToSkip?: ConnectionPlugin): Promise { - return this.pluginServiceManagerContainer.pluginManager!.forceConnect(hostInfo, props, false, pluginToSkip); + return this.servicesContainer.pluginManager!.forceConnect(hostInfo, props, false, pluginToSkip); } async setCurrentClient(newClient: ClientWrapper, hostInfo: HostInfo): Promise> { @@ -541,8 +528,8 @@ export class PluginServiceImpl implements PluginService, HostListProviderService this.sessionStateService.reset(); const changes = new Set([HostChangeOptions.INITIAL_CONNECTION]); - if (this.pluginServiceManagerContainer.pluginManager) { - await this.pluginServiceManagerContainer.pluginManager.notifyConnectionChanged(changes, null); + if (this.servicesContainer.pluginManager) { + await this.servicesContainer.pluginManager.notifyConnectionChanged(changes, null); } return changes; @@ -569,8 +556,10 @@ export class PluginServiceImpl implements PluginService, HostListProviderService } } - const pluginOpinions: Set = - await this.pluginServiceManagerContainer.pluginManager!.notifyConnectionChanged(changes, null); + const pluginOpinions: Set = await this.servicesContainer.pluginManager!.notifyConnectionChanged( + changes, + null + ); const shouldCloseConnection = changes.has(HostChangeOptions.CONNECTION_OBJECT_CHANGED) && @@ -644,11 +633,12 @@ export class PluginServiceImpl implements PluginService, HostListProviderService const originalDialect = this.dialect; this.dialect = await this.dbDialectProvider.getDialectForUpdate(targetClient, this.initialHost, this.props.get(WrapperProperties.HOST.name)); + this._isDialectConfirmed = true; if (originalDialect === this.dialect) { return; } - this._hostListProvider = this.dialect.getHostListProvider(this.props, this.props.get(WrapperProperties.HOST.name), this); + this._hostListProvider = this.dialect.getHostListProvider(this.props, this.props.get(WrapperProperties.HOST.name), this.servicesContainer); } private async updateReadOnly(statements: string[]) { @@ -691,7 +681,7 @@ export class PluginServiceImpl implements PluginService, HostListProviderService } getTelemetryFactory(): TelemetryFactory { - return this.pluginServiceManagerContainer.pluginManager!.getTelemetryFactory(); + return this.servicesContainer.pluginManager!.getTelemetryFactory(); } /* Error Handler interface implementation */ @@ -736,15 +726,12 @@ export class PluginServiceImpl implements PluginService, HostListProviderService this.allowedAndBlockedHosts = allowedAndBlockedHosts; } - static clearHostAvailabilityCache(): void { - PluginServiceImpl.hostAvailabilityExpiringCache.clear(); - } - getStatus(clazz: any, clusterBound: boolean): T; getStatus(clazz: any, key: string): T; getStatus(clazz: any, clusterBound: boolean | string): T { if (typeof clusterBound === "string") { - return PluginServiceImpl.statusesExpiringCache.get(this.getStatusCacheKey(clazz, clusterBound)); + const cacheItem = this.storageService.get(StatusCacheItem, this.getStatusCacheKey(clazz, clusterBound)); + return cacheItem ? cacheItem.status : null; } let clusterId: string = null; if (clusterBound) { @@ -767,9 +754,9 @@ export class PluginServiceImpl implements PluginService, HostListProviderService if (typeof clusterBound === "string") { const cacheKey: string = this.getStatusCacheKey(clazz, clusterBound); if (!status) { - PluginServiceImpl.statusesExpiringCache.delete(cacheKey); + this.storageService.remove(StatusCacheItem, cacheKey); } else { - PluginServiceImpl.statusesExpiringCache.put(cacheKey, status, PluginServiceImpl.DEFAULT_STATUS_CACHE_EXPIRE_NANO); + this.storageService.set(cacheKey, new StatusCacheItem(status)); } return; } @@ -786,7 +773,7 @@ export class PluginServiceImpl implements PluginService, HostListProviderService } isPluginInUse(plugin: any) { - return this.pluginServiceManagerContainer.pluginManager!.isPluginInUse(plugin); + return this.servicesContainer.pluginManager!.isPluginInUse(plugin); } isPooledClient(): boolean { diff --git a/common/lib/plugins/aurora_initial_connection_strategy_plugin.ts b/common/lib/plugins/aurora_initial_connection_strategy_plugin.ts index 2a56e606c..7d7891a6a 100644 --- a/common/lib/plugins/aurora_initial_connection_strategy_plugin.ts +++ b/common/lib/plugins/aurora_initial_connection_strategy_plugin.ts @@ -110,7 +110,7 @@ export class AuroraInitialConnectionStrategyPlugin extends AbstractConnectionPlu if (writerCandidate === null || this.rdsUtils.isRdsClusterDns(writerCandidate.host)) { // Writer is not found. It seems that topology is outdated. writerCandidateClient = await connectFunc(); - await this.pluginService.forceRefreshHostList(writerCandidateClient); + await this.pluginService.forceRefreshHostList(); writerCandidate = await this.pluginService.identifyConnection(writerCandidateClient); if (writerCandidate) { @@ -132,7 +132,7 @@ export class AuroraInitialConnectionStrategyPlugin extends AbstractConnectionPlu if ((await this.pluginService.getHostRole(writerCandidateClient)) !== HostRole.WRITER) { // If the new connection resolves to a reader instance, this means the topology is outdated. // Force refresh to update the topology. - await this.pluginService.forceRefreshHostList(writerCandidateClient); + await this.pluginService.forceRefreshHostList(); await this.pluginService.abortTargetClient(writerCandidateClient); await sleep(retryDelayMs); continue; @@ -177,7 +177,7 @@ export class AuroraInitialConnectionStrategyPlugin extends AbstractConnectionPlu if (readerCandidate === null || this.rdsUtils.isRdsClusterDns(readerCandidate.host)) { // Reader is not found. It seems that topology is outdated. readerCandidateClient = await connectFunc(); - await this.pluginService.forceRefreshHostList(readerCandidateClient); + await this.pluginService.forceRefreshHostList(); readerCandidate = await this.pluginService.identifyConnection(readerCandidateClient); if (readerCandidate) { @@ -209,7 +209,7 @@ export class AuroraInitialConnectionStrategyPlugin extends AbstractConnectionPlu if ((await this.pluginService.getHostRole(readerCandidateClient)) !== HostRole.READER) { // If the new connection resolves to a writer instance, this means the topology is outdated. // Force refresh to update the topology. - await this.pluginService.forceRefreshHostList(readerCandidateClient); + await this.pluginService.forceRefreshHostList(); if (this.hasNoReaders()) { // It seems that cluster has no readers. Simulate Aurora reader cluster endpoint logic diff --git a/common/lib/plugins/aurora_initial_connection_strategy_plugin_factory.ts b/common/lib/plugins/aurora_initial_connection_strategy_plugin_factory.ts index d13d287f7..b12e4d7b7 100644 --- a/common/lib/plugins/aurora_initial_connection_strategy_plugin_factory.ts +++ b/common/lib/plugins/aurora_initial_connection_strategy_plugin_factory.ts @@ -15,19 +15,21 @@ */ import { ConnectionPluginFactory } from "../plugin_factory"; -import { PluginService } from "../plugin_service"; import { AwsWrapperError } from "../utils/errors"; import { Messages } from "../utils/messages"; +import { FullServicesContainer } from "../utils/full_services_container"; export class AuroraInitialConnectionStrategyFactory extends ConnectionPluginFactory { private static auroraInitialConnectionStrategyPlugin: any; - async getInstance(pluginService: PluginService, props: Map) { + async getInstance(servicesContainer: FullServicesContainer, props: Map) { try { if (!AuroraInitialConnectionStrategyFactory.auroraInitialConnectionStrategyPlugin) { AuroraInitialConnectionStrategyFactory.auroraInitialConnectionStrategyPlugin = await import("./aurora_initial_connection_strategy_plugin"); } - return new AuroraInitialConnectionStrategyFactory.auroraInitialConnectionStrategyPlugin.AuroraInitialConnectionStrategyPlugin(pluginService); + return new AuroraInitialConnectionStrategyFactory.auroraInitialConnectionStrategyPlugin.AuroraInitialConnectionStrategyPlugin( + servicesContainer.pluginService + ); } catch (error: any) { throw new AwsWrapperError( Messages.get("ConnectionPluginChainBuilder.errorImportingPlugin", error.message, "AuroraInitialConnectionStrategyPlugin") diff --git a/common/lib/plugins/bluegreen/blue_green_plugin.ts b/common/lib/plugins/bluegreen/blue_green_plugin.ts index 4f7b90ece..17cacb8d9 100644 --- a/common/lib/plugins/bluegreen/blue_green_plugin.ts +++ b/common/lib/plugins/bluegreen/blue_green_plugin.ts @@ -176,7 +176,7 @@ export class BlueGreenPlugin extends AbstractConnectionPlugin implements CanRele this.startTimeNano = getTimeInNanos(); - while (routing && result && !result.isPresent()) { + while (routing && (!result || !result.isPresent())) { result = await routing.apply(this, methodName, methodFunc, methodArgs, this.properties, this.pluginService); if (!result?.isPresent()) { this.bgStatus = this.pluginService.getStatus(BlueGreenStatus, this.bgdId); diff --git a/common/lib/plugins/bluegreen/blue_green_plugin_factory.ts b/common/lib/plugins/bluegreen/blue_green_plugin_factory.ts index f5ba48a3c..d20e00b66 100644 --- a/common/lib/plugins/bluegreen/blue_green_plugin_factory.ts +++ b/common/lib/plugins/bluegreen/blue_green_plugin_factory.ts @@ -15,20 +15,20 @@ */ import { ConnectionPluginFactory } from "../../plugin_factory"; -import { PluginService } from "../../plugin_service"; import { ConnectionPlugin } from "../../connection_plugin"; import { AwsWrapperError } from "../../utils/errors"; import { Messages } from "../../utils/messages"; +import { FullServicesContainer } from "../../utils/full_services_container"; export class BlueGreenPluginFactory extends ConnectionPluginFactory { private static blueGreenPlugin: any; - async getInstance(pluginService: PluginService, props: Map): Promise { + async getInstance(servicesContainer: FullServicesContainer, props: Map): Promise { try { if (!BlueGreenPluginFactory.blueGreenPlugin) { BlueGreenPluginFactory.blueGreenPlugin = await import("./blue_green_plugin"); } - return new BlueGreenPluginFactory.blueGreenPlugin.BlueGreenPlugin(pluginService, props); + return new BlueGreenPluginFactory.blueGreenPlugin.BlueGreenPlugin(servicesContainer.pluginService, props); } catch (error: any) { throw new AwsWrapperError(Messages.get("ConnectionPluginChainBuilder.errorImportingPlugin", error.message, "BlueGreenPluginFactory")); } diff --git a/common/lib/plugins/bluegreen/blue_green_status.ts b/common/lib/plugins/bluegreen/blue_green_status.ts index 1c17463f3..783e28407 100644 --- a/common/lib/plugins/bluegreen/blue_green_status.ts +++ b/common/lib/plugins/bluegreen/blue_green_status.ts @@ -19,7 +19,6 @@ import { ConnectRouting } from "./routing/connect_routing"; import { ExecuteRouting } from "./routing/execute_routing"; import { BlueGreenRole } from "./blue_green_role"; import { HostInfo } from "../../host_info"; -import { Pair } from "../../utils/utils"; export class BlueGreenStatus { private readonly bgdId: string; @@ -28,7 +27,7 @@ export class BlueGreenStatus { private readonly _unmodifiableExecuteRouting: readonly ExecuteRouting[]; private readonly _roleByHost: Map; - private readonly _correspondingHosts: Map>; + private readonly _correspondingHosts: Map; constructor( bgdId: string, @@ -36,7 +35,7 @@ export class BlueGreenStatus { unmodifiableConnectRouting?: ConnectRouting[], unmodifiableExecuteRouting?: ExecuteRouting[], roleByHost?: Map, - correspondingHosts?: Map> + correspondingHosts?: Map ) { this.bgdId = bgdId; this._currentPhase = phase; @@ -62,7 +61,7 @@ export class BlueGreenStatus { return this._roleByHost; } - get correspondingHosts(): Map> { + get correspondingHosts(): Map { return this._correspondingHosts; } diff --git a/common/lib/plugins/bluegreen/blue_green_status_monitor.ts b/common/lib/plugins/bluegreen/blue_green_status_monitor.ts index f320c84f7..0bd698580 100644 --- a/common/lib/plugins/bluegreen/blue_green_status_monitor.ts +++ b/common/lib/plugins/bluegreen/blue_green_status_monitor.ts @@ -36,6 +36,7 @@ import { HostListProviderService } from "../../host_list_provider_service"; import { StatusInfo } from "./status_info"; import { DatabaseDialect } from "../../database_dialect/database_dialect"; import { AwsWrapperError } from "../../utils/errors"; +import { FullServicesContainer } from "../../utils/full_services_container"; export interface OnBlueGreenStatusChange { onBlueGreenStatusChanged(role: BlueGreenRole, interimStatus: BlueGreenInterimStatus): void; @@ -50,6 +51,7 @@ export class BlueGreenStatusMonitor { protected static readonly knownVersions: Set = new Set([BlueGreenStatusMonitor.latestKnownVersion]); protected readonly blueGreenDialect: BlueGreenDialect; + protected readonly servicesContainer: FullServicesContainer; protected readonly pluginService: PluginService; protected readonly bgdId: string; protected readonly props: Map; @@ -99,7 +101,7 @@ export class BlueGreenStatusMonitor { role: BlueGreenRole, bgdId: string, initialHostInfo: HostInfo, - pluginService: PluginService, + servicesContainer: FullServicesContainer, props: Map, statusCheckIntervalMap: Map, onBlueGreenStatusChangeFunc: OnBlueGreenStatusChange @@ -107,7 +109,8 @@ export class BlueGreenStatusMonitor { this.role = role; this.bgdId = bgdId; this.initialHostInfo = initialHostInfo; - this.pluginService = pluginService; + this.servicesContainer = servicesContainer; + this.pluginService = this.servicesContainer.pluginService; this.props = props; this.statusCheckIntervalMap = statusCheckIntervalMap; this.onBlueGreenStatusChangeFunc = onBlueGreenStatusChangeFunc; @@ -296,12 +299,7 @@ export class BlueGreenStatusMonitor { return; } - const client: ClientWrapper = this.clientWrapper; - if (await this.isConnectionClosed(client)) { - return; - } - - this.currentTopology = await this.hostListProvider.forceRefresh(client); + this.currentTopology = await this.hostListProvider.forceRefresh(); if (this.collectedTopology) { this.startTopology = this.currentTopology; } @@ -518,7 +516,7 @@ export class BlueGreenStatusMonitor { if (connectionHostInfoCopy) { this.hostListProvider = this.pluginService .getDialect() - .getHostListProvider(hostListProperties, connectionHostInfoCopy.host, this.pluginService as unknown as HostListProviderService); + .getHostListProvider(hostListProperties, connectionHostInfoCopy.host, this.servicesContainer); } else { logger.warn(Messages.get("Bgd.hostInfoNull")); } diff --git a/common/lib/plugins/bluegreen/blue_green_status_provider.ts b/common/lib/plugins/bluegreen/blue_green_status_provider.ts index 0919fc781..5a80116d9 100644 --- a/common/lib/plugins/bluegreen/blue_green_status_provider.ts +++ b/common/lib/plugins/bluegreen/blue_green_status_provider.ts @@ -20,7 +20,7 @@ import { SimpleHostAvailabilityStrategy } from "../../host_availability/simple_h import { BlueGreenStatusMonitor } from "./blue_green_status_monitor"; import { BlueGreenInterimStatus } from "./blue_green_interim_status"; import { HostInfo } from "../../host_info"; -import { convertMsToNanos, getTimeInNanos, Pair } from "../../utils/utils"; +import { convertMsToNanos, getTimeInNanos } from "../../utils/utils"; import { BlueGreenRole } from "./blue_green_role"; import { BlueGreenStatus } from "./blue_green_status"; import { BlueGreenPhase } from "./blue_green_phase"; @@ -38,14 +38,13 @@ import { SubstituteConnectRouting } from "./routing/substitute_connect_routing"; import { SuspendConnectRouting } from "./routing/suspend_connect_routing"; import { ExecuteRouting } from "./routing/execute_routing"; import { SuspendExecuteRouting } from "./routing/suspend_execute_routing"; -import { - SuspendUntilCorrespondingHostFoundConnectRouting -} from "./routing/suspend_until_corresponding_host_found_connect_routing"; +import { SuspendUntilCorrespondingHostFoundConnectRouting } from "./routing/suspend_until_corresponding_host_found_connect_routing"; import { RejectConnectRouting } from "./routing/reject_connect_routing"; import { getValueHash } from "./blue_green_utils"; +import { FullServicesContainer } from "../../utils/full_services_container"; +import { StorageService } from "../../utils/storage/storage_service"; export class BlueGreenStatusProvider { - static readonly MONITORING_PROPERTY_PREFIX = "blue_green_monitoring_"; private static readonly DEFAULT_CONNECT_TIMEOUT_MS = 10_000; // 10 seconds private static readonly DEFAULT_QUERY_TIMEOUT_MS = 10_000; // 10 seconds @@ -56,8 +55,8 @@ export class BlueGreenStatusProvider { protected interimStatuses: BlueGreenInterimStatus[] = [null, null]; protected hostIpAddresses: Map = new Map(); - // The second parameter of Pair is null when no corresponding host is found. - protected readonly correspondingHosts: Map> = new Map(); + // The second element is null when no corresponding host is found. + protected readonly correspondingHosts: Map = new Map(); // all known host names; host with no port protected readonly roleByHost: Map = new Map(); @@ -78,14 +77,18 @@ export class BlueGreenStatusProvider { protected readonly switchoverTimeoutNanos: bigint; protected readonly suspendNewBlueConnectionsWhenInProgress: boolean; + protected readonly servicesContainer: FullServicesContainer; + protected readonly storageService: StorageService; protected readonly pluginService: PluginService; protected readonly properties: Map; protected readonly bgdId: string; protected phaseTimeNanos: Map = new Map(); protected readonly rdsUtils: RdsUtils = new RdsUtils(); - constructor(pluginService: PluginService, properties: Map, bgdId: string) { - this.pluginService = pluginService; + constructor(servicesContainer: FullServicesContainer, properties: Map, bgdId: string) { + this.servicesContainer = servicesContainer; + this.pluginService = this.servicesContainer.pluginService; + this.storageService = this.servicesContainer.storageService; this.properties = properties; this.bgdId = bgdId; @@ -109,7 +112,7 @@ export class BlueGreenStatusProvider { BlueGreenRole.SOURCE, this.bgdId, this.pluginService.getCurrentHostInfo(), - this.pluginService, + this.servicesContainer, this.getMonitoringProperties(), this.statusCheckIntervalMap, { onBlueGreenStatusChanged: (role, status) => this.prepareStatus(role, status) } @@ -119,7 +122,7 @@ export class BlueGreenStatusProvider { BlueGreenRole.TARGET, this.bgdId, this.pluginService.getCurrentHostInfo(), - this.pluginService, + this.servicesContainer, this.getMonitoringProperties(), this.statusCheckIntervalMap, { onBlueGreenStatusChanged: (role, status) => this.prepareStatus(role, status) } @@ -130,7 +133,7 @@ export class BlueGreenStatusProvider { const monitoringConnProperties: Map = new Map(this.properties); for (const key of monitoringConnProperties.keys()) { - if (!key.startsWith(BlueGreenStatusProvider.MONITORING_PROPERTY_PREFIX)) { + if (!key.startsWith(WrapperProperties.BG_MONITORING_PROPERTY_PREFIX)) { continue; } @@ -236,19 +239,19 @@ export class BlueGreenStatusProvider { if (blueWriterHostInfo) { // greenWriterHostInfo can be null but that will be handled properly by corresponding routing. - this.correspondingHosts.set(blueWriterHostInfo.host, new Pair(blueWriterHostInfo, greenWriterHostInfo)); + this.correspondingHosts.set(blueWriterHostInfo.host, [blueWriterHostInfo, greenWriterHostInfo]); } if (sortedBlueReaderHostInfos?.length > 0) { if (sortedGreenReaderHostInfos?.length > 0) { let greenIndex: number = 0; sortedBlueReaderHostInfos.forEach((blueHostInfo) => { - this.correspondingHosts.set(blueHostInfo.host, new Pair(blueHostInfo, sortedGreenReaderHostInfos.at(greenIndex++))); + this.correspondingHosts.set(blueHostInfo.host, [blueHostInfo, sortedGreenReaderHostInfos.at(greenIndex++)]); greenIndex %= sortedGreenReaderHostInfos.length; }); } else { sortedBlueReaderHostInfos.forEach((blueHostInfo) => { - this.correspondingHosts.set(blueHostInfo.host, new Pair(blueHostInfo, greenWriterHostInfo)); + this.correspondingHosts.set(blueHostInfo.host, [blueHostInfo, greenWriterHostInfo]); }); } } @@ -270,10 +273,10 @@ export class BlueGreenStatusProvider { if (blueClusterHost !== null && greenClusterHost !== null) { if (!this.correspondingHosts.has(blueClusterHost)) { - this.correspondingHosts.set( - blueClusterHost, - new Pair(this.hostInfoBuilder.withHost(blueClusterHost).build(), this.hostInfoBuilder.withHost(greenClusterHost).build()) - ); + this.correspondingHosts.set(blueClusterHost, [ + this.hostInfoBuilder.withHost(blueClusterHost).build(), + this.hostInfoBuilder.withHost(greenClusterHost).build() + ]); } } @@ -290,10 +293,10 @@ export class BlueGreenStatusProvider { if (blueClusterReaderHost !== null && greenClusterReaderHost !== null) { if (!this.correspondingHosts.has(blueClusterReaderHost)) { - this.correspondingHosts.set( - blueClusterReaderHost, - new Pair(this.hostInfoBuilder.withHost(blueClusterReaderHost).build(), this.hostInfoBuilder.withHost(greenClusterReaderHost).build()) - ); + this.correspondingHosts.set(blueClusterReaderHost, [ + this.hostInfoBuilder.withHost(blueClusterReaderHost).build(), + this.hostInfoBuilder.withHost(greenClusterReaderHost).build() + ]); } } @@ -310,10 +313,10 @@ export class BlueGreenStatusProvider { }); if (greenHost) { if (!this.correspondingHosts.has(blueHost)) { - this.correspondingHosts.set( - blueHost, - new Pair(this.hostInfoBuilder.withHost(blueHost).build(), this.hostInfoBuilder.withHost(greenHost).build()) - ); + this.correspondingHosts.set(blueHost, [ + this.hostInfoBuilder.withHost(blueHost).build(), + this.hostInfoBuilder.withHost(greenHost).build() + ]); } } } @@ -485,7 +488,7 @@ export class BlueGreenStatusProvider { Array.from(this.roleByHost.entries()) .filter(([host, role]) => role === BlueGreenRole.SOURCE && this.correspondingHosts.has(host)) .forEach(([host, role]) => { - const hostSpec = this.correspondingHosts.get(host).left; + const hostSpec = this.correspondingHosts.get(host)[0]; const blueIp = this.hostIpAddresses.get(hostSpec.host); const substituteHostSpecWithIp = !blueIp ? hostSpec : this.hostInfoBuilder.copyFrom(hostSpec).withHost(blueIp).build(); @@ -629,9 +632,9 @@ export class BlueGreenStatusProvider { .forEach(([host, role]) => { const blueHost: string = host; const isBlueHostInstance: boolean = this.rdsUtils.isRdsInstance(blueHost); - const pair: Pair | undefined = this.correspondingHosts?.get(host); - const blueHostInfo: HostInfo | undefined = pair?.left; - const greenHostInfo: HostInfo | undefined = pair?.right; + const pair: [HostInfo, HostInfo | null] | undefined = this.correspondingHosts?.get(host); + const blueHostInfo: HostInfo | undefined = pair?.[0]; + const greenHostInfo: HostInfo | undefined = pair?.[1]; if (!greenHostInfo) { // A corresponding host is not found. We need to suspend this call. @@ -874,7 +877,7 @@ export class BlueGreenStatusProvider { logger.debug( "Corresponding hosts:\n" + Array.from(this.correspondingHosts.entries()) - .map(([key, value]) => ` ${key} -> ${value.right == null ? "" : value.right.hostAndPort}`) + .map(([key, value]) => ` ${key} -> ${value[1] == null ? "" : value[1].hostAndPort}`) .join("\n") ); diff --git a/common/lib/plugins/bluegreen/routing/suspend_until_corresponding_host_found_connect_routing.ts b/common/lib/plugins/bluegreen/routing/suspend_until_corresponding_host_found_connect_routing.ts index 7ce13cb1e..ecf7cab26 100644 --- a/common/lib/plugins/bluegreen/routing/suspend_until_corresponding_host_found_connect_routing.ts +++ b/common/lib/plugins/bluegreen/routing/suspend_until_corresponding_host_found_connect_routing.ts @@ -26,7 +26,7 @@ import { TelemetryFactory } from "../../../utils/telemetry/telemetry_factory"; import { TelemetryContext } from "../../../utils/telemetry/telemetry_context"; import { TelemetryTraceLevel } from "../../../utils/telemetry/telemetry_trace_level"; import { BlueGreenStatus } from "../blue_green_status"; -import { convertMsToNanos, convertNanosToMs, getTimeInNanos, Pair } from "../../../utils/utils"; +import { convertMsToNanos, convertNanosToMs, getTimeInNanos } from "../../../utils/utils"; import { WrapperProperties } from "../../../wrapper_property"; import { BlueGreenPhase } from "../blue_green_phase"; import { AwsWrapperError } from "../../../utils/errors"; @@ -60,7 +60,7 @@ export class SuspendUntilCorrespondingHostFoundConnectRouting extends BaseConnec return await telemetryContext.start(async () => { let bgStatus: BlueGreenStatus = pluginService.getStatus(BlueGreenStatus, this.bgdId); - let correspondingPair: Pair = bgStatus?.correspondingHosts.get(hostInfo.host); + let correspondingPair: [HostInfo, HostInfo] | undefined = bgStatus?.correspondingHosts.get(hostInfo.host); const timeoutNanos: bigint = convertMsToNanos(WrapperProperties.BG_CONNECT_TIMEOUT_MS.get(properties)); const suspendStartTime: bigint = getTimeInNanos(); @@ -70,7 +70,7 @@ export class SuspendUntilCorrespondingHostFoundConnectRouting extends BaseConnec getTimeInNanos() <= endTime && bgStatus != null && bgStatus.currentPhase !== BlueGreenPhase.COMPLETED && - (!correspondingPair || !correspondingPair.right) + (!correspondingPair || !correspondingPair[1]) ) { await this.delay(SuspendUntilCorrespondingHostFoundConnectRouting.SLEEP_TIME_MS, bgStatus, pluginService, this.bgdId); diff --git a/common/lib/plugins/connect_time_plugin_factory.ts b/common/lib/plugins/connect_time_plugin_factory.ts index bc333ccbc..f2362458c 100644 --- a/common/lib/plugins/connect_time_plugin_factory.ts +++ b/common/lib/plugins/connect_time_plugin_factory.ts @@ -15,15 +15,15 @@ */ import { ConnectionPluginFactory } from "../plugin_factory"; -import { PluginService } from "../plugin_service"; import { ConnectionPlugin } from "../connection_plugin"; import { AwsWrapperError } from "../utils/errors"; import { Messages } from "../utils/messages"; +import { FullServicesContainer } from "../utils/full_services_container"; export class ConnectTimePluginFactory extends ConnectionPluginFactory { private static connectTimePlugin: any; - async getInstance(pluginService: PluginService, properties: Map): Promise { + async getInstance(servicesContainer: FullServicesContainer, properties: Map): Promise { try { if (!ConnectTimePluginFactory.connectTimePlugin) { ConnectTimePluginFactory.connectTimePlugin = await import("./connect_time_plugin"); diff --git a/common/lib/plugins/connection_tracker/aurora_connection_tracker_plugin_factory.ts b/common/lib/plugins/connection_tracker/aurora_connection_tracker_plugin_factory.ts index 4eb2b6e1b..b7ce5b357 100644 --- a/common/lib/plugins/connection_tracker/aurora_connection_tracker_plugin_factory.ts +++ b/common/lib/plugins/connection_tracker/aurora_connection_tracker_plugin_factory.ts @@ -15,20 +15,20 @@ */ import { ConnectionPluginFactory } from "../../plugin_factory"; -import { PluginService } from "../../plugin_service"; import { ConnectionPlugin } from "../../connection_plugin"; import { AwsWrapperError } from "../../utils/errors"; import { Messages } from "../../utils/messages"; +import { FullServicesContainer } from "../../utils/full_services_container"; export class AuroraConnectionTrackerPluginFactory extends ConnectionPluginFactory { private static auroraConnectionTrackerPlugin: any; - async getInstance(pluginService: PluginService, props: Map): Promise { + async getInstance(servicesContainer: FullServicesContainer, props: Map): Promise { try { if (!AuroraConnectionTrackerPluginFactory.auroraConnectionTrackerPlugin) { AuroraConnectionTrackerPluginFactory.auroraConnectionTrackerPlugin = await import("./aurora_connection_tracker_plugin"); } - return new AuroraConnectionTrackerPluginFactory.auroraConnectionTrackerPlugin.AuroraConnectionTrackerPlugin(pluginService); + return new AuroraConnectionTrackerPluginFactory.auroraConnectionTrackerPlugin.AuroraConnectionTrackerPlugin(servicesContainer.pluginService); } catch (error: any) { throw new AwsWrapperError(Messages.get("ConnectionPluginChainBuilder.errorImportingPlugin", error.message, "AuroraConnectionTrackerPlugin")); } diff --git a/common/lib/plugins/custom_endpoint/custom_endpoint_plugin_factory.ts b/common/lib/plugins/custom_endpoint/custom_endpoint_plugin_factory.ts index 0e945c2ea..23a3ee502 100644 --- a/common/lib/plugins/custom_endpoint/custom_endpoint_plugin_factory.ts +++ b/common/lib/plugins/custom_endpoint/custom_endpoint_plugin_factory.ts @@ -15,19 +15,19 @@ */ import { ConnectionPluginFactory } from "../../plugin_factory"; -import { PluginService } from "../../plugin_service"; import { AwsWrapperError } from "../../utils/errors"; import { Messages } from "../../utils/messages"; +import { FullServicesContainer } from "../../utils/full_services_container"; export class CustomEndpointPluginFactory extends ConnectionPluginFactory { private static customEndpointPlugin: any; - async getInstance(pluginService: PluginService, props: Map) { + async getInstance(servicesContainer: FullServicesContainer, props: Map) { try { if (!CustomEndpointPluginFactory.customEndpointPlugin) { CustomEndpointPluginFactory.customEndpointPlugin = await import("./custom_endpoint_plugin"); } - return new CustomEndpointPluginFactory.customEndpointPlugin.CustomEndpointPlugin(pluginService, props); + return new CustomEndpointPluginFactory.customEndpointPlugin.CustomEndpointPlugin(servicesContainer.pluginService, props); } catch (error: any) { throw new AwsWrapperError(Messages.get("ConnectionPluginChainBuilder.errorImportingPlugin", error.message, "CustomEndpointPlugin")); } diff --git a/common/lib/plugins/custom_endpoint/custom_endpoint_role_type.ts b/common/lib/plugins/custom_endpoint/custom_endpoint_role_type.ts index b4e2abec5..5b89d8d55 100644 --- a/common/lib/plugins/custom_endpoint/custom_endpoint_role_type.ts +++ b/common/lib/plugins/custom_endpoint/custom_endpoint_role_type.ts @@ -15,17 +15,16 @@ */ export enum CustomEndpointRoleType { - ANY, - WRITER, - READER + ANY = "ANY", + WRITER = "WRITER", + READER = "READER", + UNKNOWN = "UNKNOWN" } -const nameToValue = new Map([ - ["ANY", CustomEndpointRoleType.ANY], - ["WRITER", CustomEndpointRoleType.WRITER], - ["READER", CustomEndpointRoleType.READER] -]); - -export function customEndpointRoleTypeFromValue(name: string): CustomEndpointRoleType { - return nameToValue.get(name.toUpperCase()) ?? CustomEndpointRoleType.ANY; +export function customEndpointRoleTypeFromValue(value: string | null | undefined): CustomEndpointRoleType { + if (!value) { + return CustomEndpointRoleType.UNKNOWN; + } + const normalized = value.toUpperCase(); + return Object.values(CustomEndpointRoleType).find((v) => v === normalized) ?? CustomEndpointRoleType.UNKNOWN; } diff --git a/common/lib/plugins/default_plugin.ts b/common/lib/plugins/default_plugin.ts index 3af9b3fc2..2c3b45d8a 100644 --- a/common/lib/plugins/default_plugin.ts +++ b/common/lib/plugins/default_plugin.ts @@ -30,15 +30,18 @@ import { HostAvailability } from "../host_availability/host_availability"; import { ClientWrapper } from "../client_wrapper"; import { TelemetryTraceLevel } from "../utils/telemetry/telemetry_trace_level"; import { ConnectionInfo } from "../connection_info"; +import { FullServicesContainer } from "../utils/full_services_container"; export class DefaultPlugin extends AbstractConnectionPlugin { id: string = uniqueId("_defaultPlugin"); + private readonly servicesContainer: FullServicesContainer; private readonly pluginService: PluginService; private readonly connectionProviderManager: ConnectionProviderManager; - constructor(pluginService: PluginService, connectionProviderManager: ConnectionProviderManager) { + constructor(servicesContainer: FullServicesContainer, connectionProviderManager: ConnectionProviderManager) { super(); - this.pluginService = pluginService; + this.servicesContainer = servicesContainer; + this.pluginService = servicesContainer.pluginService; this.connectionProviderManager = connectionProviderManager; } diff --git a/common/lib/plugins/dev/developer_connection_plugin_factory.ts b/common/lib/plugins/dev/developer_connection_plugin_factory.ts index ef942d673..e03dfa328 100644 --- a/common/lib/plugins/dev/developer_connection_plugin_factory.ts +++ b/common/lib/plugins/dev/developer_connection_plugin_factory.ts @@ -15,21 +15,25 @@ */ import { ConnectionPluginFactory } from "../../plugin_factory"; -import { PluginService } from "../../plugin_service"; import { ConnectionPlugin } from "../../connection_plugin"; import { RdsUtils } from "../../utils/rds_utils"; import { Messages } from "../../utils/messages"; import { AwsWrapperError } from "../../utils/errors"; +import { FullServicesContainer } from "../../utils/full_services_container"; export class DeveloperConnectionPluginFactory extends ConnectionPluginFactory { private static developerPlugin: any; - async getInstance(pluginService: PluginService, properties: Map): Promise { + async getInstance(servicesContainer: FullServicesContainer, properties: Map): Promise { try { if (!DeveloperConnectionPluginFactory.developerPlugin) { DeveloperConnectionPluginFactory.developerPlugin = await import("./developer_connection_plugin"); } - return new DeveloperConnectionPluginFactory.developerPlugin.DeveloperConnectionPlugin(pluginService, properties, new RdsUtils()); + return new DeveloperConnectionPluginFactory.developerPlugin.DeveloperConnectionPlugin( + servicesContainer.pluginService, + properties, + new RdsUtils() + ); } catch (error: any) { throw new AwsWrapperError(Messages.get("ConnectionPluginChainBuilder.errorImportingPlugin", error.message, "DeveloperConnectionPlugin")); } diff --git a/common/lib/plugins/efm/host_monitoring_connection_plugin.ts b/common/lib/plugins/efm/host_monitoring_connection_plugin.ts index e835622f0..753e2f695 100644 --- a/common/lib/plugins/efm/host_monitoring_connection_plugin.ts +++ b/common/lib/plugins/efm/host_monitoring_connection_plugin.ts @@ -14,12 +14,7 @@ limitations under the License. */ -import { - HostInfo, - AwsWrapperError, - UnavailableHostError, - HostAvailability -} from "../../"; +import { HostInfo, AwsWrapperError, UnavailableHostError, HostAvailability } from "../../"; import { PluginService } from "../../plugin_service"; import { HostChangeOptions } from "../../host_change_options"; import { OldConnectionSuggestionAction } from "../../old_connection_suggestion_action"; diff --git a/common/lib/plugins/efm/host_monitoring_plugin_factory.ts b/common/lib/plugins/efm/host_monitoring_plugin_factory.ts index 62b59cd6a..d6b430954 100644 --- a/common/lib/plugins/efm/host_monitoring_plugin_factory.ts +++ b/common/lib/plugins/efm/host_monitoring_plugin_factory.ts @@ -15,21 +15,25 @@ */ import { ConnectionPluginFactory } from "../../plugin_factory"; -import { PluginService } from "../../plugin_service"; import { ConnectionPlugin } from "../../connection_plugin"; import { RdsUtils } from "../../utils/rds_utils"; import { AwsWrapperError } from "../../utils/errors"; import { Messages } from "../../utils/messages"; +import { FullServicesContainer } from "../../utils/full_services_container"; export class HostMonitoringPluginFactory extends ConnectionPluginFactory { private static hostMonitoringPlugin: any; - async getInstance(pluginService: PluginService, properties: Map): Promise { + async getInstance(servicesContainer: FullServicesContainer, properties: Map): Promise { try { if (!HostMonitoringPluginFactory.hostMonitoringPlugin) { HostMonitoringPluginFactory.hostMonitoringPlugin = await import("./host_monitoring_connection_plugin"); } - return new HostMonitoringPluginFactory.hostMonitoringPlugin.HostMonitoringConnectionPlugin(pluginService, properties, new RdsUtils()); + return new HostMonitoringPluginFactory.hostMonitoringPlugin.HostMonitoringConnectionPlugin( + servicesContainer.pluginService, + properties, + new RdsUtils() + ); } catch (error: any) { throw new AwsWrapperError(Messages.get("ConnectionPluginChainBuilder.errorImportingPlugin", error.message, "HostMonitoringPlugin")); } diff --git a/common/lib/plugins/efm2/host_monitoring2_plugin_factory.ts b/common/lib/plugins/efm2/host_monitoring2_plugin_factory.ts index 6763b0285..0ac0319f4 100644 --- a/common/lib/plugins/efm2/host_monitoring2_plugin_factory.ts +++ b/common/lib/plugins/efm2/host_monitoring2_plugin_factory.ts @@ -15,21 +15,25 @@ */ import { ConnectionPluginFactory } from "../../plugin_factory"; -import { PluginService } from "../../plugin_service"; import { ConnectionPlugin } from "../../connection_plugin"; import { RdsUtils } from "../../utils/rds_utils"; import { AwsWrapperError } from "../../utils/errors"; import { Messages } from "../../utils/messages"; +import { FullServicesContainer } from "../../utils/full_services_container"; export class HostMonitoring2PluginFactory extends ConnectionPluginFactory { private static hostMonitoring2Plugin: any; - async getInstance(pluginService: PluginService, properties: Map): Promise { + async getInstance(servicesContainer: FullServicesContainer, properties: Map): Promise { try { if (!HostMonitoring2PluginFactory.hostMonitoring2Plugin) { HostMonitoring2PluginFactory.hostMonitoring2Plugin = await import("./host_monitoring2_connection_plugin"); } - return new HostMonitoring2PluginFactory.hostMonitoring2Plugin.HostMonitoring2ConnectionPlugin(pluginService, properties, new RdsUtils()); + return new HostMonitoring2PluginFactory.hostMonitoring2Plugin.HostMonitoring2ConnectionPlugin( + servicesContainer.pluginService, + properties, + new RdsUtils() + ); } catch (error: any) { throw new AwsWrapperError(Messages.get("ConnectionPluginChainBuilder.errorImportingPlugin", error.message, "HostMonitoringPlugin")); } diff --git a/common/lib/plugins/execute_time_plugin_factory.ts b/common/lib/plugins/execute_time_plugin_factory.ts index 36e2885e4..3087ad204 100644 --- a/common/lib/plugins/execute_time_plugin_factory.ts +++ b/common/lib/plugins/execute_time_plugin_factory.ts @@ -15,15 +15,15 @@ */ import { ConnectionPluginFactory } from "../plugin_factory"; -import { PluginService } from "../plugin_service"; import { ConnectionPlugin } from "../connection_plugin"; import { AwsWrapperError } from "../utils/errors"; import { Messages } from "../utils/messages"; +import { FullServicesContainer } from "../utils/full_services_container"; export class ExecuteTimePluginFactory extends ConnectionPluginFactory { private static executeTimePlugin: any; - async getInstance(pluginService: PluginService, properties: Map): Promise { + async getInstance(servicesContainer: FullServicesContainer, properties: Map): Promise { try { if (!ExecuteTimePluginFactory.executeTimePlugin) { ExecuteTimePluginFactory.executeTimePlugin = await import("./execute_time_plugin"); diff --git a/common/lib/plugins/failover/failover_mode.ts b/common/lib/plugins/failover/failover_mode.ts index e6522335d..42444f6cc 100644 --- a/common/lib/plugins/failover/failover_mode.ts +++ b/common/lib/plugins/failover/failover_mode.ts @@ -15,19 +15,16 @@ */ export enum FailoverMode { - STRICT_WRITER, - STRICT_READER, - READER_OR_WRITER, - UNKNOWN + STRICT_WRITER = "strict-writer", + STRICT_READER = "strict-reader", + READER_OR_WRITER = "reader-or-writer", + UNKNOWN = "unknown" } -const nameToValue = new Map([ - ["strict-writer", FailoverMode.STRICT_WRITER], - ["strict-reader", FailoverMode.STRICT_READER], - ["reader-or-writer", FailoverMode.READER_OR_WRITER], - ["unknown", FailoverMode.UNKNOWN] -]); - -export function failoverModeFromValue(name: string): FailoverMode { - return nameToValue.get(name.toLowerCase()) ?? FailoverMode.UNKNOWN; +export function failoverModeFromValue(value: string | null | undefined): FailoverMode { + if (!value) { + return FailoverMode.UNKNOWN; + } + const normalized = value.toLowerCase(); + return Object.values(FailoverMode).find((v) => v === normalized) ?? FailoverMode.UNKNOWN; } diff --git a/common/lib/plugins/failover/failover_plugin.ts b/common/lib/plugins/failover/failover_plugin.ts index 4145e5196..bbb81e73e 100644 --- a/common/lib/plugins/failover/failover_plugin.ts +++ b/common/lib/plugins/failover/failover_plugin.ts @@ -205,6 +205,7 @@ export class FailoverPlugin extends AbstractConnectionPlugin { return ( this.enableFailoverSetting && this._rdsUrlType !== RdsUrlType.RDS_PROXY && + this._rdsUrlType !== RdsUrlType.RDS_PROXY_ENDPOINT && this.pluginService.getAllHosts() && this.pluginService.getAllHosts().length > 0 ); @@ -541,7 +542,7 @@ export class FailoverPlugin extends AbstractConnectionPlugin { this._readerFailoverHandler.setEnableFailoverStrictReader(this.failoverMode === FailoverMode.STRICT_READER); - logger.debug(Messages.get("Failover.parameterValue", "failoverMode", FailoverMode[this.failoverMode])); + logger.debug(Messages.get("Failover.parameterValue", "failoverMode", String(this.failoverMode))); } } } diff --git a/common/lib/plugins/failover/failover_plugin_factory.ts b/common/lib/plugins/failover/failover_plugin_factory.ts index 5f8a5c41c..cab67b400 100644 --- a/common/lib/plugins/failover/failover_plugin_factory.ts +++ b/common/lib/plugins/failover/failover_plugin_factory.ts @@ -15,21 +15,21 @@ */ import { ConnectionPluginFactory } from "../../plugin_factory"; -import { PluginService } from "../../plugin_service"; import { ConnectionPlugin } from "../../connection_plugin"; import { RdsUtils } from "../../utils/rds_utils"; import { AwsWrapperError } from "../../utils/errors"; import { Messages } from "../../utils/messages"; +import { FullServicesContainer } from "../../utils/full_services_container"; export class FailoverPluginFactory extends ConnectionPluginFactory { private static failoverPlugin: any; - async getInstance(pluginService: PluginService, properties: Map): Promise { + async getInstance(servicesContainer: FullServicesContainer, properties: Map): Promise { try { if (!FailoverPluginFactory.failoverPlugin) { FailoverPluginFactory.failoverPlugin = await import("./failover_plugin"); } - return new FailoverPluginFactory.failoverPlugin.FailoverPlugin(pluginService, properties, new RdsUtils()); + return new FailoverPluginFactory.failoverPlugin.FailoverPlugin(servicesContainer.pluginService, properties, new RdsUtils()); } catch (error: any) { throw new AwsWrapperError(Messages.get("ConnectionPluginChainBuilder.errorImportingPlugin", error.message, "FailoverPlugin")); } diff --git a/common/lib/plugins/failover/writer_failover_handler.ts b/common/lib/plugins/failover/writer_failover_handler.ts index 70d46eab2..b8780b888 100644 --- a/common/lib/plugins/failover/writer_failover_handler.ts +++ b/common/lib/plugins/failover/writer_failover_handler.ts @@ -236,7 +236,7 @@ class ReconnectToWriterHandlerTask { const props = new Map(this.initialConnectionProps); props.set(WrapperProperties.HOST.name, this.originalWriterHost.host); this.currentClient = await this.pluginService.forceConnect(this.originalWriterHost, props); - await this.pluginService.forceRefreshHostList(this.currentClient); + await this.pluginService.forceRefreshHostList(); latestTopology = this.pluginService.getAllHosts(); } catch (error) { // Propagate errors that are not caused by network errors. @@ -382,7 +382,7 @@ class WaitForNewWriterHandlerTask { while (this.pluginService.getCurrentClient() && Date.now() < this.endTime && !this.failoverCompleted) { try { if (this.currentReaderTargetClient) { - await this.pluginService.forceRefreshHostList(this.currentReaderTargetClient); + await this.pluginService.forceRefreshHostList(); } const topology = this.pluginService.getAllHosts(); diff --git a/common/lib/plugins/failover2/failover2_plugin.ts b/common/lib/plugins/failover2/failover2_plugin.ts index 3b4a7f044..f9e04fe6c 100644 --- a/common/lib/plugins/failover2/failover2_plugin.ts +++ b/common/lib/plugins/failover2/failover2_plugin.ts @@ -38,46 +38,47 @@ import { ClientWrapper } from "../../client_wrapper"; import { HostAvailability } from "../../host_availability/host_availability"; import { TelemetryTraceLevel } from "../../utils/telemetry/telemetry_trace_level"; import { HostRole } from "../../host_role"; -import { CanReleaseResources } from "../../can_release_resources"; import { ReaderFailoverResult } from "../failover/reader_failover_result"; -import { BlockingHostListProvider, HostListProvider } from "../../host_list_provider/host_list_provider"; import { logTopology } from "../../utils/utils"; +import { FullServicesContainer } from "../../utils/full_services_container"; -export class Failover2Plugin extends AbstractConnectionPlugin implements CanReleaseResources { +export class Failover2Plugin extends AbstractConnectionPlugin { private static readonly TELEMETRY_WRITER_FAILOVER = "failover to writer instance"; private static readonly TELEMETRY_READER_FAILOVER = "failover to reader"; private static readonly METHOD_END = "end"; private static readonly SUBSCRIBED_METHODS: Set = new Set(["initHostProvider", "connect", "query"]); private readonly _staleDnsHelper: StaleDnsHelper; - private readonly _properties: Map; - private readonly pluginService: PluginService; - private readonly _rdsHelper: RdsUtils; - private readonly failoverWriterTriggeredCounter: TelemetryCounter; - private readonly failoverWriterSuccessCounter: TelemetryCounter; - private readonly failoverWriterFailedCounter: TelemetryCounter; - private readonly failoverReaderTriggeredCounter: TelemetryCounter; - private readonly failoverReaderSuccessCounter: TelemetryCounter; - private readonly failoverReaderFailedCounter: TelemetryCounter; - private telemetryFailoverAdditionalTopTraceSetting: boolean = false; - private _rdsUrlType: RdsUrlType | null = null; - private _isInTransaction: boolean = false; + protected readonly properties: Map; + private readonly servicesContainer: FullServicesContainer; + protected readonly pluginService: PluginService; + protected readonly rdsHelper: RdsUtils; + protected readonly failoverWriterTriggeredCounter: TelemetryCounter; + protected readonly failoverWriterSuccessCounter: TelemetryCounter; + protected readonly failoverWriterFailedCounter: TelemetryCounter; + protected readonly failoverReaderTriggeredCounter: TelemetryCounter; + protected readonly failoverReaderSuccessCounter: TelemetryCounter; + protected readonly failoverReaderFailedCounter: TelemetryCounter; + protected telemetryFailoverAdditionalTopTraceSetting: boolean = false; + protected rdsUrlType: RdsUrlType | null = null; + protected _isInTransaction: boolean = false; private _lastError: any; failoverMode: FailoverMode = FailoverMode.UNKNOWN; - private hostListProviderService?: HostListProviderService; + protected hostListProviderService?: HostListProviderService; protected enableFailoverSetting: boolean = WrapperProperties.ENABLE_CLUSTER_AWARE_FAILOVER.defaultValue; - private readonly failoverTimeoutSettingMs: number = WrapperProperties.FAILOVER_TIMEOUT_MS.defaultValue; - private readonly failoverReaderHostSelectorStrategy: string = WrapperProperties.FAILOVER_READER_HOST_SELECTOR_STRATEGY.defaultValue; + protected readonly failoverTimeoutSettingMs: number = WrapperProperties.FAILOVER_TIMEOUT_MS.defaultValue; + protected readonly failoverReaderHostSelectorStrategy: string = WrapperProperties.FAILOVER_READER_HOST_SELECTOR_STRATEGY.defaultValue; - constructor(pluginService: PluginService, properties: Map, rdsHelper: RdsUtils) { + constructor(servicesContainer: FullServicesContainer, properties: Map, rdsHelper: RdsUtils) { super(); - this._properties = properties; - this.pluginService = pluginService; - this._rdsHelper = rdsHelper; + this.properties = properties; + this.servicesContainer = servicesContainer; + this.pluginService = servicesContainer.pluginService; + this.rdsHelper = rdsHelper; this._staleDnsHelper = new StaleDnsHelper(this.pluginService); - this.enableFailoverSetting = WrapperProperties.ENABLE_CLUSTER_AWARE_FAILOVER.get(this._properties); - this.failoverTimeoutSettingMs = WrapperProperties.FAILOVER_TIMEOUT_MS.get(this._properties); - this.failoverReaderHostSelectorStrategy = WrapperProperties.FAILOVER_READER_HOST_SELECTOR_STRATEGY.get(this._properties); + this.enableFailoverSetting = WrapperProperties.ENABLE_CLUSTER_AWARE_FAILOVER.get(this.properties); + this.failoverTimeoutSettingMs = WrapperProperties.FAILOVER_TIMEOUT_MS.get(this.properties); + this.failoverReaderHostSelectorStrategy = WrapperProperties.FAILOVER_READER_HOST_SELECTOR_STRATEGY.get(this.properties); const telemetryFactory = this.pluginService.getTelemetryFactory(); this.failoverWriterTriggeredCounter = telemetryFactory.createCounter("writerFailover.triggered.count"); @@ -104,21 +105,13 @@ export class Failover2Plugin extends AbstractConnectionPlugin implements CanRele } initHostProviderFunc(); - - this.failoverMode = failoverModeFromValue(WrapperProperties.FAILOVER_MODE.get(props)); - this._rdsUrlType = this._rdsHelper.identifyRdsType(hostInfo.host); - - if (this.failoverMode === FailoverMode.UNKNOWN) { - this.failoverMode = this._rdsUrlType === RdsUrlType.RDS_READER_CLUSTER ? FailoverMode.READER_OR_WRITER : FailoverMode.STRICT_WRITER; - } - - logger.debug(Messages.get("Failover.parameterValue", "failoverMode", FailoverMode[this.failoverMode])); } - private isFailoverEnabled(): boolean { + protected isFailoverEnabled(): boolean { return ( this.enableFailoverSetting && - this._rdsUrlType !== RdsUrlType.RDS_PROXY && + this.rdsUrlType !== RdsUrlType.RDS_PROXY && + this.rdsUrlType !== RdsUrlType.RDS_PROXY_ENDPOINT && this.pluginService.getAllHosts() && this.pluginService.getAllHosts().length > 0 ); @@ -130,6 +123,8 @@ export class Failover2Plugin extends AbstractConnectionPlugin implements CanRele isInitialConnection: boolean, connectFunc: () => Promise ): Promise { + this.initFailoverMode(); + if ( // Failover is not enabled, does not require additional processing. !this.enableFailoverSetting || @@ -188,7 +183,7 @@ export class Failover2Plugin extends AbstractConnectionPlugin implements CanRele } if (isInitialConnection) { - await this.pluginService.refreshHostList(client); + await this.pluginService.refreshHostList(); } return client; @@ -231,6 +226,10 @@ export class Failover2Plugin extends AbstractConnectionPlugin implements CanRele await this.failoverReader(); } + this.throwFailoverSuccessException(); + } + + protected throwFailoverSuccessException(): void { if (this._isInTransaction || this.pluginService.isInTransaction()) { // "Transaction resolution unknown. Please re-configure session state if required and try // restarting transaction." @@ -243,7 +242,7 @@ export class Failover2Plugin extends AbstractConnectionPlugin implements CanRele } } - async failoverReader() { + private async failoverReader() { const telemetryFactory = this.pluginService.getTelemetryFactory(); const telemetryContext = telemetryFactory.openTelemetryContext(Failover2Plugin.TELEMETRY_READER_FAILOVER, TelemetryTraceLevel.NESTED); this.failoverReaderTriggeredCounter.inc(); @@ -367,7 +366,7 @@ export class Failover2Plugin extends AbstractConnectionPlugin implements CanRele throw new InternalQueryTimeoutError(Messages.get("Failover.timeoutError")); } - async failoverWriter() { + private async failoverWriter() { const telemetryFactory = this.pluginService.getTelemetryFactory(); const telemetryContext = telemetryFactory.openTelemetryContext(Failover2Plugin.TELEMETRY_WRITER_FAILOVER, TelemetryTraceLevel.NESTED); this.failoverWriterTriggeredCounter.inc(); @@ -430,7 +429,7 @@ export class Failover2Plugin extends AbstractConnectionPlugin implements CanRele } private async createConnectionForHost(hostInfo: HostInfo): Promise { - const copyProps = new Map(this._properties); + const copyProps = new Map(this.properties); copyProps.set(WrapperProperties.HOST.name, hostInfo.host); return await this.pluginService.connect(hostInfo, copyProps, this); } @@ -464,6 +463,22 @@ export class Failover2Plugin extends AbstractConnectionPlugin implements CanRele return methodName === Failover2Plugin.METHOD_END; } + protected initFailoverMode(): void { + if (this.rdsUrlType) { + return; + } + + this.failoverMode = failoverModeFromValue(WrapperProperties.FAILOVER_MODE.get(this.properties)); + const initialHostInfo: HostInfo | undefined | null = this.hostListProviderService?.getInitialConnectionHostInfo(); + this.rdsUrlType = this.rdsHelper.identifyRdsType(initialHostInfo?.host); + + if (this.failoverMode === FailoverMode.UNKNOWN) { + this.failoverMode = this.rdsUrlType === RdsUrlType.RDS_READER_CLUSTER ? FailoverMode.READER_OR_WRITER : FailoverMode.STRICT_WRITER; + } + + logger.debug(Messages.get("Failover.parameterValue", "failoverMode", String(this.failoverMode))); + } + private shouldErrorTriggerClientSwitch(error: any): boolean { if (!this.isFailoverEnabled()) { logger.debug(Messages.get("Failover.failoverDisabled")); @@ -486,11 +501,4 @@ export class Failover2Plugin extends AbstractConnectionPlugin implements CanRele this.failoverWriterFailedCounter.inc(); throw new FailoverFailedError(errorMessage); } - - async releaseResources(): Promise { - const hostListProvider: HostListProvider = this.pluginService.getHostListProvider(); - if (this.hostListProviderService.isBlockingHostListProvider(hostListProvider)) { - await (hostListProvider as BlockingHostListProvider).clearAll(); - } - } } diff --git a/common/lib/plugins/failover2/failover2_plugin_factory.ts b/common/lib/plugins/failover2/failover2_plugin_factory.ts index 64b2bcdbf..d60687601 100644 --- a/common/lib/plugins/failover2/failover2_plugin_factory.ts +++ b/common/lib/plugins/failover2/failover2_plugin_factory.ts @@ -15,21 +15,21 @@ */ import { ConnectionPluginFactory } from "../../plugin_factory"; -import { PluginService } from "../../plugin_service"; import { ConnectionPlugin } from "../../connection_plugin"; import { RdsUtils } from "../../utils/rds_utils"; import { AwsWrapperError } from "../../utils/errors"; import { Messages } from "../../utils/messages"; +import { FullServicesContainer } from "../../utils/full_services_container"; export class Failover2PluginFactory extends ConnectionPluginFactory { private static failover2Plugin: any; - async getInstance(pluginService: PluginService, properties: Map): Promise { + async getInstance(servicesContainer: FullServicesContainer, properties: Map): Promise { try { if (!Failover2PluginFactory.failover2Plugin) { Failover2PluginFactory.failover2Plugin = await import("./failover2_plugin"); } - return new Failover2PluginFactory.failover2Plugin.Failover2Plugin(pluginService, properties, new RdsUtils()); + return new Failover2PluginFactory.failover2Plugin.Failover2Plugin(servicesContainer, properties, new RdsUtils()); } catch (error: any) { throw new AwsWrapperError(Messages.get("ConnectionPluginChainBuilder.errorImportingPlugin", error.message, "Failover2Plugin")); } diff --git a/common/lib/plugins/federated_auth/federated_auth_plugin_factory.ts b/common/lib/plugins/federated_auth/federated_auth_plugin_factory.ts index c77895020..ac391d642 100644 --- a/common/lib/plugins/federated_auth/federated_auth_plugin_factory.ts +++ b/common/lib/plugins/federated_auth/federated_auth_plugin_factory.ts @@ -15,16 +15,16 @@ */ import { ConnectionPluginFactory } from "../../plugin_factory"; -import { PluginService } from "../../plugin_service"; import { ConnectionPlugin } from "../../connection_plugin"; import { AwsWrapperError } from "../../utils/errors"; import { Messages } from "../../utils/messages"; +import { FullServicesContainer } from "../../utils/full_services_container"; export class FederatedAuthPluginFactory extends ConnectionPluginFactory { private static federatedAuthPlugin: any; private static adfsCredentialsProvider: any; - async getInstance(pluginService: PluginService, properties: Map): Promise { + async getInstance(servicesContainer: FullServicesContainer, properties: Map): Promise { try { if (!FederatedAuthPluginFactory.federatedAuthPlugin) { FederatedAuthPluginFactory.federatedAuthPlugin = await import("./federated_auth_plugin"); @@ -34,6 +34,7 @@ export class FederatedAuthPluginFactory extends ConnectionPluginFactory { FederatedAuthPluginFactory.adfsCredentialsProvider = await import("./adfs_credentials_provider_factory"); } + const pluginService = servicesContainer.pluginService; const adfsCredentialsProviderFactory = new FederatedAuthPluginFactory.adfsCredentialsProvider.AdfsCredentialsProviderFactory(pluginService); return new FederatedAuthPluginFactory.federatedAuthPlugin.FederatedAuthPlugin(pluginService, adfsCredentialsProviderFactory); } catch (error: any) { diff --git a/common/lib/plugins/federated_auth/okta_auth_plugin_factory.ts b/common/lib/plugins/federated_auth/okta_auth_plugin_factory.ts index c4b80147b..b3fc53329 100644 --- a/common/lib/plugins/federated_auth/okta_auth_plugin_factory.ts +++ b/common/lib/plugins/federated_auth/okta_auth_plugin_factory.ts @@ -15,16 +15,16 @@ */ import { ConnectionPluginFactory } from "../../plugin_factory"; -import { PluginService } from "../../plugin_service"; import { ConnectionPlugin } from "../../connection_plugin"; import { AwsWrapperError } from "../../utils/errors"; import { Messages } from "../../utils/messages"; +import { FullServicesContainer } from "../../utils/full_services_container"; export class OktaAuthPluginFactory extends ConnectionPluginFactory { private static oktaAuthPlugin: any; private static oktaCredentialsProviderFactory: any; - async getInstance(pluginService: PluginService, properties: Map): Promise { + async getInstance(servicesContainer: FullServicesContainer, properties: Map): Promise { try { if (!OktaAuthPluginFactory.oktaAuthPlugin) { OktaAuthPluginFactory.oktaAuthPlugin = await import("./okta_auth_plugin"); @@ -33,6 +33,7 @@ export class OktaAuthPluginFactory extends ConnectionPluginFactory { OktaAuthPluginFactory.oktaCredentialsProviderFactory = await import("./okta_credentials_provider_factory"); } + const pluginService = servicesContainer.pluginService; const oktaCredentialsProviderFactory = new OktaAuthPluginFactory.oktaCredentialsProviderFactory.OktaCredentialsProviderFactory(pluginService); return new OktaAuthPluginFactory.oktaAuthPlugin.OktaAuthPlugin(pluginService, oktaCredentialsProviderFactory); } catch (error: any) { diff --git a/common/lib/plugins/gdb_failover/global_db_failover_mode.ts b/common/lib/plugins/gdb_failover/global_db_failover_mode.ts new file mode 100644 index 000000000..975372eca --- /dev/null +++ b/common/lib/plugins/gdb_failover/global_db_failover_mode.ts @@ -0,0 +1,34 @@ +/* + Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"). + You may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +export enum GlobalDbFailoverMode { + STRICT_WRITER = "strict-writer", + STRICT_HOME_READER = "strict-home-reader", + STRICT_OUT_OF_HOME_READER = "strict-out-of-home-reader", + STRICT_ANY_READER = "strict-any-reader", + HOME_READER_OR_WRITER = "home-reader-or-writer", + OUT_OF_HOME_READER_OR_WRITER = "out-of-home-reader-or-writer", + ANY_READER_OR_WRITER = "any-reader-or-writer", + UNKNOWN = "unknown" +} + +export function globalDbFailoverModeFromValue(value: string | null | undefined): GlobalDbFailoverMode { + if (!value) { + return GlobalDbFailoverMode.UNKNOWN; + } + const normalized = value.toLowerCase(); + return Object.values(GlobalDbFailoverMode).find((v) => v === normalized) ?? GlobalDbFailoverMode.UNKNOWN; +} diff --git a/common/lib/plugins/gdb_failover/global_db_failover_plugin.ts b/common/lib/plugins/gdb_failover/global_db_failover_plugin.ts new file mode 100644 index 000000000..dd366c1c5 --- /dev/null +++ b/common/lib/plugins/gdb_failover/global_db_failover_plugin.ts @@ -0,0 +1,367 @@ +/* + Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"). + You may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +import { RdsUtils } from "../../utils/rds_utils"; +import { GlobalDbFailoverMode, globalDbFailoverModeFromValue } from "./global_db_failover_mode"; +import { HostInfo } from "../../host_info"; +import { WrapperProperties } from "../../wrapper_property"; +import { RdsUrlType } from "../../utils/rds_url_type"; +import { logger } from "../../../logutils"; +import { Messages } from "../../utils/messages"; +import { AwsTimeoutError, AwsWrapperError, FailoverFailedError, FailoverSuccessError, UnsupportedMethodError } from "../../utils/errors"; +import { ClientWrapper } from "../../client_wrapper"; +import { HostAvailability } from "../../host_availability/host_availability"; +import { TelemetryTraceLevel } from "../../utils/telemetry/telemetry_trace_level"; +import { HostRole } from "../../host_role"; +import { ReaderFailoverResult } from "../failover/reader_failover_result"; +import { containsHostAndPort, convertNanosToMs, equalsIgnoreCase, getTimeInNanos, getWriter, logTopology, sleep } from "../../utils/utils"; +import { Failover2Plugin } from "../failover2/failover2_plugin"; +import { FullServicesContainer } from "../../utils/full_services_container"; + +export class GlobalDbFailoverPlugin extends Failover2Plugin { + private static readonly TELEMETRY_FAILOVER = "failover"; + + protected activeHomeFailoverMode: GlobalDbFailoverMode = GlobalDbFailoverMode.UNKNOWN; + protected inactiveHomeFailoverMode: GlobalDbFailoverMode = GlobalDbFailoverMode.UNKNOWN; + protected homeRegion: string | null = null; + + constructor(servicesContainer: FullServicesContainer, properties: Map, rdsHelper: RdsUtils) { + super(servicesContainer, properties, rdsHelper); + } + + protected initFailoverMode(): void { + if (this.rdsUrlType !== null) { + return; + } + + const initialHostInfo = this.hostListProviderService?.getInitialConnectionHostInfo(); + if (!initialHostInfo) { + throw new AwsWrapperError(Messages.get("GlobalDbFailoverPlugin.missingInitialHost")); + } + + this.rdsUrlType = this.rdsHelper.identifyRdsType(initialHostInfo.host); + + this.homeRegion = WrapperProperties.FAILOVER_HOME_REGION.get(this.properties) ?? null; + if (!this.homeRegion) { + if (!this.rdsUrlType.hasRegion) { + throw new AwsWrapperError(Messages.get("GlobalDbFailoverPlugin.missingHomeRegion")); + } + this.homeRegion = this.rdsHelper.getRdsRegion(initialHostInfo.host); + if (!this.homeRegion) { + throw new AwsWrapperError(Messages.get("GlobalDbFailoverPlugin.missingHomeRegion")); + } + } + + logger.debug(Messages.get("Failover.parameterValue", "failoverHomeRegion", this.homeRegion)); + + const activeHomeMode = WrapperProperties.ACTIVE_HOME_FAILOVER_MODE.get(this.properties); + const inactiveHomeMode = WrapperProperties.INACTIVE_HOME_FAILOVER_MODE.get(this.properties); + + this.activeHomeFailoverMode = globalDbFailoverModeFromValue(activeHomeMode); + this.inactiveHomeFailoverMode = globalDbFailoverModeFromValue(inactiveHomeMode); + + if (this.activeHomeFailoverMode === GlobalDbFailoverMode.UNKNOWN) { + switch (this.rdsUrlType) { + case RdsUrlType.RDS_WRITER_CLUSTER: + case RdsUrlType.RDS_GLOBAL_WRITER_CLUSTER: + this.activeHomeFailoverMode = GlobalDbFailoverMode.STRICT_WRITER; + break; + default: + this.activeHomeFailoverMode = GlobalDbFailoverMode.HOME_READER_OR_WRITER; + } + } + + if (this.inactiveHomeFailoverMode === GlobalDbFailoverMode.UNKNOWN) { + switch (this.rdsUrlType) { + case RdsUrlType.RDS_WRITER_CLUSTER: + case RdsUrlType.RDS_GLOBAL_WRITER_CLUSTER: + this.inactiveHomeFailoverMode = GlobalDbFailoverMode.STRICT_WRITER; + break; + default: + this.inactiveHomeFailoverMode = GlobalDbFailoverMode.HOME_READER_OR_WRITER; + } + } + + logger.debug(Messages.get("Failover.parameterValue", "activeHomeFailoverMode", this.activeHomeFailoverMode)); + logger.debug(Messages.get("Failover.parameterValue", "inactiveHomeFailoverMode", this.inactiveHomeFailoverMode)); + } + + override async failover(): Promise { + const telemetryFactory = this.pluginService.getTelemetryFactory(); + const telemetryContext = telemetryFactory.openTelemetryContext(GlobalDbFailoverPlugin.TELEMETRY_FAILOVER, TelemetryTraceLevel.NESTED); + + const failoverStartTimeNs = getTimeInNanos(); + const failoverEndTimeNs = failoverStartTimeNs + BigInt(this.failoverTimeoutSettingMs) * BigInt(1_000_000); + + try { + await telemetryContext.start(async () => { + logger.info(Messages.get("GlobalDbFailoverPlugin.startFailover")); + + // Force refresh host list and wait for topology to stabilize + const refreshResult = await this.pluginService.forceMonitoringRefresh(true, this.failoverTimeoutSettingMs); + if (!refreshResult) { + this.failoverWriterTriggeredCounter.inc(); + this.failoverWriterFailedCounter.inc(); + logger.error(Messages.get("Failover.unableToRefreshHostList")); + throw new FailoverFailedError(Messages.get("Failover.unableToRefreshHostList")); + } + + const updatedHosts = this.pluginService.getAllHosts(); + const writerCandidate = getWriter(updatedHosts); + + if (!writerCandidate) { + this.failoverWriterTriggeredCounter.inc(); + this.failoverWriterFailedCounter.inc(); + const message = logTopology(updatedHosts, Messages.get("Failover.unableToDetermineWriter")); + logger.error(message); + throw new FailoverFailedError(message); + } + + // Check writer region to determine failover mode + const writerRegion = this.rdsHelper.getRdsRegion(writerCandidate.host); + const isHomeRegion = equalsIgnoreCase(this.homeRegion, writerRegion); + logger.debug(Messages.get("GlobalDbFailoverPlugin.isHomeRegion", String(isHomeRegion))); + + const currentFailoverMode = isHomeRegion ? this.activeHomeFailoverMode : this.inactiveHomeFailoverMode; + logger.debug(Messages.get("GlobalDbFailoverPlugin.currentFailoverMode", String(currentFailoverMode))); + + switch (currentFailoverMode) { + case GlobalDbFailoverMode.STRICT_WRITER: + await this.failoverToWriter(writerCandidate); + break; + case GlobalDbFailoverMode.STRICT_HOME_READER: + await this.failoverToAllowedHost( + () => this.pluginService.getHosts().filter((x) => x.role === HostRole.READER && this.isHostInHomeRegion(x)), + HostRole.READER, + failoverEndTimeNs + ); + break; + case GlobalDbFailoverMode.STRICT_OUT_OF_HOME_READER: + await this.failoverToAllowedHost( + () => this.pluginService.getHosts().filter((x) => x.role === HostRole.READER && !this.isHostInHomeRegion(x)), + HostRole.READER, + failoverEndTimeNs + ); + break; + case GlobalDbFailoverMode.STRICT_ANY_READER: + await this.failoverToAllowedHost( + () => this.pluginService.getHosts().filter((x) => x.role === HostRole.READER), + HostRole.READER, + failoverEndTimeNs + ); + break; + case GlobalDbFailoverMode.HOME_READER_OR_WRITER: + await this.failoverToAllowedHost( + () => + this.pluginService.getHosts().filter((x) => x.role === HostRole.WRITER || (x.role === HostRole.READER && this.isHostInHomeRegion(x))), + null, + failoverEndTimeNs + ); + break; + case GlobalDbFailoverMode.OUT_OF_HOME_READER_OR_WRITER: + await this.failoverToAllowedHost( + () => + this.pluginService + .getHosts() + .filter((x) => x.role === HostRole.WRITER || (x.role === HostRole.READER && !this.isHostInHomeRegion(x))), + null, + failoverEndTimeNs + ); + break; + case GlobalDbFailoverMode.ANY_READER_OR_WRITER: + await this.failoverToAllowedHost(() => [...this.pluginService.getHosts()], null, failoverEndTimeNs); + break; + case GlobalDbFailoverMode.UNKNOWN: + default: + throw new UnsupportedMethodError(`Unsupported failover mode: ${currentFailoverMode}`); + } + + logger.debug(Messages.get("Failover.establishedConnection", this.pluginService.getCurrentHostInfo()?.host ?? "unknown")); + this.throwFailoverSuccessException(); + }); + } finally { + logger.debug(Messages.get("GlobalDbFailoverPlugin.failoverElapsed", String(convertNanosToMs(getTimeInNanos() - failoverStartTimeNs)))); + + if (this.telemetryFailoverAdditionalTopTraceSetting && telemetryContext) { + await telemetryFactory.postCopy(telemetryContext, TelemetryTraceLevel.FORCE_TOP_LEVEL); + } + } + } + + private isHostInHomeRegion(host: HostInfo): boolean { + const hostRegion = this.rdsHelper.getRdsRegion(host.host); + return equalsIgnoreCase(hostRegion, this.homeRegion); + } + + protected async failoverToWriter(writerCandidate: HostInfo): Promise { + this.failoverWriterTriggeredCounter.inc(); + let writerCandidateConn: ClientWrapper | null = null; + + try { + const allowedHosts = this.pluginService.getHosts(); + if (!containsHostAndPort(allowedHosts, writerCandidate.hostAndPort)) { + this.failoverWriterFailedCounter.inc(); + const topologyString = logTopology(allowedHosts, ""); + logger.error(Messages.get("Failover.newWriterNotAllowed", writerCandidate.url, topologyString)); + throw new FailoverFailedError(Messages.get("Failover.newWriterNotAllowed", writerCandidate.url, topologyString)); + } + + try { + writerCandidateConn = await this.pluginService.connect(writerCandidate, this.properties, this); + } catch (error) { + this.failoverWriterFailedCounter.inc(); + logger.error(Messages.get("Failover.unableToConnectToWriterDueToError", writerCandidate.host, error.message)); + throw new FailoverFailedError(Messages.get("Failover.unableToConnectToWriterDueToError", writerCandidate.host, error.message)); + } + + const role = await this.pluginService.getHostRole(writerCandidateConn); + if (role !== HostRole.WRITER) { + await writerCandidateConn?.abort(); + writerCandidateConn = null; + this.failoverWriterFailedCounter.inc(); + logger.error(Messages.get("Failover.unexpectedReaderRole", writerCandidate.host)); + throw new FailoverFailedError(Messages.get("Failover.unexpectedReaderRole", writerCandidate.host)); + } + + await this.pluginService.setCurrentClient(writerCandidateConn, writerCandidate); + writerCandidateConn = null; // Prevent connection from being closed in finally block + + this.failoverWriterSuccessCounter.inc(); + } catch (ex) { + if (!(ex instanceof FailoverFailedError)) { + // Counter has already been incremented in Failover2Plugin before throwing the FailoverFailedError. + // So no need to increment again here. + this.failoverWriterFailedCounter.inc(); + } + throw ex; + } finally { + if (writerCandidateConn && this.pluginService.getCurrentClient().targetClient !== writerCandidateConn) { + await writerCandidateConn.abort(); + } + } + } + + protected async failoverToAllowedHost(getAllowedHosts: () => HostInfo[], verifyRole: HostRole | null, failoverEndTimeNs: bigint): Promise { + this.failoverReaderTriggeredCounter.inc(); + + let result: ReaderFailoverResult | null = null; + try { + try { + result = await this.getAllowedFailoverConnection(getAllowedHosts, verifyRole, failoverEndTimeNs); + await this.pluginService.setCurrentClient(result.client!, result.newHost!); + result = null; + } catch (e) { + if (e instanceof AwsTimeoutError) { + logger.error(Messages.get("Failover.unableToConnectToReader")); + throw new FailoverFailedError(Messages.get("Failover.unableToConnectToReader")); + } + throw e; + } + + logger.info(Messages.get("Failover.establishedConnection", this.pluginService.getCurrentHostInfo()?.host ?? "unknown")); + this.throwFailoverSuccessException(); + } catch (ex) { + if (ex instanceof FailoverSuccessError) { + this.failoverReaderSuccessCounter.inc(); + } else { + this.failoverReaderFailedCounter.inc(); + } + throw ex; + } finally { + if (result?.client && result?.client !== this.pluginService.getCurrentClient().targetClient) { + await result?.client.abort(); + } + } + } + + protected async getAllowedFailoverConnection( + getAllowedHosts: () => HostInfo[], + verifyRole: HostRole | null, + failoverEndTimeNs: bigint + ): Promise { + do { + await this.pluginService.refreshHostList(); + let updatedAllowedHosts = getAllowedHosts(); + + // Make a copy of hosts and set their availability + updatedAllowedHosts = updatedAllowedHosts.map((x) => + this.pluginService.getHostInfoBuilder().copyFrom(x).withAvailability(HostAvailability.AVAILABLE).build() + ); + + const remainingAllowedHosts = [...updatedAllowedHosts]; + + if (remainingAllowedHosts.length === 0) { + await sleep(100); + continue; + } + + while (remainingAllowedHosts.length > 0 && getTimeInNanos() < failoverEndTimeNs) { + let candidateHost: HostInfo | undefined; + try { + candidateHost = this.pluginService.getHostInfoByStrategy(verifyRole, this.failoverReaderHostSelectorStrategy, remainingAllowedHosts); + } catch { + // Strategy can't get a host according to requested conditions. + // Do nothing + } + + if (!candidateHost) { + logger.debug( + logTopology( + remainingAllowedHosts, + `${Messages.get("GlobalDbFailoverPlugin.unableToFindCandidateWithMatchingRole", String(verifyRole), this.failoverReaderHostSelectorStrategy)}` + ) + ); + await sleep(100); + break; + } + + let candidateConn: ClientWrapper | null = null; + try { + candidateConn = await this.pluginService.connect(candidateHost, this.properties, this); + // Since the roles in the host list might not be accurate, we execute a query to check the instance's role + const role = verifyRole === null ? null : await this.pluginService.getHostRole(candidateConn); + + if (verifyRole === null || verifyRole === role) { + const updatedHostSpec = this.pluginService + .getHostInfoBuilder() + .copyFrom(candidateHost) + .withRole(role ?? candidateHost.role) + .build(); + return new ReaderFailoverResult(candidateConn, updatedHostSpec, true); + } + + // The role is not as expected, so the connection is not valid + const index = remainingAllowedHosts.findIndex((h) => h.hostAndPort === candidateHost!.hostAndPort); + if (index !== -1) { + remainingAllowedHosts.splice(index, 1); + } + await candidateConn.abort(); + candidateConn = null; + } catch { + const index = remainingAllowedHosts.findIndex((h) => h.hostAndPort === candidateHost!.hostAndPort); + if (index !== -1) { + remainingAllowedHosts.splice(index, 1); + } + if (candidateConn) { + await candidateConn.abort(); + } + } + } + } while (getTimeInNanos() < failoverEndTimeNs); // All hosts failed. Keep trying until we hit the timeout. + + throw new AwsTimeoutError(Messages.get("Failover.failoverReaderTimeout")); + } +} diff --git a/common/lib/plugins/gdb_failover/global_db_failover_plugin_factory.ts b/common/lib/plugins/gdb_failover/global_db_failover_plugin_factory.ts new file mode 100644 index 000000000..259505096 --- /dev/null +++ b/common/lib/plugins/gdb_failover/global_db_failover_plugin_factory.ts @@ -0,0 +1,38 @@ +/* + Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"). + You may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +import { ConnectionPluginFactory } from "../../plugin_factory"; +import { PluginService } from "../../plugin_service"; +import { ConnectionPlugin } from "../../connection_plugin"; +import { RdsUtils } from "../../utils/rds_utils"; +import { AwsWrapperError } from "../../utils/errors"; +import { Messages } from "../../utils/messages"; +import { FullServicesContainer } from "../../utils/full_services_container"; + +export class GlobalDbFailoverPluginFactory extends ConnectionPluginFactory { + private static globalDbFailoverPlugin: any; + + async getInstance(servicesContainer: FullServicesContainer, properties: Map): Promise { + try { + if (!GlobalDbFailoverPluginFactory.globalDbFailoverPlugin) { + GlobalDbFailoverPluginFactory.globalDbFailoverPlugin = await import("./global_db_failover_plugin"); + } + return new GlobalDbFailoverPluginFactory.globalDbFailoverPlugin.GlobalDbFailoverPlugin(servicesContainer, properties, new RdsUtils()); + } catch (error: any) { + throw new AwsWrapperError(Messages.get("ConnectionPluginChainBuilder.errorImportingPlugin", error.message, "GlobalDbFailoverPlugin")); + } + } +} diff --git a/common/lib/plugins/limitless/limitless_connection_plugin_factory.ts b/common/lib/plugins/limitless/limitless_connection_plugin_factory.ts index 330c02e63..b32677664 100644 --- a/common/lib/plugins/limitless/limitless_connection_plugin_factory.ts +++ b/common/lib/plugins/limitless/limitless_connection_plugin_factory.ts @@ -15,20 +15,20 @@ */ import { ConnectionPluginFactory } from "../../plugin_factory"; -import { PluginService } from "../../plugin_service"; import { ConnectionPlugin } from "../../connection_plugin"; import { AwsWrapperError } from "../../utils/errors"; import { Messages } from "../../utils/messages"; +import { FullServicesContainer } from "../../utils/full_services_container"; export class LimitlessConnectionPluginFactory implements ConnectionPluginFactory { private static limitlessPlugin: any; - async getInstance(pluginService: PluginService, properties: Map): Promise { + async getInstance(servicesContainer: FullServicesContainer, properties: Map): Promise { try { if (!LimitlessConnectionPluginFactory.limitlessPlugin) { LimitlessConnectionPluginFactory.limitlessPlugin = await import("./limitless_connection_plugin"); } - return new LimitlessConnectionPluginFactory.limitlessPlugin.LimitlessConnectionPlugin(pluginService, properties); + return new LimitlessConnectionPluginFactory.limitlessPlugin.LimitlessConnectionPlugin(servicesContainer.pluginService, properties); } catch (error: any) { throw new AwsWrapperError(Messages.get("ConnectionPluginChainBuilder.errorImportingPlugin", error.message, "LimitlessConnectionPlugin")); } diff --git a/common/lib/plugins/read_write_splitting/abstract_read_write_splitting_plugin.ts b/common/lib/plugins/read_write_splitting/abstract_read_write_splitting_plugin.ts index 654822114..14008e4aa 100644 --- a/common/lib/plugins/read_write_splitting/abstract_read_write_splitting_plugin.ts +++ b/common/lib/plugins/read_write_splitting/abstract_read_write_splitting_plugin.ts @@ -282,7 +282,6 @@ export abstract class AbstractReadWriteSplittingPlugin extends AbstractConnectio } async closeIdleClients() { - logger.debug(Messages.get("ReadWriteSplittingPlugin.closingInternalClients")); await this.closeReaderClientIfIdle(); await this.closeWriterClientIfIdle(); } diff --git a/common/lib/plugins/read_write_splitting/read_write_splitting_plugin.ts b/common/lib/plugins/read_write_splitting/read_write_splitting_plugin.ts index 397d5dfbb..6ade242fa 100644 --- a/common/lib/plugins/read_write_splitting/read_write_splitting_plugin.ts +++ b/common/lib/plugins/read_write_splitting/read_write_splitting_plugin.ts @@ -61,7 +61,7 @@ export class ReadWriteSplittingPlugin extends AbstractReadWriteSplittingPlugin { } const result = await connectFunc(); - if (!isInitialConnection || this._hostListProviderService?.isStaticHostListProvider()) { + if (!isInitialConnection || !this._hostListProviderService?.isDynamicHostListProvider()) { return result; } const currentRole = this.pluginService.getCurrentHostInfo()?.role; diff --git a/common/lib/plugins/read_write_splitting/read_write_splitting_plugin_factory.ts b/common/lib/plugins/read_write_splitting/read_write_splitting_plugin_factory.ts index 62485db1e..ac7fd53bd 100644 --- a/common/lib/plugins/read_write_splitting/read_write_splitting_plugin_factory.ts +++ b/common/lib/plugins/read_write_splitting/read_write_splitting_plugin_factory.ts @@ -15,20 +15,20 @@ */ import { ConnectionPluginFactory } from "../../plugin_factory"; -import { PluginService } from "../../plugin_service"; import { ConnectionPlugin } from "../../connection_plugin"; import { AwsWrapperError } from "../../utils/errors"; import { Messages } from "../../utils/messages"; +import { FullServicesContainer } from "../../utils/full_services_container"; export class ReadWriteSplittingPluginFactory extends ConnectionPluginFactory { private static readWriteSplittingPlugin: any; - async getInstance(pluginService: PluginService, properties: Map): Promise { + async getInstance(servicesContainer: FullServicesContainer, properties: Map): Promise { try { if (!ReadWriteSplittingPluginFactory.readWriteSplittingPlugin) { ReadWriteSplittingPluginFactory.readWriteSplittingPlugin = await import("./read_write_splitting_plugin"); } - return new ReadWriteSplittingPluginFactory.readWriteSplittingPlugin.ReadWriteSplittingPlugin(pluginService, properties); + return new ReadWriteSplittingPluginFactory.readWriteSplittingPlugin.ReadWriteSplittingPlugin(servicesContainer.pluginService, properties); } catch (error: any) { throw new AwsWrapperError(Messages.get("ConnectionPluginChainBuilder.errorImportingPlugin", error.message, "readWriteSplittingPlugin")); } diff --git a/common/lib/plugins/stale_dns/stale_dns_helper.ts b/common/lib/plugins/stale_dns/stale_dns_helper.ts index 501fd5bfc..d1b715350 100644 --- a/common/lib/plugins/stale_dns/stale_dns_helper.ts +++ b/common/lib/plugins/stale_dns/stale_dns_helper.ts @@ -14,28 +14,26 @@ limitations under the License. */ -import { logger } from "../../../logutils"; +import { levels, logger } from "../../../logutils"; import { HostInfo } from "../../host_info"; import { HostListProviderService } from "../../host_list_provider_service"; import { HostRole } from "../../host_role"; import { PluginService } from "../../plugin_service"; import { Messages } from "../../utils/messages"; import { RdsUtils } from "../../utils/rds_utils"; -import { lookup, LookupAddress } from "dns"; -import { promisify } from "util"; -import { AwsWrapperError } from "../../utils/errors"; import { HostChangeOptions } from "../../host_change_options"; import { WrapperProperties } from "../../wrapper_property"; import { ClientWrapper } from "../../client_wrapper"; -import { getWriter, logTopology } from "../../utils/utils"; +import { containsHostAndPort, getWriter, logTopology } from "../../utils/utils"; import { TelemetryFactory } from "../../utils/telemetry/telemetry_factory"; import { TelemetryCounter } from "../../utils/telemetry/telemetry_counter"; +import { RdsUrlType } from "../../utils/rds_url_type"; +import { AwsWrapperError } from "../../utils/errors"; export class StaleDnsHelper { private readonly pluginService: PluginService; private readonly rdsUtils: RdsUtils = new RdsUtils(); private writerHostInfo: HostInfo | null = null; - private writerHostAddress: string = ""; private readonly telemetryFactory: TelemetryFactory; private readonly staleDNSDetectedCounter: TelemetryCounter; @@ -53,39 +51,44 @@ export class StaleDnsHelper { props: Map, connectFunc: () => Promise ): Promise { - if (!this.rdsUtils.isWriterClusterDns(host)) { - return connectFunc(); - } + const type: RdsUrlType = this.rdsUtils.identifyRdsType(host); - const currentTargetClient = await connectFunc(); - - let clusterInetAddress = ""; - try { - const lookupResult = await this.lookupResult(host); - clusterInetAddress = lookupResult.address; - } catch (error) { - // ignore + if (type !== RdsUrlType.RDS_WRITER_CLUSTER && type !== RdsUrlType.RDS_GLOBAL_WRITER_CLUSTER) { + return connectFunc(); } - const hostInetAddress = clusterInetAddress; - logger.debug(Messages.get("StaleDnsHelper.clusterEndpointDns", hostInetAddress)); - - if (!clusterInetAddress) { - return currentTargetClient; + if (type === RdsUrlType.RDS_WRITER_CLUSTER) { + const writer = getWriter(this.pluginService.getAllHosts()); + if (writer != null && this.rdsUtils.isRdsInstance(writer.host)) { + if ( + isInitialConnection && + WrapperProperties.SKIP_INACTIVE_WRITER_CLUSTER_CHECK.get(props) && + !this.rdsUtils.isSameRegion(writer.host, host) + ) { + // The cluster writer endpoint belongs to a different region than the current writer region. + // It means that the cluster is Aurora Global Database and cluster writer endpoint is in secondary region. + // In this case the cluster writer endpoint is in inactive state and doesn't represent the current writer + // so any connection check should be skipped. + // Continue with a normal workflow. + return connectFunc(); + } + } else { + // No writer is available. It could be the case with the first connection when topology isn't yet available. + // Continue with a normal workflow. + return connectFunc(); + } } - const currentHostInfo = this.pluginService.getCurrentHostInfo(); - if (!currentHostInfo) { - throw new AwsWrapperError("Stale DNS Helper: Current hostInfo was null."); - } + const currentTargetClient = await connectFunc(); - if (currentHostInfo && currentHostInfo.role === HostRole.READER) { + const isConnectedToReader: boolean = (await this.pluginService.getHostRole(currentTargetClient)) === HostRole.READER; + if (isConnectedToReader) { // This is if-statement is only reached if the connection url is a writer cluster endpoint. // If the new connection resolves to a reader instance, this means the topology is outdated. // Force refresh to update the topology. - await this.pluginService.forceRefreshHostList(currentTargetClient); + await this.pluginService.forceRefreshHostList(); } else { - await this.pluginService.refreshHostList(currentTargetClient); + await this.pluginService.refreshHostList(); } logger.debug(logTopology(this.pluginService.getAllHosts(), "[StaleDnsHelper.getVerifiedConnection] ")); @@ -104,27 +107,18 @@ export class StaleDnsHelper { return currentTargetClient; } - if (!this.writerHostAddress) { - try { - const lookupResult = await this.lookupResult(this.writerHostInfo.host); - this.writerHostAddress = lookupResult.address; - } catch (error) { - // ignore - } - } - - logger.debug(Messages.get("StaleDnsHelper.writerInetAddress", this.writerHostAddress)); - - if (!this.writerHostAddress) { - return currentTargetClient; - } + if (isConnectedToReader) { + // Reconnect to writer host if current connection is reader. - if (this.writerHostAddress !== clusterInetAddress) { - // DNS resolves a cluster endpoint to a wrong writer - // opens a connection to a proper writer host logger.debug(Messages.get("StaleDnsHelper.staleDnsDetected", this.writerHostInfo.host)); this.staleDNSDetectedCounter.inc(); + const allowedHosts: HostInfo[] = this.pluginService.getHosts(); + + if (!containsHostAndPort(allowedHosts, this.writerHostInfo.hostAndPort)) { + throw new AwsWrapperError(Messages.get("StaleDnsHelper.currentWriterNotAllowed", this.writerHostInfo.host, logTopology(allowedHosts, ""))); + } + let targetClient = null; try { const newProps = new Map(props); @@ -149,7 +143,7 @@ export class StaleDnsHelper { } for (const [key, values] of changes.entries()) { - if (logger.level === "debug") { + if (levels[logger.level] <= levels.debug) { const valStr = Array.from(values) .map((x) => HostChangeOptions[x]) .join(", "); @@ -159,14 +153,9 @@ export class StaleDnsHelper { if (key === this.writerHostInfo.url && values.has(HostChangeOptions.PROMOTED_TO_READER)) { logger.debug(Messages.get("StaleDnsHelper.reset")); this.writerHostInfo = null; - this.writerHostAddress = ""; } } } return Promise.resolve(); } - - lookupResult(host: string): Promise { - return promisify(lookup)(host, {}); - } } diff --git a/common/lib/plugins/stale_dns/stale_dns_plugin_factory.ts b/common/lib/plugins/stale_dns/stale_dns_plugin_factory.ts index 8669e59d4..2c6d67403 100644 --- a/common/lib/plugins/stale_dns/stale_dns_plugin_factory.ts +++ b/common/lib/plugins/stale_dns/stale_dns_plugin_factory.ts @@ -15,20 +15,20 @@ */ import { ConnectionPluginFactory } from "../../plugin_factory"; -import { PluginService } from "../../plugin_service"; import { ConnectionPlugin } from "../../connection_plugin"; import { AwsWrapperError } from "../../utils/errors"; import { Messages } from "../../utils/messages"; +import { FullServicesContainer } from "../../utils/full_services_container"; export class StaleDnsPluginFactory extends ConnectionPluginFactory { private static staleDnsPlugin: any; - async getInstance(pluginService: PluginService, properties: Map): Promise { + async getInstance(servicesContainer: FullServicesContainer, properties: Map): Promise { try { if (!StaleDnsPluginFactory.staleDnsPlugin) { StaleDnsPluginFactory.staleDnsPlugin = await import("./stale_dns_plugin"); } - return new StaleDnsPluginFactory.staleDnsPlugin.StaleDnsPlugin(pluginService, properties); + return new StaleDnsPluginFactory.staleDnsPlugin.StaleDnsPlugin(servicesContainer.pluginService, properties); } catch (error: any) { throw new AwsWrapperError(Messages.get("ConnectionPluginChainBuilder.errorImportingPlugin", error.message, "StaleDnsPlugin")); } diff --git a/common/lib/plugins/strategy/fastest_response/fastest_respose_strategy_plugin_factory.ts b/common/lib/plugins/strategy/fastest_response/fastest_respose_strategy_plugin_factory.ts index 05146d74b..9518fdfd9 100644 --- a/common/lib/plugins/strategy/fastest_response/fastest_respose_strategy_plugin_factory.ts +++ b/common/lib/plugins/strategy/fastest_response/fastest_respose_strategy_plugin_factory.ts @@ -15,20 +15,23 @@ */ import { ConnectionPluginFactory } from "../../../plugin_factory"; -import { PluginService } from "../../../plugin_service"; import { ConnectionPlugin } from "../../../connection_plugin"; import { AwsWrapperError } from "../../../utils/errors"; import { Messages } from "../../../utils/messages"; +import { FullServicesContainer } from "../../../utils/full_services_container"; export class FastestResponseStrategyPluginFactory extends ConnectionPluginFactory { private static fastestResponseStrategyPlugin: any; - async getInstance(pluginService: PluginService, properties: Map): Promise { + async getInstance(servicesContainer: FullServicesContainer, properties: Map): Promise { try { if (!FastestResponseStrategyPluginFactory.fastestResponseStrategyPlugin) { FastestResponseStrategyPluginFactory.fastestResponseStrategyPlugin = await import("./fastest_response_strategy_plugin"); } - return new FastestResponseStrategyPluginFactory.fastestResponseStrategyPlugin.FastestResponseStrategyPlugin(pluginService, properties); + return new FastestResponseStrategyPluginFactory.fastestResponseStrategyPlugin.FastestResponseStrategyPlugin( + servicesContainer.pluginService, + properties + ); } catch (error: any) { throw new AwsWrapperError( Messages.get("ConnectionPluginChainBuilder.errorImportingPlugin", error.message, "FastestResponseStrategyPluginFactory") diff --git a/common/lib/random_host_selector.ts b/common/lib/random_host_selector.ts index d38d985c2..5d35f0ebf 100644 --- a/common/lib/random_host_selector.ts +++ b/common/lib/random_host_selector.ts @@ -25,7 +25,9 @@ export class RandomHostSelector implements HostSelector { public static STRATEGY_NAME = "random"; getHost(hosts: HostInfo[], role: HostRole, props?: Map): HostInfo { - const eligibleHosts = hosts.filter((hostInfo: HostInfo) => hostInfo.role === role && hostInfo.getAvailability() === HostAvailability.AVAILABLE); + const eligibleHosts = hosts.filter( + (hostInfo: HostInfo) => (role === null || hostInfo.role === role) && hostInfo.getAvailability() === HostAvailability.AVAILABLE + ); if (eligibleHosts.length === 0) { throw new AwsWrapperError(Messages.get("HostSelector.noHostsMatchingRole", role)); } diff --git a/common/lib/round_robin_host_selector.ts b/common/lib/round_robin_host_selector.ts index bc53a251c..d6fcd6216 100644 --- a/common/lib/round_robin_host_selector.ts +++ b/common/lib/round_robin_host_selector.ts @@ -31,7 +31,7 @@ export class RoundRobinHostSelector implements HostSelector { getHost(hosts: HostInfo[], role: HostRole, props?: Map): HostInfo { const eligibleHosts: HostInfo[] = hosts - .filter((host: HostInfo) => host.role === role && host.availability === HostAvailability.AVAILABLE) + .filter((host: HostInfo) => (role === null || host.role === role) && host.availability === HostAvailability.AVAILABLE) .sort((hostA: HostInfo, hostB: HostInfo) => { const hostAHostName = hostA.host.toLowerCase(); const hostBHostName = hostB.host.toLowerCase(); diff --git a/common/lib/session_state_client.ts b/common/lib/session_state_client.ts index 0816fe4df..cd660158e 100644 --- a/common/lib/session_state_client.ts +++ b/common/lib/session_state_client.ts @@ -1,12 +1,12 @@ /* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - + Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance with the License. You may obtain a copy of the License at - + http://www.apache.org/licenses/LICENSE-2.0 - + Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -19,11 +19,11 @@ import { TransactionIsolationLevel } from "./utils/transaction_isolation_level"; export interface SessionStateClient { setReadOnly(readOnly: boolean): Promise; - isReadOnly(): boolean; + isReadOnly(): boolean | undefined; setAutoCommit(autoCommit: boolean): Promise; - getAutoCommit(): boolean; + getAutoCommit(): boolean | undefined; setTransactionIsolation(level: TransactionIsolationLevel): Promise; @@ -31,9 +31,9 @@ export interface SessionStateClient { setSchema(schema: any): Promise; - getSchema(): string; + getSchema(): string | undefined; setCatalog(catalog: string): Promise; - getCatalog(): string; + getCatalog(): string | undefined; } diff --git a/common/lib/types.ts b/common/lib/types.ts index 3332b0c94..a89e6af4b 100644 --- a/common/lib/types.ts +++ b/common/lib/types.ts @@ -1,18 +1,18 @@ /* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ + Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"). + You may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ /** * Type representing a constructor for any class. diff --git a/common/lib/utils/cache_map.ts b/common/lib/utils/cache_map.ts index 19485421f..69650202d 100644 --- a/common/lib/utils/cache_map.ts +++ b/common/lib/utils/cache_map.ts @@ -17,7 +17,7 @@ import { getTimeInNanos } from "./utils"; export class CacheItem { - private readonly item: V; + readonly item: V; private _expirationTimeNs: bigint; constructor(item: V, expirationTime: bigint) { diff --git a/common/lib/utils/core_services_container.ts b/common/lib/utils/core_services_container.ts index 8ae9f27f3..c3bed54d5 100644 --- a/common/lib/utils/core_services_container.ts +++ b/common/lib/utils/core_services_container.ts @@ -1,20 +1,23 @@ /* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ + Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"). + You may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ import { StorageService, StorageServiceImpl } from "./storage/storage_service"; +import { MonitorService, MonitorServiceImpl } from "./monitoring/monitor_service"; +import { EventPublisher } from "./events/event"; +import { BatchingEventPublisher } from "./events/batching_event_publisher"; /** * A singleton container object used to instantiate and access core universal services. This class should be used @@ -26,27 +29,25 @@ import { StorageService, StorageServiceImpl } from "./storage/storage_service"; export class CoreServicesContainer { private static readonly INSTANCE = new CoreServicesContainer(); - // private readonly monitorService: MonitorService; // TODO: implement monitor service - private readonly storageService: StorageService; + readonly monitorService: MonitorService; + readonly storageService: StorageService; + readonly eventPublisher: EventPublisher; private constructor() { - this.storageService = new StorageServiceImpl(); - // this.monitorService = new MonitorServiceImpl(); + this.eventPublisher = new BatchingEventPublisher(); + this.storageService = new StorageServiceImpl(this.eventPublisher); + this.monitorService = new MonitorServiceImpl(this.eventPublisher); } static getInstance(): CoreServicesContainer { return CoreServicesContainer.INSTANCE; } - getStorageService(): StorageService { - return this.storageService; - } - - // getMonitorService(): MonitorService { - // return this.monitorService; - // } - - static releaseResources(): void { - CoreServicesContainer.INSTANCE.storageService.releaseResources(); + static async releaseResources(): Promise { + await CoreServicesContainer.INSTANCE.storageService.releaseResources(); + await CoreServicesContainer.INSTANCE.monitorService.releaseResources(); + if (CoreServicesContainer.INSTANCE.eventPublisher instanceof BatchingEventPublisher) { + CoreServicesContainer.INSTANCE.eventPublisher.releaseResources(); + } } } diff --git a/common/lib/utils/errors.ts b/common/lib/utils/errors.ts index 4d854c0c5..247bbba0a 100644 --- a/common/lib/utils/errors.ts +++ b/common/lib/utils/errors.ts @@ -50,6 +50,8 @@ export class TransactionResolutionUnknownError extends FailoverError {} export class LoginError extends AwsWrapperError {} -export class InternalQueryTimeoutError extends AwsWrapperError {} +export class AwsTimeoutError extends AwsWrapperError {} + +export class InternalQueryTimeoutError extends AwsTimeoutError {} export class UnavailableHostError extends AwsWrapperError {} diff --git a/common/lib/utils/events/batching_event_publisher.ts b/common/lib/utils/events/batching_event_publisher.ts new file mode 100644 index 000000000..750afdeef --- /dev/null +++ b/common/lib/utils/events/batching_event_publisher.ts @@ -0,0 +1,99 @@ +/* + Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"). + You may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +import { Event, EventClass, EventPublisher, EventSubscriber } from "./event"; +import { Messages } from "../messages"; +import { logger } from "../../../logutils"; + +const DEFAULT_MESSAGE_INTERVAL_MS = 30_000; // 30 seconds + +/** + * An event publisher that periodically publishes a batch of all unique events + * encountered during the latest time interval. + */ +export class BatchingEventPublisher implements EventPublisher { + protected readonly subscribersMap = new Map>(); + protected readonly pendingEvents = new Set(); + protected publishingInterval?: ReturnType; + + constructor(messageIntervalMs: number = DEFAULT_MESSAGE_INTERVAL_MS) { + this.initPublishingInterval(messageIntervalMs); + } + + protected initPublishingInterval(messageIntervalMs: number): void { + this.publishingInterval = setInterval(() => this.sendMessages(), messageIntervalMs); + // Unref the timer to prevent this background task from blocking the application from gracefully exiting. + this.publishingInterval.unref(); + } + + protected async sendMessages(): Promise { + for (const event of this.pendingEvents) { + this.pendingEvents.delete(event); + await this.deliverEvent(event); + } + } + + protected async deliverEvent(event: Event): Promise { + const subscribers = this.subscribersMap.get(event.constructor as EventClass); + if (!subscribers) { + return; + } + + for (const subscriber of subscribers) { + await subscriber.processEvent(event); + } + } + + subscribe(subscriber: EventSubscriber, eventClasses: Set): void { + for (const eventClass of eventClasses) { + let subscribers = this.subscribersMap.get(eventClass); + if (!subscribers) { + subscribers = new Set(); + this.subscribersMap.set(eventClass, subscribers); + } + subscribers.add(subscriber); + } + } + + unsubscribe(subscriber: EventSubscriber, eventClasses: Set): void { + for (const eventClass of eventClasses) { + const subscribers = this.subscribersMap.get(eventClass); + if (subscribers) { + subscribers.delete(subscriber); + if (subscribers.size === 0) { + this.subscribersMap.delete(eventClass); + } + } + } + } + + publish(event: Event): void { + if (event.isImmediateDelivery) { + this.deliverEvent(event).catch((err) => { + logger.debug(Messages.get("BatchingEventPublisher.errorDeliveringImmediateEvent", err?.message ?? String(err))); + }); + } else { + this.pendingEvents.add(event); + } + } + + releaseResources(): void { + if (this.publishingInterval) { + clearInterval(this.publishingInterval); + this.publishingInterval = undefined; + } + } +} diff --git a/common/lib/utils/events/data_access_event.ts b/common/lib/utils/events/data_access_event.ts new file mode 100644 index 000000000..bb58a168d --- /dev/null +++ b/common/lib/utils/events/data_access_event.ts @@ -0,0 +1,35 @@ +/* + Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"). + You may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +import { Event } from "./event"; + +/** + * A class defining a data access event. The class specifies the class of the data + * that was accessed and the key for the data. + * + * Used by StorageService to notify MonitorService when data is accessed, + * allowing monitors to extend their expiration time. + */ +export class DataAccessEvent implements Event { + readonly isImmediateDelivery = false; + readonly dataClass: new (...args: any[]) => any; + readonly key: unknown; + + constructor(dataClass: new (...args: any[]) => any, key: unknown) { + this.dataClass = dataClass; + this.key = key; + } +} diff --git a/common/lib/utils/events/event.ts b/common/lib/utils/events/event.ts new file mode 100644 index 000000000..8a69c643f --- /dev/null +++ b/common/lib/utils/events/event.ts @@ -0,0 +1,64 @@ +/* + Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"). + You may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +import { Constructor } from "../../types"; + +export type EventClass = Constructor; + +/** + * An interface for events that need to be communicated between different components. + */ +export interface Event { + readonly isImmediateDelivery: boolean; +} + +/** + * An event subscriber. Subscribers can subscribe to a publisher's events. + */ +export interface EventSubscriber { + /** + * Processes an event. This method will only be called on this subscriber + * if it has subscribed to the event class. + * @param event the event to process. + */ + processEvent(event: Event): Promise; +} + +/** + * An event publisher that publishes events to subscribers. + * Subscribers can specify which types of events they would like to receive. + */ +export interface EventPublisher { + /** + * Registers the given subscriber for the given event classes. + * @param subscriber the subscriber to be notified when the given event classes occur. + * @param eventClasses the classes of events that the subscriber should be notified of. + */ + subscribe(subscriber: EventSubscriber, eventClasses: Set): void; + + /** + * Unsubscribes the given subscriber from the given event classes. + * @param subscriber the subscriber to unsubscribe from the given event classes. + * @param eventClasses the classes of events that the subscriber wants to unsubscribe from. + */ + unsubscribe(subscriber: EventSubscriber, eventClasses: Set): void; + + /** + * Publishes an event. All subscribers to the given event class will be notified of the event. + * @param event the event to publish. + */ + publish(event: Event): void; +} diff --git a/common/lib/utils/events/monitor_reset_event.ts b/common/lib/utils/events/monitor_reset_event.ts new file mode 100644 index 000000000..2acf3bf9d --- /dev/null +++ b/common/lib/utils/events/monitor_reset_event.ts @@ -0,0 +1,32 @@ +/* + Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"). + You may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +import { Event } from "./event"; + +/** + * Event indicating that a monitor should be reset with new endpoints. + * Used by ClusterTopologyMonitorImpl to reset monitoring when cluster topology changes. + */ +export class MonitorResetEvent implements Event { + readonly isImmediateDelivery = true; + readonly clusterId: string; + readonly endpoints: Set; + + constructor(clusterId: string, endpoints: Set) { + this.clusterId = clusterId; + this.endpoints = endpoints; + } +} diff --git a/common/lib/utils/events/monitor_stop_event.ts b/common/lib/utils/events/monitor_stop_event.ts new file mode 100644 index 000000000..029fa35ed --- /dev/null +++ b/common/lib/utils/events/monitor_stop_event.ts @@ -0,0 +1,33 @@ +/* + Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"). + You may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +import { Event } from "./event"; +import { Monitor } from "../monitoring/monitor"; + +/** + * Event indicating that a monitor should be stopped. + * Used by MonitorService to stop and remove monitors. + */ +export class MonitorStopEvent implements Event { + readonly isImmediateDelivery = true; + readonly monitorClass: new (...args: any[]) => Monitor; + readonly key: unknown; + + constructor(monitorClass: new (...args: any[]) => Monitor, key: unknown) { + this.monitorClass = monitorClass; + this.key = key; + } +} diff --git a/common/lib/utils/full_services_container.ts b/common/lib/utils/full_services_container.ts new file mode 100644 index 000000000..59902f882 --- /dev/null +++ b/common/lib/utils/full_services_container.ts @@ -0,0 +1,67 @@ +/* + Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"). + You may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +import { PluginService } from "../plugin_service"; +import { HostListProviderService } from "../host_list_provider_service"; +import { PluginManager } from "../index"; +import { ConnectionProvider } from "../connection_provider"; +import { TelemetryFactory } from "./telemetry/telemetry_factory"; +import { StorageService } from "./storage/storage_service"; +import { MonitorService } from "./monitoring/monitor_service"; +import { EventPublisher } from "./events/event"; +import { ImportantEventService } from "./important_event_service"; + +/** + * Container for services used throughout the wrapper. + */ +export interface FullServicesContainer { + storageService: StorageService; + monitorService: MonitorService; + eventPublisher: EventPublisher; + readonly defaultConnectionProvider: ConnectionProvider; + telemetryFactory: TelemetryFactory; + pluginManager: PluginManager; + hostListProviderService: HostListProviderService; + pluginService: PluginService; + importantEventService: ImportantEventService; +} + +export class FullServicesContainerImpl implements FullServicesContainer { + storageService: StorageService; + monitorService: MonitorService; + eventPublisher: EventPublisher; + readonly defaultConnectionProvider: ConnectionProvider; + telemetryFactory: TelemetryFactory; + pluginManager!: PluginManager; + hostListProviderService!: HostListProviderService; + pluginService!: PluginService; + importantEventService: ImportantEventService; + + constructor( + storageService: StorageService, + monitorService: MonitorService, + eventPublisher: EventPublisher, + defaultConnProvider: ConnectionProvider, + telemetryFactory: TelemetryFactory + ) { + this.storageService = storageService; + this.monitorService = monitorService; + this.eventPublisher = eventPublisher; + this.defaultConnectionProvider = defaultConnProvider; + this.telemetryFactory = telemetryFactory; + this.importantEventService = new ImportantEventService(); + } +} diff --git a/common/lib/utils/gdb_region_utils.ts b/common/lib/utils/gdb_region_utils.ts index 6b1fa5a4e..815cc64ee 100644 --- a/common/lib/utils/gdb_region_utils.ts +++ b/common/lib/utils/gdb_region_utils.ts @@ -66,10 +66,8 @@ export class GDBRegionUtils extends RegionUtils { const response = await rdsClient.send(command); return this.extractWriterClusterArn(response.GlobalClusters); } catch (error) { - if (error instanceof Error) { - logger.debug(Messages.get("GDBRegionUtils.unableToRetrieveGlobalClusterARN")); - throw new AwsWrapperError(Messages.get("GDBRegionUtils.unableToRetrieveGlobalClusterARN")); - } + logger.debug(Messages.get("GDBRegionUtils.unableToRetrieveGlobalClusterARN")); + throw new AwsWrapperError(Messages.get("GDBRegionUtils.unableToRetrieveGlobalClusterARN")); } finally { rdsClient.destroy(); } diff --git a/common/lib/utils/important_event_service.ts b/common/lib/utils/important_event_service.ts new file mode 100644 index 000000000..5ae19fd18 --- /dev/null +++ b/common/lib/utils/important_event_service.ts @@ -0,0 +1,88 @@ +/* + Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"). + You may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +export class ImportantEvent { + readonly timestamp: Date; + readonly description: string; + + constructor(timestamp: Date, description: string) { + this.timestamp = timestamp; + this.description = description; + } +} + +export class ImportantEventService { + private static readonly DEFAULT_EVENT_QUEUE_MS = 60000; + + private readonly events: ImportantEvent[] = []; + private readonly eventQueueMs: number; + private readonly isEnabled: boolean; + + constructor(isEnabled: boolean = true, eventQueueMs: number = ImportantEventService.DEFAULT_EVENT_QUEUE_MS) { + this.isEnabled = isEnabled; + this.eventQueueMs = eventQueueMs; + } + + clear(): void { + this.events.length = 0; + } + + registerEvent(descriptionSupplier: () => string): void { + if (!this.isEnabled) { + return; + } + + this.removeExpiredEvents(); + + this.events.push(new ImportantEvent(new Date(), descriptionSupplier())); + } + + getEvents(): ImportantEvent[] { + if (!this.isEnabled) { + return []; + } + + this.removeExpiredEvents(); + return [...this.events]; + } + + private removeExpiredEvents(): void { + if (!this.isEnabled || this.events.length === 0) { + return; + } + + const current = Date.now(); + const cutoffTime = current - this.eventQueueMs; + + while (this.events.length > 0 && this.events[0].timestamp.getTime() <= cutoffTime) { + this.events.shift(); + } + } +} + +export class DriverImportantEventService { + private static readonly INSTANCE = new ImportantEventService(true, 60000); + + private constructor() {} + + static getInstance(): ImportantEventService { + return DriverImportantEventService.INSTANCE; + } + + static clear(): void { + DriverImportantEventService.INSTANCE.clear(); + } +} diff --git a/common/lib/utils/messages.ts b/common/lib/utils/messages.ts index 1f46e6249..353c00ded 100644 --- a/common/lib/utils/messages.ts +++ b/common/lib/utils/messages.ts @@ -89,7 +89,11 @@ const MESSAGES: Record = { "Failover.unableToConnectToWriter": "Unable to establish SQL connection to the writer instance.", "Failover.unableToConnectToWriterDueToError": "Unable to establish SQL connection to the writer instance: %s due to error: %s.", "Failover.unableToConnectToReader": "Unable to establish SQL connection to the reader instance.", + "Failover.unableToRefreshHostList": "The request to discover the new topology timed out or was unsuccessful.", "Failover.unableToDetermineWriter": "Unable to determine the current writer instance.", + "Failover.unexpectedReaderRole": "The new writer was identified to be '%s', but querying the instance for its role returned a role of %s.", + "Failover.strictReaderUnknownHostRole": + "Unable to determine host role for '%s'. Since failover mode is set to STRICT_READER and the host may be a writer, it will not be selected for reader failover.", "Failover.detectedError": "[Failover] Detected an error while executing a command: %s", "Failover.failoverDisabled": "Cluster-aware failover is disabled.", "Failover.establishedConnection": "[Failover] Connected to %s", @@ -99,6 +103,7 @@ const MESSAGES: Record = { "Failover.noOperationsAfterConnectionClosed": "No operations allowed after client ended.", "Failover.transactionResolutionUnknownError": "Unknown transaction resolution error occurred during failover.", "Failover.connectionExplicitlyClosed": "Unable to failover on an explicitly closed connection.", + "Failover.failoverReaderTimeout": "The reader failover process was not able to establish a connection before timing out.", "Failover.timeoutError": "Internal failover task has timed out.", "Failover.newWriterNotAllowed": "The failover process identified the new writer but the host is not in the list of allowed hosts. New writer host: '%s'. Allowed hosts: '%s'.", @@ -109,12 +114,9 @@ const MESSAGES: Record = { "StaleDnsHelper.staleDnsDetected": "Stale DNS data detected. Opening a connection to '%s'.", "StaleDnsHelper.reset": "Reset stored writer host.", "StaleDnsPlugin.requireDynamicProvider": "Dynamic host list provider is required.", + "StaleDnsHelper.currentWriterNotAllowed": "The current writer is not in the list of allowed hosts. Current host: '%s'. Allowed hosts: %s", "Client.methodNotSupported": "Method '%s' not supported.", "Client.invalidTransactionIsolationLevel": "An invalid transaction isolation level was provided: '%s'.", - "AuroraStaleDnsHelper.clusterEndpointDns": "Cluster endpoint resolves to '%s'.", - "AuroraStaleDnsHelper.writerHostSpec": "Writer host: '%s'.", - "AuroraStaleDnsHelper.writerInetAddress": "Writer host address: '%s'", - "AuroraStaleDnsHelper.staleDnsDetected": "Stale DNS data detected. Opening a connection to '%s'.", "ReadWriteSplittingPlugin.setReadOnlyOnClosedClient": "setReadOnly cannot be called on a closed client '%s'.", "ReadWriteSplittingPlugin.errorSwitchingToCachedReader": "An error occurred while trying to switch to a cached reader client: '%s'. Error message: '%s'. The driver will attempt to establish a new reader client.", @@ -138,7 +140,8 @@ const MESSAGES: Record = { "ReadWriteSplittingPlugin.failoverErrorWhileExecutingCommand": "Detected a failover error while executing a command: '%s'", "ReadWriteSplittingPlugin.noReadersAvailable": "The plugin was unable to establish a reader client to any reader instance.", "ReadWriteSplittingPlugin.successfullyConnectedToReader": "Successfully connected to a new reader host: '%s'", - "ReadWriteSplittingPlugin.previousReaderNotAllowed": "The previous reader connection cannot be used because it is no longer in the list of allowed hosts. Previous reader: %s. Allowed hosts: %s", + "ReadWriteSplittingPlugin.previousReaderNotAllowed": + "The previous reader connection cannot be used because it is no longer in the list of allowed hosts. Previous reader: %s. Allowed hosts: %s", "ReadWriteSplittingPlugin.failedToConnectToReader": "Failed to connect to reader host: '%s'", "ReadWriteSplittingPlugin.unsupportedHostSelectorStrategy": "Unsupported host selection strategy '%s' specified in plugin configuration parameter 'readerHostSelectorStrategy'. Please visit the Read/Write Splitting Plugin documentation for all supported strategies.", @@ -190,19 +193,31 @@ const MESSAGES: Record = { "MonitorImpl.stopMonitoringTaskNewContext": "Stop monitoring task for checking new contexts for '%s'", "MonitorService.startMonitoringNullMonitor": "Start monitoring called but could not find monitor for host: '%s'.", "MonitorService.emptyAliasSet": "Empty alias set passed for '%s'. Set should not be empty.", + "MonitorService.monitorClassMismatch": + "The monitor stored at '%s' did not have the expected type. The expected type was '%s', but the monitor '%s' had a type of '%s'.", + "MonitorService.monitorStuck": "Monitor '%s' has not been updated within the inactive timeout of %s milliseconds. The monitor will be stopped.", + "MonitorService.monitorTypeNotRegistered": + "The given monitor class '%s' is not registered. Please register the monitor class before running monitors of that class with the monitor service.", + "MonitorService.recreatingMonitor": "Recreating monitor: '%s'.", + "MonitorService.removedErrorMonitor": "Removed monitor in error state: '%s'.", + "MonitorService.removedExpiredMonitor": "Removed expired monitor: '%s'.", + "MonitorService.stopAndRemoveMissingMonitorType": + "The monitor service received a request to stop a monitor with type '%s' and key '%s', but the monitor service does not have any monitors registered under the given type. Please ensure monitors are registered under the correct type.", + "MonitorService.stopAndRemoveMonitorsMissingType": + "The monitor service received a request to stop all monitors with type '%s', but the monitor service does not have any monitors registered under the given type. Please ensure monitors are registered under the correct type.", + "MonitorService.cleanupTaskInterrupted": "Monitor service cleanup task interrupted.", "PluginService.hostListEmpty": "Current host list is empty.", "PluginService.releaseResources": "Releasing resources.", - "PluginService.hostsChangeListEmpty": "There are no changes in the hosts' availability.", "PluginService.failedToRetrieveHostPort": "Could not retrieve Host:Port for connection.", "PluginService.nonEmptyAliases": "fillAliases called when HostInfo already contains the following aliases: '%s'.", "PluginService.forceMonitoringRefreshTimeout": "A timeout error occurred after waiting '%s' ms for refreshed topology.", "PluginService.requiredBlockingHostListProvider": "The detected host list provider is not a BlockingHostListProvider. A BlockingHostListProvider is required to force refresh the host list. Detected host list provider: '%s'.", + "PluginService.requiredDynamicHostListProvider": + "The forceMonitoringRefresh method requires a DynamicHostListProvider. The current host list provider '%s' does not support this operation.", "PluginService.currentHostNotAllowed": "The current host is not in the list of allowed hosts. Current host: '%s'. Allowed hosts: '%s'.", "PluginService.currentHostNotDefined": "The current host is undefined.", - "MonitoringHostListProvider.requiresMonitor": - "The MonitoringRdsHostListProvider could not retrieve or initialize a ClusterTopologyMonitor for refreshing the topology.", - "MonitoringHostListProvider.errorForceRefresh": "The MonitoringRdsHostListProvider could not refresh the topology, caught error: '%s'", + "PartialPluginService.unexpectedMethodCall": "Unexpected method call: '%s'. This method is not supported by PartialPluginService.", "HostMonitoringConnectionPlugin.activatedMonitoring": "Executing method '%s', monitoring is activated.", "HostMonitoringConnectionPlugin.unableToIdentifyConnection": "Unable to identify the given connection: '%s', please ensure the correct host list provider is specified. The host list provider in use is: '%s'.", @@ -295,14 +310,20 @@ const MESSAGES: Record = { "An error occurred while attempting to obtain the writer id because the query was invalid. Please ensure you are connecting to an Aurora or RDS DB cluster. Error: '%s'", "ClusterTopologyMonitor.unableToConnect": "Could not connect to initial host: '%s'.", "ClusterTopologyMonitor.openedMonitoringConnection": "Opened monitoring connection to: '%s'.", - "ClusterTopologyMonitor.startMonitoring": "Start cluster monitoring task.", + "ClusterTopologyMonitor.startMonitoring": "[clusterId: '%s'] Start cluster topology monitoring for '%s'.", + "ClusterTopologyMonitor.startingHostMonitoringTasks": "Starting host monitoring tasks.", + "ClusterTopologyMonitor.stopHostMonitoringTask": "Stop cluster topology monitoring task for '%s'.", "ClusterTopologyMonitor.errorDuringMonitoring": "Error thrown during cluster topology monitoring: '%s'.", "ClusterTopologyMonitor.endMonitoring": "Stop cluster topology monitoring.", + "ClusterTopologyMonitor.matchingReaderTopologies": "Reader topologies have been consistent for '%s' ms. Updating topology cache.", + "ClusterTopologyMonitor.reset": "[clusterId: '%s'] Resetting cluster topology monitor for '%s'.", + "ClusterTopologyMonitor.resetEventReceived": "MonitorResetEvent received.", "HostMonitor.startMonitoring": "Host monitor '%s' started.", - "HostMonitor.detectedWriter": "Detected writer: '%s' - '%s'.", - "HostMonitor.endMonitoring": "Host monitor '%s' completed in '%s'.", + "HostMonitor.detectedWriter": "Detected writer: '%s'.", + "HostMonitor.endMonitoring": "Host monitor '%s' completed in '%s' ms.", "HostMonitor.writerHostChanged": "Writer host has changed from '%s' to '%s'.", "HostMonitor.writerIsStale": "Connected writer instance '%s' is stale.", + "HostMonitor.loginErrorDuringMonitoring": "Login error detected during monitoring.", "SlidingExpirationCacheWithCleanupTask.cleaningUp": "Cleanup interval of '%s' minutes has passed, cleaning up sliding expiration cache '%s'.", "SlidingExpirationCacheWithCleanupTask.cleanUpTaskInterrupted": "Sliding expiration cache '%s' cleanup task has been interrupted and is exiting.", "SlidingExpirationCacheWithCleanupTask.cleanUpTaskStopped": "Sliding expiration cache '%s' cleanup task has been stopped and is exiting.", @@ -384,7 +405,24 @@ const MESSAGES: Record = { "TopologyUtils.instanceIdRequired": "InstanceId must not be en empty string.", "TopologyUtils.errorGettingHostRole": "An error occurred while trying to get the host role.", "GlobalTopologyUtils.missingRegion": "Host '%s' is missing region information in the topology query result.", - "GlobalTopologyUtils.missingTemplateForRegion": "No cluster instance template found for region '%s' when processing host '%s'." + "GlobalTopologyUtils.missingTemplateForRegion": "No cluster instance template found for region '%s' when processing host '%s'.", + "Utils.globalClusterInstanceHostPatternsRequired": "The 'globalClusterInstanceHostPatterns' property is required for Global Aurora Databases.", + "Utils.invalidPatternFormat": + "Invalid pattern format '%s'. Expected format: 'region:host-pattern' (e.g., 'us-east-1:?.cluster-xyz.us-east-1.rds.amazonaws.com').", + "GlobalAuroraTopologyMonitor.cannotFindRegionTemplate": "Cannot find cluster instance template for region '%s'.", + "GlobalAuroraTopologyMonitor.invalidTopologyUtils": "TopologyUtils must implement GdbTopologyUtils for GlobalAuroraTopologyMonitor.", + "GlobalDbFailoverPlugin.missingHomeRegion": + "The 'failoverHomeRegion' property is required when connecting to a Global Aurora Database without a region in the URL.", + "GlobalDbFailoverPlugin.missingInitialHost": "Unable to determine the initial connection host.", + "GlobalDbFailoverPlugin.startFailover": "Starting Global DB failover procedure.", + "GlobalDbFailoverPlugin.isHomeRegion": "Is home region: %s", + "GlobalDbFailoverPlugin.currentFailoverMode": "Current Global DB failover mode: %s", + "GlobalDbFailoverPlugin.failoverElapsed": "Global DB failover elapsed time: %s ms", + "GlobalDbFailoverPlugin.unableToFindCandidateWithMatchingRole": + "Unable to find a candidate host with the expected role (%s) based on the given host selection strategy: %s", + "GlobalDbFailoverPlugin.unableToConnect": "Unable to establish a connection during Global DB failover.", + "BatchingEventPublisher.errorDeliveringImmediateEvent": "Error delivering immediate event: %s", + "WrapperProperty.invalidValue": "Invalid value '%s' for property '%s'. Allowed values: %s" }; export class Messages { diff --git a/common/lib/utils/monitoring/monitor.ts b/common/lib/utils/monitoring/monitor.ts new file mode 100644 index 000000000..81dadea43 --- /dev/null +++ b/common/lib/utils/monitoring/monitor.ts @@ -0,0 +1,128 @@ +/* + Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"). + You may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +import { FullServicesContainer } from "../full_services_container"; + +const DEFAULT_CLEANUP_INTERVAL_NANOS = BigInt(60_000_000_000); // 1 minute + +export enum MonitorState { + RUNNING, + STOPPED, + ERROR +} + +export enum MonitorErrorResponse { + NO_ACTION, + RECREATE +} + +export class MonitorSettings { + expirationTimeoutNanos: bigint; + inactiveTimeoutNanos: bigint; + errorResponses: Set; + + constructor(expirationTimeoutNanos: bigint, inactiveTimeoutNanos: bigint, errorResponses: Set) { + this.expirationTimeoutNanos = expirationTimeoutNanos; + this.inactiveTimeoutNanos = inactiveTimeoutNanos; + this.errorResponses = errorResponses; + } +} + +export interface Monitor { + start(): Promise; + + monitor(): Promise; + + stop(): Promise; + + close(): Promise; + + getLastActivityTimestampNanos(): bigint; + + getState(): MonitorState; + + canDispose(): boolean; +} + +export interface MonitorInitializer { + createMonitor(servicesContainer: FullServicesContainer): Monitor; +} + +export abstract class AbstractMonitor implements Monitor { + protected _stop = false; + protected terminationTimeoutMs: number; + protected lastActivityTimestampNanos: bigint; + protected state: MonitorState; + protected monitorPromise?: Promise; + + protected constructor(terminationTimeoutSec: number) { + this.terminationTimeoutMs = terminationTimeoutSec * 1000; + this.lastActivityTimestampNanos = BigInt(Date.now() * 1_000_000); + this.state = MonitorState.STOPPED; + } + + async start(): Promise { + this.monitorPromise = this.run(); + } + + protected async run(): Promise { + try { + this.state = MonitorState.RUNNING; + this.lastActivityTimestampNanos = BigInt(Date.now() * 1_000_000); + await this.monitor(); + } catch (error) { + this.state = MonitorState.ERROR; + } finally { + await this.close(); + } + } + + abstract monitor(): Promise; + + async stop(): Promise { + this._stop = true; + + if (this.monitorPromise) { + let timeoutId: ReturnType | undefined; + const timeout = new Promise((resolve) => { + timeoutId = setTimeout(resolve, this.terminationTimeoutMs); + }); + await Promise.race([this.monitorPromise, timeout]); + if (timeoutId !== undefined) { + clearTimeout(timeoutId); + } + } + + await this.close(); + this.state = MonitorState.STOPPED; + } + + async close(): Promise { + // Do nothing + } + + getLastActivityTimestampNanos(): bigint { + return this.lastActivityTimestampNanos; + } + + getState(): MonitorState { + return this.state; + } + + canDispose(): boolean { + return true; + } +} diff --git a/common/lib/utils/monitoring/monitor_service.ts b/common/lib/utils/monitoring/monitor_service.ts new file mode 100644 index 000000000..138b429b5 --- /dev/null +++ b/common/lib/utils/monitoring/monitor_service.ts @@ -0,0 +1,450 @@ +/* + Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"). + You may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +import { Monitor, MonitorErrorResponse, MonitorInitializer, MonitorSettings, MonitorState } from "./monitor"; +import { Constructor } from "../../types"; +import { FullServicesContainer } from "../full_services_container"; +import { logger } from "../../../logutils"; +import { Messages } from "../messages"; +import { AwsWrapperError } from "../errors"; +import { ClusterTopologyMonitorImpl } from "../../host_list_provider/monitoring/cluster_topology_monitor"; +import { Topology } from "../../host_list_provider/topology"; +import { Event, EventPublisher, EventSubscriber } from "../events/event"; +import { DataAccessEvent } from "../events/data_access_event"; +import { MonitorStopEvent } from "../events/monitor_stop_event"; +import { convertNanosToMs, getTimeInNanos, sleepWithAbort } from "../utils"; +import { CacheItem } from "../cache_map"; + +const DEFAULT_CLEANUP_INTERVAL_NS = BigInt(60_000_000_000); // 1 minute +const FIFTEEN_MINUTES_NS = BigInt(15 * 60 * 1_000_000_000); +const THREE_MINUTES_NS = BigInt(3 * 60 * 1_000_000_000); + +export interface MonitorService { + registerMonitorTypeIfAbsent( + monitorClass: Constructor, + expirationTimeoutNanos: bigint, + inactiveTimeoutNanos: bigint, + errorResponses: Set, + producedDataClass?: Constructor + ): void; + + runIfAbsent( + monitorClass: Constructor, + key: unknown, + servicesContainer: FullServicesContainer, + originalProps: Map, + initializer: MonitorInitializer + ): Promise; + + get(monitorClass: Constructor, key: unknown): T | null; + + remove(monitorClass: Constructor, key: unknown): T | null; + + stopAndRemove(monitorClass: Constructor, key: unknown): Promise; + + stopAndRemoveMonitors(monitorClass: Constructor): Promise; + + stopAndRemoveAll(): Promise; + + releaseResources(): Promise; +} + +/** + * A container object that holds a monitor together with the supplier used to generate the monitor. + * The supplier can be used to recreate the monitor if it encounters an error or becomes stuck. + */ +class MonitorItem { + private readonly monitorSupplier: () => Monitor; + private readonly _monitor: Monitor; + + constructor(monitorSupplier: () => Monitor) { + this.monitorSupplier = monitorSupplier; + this._monitor = monitorSupplier(); + } + + getMonitorSupplier(): () => Monitor { + return this.monitorSupplier; + } + + getMonitor(): Monitor { + return this._monitor; + } +} + +/** + * A container that holds a cache of monitors of a given type with the related settings and info for that type. + */ +class CacheContainer { + private readonly settings: MonitorSettings; + private readonly cache: Map>; + private readonly producedDataClass: Constructor | null; + + constructor(settings: MonitorSettings, producedDataClass: Constructor | null) { + this.settings = settings; + this.producedDataClass = producedDataClass; + this.cache = new Map>(); + } + + getSettings(): MonitorSettings { + return this.settings; + } + + getCache(): Map> { + return this.cache; + } + + getProducedDataClass(): Constructor | null { + return this.producedDataClass; + } +} + +export class MonitorServiceImpl implements MonitorService, EventSubscriber { + private static defaultSuppliers: Map, () => CacheContainer> | null = null; + + // Lazy initialization for the default suppliers to avoid circular dependencies. + private static getDefaultSuppliers(): Map, () => CacheContainer> { + if (!MonitorServiceImpl.defaultSuppliers) { + const recreateOnError = new Set([MonitorErrorResponse.RECREATE]); + const defaultSettings = new MonitorSettings(FIFTEEN_MINUTES_NS, THREE_MINUTES_NS, recreateOnError); + + MonitorServiceImpl.defaultSuppliers = new Map([[ClusterTopologyMonitorImpl, () => new CacheContainer(defaultSettings, Topology)]]); + } + return MonitorServiceImpl.defaultSuppliers; + } + + protected readonly publisher: EventPublisher; + protected readonly monitorCaches = new Map, CacheContainer>(); + // Use a pending promise map to prevent race conditions when creating monitors. + private readonly pendingMonitors = new Map>(); + private cleanupTask: Promise | null = null; + private interruptCleanupTask: (() => void) | null = null; + private isInitialized: boolean = false; + + constructor(publisher: EventPublisher, cleanupIntervalNs: bigint = DEFAULT_CLEANUP_INTERVAL_NS) { + this.publisher = publisher; + this.publisher.subscribe(this, new Set([DataAccessEvent, MonitorStopEvent])); + this.initCleanupTask(cleanupIntervalNs); + } + + protected initCleanupTask(cleanupIntervalNs: bigint): void { + this.isInitialized = true; + this.cleanupTask = this.runCleanupLoop(cleanupIntervalNs); + } + + private async runCleanupLoop(cleanupIntervalNs: bigint): Promise { + while (this.isInitialized) { + const [sleepPromise, abortSleepFunc] = sleepWithAbort( + convertNanosToMs(cleanupIntervalNs), + Messages.get("MonitorService.cleanupTaskInterrupted") + ); + this.interruptCleanupTask = abortSleepFunc; + try { + await sleepPromise; + } catch { + // Sleep has been interrupted, exit cleanup task. + return; + } + + await this.checkMonitors(); + } + } + + protected async checkMonitors(): Promise { + for (const container of this.monitorCaches.values()) { + const cache = container.getCache(); + const keysToProcess = Array.from(cache.keys()); + + for (const key of keysToProcess) { + const cacheItem = cache.get(key); + if (!cacheItem) { + continue; + } + + const monitorItem = cacheItem.get(true); + if (!monitorItem) { + continue; + } + + const monitor = monitorItem.getMonitor(); + const monitorSettings = container.getSettings(); + + // Check for stopped monitors + if (monitor.getState() === MonitorState.STOPPED) { + cache.delete(key); + await monitor.stop(); + continue; + } + + // Check for error state monitors + if (monitor.getState() === MonitorState.ERROR) { + cache.delete(key); + logger.debug(Messages.get("MonitorService.removedErrorMonitor", JSON.stringify(monitor))); + await this.handleMonitorError(container, key, monitorItem); + continue; + } + + // Check for inactive/stuck monitors + const inactiveTimeoutNs = monitorSettings.inactiveTimeoutNanos; + if (getTimeInNanos() - monitor.getLastActivityTimestampNanos() > inactiveTimeoutNs) { + cache.delete(key); + logger.info(Messages.get("MonitorService.monitorStuck", JSON.stringify(monitor), convertNanosToMs(inactiveTimeoutNs).toString())); + await this.handleMonitorError(container, key, monitorItem); + continue; + } + + // Check for expired monitors that can be disposed + if (cacheItem.isExpired() && monitor.canDispose()) { + cache.delete(key); + logger.info(Messages.get("MonitorService.removedExpiredMonitor", JSON.stringify(monitor))); + await monitor.stop(); + } + } + } + } + + protected async handleMonitorError(cacheContainer: CacheContainer, key: unknown, errorMonitorItem: MonitorItem): Promise { + const monitor = errorMonitorItem.getMonitor(); + await monitor.stop(); + + const errorResponses = cacheContainer.getSettings().errorResponses; + if (errorResponses && errorResponses.has(MonitorErrorResponse.RECREATE)) { + if (!cacheContainer.getCache().has(key)) { + logger.info(Messages.get("MonitorService.recreatingMonitor", JSON.stringify(monitor))); + const newMonitorItem = new MonitorItem(errorMonitorItem.getMonitorSupplier()); + const expirationNs = cacheContainer.getSettings().expirationTimeoutNanos; + cacheContainer.getCache().set(key, new CacheItem(newMonitorItem, getTimeInNanos() + expirationNs)); + await newMonitorItem.getMonitor().start(); + } + } + } + + registerMonitorTypeIfAbsent( + monitorClass: Constructor, + expirationTimeoutNanos: bigint, + inactiveTimeoutNanos: bigint, + errorResponses: Set, + producedDataClass?: Constructor + ): void { + if (this.monitorCaches.has(monitorClass)) { + return; + } + + const settings = new MonitorSettings(expirationTimeoutNanos, inactiveTimeoutNanos, errorResponses); + const cacheContainer = new CacheContainer(settings, producedDataClass ?? null); + this.monitorCaches.set(monitorClass, cacheContainer); + } + + async runIfAbsent( + monitorClass: Constructor, + key: unknown, + servicesContainer: FullServicesContainer, + _originalProps: Map, + initializer: MonitorInitializer + ): Promise { + let cacheContainer = this.monitorCaches.get(monitorClass); + + if (!cacheContainer) { + const supplier = MonitorServiceImpl.getDefaultSuppliers().get(monitorClass as Constructor); + if (!supplier) { + throw new AwsWrapperError(Messages.get("MonitorService.monitorTypeNotRegistered", monitorClass.name)); + } + + cacheContainer = supplier(); + this.monitorCaches.set(monitorClass, cacheContainer); + } + + const cache = cacheContainer.getCache(); + const existingCacheItem = cache.get(key); + if (existingCacheItem) { + const existingMonitorItem = existingCacheItem.get(true); + if (existingMonitorItem) { + existingCacheItem.updateExpiration(cacheContainer.getSettings().expirationTimeoutNanos); + return existingMonitorItem.getMonitor() as T; + } + } + + const pendingKey = `${monitorClass.name}:${JSON.stringify(key)}`; + + // Check if the monitor is already being created by another async task. + const pendingPromise = this.pendingMonitors.get(pendingKey); + if (pendingPromise) { + return (await pendingPromise) as T; + } + + // Use the pending promise pattern to create monitors. This prevents race condition. + const createPromise = (async (): Promise => { + try { + const recheckCacheItem = cache.get(key); + if (recheckCacheItem) { + const recheckMonitorItem = recheckCacheItem.get(true); + if (recheckMonitorItem) { + recheckCacheItem.updateExpiration(cacheContainer.getSettings().expirationTimeoutNanos); + return recheckMonitorItem.getMonitor(); + } + } + + const monitorItem = new MonitorItem(() => initializer.createMonitor(servicesContainer)); + const expirationNs = cacheContainer.getSettings().expirationTimeoutNanos; + cache.set(key, new CacheItem(monitorItem, getTimeInNanos() + expirationNs)); + await monitorItem.getMonitor().start(); + + return monitorItem.getMonitor(); + } finally { + // Delete the key once monitor has been successfully created. + this.pendingMonitors.delete(pendingKey); + } + })(); + + this.pendingMonitors.set(pendingKey, createPromise); + return (await createPromise) as T; + } + + get(monitorClass: Constructor, key: unknown): T | null { + const cacheContainer = this.monitorCaches.get(monitorClass); + if (!cacheContainer) { + return null; + } + + const cacheItem = cacheContainer.getCache().get(key); + if (!cacheItem) { + return null; + } + + const monitorItem = cacheItem.get(true); + if (!monitorItem) { + return null; + } + + const monitor = monitorItem.getMonitor(); + if (monitor instanceof monitorClass) { + return monitor as T; + } + + logger.info(Messages.get("MonitorService.monitorClassMismatch", JSON.stringify(key), monitorClass.name, JSON.stringify(monitor))); + return null; + } + + remove(monitorClass: Constructor, key: unknown): T | null { + const cacheContainer = this.monitorCaches.get(monitorClass); + if (!cacheContainer) { + return null; + } + + const cache = cacheContainer.getCache(); + const cacheItem = cache.get(key); + if (!cacheItem) { + return null; + } + + const monitorItem = cacheItem.get(true); + if (!monitorItem) { + return null; + } + + const monitor = monitorItem.getMonitor(); + if (monitor instanceof monitorClass) { + cache.delete(key); + return monitor as T; + } + + return null; + } + + async stopAndRemove(monitorClass: Constructor, key: unknown): Promise { + const cacheContainer = this.monitorCaches.get(monitorClass); + if (!cacheContainer) { + logger.info(Messages.get("MonitorService.stopAndRemoveMissingMonitorType", monitorClass.name, String(key))); + return; + } + + const cache = cacheContainer.getCache(); + const cacheItem = cache.get(key); + if (cacheItem) { + cache.delete(key); + await cacheItem.get(true)?.getMonitor().stop(); + } + } + + async stopAndRemoveMonitors(monitorClass: Constructor): Promise { + const cacheContainer = this.monitorCaches.get(monitorClass); + if (!cacheContainer) { + logger.info(Messages.get("MonitorService.stopAndRemoveMonitorsMissingType", monitorClass.name)); + return; + } + + const cache = cacheContainer.getCache(); + for (const [key, cacheItem] of cache.entries()) { + cache.delete(key); + await cacheItem.get(true)?.getMonitor().stop(); + } + } + + async stopAndRemoveAll(): Promise { + for (const monitorClass of this.monitorCaches.keys()) { + await this.stopAndRemoveMonitors(monitorClass); + } + } + + async releaseResources(): Promise { + // Stop cleanup task + this.isInitialized = false; + this.interruptCleanupTask?.(); + if (this.cleanupTask) { + await this.cleanupTask; + } + + await this.stopAndRemoveAll(); + } + + async processEvent(event: Event): Promise { + if (event instanceof DataAccessEvent) { + for (const container of this.monitorCaches.values()) { + if (!container.getProducedDataClass() || event.dataClass !== container.getProducedDataClass()) { + continue; + } + + // The data produced by the monitor in this cache with this key has been accessed recently, + // so we extend the monitor's expiration. + container.getCache().get(event.key)?.updateExpiration(container.getSettings().expirationTimeoutNanos); + } + return; + } + + if (event instanceof MonitorStopEvent) { + await this.stopAndRemove(event.monitorClass, event.key); + return; + } + + // Other event types should be propagated to monitors + for (const container of this.monitorCaches.values()) { + for (const cacheItem of container.getCache().values()) { + const monitorItem = cacheItem.get(true); + if (!monitorItem) { + continue; + } + + const monitor = monitorItem.getMonitor(); + if (this.isEventSubscriber(monitor)) { + await (monitor as unknown as EventSubscriber).processEvent(event); + } + } + } + } + + private isEventSubscriber(obj: unknown): obj is EventSubscriber { + return typeof obj === "object" && obj !== null && "processEvent" in obj && typeof (obj as EventSubscriber).processEvent === "function"; + } +} diff --git a/common/lib/utils/rds_url_type.ts b/common/lib/utils/rds_url_type.ts index 46300354d..64955089a 100644 --- a/common/lib/utils/rds_url_type.ts +++ b/common/lib/utils/rds_url_type.ts @@ -20,6 +20,7 @@ export class RdsUrlType { public static readonly RDS_READER_CLUSTER = new RdsUrlType(true, true, true); public static readonly RDS_CUSTOM_CLUSTER = new RdsUrlType(true, false, true); public static readonly RDS_PROXY = new RdsUrlType(true, false, true); + public static readonly RDS_PROXY_ENDPOINT = new RdsUrlType(true, false, true); public static readonly RDS_INSTANCE = new RdsUrlType(true, false, true); public static readonly RDS_AURORA_LIMITLESS_DB_SHARD_GROUP = new RdsUrlType(true, false, true); public static readonly RDS_GLOBAL_WRITER_CLUSTER = new RdsUrlType(true, true, false); diff --git a/common/lib/utils/rds_utils.ts b/common/lib/utils/rds_utils.ts index 665433490..95c894d71 100644 --- a/common/lib/utils/rds_utils.ts +++ b/common/lib/utils/rds_utils.ts @@ -22,12 +22,13 @@ export class RdsUtils { // can be found at // https://docs.aws.amazon.com/AmazonRDS/latest/AuroraUserGuide/Aurora.Overview.Endpoints.html // - // Details how to use RDS Proxy endpoints can be found at + // Details how to use RDS Proxy endpoints can be found at // https://docs.aws.amazon.com/AmazonRDS/latest/AuroraUserGuide/rds-proxy-endpoints.html // - // Values like "<...>" depend on particular Aurora cluster. + // Values like "<...>" depend on particular Aurora cluster. // For example: "" // + // // Cluster (Writer) Endpoint: .cluster-..rds.amazonaws.com // Example: test-postgres.cluster-123456789012.us-east-2.rds.amazonaws.com // @@ -41,7 +42,10 @@ export class RdsUtils { // Example: test-postgres-instance-1.123456789012.us-east-2.rds.amazonaws.com // // + // // Similar endpoints for China regions have different structure and are presented below. + // https://docs.amazonaws.cn/en_us/aws/latest/userguide/endpoints-Ningxia.html + // https://docs.amazonaws.cn/en_us/aws/latest/userguide/endpoints-Beijing.html // // Cluster (Writer) Endpoint: .cluster-.rds..amazonaws.com.cn // Example: test-postgres.cluster-123456789012.rds.cn-northwest-1.amazonaws.com.cn @@ -59,52 +63,51 @@ export class RdsUtils { // Governmental endpoints // https://aws.amazon.com/compliance/fips/#FIPS_Endpoints_by_Service // https://docs.aws.amazon.com/AWSJavaSDK/latest/javadoc/com/amazonaws/services/s3/model/Region.html - + // + // + // Aurora Global Database // https://docs.aws.amazon.com/AmazonRDS/latest/AuroraUserGuide/Concepts.Aurora_Fea_Regions_DB-eng.Feature.GlobalDatabase.html + // Global Database Endpoint: .global-.global.rds.amazonaws.com + // Example: test-global-db-name.global-123456789012.global.rds.amazonaws.com + // + // + // RDS Proxy + // RDS Proxy Endpoint: .proxy-..rds.amazonaws.com + // Example: test-rds-proxy-name.proxy-123456789012.us-east-2.rds.amazonaws.com + // + // RDS Proxy Custom Endpoint: .endpoint.proxy-..rds.amazonaws.com + // Example: test-custom-endpoint-name.endpoint.proxy-123456789012.us-east-2.rds.amazonaws.com + private static readonly AURORA_GLOBAL_WRITER_DNS_PATTERN = /^(?.+)\.(?global-)?(?[a-zA-Z0-9]+\.global\.rds\.amazonaws\.com\.?)$/i; private static readonly AURORA_DNS_PATTERN = - /^(?.+)\.(?proxy-|cluster-|cluster-ro-|cluster-custom-|shardgrp-)?(?[a-zA-Z0-9]+\.(?[a-zA-Z0-9-]+)\.rds\.amazonaws\.com)$/i; - private static readonly AURORA_INSTANCE_PATTERN = /^(?.+)\.(?[a-zA-Z0-9]+\.(?[a-zA-Z0-9-]+)\.rds\.amazonaws\.com)$/i; + /^(?.+)\.(?proxy-|cluster-|cluster-ro-|cluster-custom-|shardgrp-)?(?[a-zA-Z0-9]+\.(?[a-zA-Z0-9-]+)\.(rds|rds-fips)\.amazonaws\.(com|au|eu|uk)\.?)$/i; private static readonly AURORA_CLUSTER_PATTERN = - /^(?.+)\.(?cluster-|cluster-ro-)+(?[a-zA-Z0-9]+\.(?[a-zA-Z0-9-]+)\.rds\.amazonaws\.com)$/i; - private static readonly AURORA_CUSTOM_CLUSTER_PATTERN = - /^(?.+)\.(?cluster-custom-)+(?[a-zA-Z0-9]+\.(?[a-zA-Z0-9-]+)\.rds\.amazonaws\.com)$/i; + /^(?.+)\.(?cluster-|cluster-ro-)+(?[a-zA-Z0-9]+\.(?[a-zA-Z0-9-]+)\.(rds|rds-fips)\.amazonaws\.(com|au|eu|uk)\.?)$/i; private static readonly AURORA_LIMITLESS_CLUSTER_PATTERN = - /^(?.+)\.(?shardgrp-)+(?[a-zA-Z0-9]+\.(?[a-zA-Z0-9-]+)\.rds\.(amazonaws\.com(\.cn)?|sc2s\.sgov\.gov|c2s\.ic\.gov))$/i; - private static readonly AURORA_PROXY_DNS_PATTERN = - /^(?.+)\.(?proxy-)+(?[a-zA-Z0-9]+\.(?[a-zA-Z0-9-]+)\.rds\.amazonaws\.com)$/i; + /^(?.+)\.(?shardgrp-)+(?[a-zA-Z0-9]+\.(?[a-zA-Z0-9-]+)\.(rds|rds-fips)\.(amazonaws\.com\.?|amazonaws\.eu\.?|amazonaws\.au\.?|amazonaws\.uk\.?|amazonaws\.com\.cn\.?|sc2s\.sgov\.gov\.?|c2s\.ic\.gov\.?))$/i; private static readonly AURORA_CHINA_DNS_PATTERN = - /^(?.+)\.(?proxy-|cluster-|cluster-ro-|cluster-custom-|shardgrp-)?(?[a-zA-Z0-9]+\.rds\.(?[a-zA-Z0-9-]+)\.amazonaws\.com\.cn)$/i; + /^(?.+)\.(?proxy-|cluster-|cluster-ro-|cluster-custom-|shardgrp-)?(?[a-zA-Z0-9]+\.(rds|rds-fips)\.(?[a-zA-Z0-9-]+)\.amazonaws\.com\.cn\.?)$/i; private static readonly AURORA_OLD_CHINA_DNS_PATTERN = - /^(?.+)\.(?proxy-|cluster-|cluster-ro-|cluster-custom-|shardgrp-)?(?[a-zA-Z0-9]+\.(?[a-zA-Z0-9-]+)\.rds\.amazonaws\.com\.cn)$/i; - private static readonly AURORA_CHINA_INSTANCE_PATTERN = - /^(?.+)\.(?[a-zA-Z0-9]+\.rds\.(?[a-zA-Z0-9-]+)\.amazonaws\.com\.cn)$/i; - private static readonly AURORA_OLD_CHINA_INSTANCE_PATTERN = - /^(?.+)\.(?[a-zA-Z0-9]+\.(?[a-zA-Z0-9-]+)\.rds\.amazonaws\.com\.cn)$/i; + /^(?.+)\.(?proxy-|cluster-|cluster-ro-|cluster-custom-|shardgrp-)?(?[a-zA-Z0-9]+\.(?[a-zA-Z0-9-]+)\.(rds|rds-fips)\.amazonaws\.com\.cn\.?)$/i; private static readonly AURORA_CHINA_CLUSTER_PATTERN = - /^(?.+)\.(?cluster-|cluster-ro-)+(?[a-zA-Z0-9]+\.rds\.(?[a-zA-Z0-9-]+)\.amazonaws\.com\.cn)$/i; - private static readonly AURORA_CHINA_LIMITLESS_CLUSTER_PATTERN = - /^(?.+)\.(?shardgrp-)?(?[a-zA-Z0-9]+\.rds\.(?[a-zA-Z0-9-]+)\.amazonaws\.com\.cn)$/i; + /^(?.+)\.(?cluster-|cluster-ro-)+(?[a-zA-Z0-9]+\.(rds|rds-fips)\.(?[a-zA-Z0-9-]+)\.amazonaws\.com\.cn\.?)$/i; private static readonly AURORA_OLD_CHINA_CLUSTER_PATTERN = - /^(?.+)\.(?cluster-|cluster-ro-)+(?[a-zA-Z0-9]+\.(?[a-zA-Z0-9-]+)\.rds\.amazonaws\.com\.cn)$/i; - private static readonly AURORA_OLD_CHINA_LIMITLESS_CLUSTER_PATTERN = - /^(?.+)\.(?shardgrp-)?(?[a-zA-Z0-9]+\.(?[a-zA-Z0-9-]+)\.rds\.amazonaws\.com\.cn)$/i; - private static readonly AURORA_CHINA_CUSTOM_CLUSTER_PATTERN = - /^(?.+)\.(?cluster-custom-)+(?[a-zA-Z0-9]+\.rds\.(?[a-zA-Z0-9-]+)\.amazonaws\.com\.cn)$/i; - private static readonly AURORA_OLD_CHINA_CUSTOM_CLUSTER_PATTERN = - /^(?.+)\.(?cluster-custom-)+(?[a-zA-Z0-9]+\.(?[a-zA-Z0-9-]+)\.rds\.amazonaws\.com\.cn)$/i; - private static readonly AURORA_CHINA_PROXY_DNS_PATTERN = - /^(?.+)\.(?proxy-)+(?[a-zA-Z0-9]+\.rds\.(?[a-zA-Z0-9-])+\.amazonaws\.com\.cn)$/i; - private static readonly AURORA_OLD_CHINA_PROXY_DNS_PATTERN = - /^(?.+)\.(?proxy-)+(?[a-zA-Z0-9]+\.(?[a-zA-Z0-9-])+\.rds\.amazonaws\.com\.cn)$/i; - + /^(?.+)\.(?cluster-|cluster-ro-)+(?[a-zA-Z0-9]+\.(?[a-zA-Z0-9-]+)\.(rds|rds-fips)\.amazonaws\.com\.cn\.?)$/i; private static readonly AURORA_GOV_DNS_PATTERN = - /^(?.+)\.(?proxy-|cluster-|cluster-ro-|cluster-custom-|shardgrp-)?(?[a-zA-Z0-9]+\.rds\.(?[a-zA-Z0-9-]+)\.(amazonaws\.com|c2s\.ic\.gov|sc2s\.sgov\.gov))$/i; + /^(?.+)\.(?proxy-|cluster-|cluster-ro-|cluster-custom-|shardgrp-)?(?[a-zA-Z0-9]+\.(rds|rds-fips)\.(?[a-zA-Z0-9-]+)\.(amazonaws\.com\.?|c2s\.ic\.gov\.?|sc2s\.sgov\.gov\.?))$/i; private static readonly AURORA_GOV_CLUSTER_PATTERN = - /^(?.+)\.(?cluster-|cluster-ro-)+(?[a-zA-Z0-9]+\.rds\.(?[a-zA-Z0-9-]+)\.(amazonaws\.com|c2s\.ic\.gov|sc2s\.sgov\.gov))$/i; + /^(?.+)\.(?cluster-|cluster-ro-)+(?[a-zA-Z0-9]+\.(rds|rds-fips)\.(?[a-zA-Z0-9-]+)\.(amazonaws\.com\.?|c2s\.ic\.gov\.?|sc2s\.sgov\.gov\.?))$/i; + + // RDS Proxy Custom Endpoint: .endpoint.proxy-..rds.amazonaws.com + private static readonly RDS_PROXY_ENDPOINT_DNS_PATTERN = + /^(?.+)\.endpoint\.(?proxy-)?(?[a-zA-Z0-9]+\.(?[a-zA-Z0-9-]+)\.rds\.amazonaws\.com\.?)$/i; + private static readonly RDS_PROXY_ENDPOINT_CHINA_DNS_PATTERN = + /^(?.+)\.endpoint\.(?proxy-)+(?[a-zA-Z0-9]+\.rds\.(?[a-zA-Z0-9-]+)\.amazonaws\.com\.cn\.?)$/i; + private static readonly RDS_PROXY_ENDPOINT_OLD_CHINA_DNS_PATTERN = + /^(?.+)\.endpoint\.(?proxy-)?(?[a-zA-Z0-9]+\.(?[a-zA-Z0-9-]+)\.rds\.amazonaws\.com\.cn\.?)$/i; private static readonly ELB_PATTERN = /^(?.+)\.elb\.((?[a-zA-Z0-9-]+)\.amazonaws\.com)$/i; private static readonly IP_V4 = @@ -121,20 +124,24 @@ export class RdsUtils { private static readonly cachedPatterns = new Map(); private static readonly cachedDnsPatterns = new Map(); + private static prepareHostFunc?: (host: string) => string; public isRdsClusterDns(host: string): boolean { - const dnsGroup = this.getDnsGroup(host); + const preparedHost = RdsUtils.getPreparedHost(host); + const dnsGroup = this.getDnsGroup(preparedHost); return equalsIgnoreCase(dnsGroup, "cluster-") || equalsIgnoreCase(dnsGroup, "cluster-ro-"); } public isRdsCustomClusterDns(host: string): boolean { - const dnsGroup = this.getDnsGroup(host); + const preparedHost = RdsUtils.getPreparedHost(host); + const dnsGroup = this.getDnsGroup(preparedHost); return equalsIgnoreCase(dnsGroup, "cluster-custom-"); } public isRdsDns(host: string): boolean { + const preparedHost = RdsUtils.getPreparedHost(host); const matcher = this.cacheMatcher( - host, + preparedHost, RdsUtils.AURORA_DNS_PATTERN, RdsUtils.AURORA_CHINA_DNS_PATTERN, RdsUtils.AURORA_OLD_CHINA_DNS_PATTERN, @@ -143,24 +150,46 @@ export class RdsUtils { const group = this.getRegexGroup(matcher, RdsUtils.DNS_GROUP); if (group) { - RdsUtils.cachedDnsPatterns.set(host, group); + RdsUtils.cachedDnsPatterns.set(preparedHost, group); } return matcher != null; } public isRdsInstance(host: string): boolean { - return !this.getDnsGroup(host) && this.isRdsDns(host); + const preparedHost = RdsUtils.getPreparedHost(host); + return !this.getDnsGroup(preparedHost) && this.isRdsDns(preparedHost); } isRdsProxyDns(host: string) { - const dnsGroup = this.getDnsGroup(host); + const preparedHost = RdsUtils.getPreparedHost(host); + const dnsGroup = this.getDnsGroup(preparedHost); return dnsGroup && dnsGroup.startsWith("proxy-"); } + isRdsProxyEndpointDns(host: string): boolean { + if (!host) { + return false; + } + + const preparedHost = RdsUtils.getPreparedHost(host); + const matcher = this.cacheMatcher( + preparedHost, + RdsUtils.RDS_PROXY_ENDPOINT_DNS_PATTERN, + RdsUtils.RDS_PROXY_ENDPOINT_CHINA_DNS_PATTERN, + RdsUtils.RDS_PROXY_ENDPOINT_OLD_CHINA_DNS_PATTERN + ); + if (this.getRegexGroup(matcher, RdsUtils.DNS_GROUP) !== null) { + return this.getRegexGroup(matcher, RdsUtils.INSTANCE_GROUP) !== null; + } + + return false; + } + getRdsClusterId(host: string): string | null { + const preparedHost = RdsUtils.getPreparedHost(host); const matcher = this.cacheMatcher( - host, + preparedHost, RdsUtils.AURORA_DNS_PATTERN, RdsUtils.AURORA_CHINA_DNS_PATTERN, RdsUtils.AURORA_OLD_CHINA_DNS_PATTERN, @@ -179,8 +208,9 @@ export class RdsUtils { return null; } + const preparedHost = RdsUtils.getPreparedHost(host); const matcher = this.cacheMatcher( - host, + preparedHost, RdsUtils.AURORA_DNS_PATTERN, RdsUtils.AURORA_CHINA_DNS_PATTERN, RdsUtils.AURORA_OLD_CHINA_DNS_PATTERN, @@ -198,8 +228,9 @@ export class RdsUtils { return "?"; } + const preparedHost = RdsUtils.getPreparedHost(host); const matcher = this.cacheMatcher( - host, + preparedHost, RdsUtils.AURORA_DNS_PATTERN, RdsUtils.AURORA_CHINA_DNS_PATTERN, RdsUtils.AURORA_OLD_CHINA_DNS_PATTERN, @@ -214,8 +245,9 @@ export class RdsUtils { return null; } + const preparedHost = RdsUtils.getPreparedHost(host); const matcher = this.cacheMatcher( - host, + preparedHost, RdsUtils.AURORA_DNS_PATTERN, RdsUtils.AURORA_CHINA_DNS_PATTERN, RdsUtils.AURORA_OLD_CHINA_DNS_PATTERN, @@ -227,7 +259,7 @@ export class RdsUtils { return group; } - const elbMatcher = host.match(RdsUtils.ELB_PATTERN); + const elbMatcher = preparedHost.match(RdsUtils.ELB_PATTERN); if (elbMatcher && elbMatcher.length > 0) { return this.getRegexGroup(elbMatcher, RdsUtils.REGION_GROUP); } @@ -235,23 +267,36 @@ export class RdsUtils { return null; } + public isSameRegion(host1: string | null, host2: string | null): boolean { + if (!host1 || !host2) { + return false; + } + const host1Region = this.getRdsRegion(host1); + const host2Region = this.getRdsRegion(host2); + return host1Region !== null && equalsIgnoreCase(host1Region, host2Region); + } + public isGlobalDbWriterClusterDns(host: string): boolean { - const dnsGroup = this.getDnsGroup(host); + const preparedHost = RdsUtils.getPreparedHost(host); + const dnsGroup = this.getDnsGroup(preparedHost); return equalsIgnoreCase(dnsGroup, "global-"); } public isWriterClusterDns(host: string): boolean { - const dnsGroup = this.getDnsGroup(host); + const preparedHost = RdsUtils.getPreparedHost(host); + const dnsGroup = this.getDnsGroup(preparedHost); return equalsIgnoreCase(dnsGroup, "cluster-"); } public isReaderClusterDns(host: string): boolean { - const dnsGroup = this.getDnsGroup(host); + const preparedHost = RdsUtils.getPreparedHost(host); + const dnsGroup = this.getDnsGroup(preparedHost); return equalsIgnoreCase(dnsGroup, "cluster-ro-"); } public isLimitlessDbShardGroupDns(host: string): boolean { - const dnsGroup = this.getDnsGroup(host); + const preparedHost = RdsUtils.getPreparedHost(host); + const dnsGroup = this.getDnsGroup(preparedHost); if (!dnsGroup) { return false; } @@ -263,25 +308,26 @@ export class RdsUtils { return null; } - const matcher = host.match(RdsUtils.AURORA_CLUSTER_PATTERN); + const preparedHost = RdsUtils.getPreparedHost(host); + const matcher = preparedHost.match(RdsUtils.AURORA_CLUSTER_PATTERN); if (matcher) { - return host.replace(RdsUtils.AURORA_CLUSTER_PATTERN, "$.cluster-$"); + return preparedHost.replace(RdsUtils.AURORA_CLUSTER_PATTERN, "$.cluster-$"); } - const limitlessMatcher = host.match(RdsUtils.AURORA_LIMITLESS_CLUSTER_PATTERN); + const limitlessMatcher = preparedHost.match(RdsUtils.AURORA_LIMITLESS_CLUSTER_PATTERN); if (limitlessMatcher) { - return host.replace(RdsUtils.AURORA_LIMITLESS_CLUSTER_PATTERN, "$.cluster-$"); + return preparedHost.replace(RdsUtils.AURORA_LIMITLESS_CLUSTER_PATTERN, "$.cluster-$"); } - const chinaMatcher = host.match(RdsUtils.AURORA_CHINA_CLUSTER_PATTERN); + const chinaMatcher = preparedHost.match(RdsUtils.AURORA_CHINA_CLUSTER_PATTERN); if (chinaMatcher) { - return host.replace(RdsUtils.AURORA_CHINA_CLUSTER_PATTERN, "$.cluster-$"); + return preparedHost.replace(RdsUtils.AURORA_CHINA_CLUSTER_PATTERN, "$.cluster-$"); } - const oldChinaMatcher = host.match(RdsUtils.AURORA_OLD_CHINA_CLUSTER_PATTERN); + const oldChinaMatcher = preparedHost.match(RdsUtils.AURORA_OLD_CHINA_CLUSTER_PATTERN); if (oldChinaMatcher) { - return host.replace(RdsUtils.AURORA_OLD_CHINA_CLUSTER_PATTERN, "$.cluster-$"); + return preparedHost.replace(RdsUtils.AURORA_OLD_CHINA_CLUSTER_PATTERN, "$.cluster-$"); } - const govMatcher = host.match(RdsUtils.AURORA_GOV_CLUSTER_PATTERN); + const govMatcher = preparedHost.match(RdsUtils.AURORA_GOV_CLUSTER_PATTERN); if (govMatcher) { - return host.replace(RdsUtils.AURORA_GOV_CLUSTER_PATTERN, "$.cluster-$"); + return preparedHost.replace(RdsUtils.AURORA_GOV_CLUSTER_PATTERN, "$.cluster-$"); } return null; } @@ -307,21 +353,24 @@ export class RdsUtils { return RdsUrlType.OTHER; } - if (this.isIPv4(host) || this.isIPv6(host)) { + const preparedHost = RdsUtils.getPreparedHost(host); + if (this.isIPv4(preparedHost) || this.isIPv6(preparedHost)) { return RdsUrlType.IP_ADDRESS; - } else if (this.isGlobalDbWriterClusterDns(host)) { + } else if (this.isGlobalDbWriterClusterDns(preparedHost)) { return RdsUrlType.RDS_GLOBAL_WRITER_CLUSTER; - } else if (this.isWriterClusterDns(host)) { + } else if (this.isWriterClusterDns(preparedHost)) { return RdsUrlType.RDS_WRITER_CLUSTER; - } else if (this.isReaderClusterDns(host)) { + } else if (this.isReaderClusterDns(preparedHost)) { return RdsUrlType.RDS_READER_CLUSTER; - } else if (this.isRdsCustomClusterDns(host)) { + } else if (this.isRdsCustomClusterDns(preparedHost)) { return RdsUrlType.RDS_CUSTOM_CLUSTER; - } else if (this.isLimitlessDbShardGroupDns(host)) { + } else if (this.isLimitlessDbShardGroupDns(preparedHost)) { return RdsUrlType.RDS_AURORA_LIMITLESS_DB_SHARD_GROUP; - } else if (this.isRdsProxyDns(host)) { + } else if (this.isRdsProxyDns(preparedHost)) { return RdsUrlType.RDS_PROXY; - } else if (this.isRdsDns(host)) { + } else if (this.isRdsProxyEndpointDns(preparedHost)) { + return RdsUrlType.RDS_PROXY_ENDPOINT; + } else if (this.isRdsDns(preparedHost)) { return RdsUrlType.RDS_INSTANCE; } else { // ELB URLs will also be classified as other @@ -330,23 +379,27 @@ export class RdsUtils { } public isGreenInstance(host: string) { - return host && RdsUtils.BG_GREEN_HOST_PATTERN.test(host); + const preparedHost = RdsUtils.getPreparedHost(host); + return preparedHost && RdsUtils.BG_GREEN_HOST_PATTERN.test(preparedHost); } public isOldInstance(host: string): boolean { - return !!host && RdsUtils.BG_OLD_HOST_PATTERN.test(host); + const preparedHost = RdsUtils.getPreparedHost(host); + return !!preparedHost && RdsUtils.BG_OLD_HOST_PATTERN.test(preparedHost); } public isNotOldInstance(host: string): boolean { if (!host) { return true; } - return !RdsUtils.BG_OLD_HOST_PATTERN.test(host); + const preparedHost = RdsUtils.getPreparedHost(host); + return !RdsUtils.BG_OLD_HOST_PATTERN.test(preparedHost); } // Verify that provided host is a blue host name and contains neither green prefix nor old prefix. public isNotGreenAndOldPrefixInstance(host: string): boolean { - return !!host && !RdsUtils.BG_GREEN_HOST_PATTERN.test(host) && !RdsUtils.BG_OLD_HOST_PATTERN.test(host); + const preparedHost = RdsUtils.getPreparedHost(host); + return !!preparedHost && !RdsUtils.BG_GREEN_HOST_PATTERN.test(preparedHost) && !RdsUtils.BG_OLD_HOST_PATTERN.test(preparedHost); } public removeGreenInstancePrefix(host: string): string { @@ -354,7 +407,8 @@ export class RdsUtils { return host; } - const matcher = host.match(RdsUtils.BG_GREEN_HOST_PATTERN); + const preparedHost = RdsUtils.getPreparedHost(host); + const matcher = preparedHost.match(RdsUtils.BG_GREEN_HOST_PATTERN); if (!matcher || matcher.length === 0) { return host; } @@ -427,4 +481,20 @@ export class RdsUtils { RdsUtils.cachedPatterns.clear(); RdsUtils.cachedDnsPatterns.clear(); } + + static setPrepareHostFunc(func?: (host: string) => string) { + RdsUtils.prepareHostFunc = func; + } + + static resetPrepareHostFunc() { + RdsUtils.prepareHostFunc = undefined; + } + + private static getPreparedHost(host: string): string { + const func = RdsUtils.prepareHostFunc; + if (!func) { + return host; + } + return func(host) ?? host; + } } diff --git a/common/lib/utils/service_utils.ts b/common/lib/utils/service_utils.ts new file mode 100644 index 000000000..b7bda542f --- /dev/null +++ b/common/lib/utils/service_utils.ts @@ -0,0 +1,124 @@ +/* + Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"). + You may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +import { FullServicesContainer, FullServicesContainerImpl } from "./full_services_container"; +import { StorageService } from "./storage/storage_service"; +import { PluginServiceImpl } from "../plugin_service"; +import { PluginManager } from "../plugin_manager"; +import { ConnectionProviderManager } from "../connection_provider_manager"; +import { DriverConnectionProvider } from "../driver_connection_provider"; +import { WrapperProperties } from "../wrapper_property"; +import { ConnectionProvider } from "../connection_provider"; +import { AwsClient } from "../aws_client"; +import { DatabaseDialect, DatabaseType } from "../database_dialect/database_dialect"; +import { DatabaseDialectCodes } from "../database_dialect/database_dialect_codes"; +import { DriverDialect } from "../driver_dialect/driver_dialect"; +import { MonitorService } from "./monitoring/monitor_service"; +import { TelemetryFactory } from "./telemetry/telemetry_factory"; +import { EventPublisher } from "./events/event"; +import { PartialPluginService } from "../partial_plugin_service"; +import { ConnectionUrlParser } from "./connection_url_parser"; + +export class ServiceUtils { + private static readonly _instance: ServiceUtils = new ServiceUtils(); + + static get instance(): ServiceUtils { + return this._instance; + } + + createStandardServiceContainer( + storageService: StorageService, + monitorService: MonitorService, + eventPublisher: EventPublisher, + client: AwsClient, + props: Map, + dbType: DatabaseType, + knownDialectsByCode: Map, + driverDialect: DriverDialect, + telemetryFactory: TelemetryFactory, + connectionProvider: ConnectionProvider | null + ): FullServicesContainer { + const servicesContainer: FullServicesContainer = new FullServicesContainerImpl( + storageService, + monitorService, + eventPublisher, + connectionProvider, + telemetryFactory + ); + + const pluginService = new PluginServiceImpl(servicesContainer, client, dbType, knownDialectsByCode, props, driverDialect); + const pluginManager = new PluginManager( + servicesContainer, + props, + new ConnectionProviderManager(connectionProvider ?? new DriverConnectionProvider(), WrapperProperties.CONNECTION_PROVIDER.get(props)), + telemetryFactory + ); + + servicesContainer.pluginService = pluginService; + servicesContainer.pluginManager = pluginManager; + servicesContainer.hostListProviderService = pluginService; + + return servicesContainer; + } + + createMinimalServiceContainer( + storageService: StorageService, + monitorService: MonitorService, + eventPublisher: EventPublisher, + props: Map, + dialect: DatabaseDialect, + driverDialect: DriverDialect, + telemetryFactory: TelemetryFactory, + connectionProvider: ConnectionProvider | null, + connectionUrlParser: ConnectionUrlParser + ): FullServicesContainer { + const servicesContainer: FullServicesContainer = new FullServicesContainerImpl( + storageService, + monitorService, + eventPublisher, + connectionProvider, + telemetryFactory + ); + + const pluginService = new PartialPluginService(servicesContainer, props, dialect, driverDialect, connectionUrlParser); + const pluginManager = new PluginManager( + servicesContainer, + props, + new ConnectionProviderManager(connectionProvider ?? new DriverConnectionProvider(), WrapperProperties.CONNECTION_PROVIDER.get(props)), + telemetryFactory + ); + + servicesContainer.pluginService = pluginService; + servicesContainer.pluginManager = pluginManager; + servicesContainer.hostListProviderService = pluginService; + + return servicesContainer; + } + + createMinimalServiceContainerFrom(servicesContainer: FullServicesContainer, props: Map): FullServicesContainer { + return this.createMinimalServiceContainer( + servicesContainer.storageService, + servicesContainer.monitorService, + servicesContainer.eventPublisher, + props, + servicesContainer.pluginService.getDialect(), + servicesContainer.pluginService.getDriverDialect(), + servicesContainer.telemetryFactory, + servicesContainer.defaultConnectionProvider, + servicesContainer.pluginService.getConnectionUrlParser() + ); + } +} diff --git a/common/lib/utils/status_cache_item.ts b/common/lib/utils/status_cache_item.ts new file mode 100644 index 000000000..6e7489654 --- /dev/null +++ b/common/lib/utils/status_cache_item.ts @@ -0,0 +1,23 @@ +/* + Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"). + You may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +export class StatusCacheItem { + readonly status: T; + + constructor(status: T) { + this.status = status; + } +} diff --git a/common/lib/utils/storage/expiration_cache.ts b/common/lib/utils/storage/expiration_cache.ts index 5afd57e9f..72e3b7d1c 100644 --- a/common/lib/utils/storage/expiration_cache.ts +++ b/common/lib/utils/storage/expiration_cache.ts @@ -1,18 +1,18 @@ /* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ + Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"). + You may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ import { getTimeInNanos } from "../utils"; import { ItemDisposalFunc, ShouldDisposeFunc } from "../../types"; diff --git a/common/lib/utils/storage/storage_service.ts b/common/lib/utils/storage/storage_service.ts index 06f6ca00a..dc2bd71d0 100644 --- a/common/lib/utils/storage/storage_service.ts +++ b/common/lib/utils/storage/storage_service.ts @@ -1,26 +1,33 @@ /* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ + Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"). + You may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ import { Constructor, ItemDisposalFunc, ShouldDisposeFunc } from "../../types"; import { ExpirationCache } from "./expiration_cache"; import { Topology } from "../../host_list_provider/topology"; import { AwsWrapperError } from "../errors"; import { Messages } from "../messages"; +import { EventPublisher } from "../events/event"; +import { DataAccessEvent } from "../events/data_access_event"; +import { AllowedAndBlockedHosts } from "../../allowed_and_blocked_hosts"; +import { BlueGreenStatus } from "../../plugins/bluegreen/blue_green_status"; +import { HostAvailabilityCacheItem } from "../../host_availability/host_availability_cache_item"; +import { StatusCacheItem } from "../status_cache_item"; const DEFAULT_CLEANUP_INTERVAL_NANOS = 5 * 60 * 1_000_000_000; // 5 minutes +const SIXTY_MINUTES_NANOS = BigInt(60 * 60 * 1_000_000_000); // 60 minutes /** * Interface for a storage service that manages items with expiration and disposal logic. @@ -62,9 +69,10 @@ export interface StorageService { * * @param itemClass The expected constructor/class of the item being retrieved * @param key The key for the item, e.g., "custom-endpoint.cluster-custom-XYZ.us-east-2.rds.amazonaws.com:5432" + * @param registerDataAccess Whether to register a data access event. Defaults to true. * @returns The item stored at the given key for the given item class, or null/undefined if not found */ - get(itemClass: Constructor, key?: unknown): V | null; + get(itemClass: Constructor, key?: unknown, registerDataAccess?: boolean): V | null; /** * Indicates whether an item exists under the given item class and key. @@ -108,31 +116,44 @@ export interface StorageService { * Cleanup method to stop the cleanup interval timer. * Should be called when the service is no longer needed. */ - releaseResources(): void; + releaseResources(): Promise; } type CacheSupplier = () => ExpirationCache; export class StorageServiceImpl implements StorageService { - private static readonly defaultCacheSuppliers: Map = new Map([[Topology, () => new ExpirationCache()]]); + private static readonly DEFAULT_HOST_AVAILABILITY_CACHE_EXPIRE_NANO = BigInt(5 * 60_000_000_000); // 5 minutes + + private static readonly defaultCacheSuppliers: Map = (() => { + const suppliers = new Map(); + suppliers.set(Topology, () => new ExpirationCache()); + suppliers.set(AllowedAndBlockedHosts, () => new ExpirationCache()); + suppliers.set(BlueGreenStatus, () => new ExpirationCache(false, SIXTY_MINUTES_NANOS, null, null)); + suppliers.set( + HostAvailabilityCacheItem, + () => new ExpirationCache(true, StorageServiceImpl.DEFAULT_HOST_AVAILABILITY_CACHE_EXPIRE_NANO, null, null) + ); + suppliers.set(StatusCacheItem, () => new ExpirationCache(false, SIXTY_MINUTES_NANOS, null, null)); + return suppliers; + })(); protected readonly caches: Map> = new Map(); protected cleanupIntervalHandle?: NodeJS.Timeout; + protected readonly publisher: EventPublisher; - constructor(cleanupIntervalNanos: number = DEFAULT_CLEANUP_INTERVAL_NANOS) { - this.initCleanupThread(cleanupIntervalNanos); + constructor(publisher: EventPublisher, cleanupIntervalNanos: number = DEFAULT_CLEANUP_INTERVAL_NANOS) { + this.publisher = publisher; + this.initCleanupTask(cleanupIntervalNanos); } - protected initCleanupThread(cleanupIntervalNanos: number): void { + protected initCleanupTask(cleanupIntervalNanos: number): void { const intervalMs = cleanupIntervalNanos / 1_000_000; this.cleanupIntervalHandle = setInterval(() => { this.removeExpiredItems(); }, intervalMs); - // Allow Node.js to exit even if this timer is active - if (this.cleanupIntervalHandle.unref) { - this.cleanupIntervalHandle.unref(); - } + // Unref the timer to prevent this background cleanup task from blocking the application from gracefully exiting. + this.cleanupIntervalHandle.unref(); } protected removeExpiredItems(): void { @@ -179,7 +200,7 @@ export class StorageServiceImpl implements StorageService { } } - get(itemClass: Constructor, key?: unknown): V | null { + get(itemClass: Constructor, key?: unknown, registerDataAccess: boolean = true): V | null { const cache = this.caches.get(itemClass); if (!cache) { return null; @@ -191,6 +212,10 @@ export class StorageServiceImpl implements StorageService { } if (value instanceof itemClass) { + if (registerDataAccess) { + const event = new DataAccessEvent(itemClass, key); + this.publisher.publish(event); + } return value as V; } @@ -223,7 +248,6 @@ export class StorageServiceImpl implements StorageService { for (const cache of this.caches.values()) { cache.clear(); } - this.caches.clear(); } @@ -235,15 +259,7 @@ export class StorageServiceImpl implements StorageService { return cache.size(); } - /** - * Registers a default cache supplier for a specific item class. - * This allows automatic cache creation when items of this class are stored. - */ - static registerDefaultCacheSupplier(itemClass: Constructor, supplier: CacheSupplier): void { - StorageServiceImpl.defaultCacheSuppliers.set(itemClass, supplier); - } - - releaseResources(): void { + async releaseResources(): Promise { if (this.cleanupIntervalHandle) { clearInterval(this.cleanupIntervalHandle); this.cleanupIntervalHandle = undefined; diff --git a/common/lib/utils/utils.ts b/common/lib/utils/utils.ts index 1e159510c..4091b6a66 100644 --- a/common/lib/utils/utils.ts +++ b/common/lib/utils/utils.ts @@ -20,16 +20,26 @@ import { WrapperProperties } from "../wrapper_property"; import { HostRole } from "../host_role"; import { logger } from "../../logutils"; import { AwsWrapperError, InternalQueryTimeoutError } from "./errors"; -import { TopologyAwareDatabaseDialect } from "../database_dialect/topology_aware_database_dialect"; export function sleep(ms: number) { return new Promise((resolve) => setTimeout(resolve, ms)); } +/** + * Creates a sleep promise that can be aborted before completion. + * + * @param ms - Duration to sleep in milliseconds + * @param message - Error message when aborted + * @returns A tuple of [sleepPromise, abortFunction] + * - sleepPromise: Resolves after ms milliseconds, or rejects if aborted + * - abortFunction: Call to cancel the sleep and reject the promise + */ export function sleepWithAbort(ms: number, message?: string) { let abortSleep; const promise = new Promise((resolve, reject) => { const timeout = setTimeout(resolve, ms); + // Unref the timer to prevent this background task from blocking the application from gracefully exiting. + timeout.unref(); abortSleep = () => { clearTimeout(timeout); reject(new AwsWrapperError(message)); @@ -65,7 +75,7 @@ export function logTopology(hosts: HostInfo[], msgPrefix: string) { return `${msgPrefix}${Messages.get("Utils.topology", msg)}`; } -export function getTimeInNanos() { +export function getTimeInNanos(): bigint { return process.hrtime.bigint(); } @@ -113,32 +123,47 @@ export function equalsIgnoreCase(value1: string | null, value2: string | null): return value1 != null && value2 != null && value1.localeCompare(value2, undefined, { sensitivity: "accent" }) === 0; } -export function isDialectTopologyAware(dialect: any): dialect is TopologyAwareDatabaseDialect { - return dialect; -} - export function containsHostAndPort(hosts: HostInfo[] | null | undefined, hostAndPort: string): boolean { - if (hosts?.length === 0) { + if (!hosts || hosts.length === 0) { return false; } return hosts.some((host) => host.hostAndPort === hostAndPort); } -export class Pair { - private readonly _left: K; - private readonly _right: V; - - constructor(value1: K, value2: V) { - this._left = value1; - this._right = value2; +export function parseInstanceTemplates( + instanceTemplatesString: string | null, + hostValidator: (hostPattern: string) => void, + hostInfoBuilderFunc: () => { withHost(host: string): { build(): HostInfo } } +): Map { + if (!instanceTemplatesString) { + throw new AwsWrapperError(Messages.get("Utils.globalClusterInstanceHostPatternsRequired")); } - get left(): K { - return this._left; - } + const instanceTemplates = new Map(); + const patterns = instanceTemplatesString.split(","); + + for (const pattern of patterns) { + const trimmedPattern = pattern.trim(); + const colonIndex = trimmedPattern.indexOf(":"); + if (colonIndex === -1) { + throw new AwsWrapperError(Messages.get("Utils.invalidPatternFormat", trimmedPattern)); + } - get right(): V { - return this._right; + const region = trimmedPattern.substring(0, colonIndex).trim(); + const hostPattern = trimmedPattern.substring(colonIndex + 1).trim(); + + if (!region || !hostPattern) { + throw new AwsWrapperError(Messages.get("Utils.invalidPatternFormat", trimmedPattern)); + } + + hostValidator(hostPattern); + + const hostInfo = hostInfoBuilderFunc().withHost(hostPattern).build(); + instanceTemplates.set(region, hostInfo); } + + logger.debug(`Detected Global Database patterns: ${JSON.stringify(Array.from(instanceTemplates.entries()))}`); + + return instanceTemplates; } diff --git a/common/lib/wrapper_property.ts b/common/lib/wrapper_property.ts index 40f739884..2a997577e 100644 --- a/common/lib/wrapper_property.ts +++ b/common/lib/wrapper_property.ts @@ -16,18 +16,20 @@ import { ConnectionProvider } from "./connection_provider"; import { DatabaseDialect } from "./database_dialect/database_dialect"; -import { ClusterTopologyMonitorImpl } from "./host_list_provider/monitoring/cluster_topology_monitor"; -import { BlueGreenStatusProvider } from "./plugins/bluegreen/blue_green_status_provider"; +import { AwsWrapperError } from "./utils/errors"; +import { Messages } from "./utils/messages"; export class WrapperProperty { name: string; description: string; defaultValue: any; + allowedValues?: T[]; - constructor(name: string, description: string, defaultValue?: any) { + constructor(name: string, description: string, defaultValue?: any, allowedValues?: T[]) { this.name = name; this.description = description; this.defaultValue = defaultValue; + this.allowedValues = allowedValues; } get(props: Map): T { @@ -36,16 +38,29 @@ export class WrapperProperty { return this.defaultValue; } + if (val != null && this.allowedValues?.length > 0) { + if (!this.allowedValues.includes(val)) { + throw new AwsWrapperError(Messages.get("WrapperProperty.invalidValue", String(val), this.name, this.allowedValues.join(", "))); + } + } + return val; } set(props: Map, val: T) { + if (val != null && this.allowedValues?.length > 0) { + if (!this.allowedValues.includes(val)) { + throw new AwsWrapperError(Messages.get("WrapperProperty.invalidValue", String(val), this.name, this.allowedValues.join(", "))); + } + } props.set(this.name, val); } } export class WrapperProperties { static readonly MONITORING_PROPERTY_PREFIX: string = "monitoring_"; + static readonly TOPOLOGY_MONITORING_PROPERTY_PREFIX: string = "topology_monitoring_"; + static readonly BG_MONITORING_PROPERTY_PREFIX: string = "blue_green_monitoring_"; static readonly DEFAULT_PLUGINS = "auroraConnectionTracker,failover,efm2"; static readonly DEFAULT_TOKEN_EXPIRATION_SEC = 15 * 60; @@ -210,6 +225,38 @@ export class WrapperProperties { ); static readonly FAILOVER_MODE = new WrapperProperty("failoverMode", "Set host role to follow during failover.", ""); + static readonly FAILOVER_HOME_REGION = new WrapperProperty("failoverHomeRegion", "Set home region for GDB failover.", null); + + static readonly ACTIVE_HOME_FAILOVER_MODE = new WrapperProperty( + "activeHomeFailoverMode", + "Set host role to follow during failover when GDB primary region is in home region.", + null, + [ + "strict-writer", + "strict-home-reader", + "strict-out-of-home-reader", + "strict-any-reader", + "home-reader-or-writer", + "out-of-home-reader-or-writer", + "any-reader-or-writer" + ] + ); + + static readonly INACTIVE_HOME_FAILOVER_MODE = new WrapperProperty( + "inactiveHomeFailoverMode", + "Set host role to follow during failover when GDB primary region is not in home region.", + null, + [ + "strict-writer", + "strict-home-reader", + "strict-out-of-home-reader", + "strict-any-reader", + "home-reader-or-writer", + "out-of-home-reader-or-writer", + "any-reader-or-writer" + ] + ); + static readonly FAILOVER_READER_HOST_SELECTOR_STRATEGY = new WrapperProperty( "failoverReaderHostSelectorStrategy", "The strategy that should be used to select a new reader host while opening a new connection.", @@ -244,6 +291,16 @@ export class WrapperProperties { "clusters. Otherwise, if unspecified, the pattern will be automatically created for AWS RDS clusters." ); + static readonly GLOBAL_CLUSTER_INSTANCE_HOST_PATTERNS = new WrapperProperty( + "globalClusterInstanceHostPatterns", + "Comma-separated list of the cluster instance DNS patterns that will be used to " + + "build complete instance endpoints. " + + 'A "?" character in these patterns should be used as a placeholder for cluster instance names. ' + + "This parameter is required for Global Aurora Databases. " + + "Each region in the Global Aurora Database should be specified in the list. " + + "Format: region1:pattern1,region2:pattern2" + ); + static readonly SINGLE_WRITER_CONNECTION_STRING = new WrapperProperty( "singleWriterConnectionString", "Set to true if you are providing a connection string with multiple comma-delimited hosts and your cluster has only one writer. The writer must be the first host in the connection string", @@ -477,11 +534,37 @@ export class WrapperProperties { "Default value 0 means the Wrapper will keep reusing the same cached reader connection.", 0 ); + static readonly SKIP_INACTIVE_WRITER_CLUSTER_CHECK = new WrapperProperty( + "skipInactiveWriterClusterEndpointCheck", + "Allows to avoid connection check for inactive cluster writer endpoint.", + false + ); + + static readonly INACTIVE_CLUSTER_WRITER_SUBSTITUTION_ROLE = new WrapperProperty( + "inactiveClusterWriterEndpointSubstitutionRole", + "Defines whether or not the inactive cluster writer endpoint in the initial connection URL should be replaced with a writer instance URL from the topology info when available.", + "writer", + ["writer", "none"] + ); + + static readonly VERIFY_OPENED_CONNECTION_ROLE = new WrapperProperty( + "verifyOpenedConnectionType", + "Defines whether an opened connection should be verified to be a writer or reader, or if no role verification should be performed.", + null, + ["writer", "reader", "none"] + ); + + static readonly VERIFY_INACTIVE_CLUSTER_WRITER_CONNECTION_ROLE = new WrapperProperty( + "verifyInactiveClusterWriterEndpointConnectionType", + "Defines whether inactive cluster writer connection should be verified to be a writer, or if no role verification should be performed.", + "writer", + ["writer", "none"] + ); private static readonly PREFIXES = [ WrapperProperties.MONITORING_PROPERTY_PREFIX, - ClusterTopologyMonitorImpl.MONITORING_PROPERTY_PREFIX, - BlueGreenStatusProvider.MONITORING_PROPERTY_PREFIX + WrapperProperties.TOPOLOGY_MONITORING_PROPERTY_PREFIX, + WrapperProperties.BG_MONITORING_PROPERTY_PREFIX ]; private static startsWithPrefix(key: string): boolean { diff --git a/docs/using-the-nodejs-wrapper/UsingTheConnectionPool.md b/docs/using-the-nodejs-wrapper/UsingTheConnectionPool.md index 339fb3123..579ead16f 100644 --- a/docs/using-the-nodejs-wrapper/UsingTheConnectionPool.md +++ b/docs/using-the-nodejs-wrapper/UsingTheConnectionPool.md @@ -407,7 +407,7 @@ const result = await pool.query("SELECT NOW()"); ### Resources Cleanup -Throughout the application lifetime, some plugins like the Aurora Connection Tracker Plugin or the Host Monitoring Connection Plugin may create background threads shared by all connections. +Throughout the application lifetime, some plugins like the Aurora Connection Tracker Plugin or the Host Monitoring Connection Plugin may create background tasks shared by all connections. At the end of your application, call `PluginManager.releaseResources()` to clean up these shared resources. diff --git a/eslint.config.js b/eslint.config.js index b02eb491d..85a4dfb9a 100644 --- a/eslint.config.js +++ b/eslint.config.js @@ -35,7 +35,7 @@ const compat = new FlatCompat({ }); export default defineConfig([ - globalIgnores(["**/node_modules", "**/gradle", "**/dist", "**/coverage"]), + globalIgnores(["**/node_modules", "**/gradle", "**/dist", "**/coverage", "**/*.js"]), { extends: compat.extends( "eslint:recommended", diff --git a/examples/typescript_example/src/index.ts b/examples/typescript_example/src/index.ts index b90e34a67..f2b496009 100644 --- a/examples/typescript_example/src/index.ts +++ b/examples/typescript_example/src/index.ts @@ -1,12 +1,12 @@ /* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - + Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance with the License. You may obtain a copy of the License at - + http://www.apache.org/licenses/LICENSE-2.0 - + Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/mysql/lib/client.ts b/mysql/lib/client.ts index 85d3e496f..d2b0e71d9 100644 --- a/mysql/lib/client.ts +++ b/mysql/lib/client.ts @@ -39,15 +39,17 @@ import { ClientUtils } from "../../common/lib/utils/client_utils"; import { RdsMultiAZClusterMySQLDatabaseDialect } from "./dialect/rds_multi_az_mysql_database_dialect"; import { TelemetryTraceLevel } from "../../common/lib/utils/telemetry/telemetry_trace_level"; import { MySQL2DriverDialect } from "./dialect/mysql2_driver_dialect"; -import { isDialectTopologyAware } from "../../common/lib/utils/utils"; +import { isDialectTopologyAware } from "../../common/lib/database_dialect/topology_aware_database_dialect"; import { MySQLClient, MySQLPoolClient } from "./mysql_client"; import { DriverConnectionProvider } from "../../common/lib/driver_connection_provider"; +import { GlobalAuroraMySQLDatabaseDialect } from "./dialect/global_aurora_mysql_database_dialect"; class BaseAwsMySQLClient extends AwsClient implements MySQLClient { private static readonly knownDialectsByCode: Map = new Map([ [DatabaseDialectCodes.MYSQL, new MySQLDatabaseDialect()], [DatabaseDialectCodes.RDS_MYSQL, new RdsMySQLDatabaseDialect()], [DatabaseDialectCodes.AURORA_MYSQL, new AuroraMySQLDatabaseDialect()], + [DatabaseDialectCodes.GLOBAL_AURORA_MYSQL, new GlobalAuroraMySQLDatabaseDialect()], [DatabaseDialectCodes.RDS_MULTI_AZ_MYSQL, new RdsMultiAZClusterMySQLDatabaseDialect()] ]); @@ -113,7 +115,7 @@ class BaseAwsMySQLClient extends AwsClient implements MySQLClient { return result; } - isReadOnly(): boolean { + isReadOnly(): boolean | undefined { return this.pluginService.getSessionStateService().getReadOnly(); } @@ -129,7 +131,7 @@ class BaseAwsMySQLClient extends AwsClient implements MySQLClient { return result; } - getAutoCommit(): boolean { + getAutoCommit(): boolean | undefined { return this.pluginService.getSessionStateService().getAutoCommit(); } @@ -142,7 +144,7 @@ class BaseAwsMySQLClient extends AwsClient implements MySQLClient { this.pluginService.getSessionStateService().setCatalog(catalog); } - getCatalog(): string { + getCatalog(): string | undefined { return this.pluginService.getSessionStateService().getCatalog(); } @@ -150,7 +152,7 @@ class BaseAwsMySQLClient extends AwsClient implements MySQLClient { throw new UnsupportedMethodError(Messages.get("Client.methodNotSupported", "setSchema")); } - getSchema(): string { + getSchema(): string | undefined { throw new UnsupportedMethodError(Messages.get("Client.methodNotSupported", "getSchema")); } @@ -181,7 +183,7 @@ class BaseAwsMySQLClient extends AwsClient implements MySQLClient { this.pluginService.getSessionStateService().setTransactionIsolation(level); } - getTransactionIsolation(): TransactionIsolationLevel { + getTransactionIsolation(): TransactionIsolationLevel | undefined { return this.pluginService.getSessionStateService().getTransactionIsolation(); } @@ -197,6 +199,10 @@ class BaseAwsMySQLClient extends AwsClient implements MySQLClient { this.properties, "end", () => { + if (!this.targetClient) { + return Promise.resolve(undefined); + } + this.pluginService.removeErrorListener(this.targetClient); const res = ClientUtils.queryWithTimeout(this.targetClient.end(), this.properties); this.targetClient = undefined; diff --git a/mysql/lib/dialect/aurora_mysql_database_dialect.ts b/mysql/lib/dialect/aurora_mysql_database_dialect.ts index 3ebfa2240..a4f3c69a5 100644 --- a/mysql/lib/dialect/aurora_mysql_database_dialect.ts +++ b/mysql/lib/dialect/aurora_mysql_database_dialect.ts @@ -15,20 +15,16 @@ */ import { MySQLDatabaseDialect } from "./mysql_database_dialect"; -import { HostListProviderService } from "../../../common/lib/host_list_provider_service"; import { HostListProvider } from "../../../common/lib/host_list_provider/host_list_provider"; import { RdsHostListProvider } from "../../../common/lib/host_list_provider/rds_host_list_provider"; import { TopologyAwareDatabaseDialect } from "../../../common/lib/database_dialect/topology_aware_database_dialect"; import { HostRole } from "../../../common/lib/host_role"; import { ClientWrapper } from "../../../common/lib/client_wrapper"; import { DatabaseDialectCodes } from "../../../common/lib/database_dialect/database_dialect_codes"; -import { WrapperProperties } from "../../../common/lib/wrapper_property"; -import { - MonitoringRdsHostListProvider -} from "../../../common/lib/host_list_provider/monitoring/monitoring_host_list_provider"; -import { PluginService } from "../../../common/lib/plugin_service"; import { BlueGreenDialect, BlueGreenResult } from "../../../common/lib/database_dialect/blue_green_dialect"; -import { TopologyQueryResult, TopologyUtils } from "../../../common/lib/host_list_provider/topology_utils"; +import { TopologyQueryResult } from "../../../common/lib/host_list_provider/topology_utils"; +import { AuroraTopologyUtils } from "../../../common/lib/host_list_provider/aurora_topology_utils"; +import { FullServicesContainer } from "../../../common/lib/utils/full_services_container"; export class AuroraMySQLDatabaseDialect extends MySQLDatabaseDialect implements TopologyAwareDatabaseDialect, BlueGreenDialect { private static readonly TOPOLOGY_QUERY: string = @@ -50,18 +46,9 @@ export class AuroraMySQLDatabaseDialect extends MySQLDatabaseDialect implements private static readonly TOPOLOGY_TABLE_EXIST_QUERY: string = "SELECT 1 AS tmp FROM information_schema.tables WHERE table_schema = 'mysql' AND table_name = 'rds_topology'"; - getHostListProvider(props: Map, originalUrl: string, hostListProviderService: HostListProviderService): HostListProvider { - const topologyUtils: TopologyUtils = new TopologyUtils(this, hostListProviderService.getHostInfoBuilder()); - if (WrapperProperties.PLUGINS.get(props).includes("failover2")) { - return new MonitoringRdsHostListProvider( - props, - originalUrl, - topologyUtils, - hostListProviderService, - (hostListProviderService) - ); - } - return new RdsHostListProvider(props, originalUrl, topologyUtils, hostListProviderService); + getHostListProvider(props: Map, originalUrl: string, servicesContainer: FullServicesContainer): HostListProvider { + const topologyUtils = new AuroraTopologyUtils(this, servicesContainer.hostListProviderService.getHostInfoBuilder()); + return new RdsHostListProvider(props, originalUrl, topologyUtils, servicesContainer); } async queryForTopology(targetClient: ClientWrapper): Promise { @@ -142,7 +129,7 @@ export class AuroraMySQLDatabaseDialect extends MySQLDatabaseDialect implements } getDialectUpdateCandidates(): string[] { - return [DatabaseDialectCodes.RDS_MULTI_AZ_MYSQL]; + return [DatabaseDialectCodes.GLOBAL_AURORA_MYSQL, DatabaseDialectCodes.RDS_MULTI_AZ_MYSQL]; } async isBlueGreenStatusAvailable(clientWrapper: ClientWrapper): Promise { diff --git a/mysql/lib/dialect/global_aurora_mysql_database_dialect.ts b/mysql/lib/dialect/global_aurora_mysql_database_dialect.ts index 1b51b4e11..faf89a3ca 100644 --- a/mysql/lib/dialect/global_aurora_mysql_database_dialect.ts +++ b/mysql/lib/dialect/global_aurora_mysql_database_dialect.ts @@ -1,23 +1,27 @@ /* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ + Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"). + You may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ import { AuroraMySQLDatabaseDialect } from "./aurora_mysql_database_dialect"; import { GlobalAuroraTopologyDialect } from "../../../common/lib/database_dialect/topology_aware_database_dialect"; import { ClientWrapper } from "../../../common/lib/client_wrapper"; import { TopologyQueryResult } from "../../../common/lib/host_list_provider/topology_utils"; +import { FullServicesContainer } from "../../../common/lib/utils/full_services_container"; +import { HostListProvider } from "../../../common/lib/host_list_provider/host_list_provider"; +import { GlobalAuroraHostListProvider } from "../../../common/lib/host_list_provider/global_aurora_host_list_provider"; +import { GlobalTopologyUtils } from "../../../common/lib/host_list_provider/global_topology_utils"; export class GlobalAuroraMySQLDatabaseDialect extends AuroraMySQLDatabaseDialect implements GlobalAuroraTopologyDialect { private static readonly GLOBAL_STATUS_TABLE_EXISTS_QUERY = @@ -29,8 +33,8 @@ export class GlobalAuroraMySQLDatabaseDialect extends AuroraMySQLDatabaseDialect " upper(table_schema) = 'INFORMATION_SCHEMA' AND upper(table_name) = 'AURORA_GLOBAL_DB_INSTANCE_STATUS'"; private static readonly GLOBAL_TOPOLOGY_QUERY = - "SELECT SERVER_ID, CASE WHEN SESSION_ID = 'MASTER_SESSION_ID' THEN TRUE ELSE FALSE END AS IS_WRITER, " + - "VISIBILITY_LAG_IN_MSEC, AWS_REGION " + + "SELECT server_id, CASE WHEN SESSION_ID = 'MASTER_SESSION_ID' THEN TRUE ELSE FALSE END AS is_writer, " + + "visibility_lag_in_msec, aws_region " + "FROM information_schema.aurora_global_db_instance_status"; private static readonly REGION_COUNT_QUERY = "SELECT count(1) FROM information_schema.aurora_global_db_status"; @@ -68,7 +72,14 @@ export class GlobalAuroraMySQLDatabaseDialect extends AuroraMySQLDatabaseDialect return []; } - // TODO: implement GetHostListProvider once GDBHostListProvider is implemented + getHostListProvider(props: Map, originalUrl: string, servicesContainer: FullServicesContainer): HostListProvider { + return new GlobalAuroraHostListProvider( + props, + originalUrl, + new GlobalTopologyUtils(this, servicesContainer.pluginService.getHostInfoBuilder()), + servicesContainer + ); + } async queryForTopology(targetClient: ClientWrapper): Promise { const res = await targetClient.query(GlobalAuroraMySQLDatabaseDialect.GLOBAL_TOPOLOGY_QUERY); diff --git a/mysql/lib/dialect/mysql2_driver_dialect.ts b/mysql/lib/dialect/mysql2_driver_dialect.ts index c9e3e5841..3e7b9f3af 100644 --- a/mysql/lib/dialect/mysql2_driver_dialect.ts +++ b/mysql/lib/dialect/mysql2_driver_dialect.ts @@ -47,6 +47,8 @@ export class MySQL2DriverDialect implements DriverDialect { preparePoolClientProperties(props: Map, poolConfig: AwsPoolConfig | undefined): any { const finalPoolConfig: PoolOptions = {}; const finalClientProps = WrapperProperties.removeWrapperProperties(props); + this.setKeepAliveProperties(finalClientProps, props.get(WrapperProperties.KEEPALIVE_PROPERTIES.name)); + this.setConnectTimeout(finalClientProps, props.get(WrapperProperties.WRAPPER_CONNECT_TIMEOUT.name)); Object.assign(finalPoolConfig, Object.fromEntries(finalClientProps.entries())); finalPoolConfig.connectionLimit = poolConfig?.maxConnections; @@ -69,6 +71,9 @@ export class MySQL2DriverDialect implements DriverDialect { } setQueryTimeout(props: Map, sql?: any, wrapperQueryTimeout?: any) { + if (!sql) { + return; + } const timeout = wrapperQueryTimeout ?? props.get(WrapperProperties.WRAPPER_QUERY_TIMEOUT.name); if (timeout && !sql[MySQL2DriverDialect.QUERY_TIMEOUT_PROPERTY_NAME]) { sql[MySQL2DriverDialect.QUERY_TIMEOUT_PROPERTY_NAME] = Number(timeout); diff --git a/mysql/lib/dialect/mysql_database_dialect.ts b/mysql/lib/dialect/mysql_database_dialect.ts index 084ea6b21..f0fe20ad6 100644 --- a/mysql/lib/dialect/mysql_database_dialect.ts +++ b/mysql/lib/dialect/mysql_database_dialect.ts @@ -28,6 +28,7 @@ import { ErrorHandler } from "../../../common/lib/error_handler"; import { MySQLErrorHandler } from "../mysql_error_handler"; import { Messages } from "../../../common/lib/utils/messages"; import { HostRole } from "../../../common/lib/host_role"; +import { FullServicesContainer } from "../../../common/lib/utils/full_services_container"; export class MySQLDatabaseDialect implements DatabaseDialect { protected dialectName: string = this.constructor.name; @@ -108,8 +109,8 @@ export class MySQLDatabaseDialect implements DatabaseDialect { }); } - getHostListProvider(props: Map, originalUrl: string, hostListProviderService: HostListProviderService): HostListProvider { - return new ConnectionStringHostListProvider(props, originalUrl, this.getDefaultPort(), hostListProviderService); + getHostListProvider(props: Map, originalUrl: string, servicesContainer: FullServicesContainer): HostListProvider { + return new ConnectionStringHostListProvider(props, originalUrl, this.getDefaultPort(), servicesContainer.hostListProviderService); } getErrorHandler(): ErrorHandler { diff --git a/mysql/lib/dialect/rds_multi_az_mysql_database_dialect.ts b/mysql/lib/dialect/rds_multi_az_mysql_database_dialect.ts index f070e8a1f..33ac21693 100644 --- a/mysql/lib/dialect/rds_multi_az_mysql_database_dialect.ts +++ b/mysql/lib/dialect/rds_multi_az_mysql_database_dialect.ts @@ -15,7 +15,6 @@ */ import { MySQLDatabaseDialect } from "./mysql_database_dialect"; -import { HostListProviderService } from "../../../common/lib/host_list_provider_service"; import { HostListProvider } from "../../../common/lib/host_list_provider/host_list_provider"; import { ClientWrapper } from "../../../common/lib/client_wrapper"; import { HostRole } from "../../../common/lib/host_role"; @@ -24,10 +23,9 @@ import { AwsWrapperError } from "../../../common/lib/utils/errors"; import { TopologyAwareDatabaseDialect } from "../../../common/lib/database_dialect/topology_aware_database_dialect"; import { RdsHostListProvider } from "../../../common/lib/host_list_provider/rds_host_list_provider"; import { FailoverRestriction } from "../../../common/lib/plugins/failover/failover_restriction"; -import { WrapperProperties } from "../../../common/lib/wrapper_property"; -import { PluginService } from "../../../common/lib/plugin_service"; -import { MonitoringRdsHostListProvider } from "../../../common/lib/host_list_provider/monitoring/monitoring_host_list_provider"; -import { TopologyQueryResult, TopologyUtils } from "../../../common/lib/host_list_provider/topology_utils"; +import { TopologyQueryResult } from "../../../common/lib/host_list_provider/topology_utils"; +import { AuroraTopologyUtils } from "../../../common/lib/host_list_provider/aurora_topology_utils"; +import { FullServicesContainer } from "../../../common/lib/utils/full_services_container"; export class RdsMultiAZClusterMySQLDatabaseDialect extends MySQLDatabaseDialect implements TopologyAwareDatabaseDialect { private static readonly TOPOLOGY_QUERY: string = "SELECT id, endpoint, port FROM mysql.rds_topology"; @@ -71,18 +69,9 @@ export class RdsMultiAZClusterMySQLDatabaseDialect extends MySQLDatabaseDialect .catch(() => false); } - getHostListProvider(props: Map, originalUrl: string, hostListProviderService: HostListProviderService): HostListProvider { - const topologyUtils: TopologyUtils = new TopologyUtils(this, hostListProviderService.getHostInfoBuilder()); - if (WrapperProperties.PLUGINS.get(props).includes("failover2")) { - return new MonitoringRdsHostListProvider( - props, - originalUrl, - topologyUtils, - hostListProviderService, - (hostListProviderService) - ); - } - return new RdsHostListProvider(props, originalUrl, topologyUtils, hostListProviderService); + getHostListProvider(props: Map, originalUrl: string, servicesContainer: FullServicesContainer): HostListProvider { + const topologyUtils = new AuroraTopologyUtils(this, servicesContainer.hostListProviderService.getHostInfoBuilder()); + return new RdsHostListProvider(props, originalUrl, topologyUtils, servicesContainer); } async queryForTopology(targetClient: ClientWrapper): Promise { diff --git a/mysql/lib/icp/mysql_internal_pool_client.ts b/mysql/lib/icp/mysql_internal_pool_client.ts index 45f8826f5..19dae953a 100644 --- a/mysql/lib/icp/mysql_internal_pool_client.ts +++ b/mysql/lib/icp/mysql_internal_pool_client.ts @@ -1,12 +1,12 @@ /* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - + Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance with the License. You may obtain a copy of the License at - + http://www.apache.org/licenses/LICENSE-2.0 - + Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/mysql/lib/mysql_client.ts b/mysql/lib/mysql_client.ts index 1641adf79..016d710fb 100644 --- a/mysql/lib/mysql_client.ts +++ b/mysql/lib/mysql_client.ts @@ -1,12 +1,12 @@ /* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - + Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance with the License. You may obtain a copy of the License at - + http://www.apache.org/licenses/LICENSE-2.0 - + Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/package.json b/package.json index bcf18e00c..ef2f1b827 100644 --- a/package.json +++ b/package.json @@ -39,7 +39,7 @@ "format": "prettier . --write --config .prettierrc", "integration": "cross-env NODE_OPTIONS=\"$NODE_OPTIONS --experimental-vm-modules\" npx jest --config=jest.integration.config.json --runInBand --verbose", "debug-integration": "cross-env NODE_OPTIONS=\"$NODE_OPTIONS --experimental-vm-modules --inspect-brk=0.0.0.0:5005\" npx jest --config=jest.integration.config.json --runInBand", - "lint": "eslint --fix --ext .ts .", + "lint": "eslint --ext .ts .", "test": "cross-env NODE_OPTIONS=\"$NODE_OPTIONS --experimental-vm-modules\" npx jest --config=jest.unit.config.json", "fix-imports": "babel --extensions \".js\" dist -d dist", "build": "tsc", diff --git a/pg/lib/client.ts b/pg/lib/client.ts index 4517e1659..0593d1bf1 100644 --- a/pg/lib/client.ts +++ b/pg/lib/client.ts @@ -14,15 +14,7 @@ limitations under the License. */ -import { - QueryArrayConfig, - QueryArrayResult, - QueryConfig, - QueryConfigValues, - QueryResult, - QueryResultRow, - Submittable -} from "pg"; +import { QueryArrayConfig, QueryArrayResult, QueryConfig, QueryConfigValues, QueryResult, QueryResultRow, Submittable } from "pg"; import { AwsClient } from "../../common/lib/aws_client"; import { PgConnectionUrlParser } from "./pg_connection_url_parser"; import { DatabaseDialect, DatabaseType } from "../../common/lib/database_dialect/database_dialect"; @@ -46,15 +38,17 @@ import { ClientWrapper } from "../../common/lib/client_wrapper"; import { RdsMultiAZClusterPgDatabaseDialect } from "./dialect/rds_multi_az_pg_database_dialect"; import { TelemetryTraceLevel } from "../../common/lib/utils/telemetry/telemetry_trace_level"; import { NodePostgresDriverDialect } from "./dialect/node_postgres_driver_dialect"; -import { isDialectTopologyAware } from "../../common/lib/utils/utils"; +import { isDialectTopologyAware } from "../../common/lib/database_dialect/topology_aware_database_dialect"; import { PGClient, PGPoolClient } from "./pg_client"; import { DriverConnectionProvider } from "../../common/lib/driver_connection_provider"; +import { GlobalAuroraPgDatabaseDialect } from "./dialect/global_aurora_pg_database_dialect"; class BaseAwsPgClient extends AwsClient implements PGClient { private static readonly knownDialectsByCode: Map = new Map([ [DatabaseDialectCodes.PG, new PgDatabaseDialect()], [DatabaseDialectCodes.RDS_PG, new RdsPgDatabaseDialect()], [DatabaseDialectCodes.AURORA_PG, new AuroraPgDatabaseDialect()], + [DatabaseDialectCodes.GLOBAL_AURORA_PG, new GlobalAuroraPgDatabaseDialect()], [DatabaseDialectCodes.RDS_MULTI_AZ_PG, new RdsMultiAZClusterPgDatabaseDialect()] ]); @@ -90,7 +84,7 @@ class BaseAwsPgClient extends AwsClient implements PGClient { return result; } - isReadOnly(): boolean { + isReadOnly(): boolean | undefined { return this.pluginService.getSessionStateService().getReadOnly(); } @@ -128,7 +122,7 @@ class BaseAwsPgClient extends AwsClient implements PGClient { this.pluginService.getSessionStateService().setTransactionIsolation(level); } - getTransactionIsolation(): TransactionIsolationLevel { + getTransactionIsolation(): TransactionIsolationLevel | undefined { return this.pluginService.getSessionStateService().getTransactionIsolation(); } @@ -155,7 +149,7 @@ class BaseAwsPgClient extends AwsClient implements PGClient { return result; } - getSchema(): string { + getSchema(): string | undefined { return this.pluginService.getSessionStateService().getSchema(); } @@ -407,7 +401,7 @@ export class AwsPgPoolClient implements PGPoolClient { await awsPGPooledConnection.connect(); const res = await awsPGPooledConnection.query(queryTextOrConfig as any, values); await awsPGPooledConnection.end(); - return res; + return res as any; } catch (error: any) { if (!(error instanceof FailoverSuccessError)) { // Release pooled connection. diff --git a/pg/lib/dialect/aurora_pg_database_dialect.ts b/pg/lib/dialect/aurora_pg_database_dialect.ts index c717c2d66..baf8e14cd 100644 --- a/pg/lib/dialect/aurora_pg_database_dialect.ts +++ b/pg/lib/dialect/aurora_pg_database_dialect.ts @@ -15,7 +15,6 @@ */ import { PgDatabaseDialect } from "./pg_database_dialect"; -import { HostListProviderService } from "../../../common/lib/host_list_provider_service"; import { HostListProvider } from "../../../common/lib/host_list_provider/host_list_provider"; import { RdsHostListProvider } from "../../../common/lib/host_list_provider/rds_host_list_provider"; import { TopologyAwareDatabaseDialect } from "../../../common/lib/database_dialect/topology_aware_database_dialect"; @@ -23,11 +22,10 @@ import { HostRole } from "../../../common/lib"; import { ClientWrapper } from "../../../common/lib/client_wrapper"; import { DatabaseDialectCodes } from "../../../common/lib/database_dialect/database_dialect_codes"; import { LimitlessDatabaseDialect } from "../../../common/lib/database_dialect/limitless_database_dialect"; -import { WrapperProperties } from "../../../common/lib/wrapper_property"; -import { MonitoringRdsHostListProvider } from "../../../common/lib/host_list_provider/monitoring/monitoring_host_list_provider"; -import { PluginService } from "../../../common/lib/plugin_service"; import { BlueGreenDialect, BlueGreenResult } from "../../../common/lib/database_dialect/blue_green_dialect"; -import { TopologyQueryResult, TopologyUtils } from "../../../common/lib/host_list_provider/topology_utils"; +import { TopologyQueryResult } from "../../../common/lib/host_list_provider/topology_utils"; +import { AuroraTopologyUtils } from "../../../common/lib/host_list_provider/aurora_topology_utils"; +import { FullServicesContainer } from "../../../common/lib/utils/full_services_container"; export class AuroraPgDatabaseDialect extends PgDatabaseDialect implements TopologyAwareDatabaseDialect, LimitlessDatabaseDialect, BlueGreenDialect { private static readonly VERSION = process.env.npm_package_version; @@ -53,22 +51,13 @@ export class AuroraPgDatabaseDialect extends PgDatabaseDialect implements Topolo private static readonly TOPOLOGY_TABLE_EXIST_QUERY: string = "SELECT pg_catalog.'get_blue_green_fast_switchover_metadata'::regproc"; - getHostListProvider(props: Map, originalUrl: string, hostListProviderService: HostListProviderService): HostListProvider { - const topologyUtils: TopologyUtils = new TopologyUtils(this, hostListProviderService.getHostInfoBuilder()); - if (WrapperProperties.PLUGINS.get(props).includes("failover2")) { - return new MonitoringRdsHostListProvider( - props, - originalUrl, - topologyUtils, - hostListProviderService, - (hostListProviderService) - ); - } - return new RdsHostListProvider(props, originalUrl, topologyUtils, hostListProviderService); + getHostListProvider(props: Map, originalUrl: string, servicesContainer: FullServicesContainer): HostListProvider { + const topologyUtils = new AuroraTopologyUtils(this, servicesContainer.hostListProviderService.getHostInfoBuilder()); + return new RdsHostListProvider(props, originalUrl, topologyUtils, servicesContainer); } async queryForTopology(targetClient: ClientWrapper): Promise { - const res = await targetClient.query(AuroraPgDatabaseDialect.TOPOLOGY_QUERY); + const res = await targetClient.queryWithTimeout(AuroraPgDatabaseDialect.TOPOLOGY_QUERY); const results: TopologyQueryResult[] = []; const rows: any[] = res.rows; rows.forEach((row) => { @@ -149,7 +138,7 @@ export class AuroraPgDatabaseDialect extends PgDatabaseDialect implements Topolo } getDialectUpdateCandidates(): string[] { - return [DatabaseDialectCodes.RDS_MULTI_AZ_PG]; + return [DatabaseDialectCodes.GLOBAL_AURORA_PG, DatabaseDialectCodes.RDS_MULTI_AZ_PG]; } getLimitlessRoutersQuery(): string { diff --git a/pg/lib/dialect/global_aurora_pg_database_dialect.ts b/pg/lib/dialect/global_aurora_pg_database_dialect.ts index c2a864545..d452428ed 100644 --- a/pg/lib/dialect/global_aurora_pg_database_dialect.ts +++ b/pg/lib/dialect/global_aurora_pg_database_dialect.ts @@ -1,23 +1,27 @@ /* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ + Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"). + You may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ import { AuroraPgDatabaseDialect } from "./aurora_pg_database_dialect"; import { GlobalAuroraTopologyDialect } from "../../../common/lib/database_dialect/topology_aware_database_dialect"; import { ClientWrapper } from "../../../common/lib/client_wrapper"; import { TopologyQueryResult } from "../../../common/lib/host_list_provider/topology_utils"; +import { FullServicesContainer } from "../../../common/lib/utils/full_services_container"; +import { HostListProvider } from "../../../common/lib/host_list_provider/host_list_provider"; +import { GlobalAuroraHostListProvider } from "../../../common/lib/host_list_provider/global_aurora_host_list_provider"; +import { GlobalTopologyUtils } from "../../../common/lib/host_list_provider/global_topology_utils"; export class GlobalAuroraPgDatabaseDialect extends AuroraPgDatabaseDialect implements GlobalAuroraTopologyDialect { private static readonly GLOBAL_STATUS_FUNC_EXISTS_QUERY = "select 'aurora_global_db_status'::regproc"; @@ -77,10 +81,17 @@ export class GlobalAuroraPgDatabaseDialect extends AuroraPgDatabaseDialect imple return []; } - // TODO: implement GetHostListProvider once GDBHostListProvider is implemented + getHostListProvider(props: Map, originalUrl: string, servicesContainer: FullServicesContainer): HostListProvider { + return new GlobalAuroraHostListProvider( + props, + originalUrl, + new GlobalTopologyUtils(this, servicesContainer.pluginService.getHostInfoBuilder()), + servicesContainer + ); + } async queryForTopology(targetClient: ClientWrapper): Promise { - const res = await targetClient.query(GlobalAuroraPgDatabaseDialect.GLOBAL_TOPOLOGY_QUERY); + const res = await targetClient.queryWithTimeout(GlobalAuroraPgDatabaseDialect.GLOBAL_TOPOLOGY_QUERY); const hosts: TopologyQueryResult[] = []; const rows: any[] = res.rows; rows.forEach((row) => { diff --git a/pg/lib/dialect/node_postgres_driver_dialect.ts b/pg/lib/dialect/node_postgres_driver_dialect.ts index 7cf659392..40d1ba381 100644 --- a/pg/lib/dialect/node_postgres_driver_dialect.ts +++ b/pg/lib/dialect/node_postgres_driver_dialect.ts @@ -49,8 +49,12 @@ export class NodePostgresDriverDialect implements DriverDialect { preparePoolClientProperties(props: Map, poolConfig: AwsPoolConfig | undefined): any { const finalPoolConfig: pkgPg.PoolConfig = {}; const finalClientProps = WrapperProperties.removeWrapperProperties(props); + this.setKeepAliveProperties(finalClientProps, props.get(WrapperProperties.KEEPALIVE_PROPERTIES.name)); + this.setConnectTimeout(finalClientProps, props.get(WrapperProperties.WRAPPER_CONNECT_TIMEOUT.name)); + this.setQueryTimeout(finalClientProps, undefined, props.get(WrapperProperties.WRAPPER_QUERY_TIMEOUT.name)); Object.assign(finalPoolConfig, Object.fromEntries(finalClientProps.entries())); + finalPoolConfig.max = poolConfig?.maxConnections; finalPoolConfig.idleTimeoutMillis = poolConfig?.idleTimeoutMillis; finalPoolConfig.allowExitOnIdle = poolConfig?.allowExitOnIdle; diff --git a/pg/lib/dialect/pg_database_dialect.ts b/pg/lib/dialect/pg_database_dialect.ts index 3b03ef24a..0afa54f29 100644 --- a/pg/lib/dialect/pg_database_dialect.ts +++ b/pg/lib/dialect/pg_database_dialect.ts @@ -25,6 +25,7 @@ import { FailoverRestriction } from "../../../common/lib/plugins/failover/failov import { ErrorHandler } from "../../../common/lib/error_handler"; import { PgErrorHandler } from "../pg_error_handler"; import { Messages } from "../../../common/lib/utils/messages"; +import { FullServicesContainer } from "../../../common/lib/utils/full_services_container"; export class PgDatabaseDialect implements DatabaseDialect { protected dialectName: string = this.constructor.name; @@ -105,8 +106,8 @@ export class PgDatabaseDialect implements DatabaseDialect { }); } - getHostListProvider(props: Map, originalUrl: string, hostListProviderService: HostListProviderService): HostListProvider { - return new ConnectionStringHostListProvider(props, originalUrl, this.getDefaultPort(), hostListProviderService); + getHostListProvider(props: Map, originalUrl: string, servicesContainer: FullServicesContainer): HostListProvider { + return new ConnectionStringHostListProvider(props, originalUrl, this.getDefaultPort(), servicesContainer.hostListProviderService); } getErrorHandler(): ErrorHandler { diff --git a/pg/lib/dialect/rds_multi_az_pg_database_dialect.ts b/pg/lib/dialect/rds_multi_az_pg_database_dialect.ts index d3e47f94c..5b4b79b2a 100644 --- a/pg/lib/dialect/rds_multi_az_pg_database_dialect.ts +++ b/pg/lib/dialect/rds_multi_az_pg_database_dialect.ts @@ -14,7 +14,6 @@ limitations under the License. */ -import { HostListProviderService } from "../../../common/lib/host_list_provider_service"; import { HostListProvider } from "../../../common/lib/host_list_provider/host_list_provider"; import { ClientWrapper } from "../../../common/lib/client_wrapper"; import { AwsWrapperError, HostRole } from "../../../common/lib"; @@ -24,10 +23,9 @@ import { RdsHostListProvider } from "../../../common/lib/host_list_provider/rds_ import { PgDatabaseDialect } from "./pg_database_dialect"; import { ErrorHandler } from "../../../common/lib/error_handler"; import { MultiAzPgErrorHandler } from "../multi_az_pg_error_handler"; -import { WrapperProperties } from "../../../common/lib/wrapper_property"; -import { PluginService } from "../../../common/lib/plugin_service"; -import { MonitoringRdsHostListProvider } from "../../../common/lib/host_list_provider/monitoring/monitoring_host_list_provider"; -import { TopologyQueryResult, TopologyUtils } from "../../../common/lib/host_list_provider/topology_utils"; +import { TopologyQueryResult } from "../../../common/lib/host_list_provider/topology_utils"; +import { AuroraTopologyUtils } from "../../../common/lib/host_list_provider/aurora_topology_utils"; +import { FullServicesContainer } from "../../../common/lib/utils/full_services_container"; export class RdsMultiAZClusterPgDatabaseDialect extends PgDatabaseDialect implements TopologyAwareDatabaseDialect { constructor() { @@ -64,18 +62,9 @@ export class RdsMultiAZClusterPgDatabaseDialect extends PgDatabaseDialect implem } } - getHostListProvider(props: Map, originalUrl: string, hostListProviderService: HostListProviderService): HostListProvider { - const topologyUtils: TopologyUtils = new TopologyUtils(this, hostListProviderService.getHostInfoBuilder()); - if (WrapperProperties.PLUGINS.get(props).includes("failover2")) { - return new MonitoringRdsHostListProvider( - props, - originalUrl, - topologyUtils, - hostListProviderService, - (hostListProviderService) - ); - } - return new RdsHostListProvider(props, originalUrl, topologyUtils, hostListProviderService); + getHostListProvider(props: Map, originalUrl: string, servicesContainer: FullServicesContainer): HostListProvider { + const topologyUtils = new AuroraTopologyUtils(this, servicesContainer.hostListProviderService.getHostInfoBuilder()); + return new RdsHostListProvider(props, originalUrl, topologyUtils, servicesContainer); } async queryForTopology(targetClient: ClientWrapper): Promise { diff --git a/pg/lib/icp/pg_internal_pool_client.ts b/pg/lib/icp/pg_internal_pool_client.ts index ed0324db1..b01045de0 100644 --- a/pg/lib/icp/pg_internal_pool_client.ts +++ b/pg/lib/icp/pg_internal_pool_client.ts @@ -1,12 +1,12 @@ /* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - + Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance with the License. You may obtain a copy of the License at - + http://www.apache.org/licenses/LICENSE-2.0 - + Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/pg/lib/pg_client.ts b/pg/lib/pg_client.ts index 468eb1038..1105f1e47 100644 --- a/pg/lib/pg_client.ts +++ b/pg/lib/pg_client.ts @@ -1,12 +1,12 @@ /* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - + Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 - + Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -14,15 +14,7 @@ limitations under the License. */ -import { - QueryArrayConfig, - QueryArrayResult, - QueryConfig, - QueryConfigValues, - QueryResult, - QueryResultRow, - Submittable -} from "pg"; +import { QueryArrayConfig, QueryArrayResult, QueryConfig, QueryConfigValues, QueryResult, QueryResultRow, Submittable } from "pg"; import { AwsPGPooledConnection } from "./client"; export interface PGClient { diff --git a/tests/integration/container/tests/aurora_failover.test.ts b/tests/integration/container/tests/aurora_failover.test.ts deleted file mode 100644 index 589fb368f..000000000 --- a/tests/integration/container/tests/aurora_failover.test.ts +++ /dev/null @@ -1,321 +0,0 @@ -/* - Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - - Licensed under the Apache License, Version 2.0 (the "License"). - You may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ - -import { TestEnvironment } from "./utils/test_environment"; -import { DriverHelper } from "./utils/driver_helper"; -import { AuroraTestUtility } from "./utils/aurora_test_utility"; -import { - FailoverSuccessError, - PluginManager, - TransactionIsolationLevel, - TransactionResolutionUnknownError -} from "../../../../index"; -import { DatabaseEngine } from "./utils/database_engine"; -import { QueryResult } from "pg"; -import { ProxyHelper } from "./utils/proxy_helper"; -import { logger } from "../../../../common/logutils"; -import { features, instanceCount } from "./config"; -import { TestEnvironmentFeatures } from "./utils/test_environment_features"; -import { RdsUtils } from "../../../../common/lib/utils/rds_utils"; - -const itIf = - features.includes(TestEnvironmentFeatures.FAILOVER_SUPPORTED) && - !features.includes(TestEnvironmentFeatures.PERFORMANCE) && - !features.includes(TestEnvironmentFeatures.RUN_AUTOSCALING_TESTS_ONLY) && - instanceCount >= 2 - ? it - : it.skip; -const itIfTwoInstance = instanceCount == 2 ? itIf : it.skip; -const itIfThreeInstanceAuroraCluster = - instanceCount == 3 && !features.includes(TestEnvironmentFeatures.RDS_MULTI_AZ_SUPPORTED) ? it : it.skip; - -let env: TestEnvironment; -let driver; -let client: any; -let secondaryClient: any; -let initClientFunc: (props: any) => any; - -let auroraTestUtility: AuroraTestUtility; - -async function initDefaultConfig(host: string, port: number, connectToProxy: boolean): Promise { - let config: any = { - user: env.databaseInfo.username, - host: host, - database: env.databaseInfo.defaultDbName, - password: env.databaseInfo.password, - port: port, - plugins: "failover", - failoverTimeoutMs: 250000, - enableTelemetry: true, - telemetryTracesBackend: "OTLP", - telemetryMetricsBackend: "OTLP" - }; - if (connectToProxy) { - config["clusterInstanceHostPattern"] = "?." + env.proxyDatabaseInfo.instanceEndpointSuffix; - } - config = DriverHelper.addDriverSpecificConfiguration(config, env.engine); - return config; -} - -async function initConfigWithEFM2(host: string, port: number, connectToProxy: boolean): Promise { - const config: any = await initDefaultConfig(host, port, connectToProxy); - config["plugins"] = "failover,efm2"; - config["failoverTimeoutMs"] = 20000; - config["failureDetectionCount"] = 2; - config["failureDetectionInterval"] = 1000; - config["failureDetectionTime"] = 2000; - config["connectTimeout"] = 10000; - config["wrapperQueryTimeout"] = 20000; - config["monitoring_wrapperQueryTimeout"] = 3000; - config["monitoring_wrapperConnectTimeout"] = 3000; - return config; -} - -describe("aurora failover", () => { - beforeEach(async () => { - logger.info(`Test started: ${expect.getState().currentTestName}`); - env = await TestEnvironment.getCurrent(); - - auroraTestUtility = new AuroraTestUtility(env.region); - driver = DriverHelper.getDriverForDatabaseEngine(env.engine); - initClientFunc = DriverHelper.getClient(driver); - await ProxyHelper.enableAllConnectivity(); - await TestEnvironment.verifyClusterStatus(); - - client = null; - secondaryClient = null; - }, 1320000); - - afterEach(async () => { - if (client !== null) { - try { - await client.end(); - } catch (error) { - // pass - } - } - - if (secondaryClient !== null) { - try { - await secondaryClient.end(); - } catch (error) { - // pass - } - } - await PluginManager.releaseResources(); - logger.info(`Test finished: ${expect.getState().currentTestName}`); - }, 1320000); - - itIfThreeInstanceAuroraCluster( - "writer failover efm", - async () => { - // Connect to writer instance. - const writerConfig = await initDefaultConfig(env.proxyDatabaseInfo.writerInstanceEndpoint, env.proxyDatabaseInfo.instanceEndpointPort, true); - writerConfig["failoverMode"] = "reader-or-writer"; - - client = initClientFunc(writerConfig); - await client.connect(); - - const initialWriterId = await auroraTestUtility.queryInstanceId(client); - expect(await auroraTestUtility.isDbInstanceWriter(initialWriterId)).toStrictEqual(true); - const instances = env.databaseInfo.instances; - const readerInstance = instances[1].instanceId; - await ProxyHelper.disableAllConnectivity(env.engine); - - try { - await ProxyHelper.enableConnectivity(initialWriterId); - - // Sleep query activates monitoring connection after monitoring_wrapperQueryTimeout time is reached. - await auroraTestUtility.queryInstanceIdWithSleep(client); - - await ProxyHelper.enableConnectivity(readerInstance); - await ProxyHelper.disableConnectivity(env.engine, initialWriterId); - } catch (error) { - fail("The disable connectivity task was unexpectedly interrupted."); - } - // Failure occurs on connection invocation. - await expect(async () => { - await auroraTestUtility.queryInstanceId(client); - }).rejects.toThrow(FailoverSuccessError); - - const currentConnectionId = await auroraTestUtility.queryInstanceId(client); - expect(await auroraTestUtility.isDbInstanceWriter(currentConnectionId)).toBe(false); - expect(currentConnectionId).not.toBe(initialWriterId); - }, - 1320000 - ); - - itIf( - "fails from writer to new writer on connection invocation", - async () => { - const config = await initDefaultConfig(env.databaseInfo.writerInstanceEndpoint, env.databaseInfo.instanceEndpointPort, false); - client = initClientFunc(config); - - await client.connect(); - - const initialWriterId = await auroraTestUtility.queryInstanceId(client); - expect(await auroraTestUtility.isDbInstanceWriter(initialWriterId)).toStrictEqual(true); - - // Crash instance 1 and nominate a new writer - await auroraTestUtility.failoverClusterAndWaitUntilWriterChanged(); - - await expect(async () => { - await auroraTestUtility.queryInstanceId(client); - }).rejects.toThrow(FailoverSuccessError); - - // Assert that we are connected to the new writer after failover happens - const currentConnectionId = await auroraTestUtility.queryInstanceId(client); - expect(await auroraTestUtility.isDbInstanceWriter(currentConnectionId)).toBe(true); - expect(currentConnectionId).not.toBe(initialWriterId); - }, - 1320000 - ); - - itIf( - "writer fails within transaction", - async () => { - const config = await initDefaultConfig(env.databaseInfo.writerInstanceEndpoint, env.databaseInfo.instanceEndpointPort, false); - client = initClientFunc(config); - - await client.connect(); - const initialWriterId = await auroraTestUtility.queryInstanceId(client); - expect(await auroraTestUtility.isDbInstanceWriter(initialWriterId)).toStrictEqual(true); - - await DriverHelper.executeQuery(env.engine, client, "DROP TABLE IF EXISTS test3_3"); - await DriverHelper.executeQuery(env.engine, client, "CREATE TABLE test3_3 (id int not null primary key, test3_3_field varchar(255) not null)"); - - await DriverHelper.executeQuery(env.engine, client, "START TRANSACTION"); // start transaction - await DriverHelper.executeQuery(env.engine, client, "INSERT INTO test3_3 VALUES (1, 'test field string 1')"); - - // Crash instance 1 and nominate a new writer - await auroraTestUtility.failoverClusterAndWaitUntilWriterChanged(); - - await expect(async () => { - await DriverHelper.executeQuery(env.engine, client, "INSERT INTO test3_3 VALUES (2, 'test field string 2')"); - }).rejects.toThrow(TransactionResolutionUnknownError); - - // Attempt to query the instance id. - const currentConnectionId = await auroraTestUtility.queryInstanceId(client); - - // Assert that we are connected to the new writer after failover happens. - expect(await auroraTestUtility.isDbInstanceWriter(currentConnectionId)).toBe(true); - - const nextClusterWriterId = await auroraTestUtility.getClusterWriterInstanceId(); - expect(currentConnectionId).toBe(nextClusterWriterId); - expect(initialWriterId).not.toBe(nextClusterWriterId); - - // Assert that NO row has been inserted to the table. - const result = await DriverHelper.executeQuery(env.engine, client, "SELECT count(*) from test3_3"); - if (env.engine === DatabaseEngine.PG) { - expect((result as QueryResult).rows[0]["count"]).toBe("0"); - } else if (env.engine === DatabaseEngine.MYSQL) { - expect(JSON.parse(JSON.stringify(result))[0][0]["count(*)"]).toBe(0); - } - - await DriverHelper.executeQuery(env.engine, client, "DROP TABLE IF EXISTS test3_3"); - }, - 2000000 - ); - - itIf( - "fails from writer and transfers session state", - async () => { - const config = await initDefaultConfig(env.databaseInfo.writerInstanceEndpoint, env.databaseInfo.instanceEndpointPort, false); - client = initClientFunc(config); - - await client.connect(); - const initialWriterId = await auroraTestUtility.queryInstanceId(client); - expect(await auroraTestUtility.isDbInstanceWriter(initialWriterId)).toBe(true); - - await client.setReadOnly(true); - await client.setTransactionIsolation(TransactionIsolationLevel.TRANSACTION_SERIALIZABLE); - - if (driver === DatabaseEngine.PG) { - await client.setSchema(env.databaseInfo.defaultDbName); - } else if (driver === DatabaseEngine.MYSQL) { - await client.setAutoCommit(false); - await client.setCatalog(env.databaseInfo.defaultDbName); - } - - // Failover cluster and nominate a new writer - await auroraTestUtility.failoverClusterAndWaitUntilWriterChanged(); - - await expect(async () => { - await auroraTestUtility.queryInstanceId(client); - }).rejects.toThrow(FailoverSuccessError); - - // Assert that we are connected to the new writer after failover happens - const currentConnectionId = await auroraTestUtility.queryInstanceId(client); - expect(await auroraTestUtility.isDbInstanceWriter(currentConnectionId)).toBe(true); - expect(currentConnectionId).not.toBe(initialWriterId); - expect(client.isReadOnly()).toBe(true); - expect(client.getTransactionIsolation()).toBe(TransactionIsolationLevel.TRANSACTION_SERIALIZABLE); - if (driver === DatabaseEngine.PG) { - expect(client.getSchema()).toBe(env.databaseInfo.defaultDbName); - } else if (driver === DatabaseEngine.MYSQL) { - expect(client.getAutoCommit()).toBe(false); - expect(client.getCatalog()).toBe(env.databaseInfo.defaultDbName); - } - }, - 1320000 - ); - - itIfTwoInstance( - "fails from reader to writer", - async () => { - // Connect to writer instance - const writerConfig = await initDefaultConfig(env.proxyDatabaseInfo.writerInstanceEndpoint, env.proxyDatabaseInfo.instanceEndpointPort, true); - client = initClientFunc(writerConfig); - await client.connect(); - const initialWriterId = await auroraTestUtility.queryInstanceId(client); - expect(await auroraTestUtility.isDbInstanceWriter(initialWriterId)).toStrictEqual(true); - - // Get a reader instance - let readerInstanceHost; - for (const host of env.proxyDatabaseInfo.instances) { - if (host.instanceId && host.instanceId !== initialWriterId) { - readerInstanceHost = host.host; - } - } - if (!readerInstanceHost) { - throw new Error("Could not find a reader instance"); - } - const readerConfig = await initDefaultConfig(readerInstanceHost, env.proxyDatabaseInfo.instanceEndpointPort, true); - - secondaryClient = initClientFunc(readerConfig); - await secondaryClient.connect(); - - // Crash the reader instance - const rdsUtils = new RdsUtils(); - const readerInstanceId = rdsUtils.getRdsInstanceId(readerInstanceHost); - if (readerInstanceId) { - await ProxyHelper.disableConnectivity(env.engine, readerInstanceId); - - await expect(async () => { - await auroraTestUtility.queryInstanceId(secondaryClient); - }).rejects.toThrow(FailoverSuccessError); - - await ProxyHelper.enableConnectivity(readerInstanceId); - - // Assert that we are currently connected to the writer instance - const currentConnectionId = await auroraTestUtility.queryInstanceId(secondaryClient); - expect(await auroraTestUtility.isDbInstanceWriter(currentConnectionId)).toBe(true); - expect(currentConnectionId).toBe(initialWriterId); - } - }, - 1320000 - ); -}); diff --git a/tests/integration/container/tests/aurora_failover2.test.ts b/tests/integration/container/tests/aurora_failover2.test.ts deleted file mode 100644 index 62e3c740c..000000000 --- a/tests/integration/container/tests/aurora_failover2.test.ts +++ /dev/null @@ -1,264 +0,0 @@ -/* - Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - - Licensed under the Apache License, Version 2.0 (the "License"). - You may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ - -import { TestEnvironment } from "./utils/test_environment"; -import { DriverHelper } from "./utils/driver_helper"; -import { AuroraTestUtility } from "./utils/aurora_test_utility"; -import { - FailoverSuccessError, - PluginManager, - TransactionIsolationLevel, - TransactionResolutionUnknownError -} from "../../../../index"; -import { DatabaseEngine } from "./utils/database_engine"; -import { QueryResult } from "pg"; -import { ProxyHelper } from "./utils/proxy_helper"; -import { logger } from "../../../../common/logutils"; -import { features, instanceCount } from "./config"; -import { TestEnvironmentFeatures } from "./utils/test_environment_features"; -import { RdsUtils } from "../../../../common/lib/utils/rds_utils"; - -const itIf = - features.includes(TestEnvironmentFeatures.FAILOVER_SUPPORTED) && - !features.includes(TestEnvironmentFeatures.PERFORMANCE) && - !features.includes(TestEnvironmentFeatures.RUN_AUTOSCALING_TESTS_ONLY) && - instanceCount >= 2 - ? it - : it.skip; -const itIfTwoInstance = instanceCount == 2 ? itIf : it.skip; - -let env: TestEnvironment; -let driver; -let client: any; -let secondaryClient: any; -let initClientFunc: (props: any) => any; - -let auroraTestUtility: AuroraTestUtility; - -async function initDefaultConfig(host: string, port: number, connectToProxy: boolean): Promise { - let config: any = { - user: env.databaseInfo.username, - host: host, - database: env.databaseInfo.defaultDbName, - password: env.databaseInfo.password, - port: port, - plugins: "failover2", - failoverTimeoutMs: 250000, - enableTelemetry: true, - telemetryTracesBackend: "OTLP", - telemetryMetricsBackend: "OTLP" - }; - if (connectToProxy) { - config["clusterInstanceHostPattern"] = "?." + env.proxyDatabaseInfo.instanceEndpointSuffix; - } - config = DriverHelper.addDriverSpecificConfiguration(config, env.engine); - return config; -} - -describe("aurora failover2", () => { - beforeEach(async () => { - logger.info(`Test started: ${expect.getState().currentTestName}`); - env = await TestEnvironment.getCurrent(); - - auroraTestUtility = new AuroraTestUtility(env.region); - driver = DriverHelper.getDriverForDatabaseEngine(env.engine); - initClientFunc = DriverHelper.getClient(driver); - await ProxyHelper.enableAllConnectivity(); - await TestEnvironment.verifyClusterStatus(); - - client = null; - secondaryClient = null; - }, 1320000); - - afterEach(async () => { - if (client !== null) { - try { - await client.end(); - } catch (error) { - // pass - } - } - - if (secondaryClient !== null) { - try { - await secondaryClient.end(); - } catch (error) { - // pass - } - } - await PluginManager.releaseResources(); - logger.info(`Test finished: ${expect.getState().currentTestName}`); - }, 1320000); - - itIf( - "fails from writer to new writer on connection invocation", - async () => { - const config = await initDefaultConfig(env.databaseInfo.writerInstanceEndpoint, env.databaseInfo.instanceEndpointPort, false); - client = initClientFunc(config); - - await client.connect(); - - const initialWriterId = await auroraTestUtility.queryInstanceId(client); - expect(await auroraTestUtility.isDbInstanceWriter(initialWriterId)).toStrictEqual(true); - - // Crash instance 1 and nominate a new writer. - await auroraTestUtility.failoverClusterAndWaitUntilWriterChanged(); - - await expect(async () => { - await auroraTestUtility.queryInstanceId(client); - }).rejects.toThrow(FailoverSuccessError); - - // Assert that we are connected to the new writer after failover happens. - const currentConnectionId = await auroraTestUtility.queryInstanceId(client); - expect(await auroraTestUtility.isDbInstanceWriter(currentConnectionId)).toBe(true); - expect(currentConnectionId).not.toBe(initialWriterId); - }, - 1320000 - ); - - itIf( - "writer fails within transaction", - async () => { - const config = await initDefaultConfig(env.databaseInfo.writerInstanceEndpoint, env.databaseInfo.instanceEndpointPort, false); - client = initClientFunc(config); - - await client.connect(); - const initialWriterId = await auroraTestUtility.queryInstanceId(client); - expect(await auroraTestUtility.isDbInstanceWriter(initialWriterId)).toStrictEqual(true); - - await DriverHelper.executeQuery(env.engine, client, "DROP TABLE IF EXISTS test3_3"); - await DriverHelper.executeQuery(env.engine, client, "CREATE TABLE test3_3 (id int not null primary key, test3_3_field varchar(255) not null)"); - - await DriverHelper.executeQuery(env.engine, client, "START TRANSACTION"); // start transaction - await DriverHelper.executeQuery(env.engine, client, "INSERT INTO test3_3 VALUES (1, 'test field string 1')"); - - // Crash instance 1 and nominate a new writer. - await auroraTestUtility.failoverClusterAndWaitUntilWriterChanged(); - - await expect(async () => { - await DriverHelper.executeQuery(env.engine, client, "INSERT INTO test3_3 VALUES (2, 'test field string 2')"); - }).rejects.toThrow(TransactionResolutionUnknownError); - - const currentConnectionId = await auroraTestUtility.queryInstanceId(client); - // Assert that we are connected to the new writer after failover happens. - expect(await auroraTestUtility.isDbInstanceWriter(currentConnectionId)).toBe(true); - - const nextClusterWriterId = await auroraTestUtility.getClusterWriterInstanceId(); - expect(currentConnectionId).toBe(nextClusterWriterId); - expect(initialWriterId).not.toBe(nextClusterWriterId); - - // Assert that NO row has been inserted to the table. - const result = await DriverHelper.executeQuery(env.engine, client, "SELECT count(*) from test3_3"); - if (env.engine === DatabaseEngine.PG) { - expect((result as QueryResult).rows[0]["count"]).toBe("0"); - } else if (env.engine === DatabaseEngine.MYSQL) { - expect(JSON.parse(JSON.stringify(result))[0][0]["count(*)"]).toBe(0); - } - - await DriverHelper.executeQuery(env.engine, client, "DROP TABLE IF EXISTS test3_3"); - }, - 2000000 - ); - - itIf( - "fails from writer and transfers session state", - async () => { - const config = await initDefaultConfig(env.databaseInfo.writerInstanceEndpoint, env.databaseInfo.instanceEndpointPort, false); - client = initClientFunc(config); - - await client.connect(); - const initialWriterId = await auroraTestUtility.queryInstanceId(client); - expect(await auroraTestUtility.isDbInstanceWriter(initialWriterId)).toBe(true); - - await client.setReadOnly(true); - await client.setTransactionIsolation(TransactionIsolationLevel.TRANSACTION_SERIALIZABLE); - - if (driver === DatabaseEngine.PG) { - await client.setSchema(env.databaseInfo.defaultDbName); - } else if (driver === DatabaseEngine.MYSQL) { - await client.setAutoCommit(false); - await client.setCatalog(env.databaseInfo.defaultDbName); - } - - // Failover cluster and nominate a new writer. - await auroraTestUtility.failoverClusterAndWaitUntilWriterChanged(); - - await expect(async () => { - await auroraTestUtility.queryInstanceId(client); - }).rejects.toThrow(FailoverSuccessError); - - // Assert that we are connected to the new writer after failover happens. - const currentConnectionId = await auroraTestUtility.queryInstanceId(client); - expect(await auroraTestUtility.isDbInstanceWriter(currentConnectionId)).toBe(true); - expect(currentConnectionId).not.toBe(initialWriterId); - expect(client.isReadOnly()).toBe(true); - expect(client.getTransactionIsolation()).toBe(TransactionIsolationLevel.TRANSACTION_SERIALIZABLE); - if (driver === DatabaseEngine.PG) { - expect(client.getSchema()).toBe(env.databaseInfo.defaultDbName); - } else if (driver === DatabaseEngine.MYSQL) { - expect(client.getAutoCommit()).toBe(false); - expect(client.getCatalog()).toBe(env.databaseInfo.defaultDbName); - } - }, - 1320000 - ); - - itIfTwoInstance( - "fails from reader to writer", - async () => { - // Connect to writer instance. - const writerConfig = await initDefaultConfig(env.proxyDatabaseInfo.writerInstanceEndpoint, env.proxyDatabaseInfo.instanceEndpointPort, true); - client = initClientFunc(writerConfig); - await client.connect(); - const initialWriterId = await auroraTestUtility.queryInstanceId(client); - expect(await auroraTestUtility.isDbInstanceWriter(initialWriterId)).toStrictEqual(true); - - // Get a reader instance. - let readerInstanceHost; - for (const host of env.proxyDatabaseInfo.instances) { - if (host.instanceId && host.instanceId !== initialWriterId) { - readerInstanceHost = host.host; - } - } - if (!readerInstanceHost) { - throw new Error("Could not find a reader instance"); - } - const readerConfig = await initDefaultConfig(readerInstanceHost, env.proxyDatabaseInfo.instanceEndpointPort, true); - - secondaryClient = initClientFunc(readerConfig); - await secondaryClient.connect(); - - // Crash the reader instance. - const rdsUtils = new RdsUtils(); - const readerInstanceId = rdsUtils.getRdsInstanceId(readerInstanceHost); - if (readerInstanceId) { - await ProxyHelper.disableConnectivity(env.engine, readerInstanceId); - - await expect(async () => { - await auroraTestUtility.queryInstanceId(secondaryClient); - }).rejects.toThrow(FailoverSuccessError); - - await ProxyHelper.enableConnectivity(readerInstanceId); - - // Assert that we are currently connected to the writer instance. - const currentConnectionId = await auroraTestUtility.queryInstanceId(secondaryClient); - expect(await auroraTestUtility.isDbInstanceWriter(currentConnectionId)).toBe(true); - expect(currentConnectionId).toBe(initialWriterId); - } - }, - 1320000 - ); -}); diff --git a/tests/integration/container/tests/autoscaling.test.ts b/tests/integration/container/tests/autoscaling.test.ts index f29223483..cd0d166dc 100644 --- a/tests/integration/container/tests/autoscaling.test.ts +++ b/tests/integration/container/tests/autoscaling.test.ts @@ -20,12 +20,7 @@ import { AuroraTestUtility } from "./utils/aurora_test_utility"; import { logger } from "../../../../common/logutils"; import { TestEnvironmentFeatures } from "./utils/test_environment_features"; import { features, instanceCount } from "./config"; -import { - AwsPoolConfig, - FailoverSuccessError, - InternalPooledConnectionProvider, - PluginManager, -} from "../../../../index"; +import { AwsPoolConfig, FailoverSuccessError, InternalPooledConnectionProvider, PluginManager } from "../../../../index"; import { TestInstanceInfo } from "./utils/test_instance_info"; import { sleep } from "../../../../common/lib/utils/utils"; diff --git a/tests/integration/container/tests/connect_execute_time_plugin.test.ts b/tests/integration/container/tests/connect_execute_time_plugin.test.ts index 3ff336add..b8091d264 100644 --- a/tests/integration/container/tests/connect_execute_time_plugin.test.ts +++ b/tests/integration/container/tests/connect_execute_time_plugin.test.ts @@ -70,8 +70,7 @@ describe("aurora connect and execute time plugin", () => { await TestEnvironment.verifyAllInstancesHasRightState("available"); await TestEnvironment.verifyAllInstancesUp(); - RdsHostListProvider.clearAll(); - PluginServiceImpl.clearHostAvailabilityCache(); + await PluginManager.releaseResources(); }, 1320000); afterEach(async () => { diff --git a/tests/integration/container/tests/failover/aurora_failover.test.ts b/tests/integration/container/tests/failover/aurora_failover.test.ts new file mode 100644 index 000000000..7d91ee451 --- /dev/null +++ b/tests/integration/container/tests/failover/aurora_failover.test.ts @@ -0,0 +1,125 @@ +/* + Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"). + You may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +import { TestEnvironment } from "../utils/test_environment"; +import { DriverHelper } from "../utils/driver_helper"; +import { AuroraTestUtility } from "../utils/aurora_test_utility"; +import { FailoverSuccessError, PluginManager } from "../../../../../index"; +import { ProxyHelper } from "../utils/proxy_helper"; +import { logger } from "../../../../../common/logutils"; +import { features, instanceCount } from "../config"; +import { TestEnvironmentFeatures } from "../utils/test_environment_features"; +import { createFailoverTests } from "./failover_tests"; + +const itIfThreeInstanceAuroraCluster = instanceCount == 3 && !features.includes(TestEnvironmentFeatures.RDS_MULTI_AZ_SUPPORTED) ? it : it.skip; + +describe("aurora failover", createFailoverTests({ plugins: "failover" })); + +describe("aurora failover - efm specific", () => { + let env: TestEnvironment; + let client: any; + let initClientFunc: (props: any) => any; + let auroraTestUtility: AuroraTestUtility; + + async function initConfigWithEFM2(host: string, port: number, connectToProxy: boolean): Promise { + let config: any = { + user: env.databaseInfo.username, + host: host, + database: env.databaseInfo.defaultDbName, + password: env.databaseInfo.password, + port: port, + plugins: "failover,efm2", + failoverTimeoutMs: 20000, + failureDetectionCount: 2, + failureDetectionInterval: 1000, + failureDetectionTime: 2000, + connectTimeout: 10000, + wrapperQueryTimeout: 20000, + monitoring_wrapperQueryTimeout: 3000, + monitoring_wrapperConnectTimeout: 3000, + enableTelemetry: true, + telemetryTracesBackend: "OTLP", + telemetryMetricsBackend: "OTLP" + }; + if (connectToProxy) { + config["clusterInstanceHostPattern"] = "?." + env.proxyDatabaseInfo.instanceEndpointSuffix; + } + config = DriverHelper.addDriverSpecificConfiguration(config, env.engine); + return config; + } + + beforeEach(async () => { + logger.info(`Test started: ${expect.getState().currentTestName}`); + env = await TestEnvironment.getCurrent(); + auroraTestUtility = new AuroraTestUtility(env.region); + const driver = DriverHelper.getDriverForDatabaseEngine(env.engine); + initClientFunc = DriverHelper.getClient(driver); + await ProxyHelper.enableAllConnectivity(); + await TestEnvironment.verifyClusterStatus(); + client = null; + }, 1320000); + + afterEach(async () => { + if (client !== null) { + try { + await client.end(); + } catch (error) { + // pass + } + } + await PluginManager.releaseResources(); + logger.info(`Test finished: ${expect.getState().currentTestName}`); + }, 1320000); + + itIfThreeInstanceAuroraCluster( + "writer failover efm", + async () => { + // Connect to writer instance + const writerConfig = await initConfigWithEFM2(env.proxyDatabaseInfo.writerInstanceEndpoint, env.proxyDatabaseInfo.instanceEndpointPort, true); + writerConfig["failoverMode"] = "reader-or-writer"; + + client = initClientFunc(writerConfig); + await client.connect(); + + const initialWriterId = await auroraTestUtility.queryInstanceId(client); + expect(await auroraTestUtility.isDbInstanceWriter(initialWriterId)).toStrictEqual(true); + const instances = env.databaseInfo.instances; + const readerInstance = instances[1].instanceId; + await ProxyHelper.disableAllConnectivity(env.engine); + + try { + await ProxyHelper.enableConnectivity(initialWriterId); + + // Sleep query activates monitoring connection after monitoring_wrapperQueryTimeout time is reached + await auroraTestUtility.queryInstanceIdWithSleep(client); + + await ProxyHelper.enableConnectivity(readerInstance); + await ProxyHelper.disableConnectivity(env.engine, initialWriterId); + } catch (error) { + fail("The disable connectivity task was unexpectedly interrupted."); + } + // Failure occurs on connection invocation + await expect(async () => { + await auroraTestUtility.queryInstanceId(client); + }).rejects.toThrow(FailoverSuccessError); + + const currentConnectionId = await auroraTestUtility.queryInstanceId(client); + expect(await auroraTestUtility.isDbInstanceWriter(currentConnectionId)).toBe(false); + expect(currentConnectionId).not.toBe(initialWriterId); + }, + 1320000 + ); +}); diff --git a/tests/integration/container/tests/failover/aurora_failover2.test.ts b/tests/integration/container/tests/failover/aurora_failover2.test.ts new file mode 100644 index 000000000..aab1564e0 --- /dev/null +++ b/tests/integration/container/tests/failover/aurora_failover2.test.ts @@ -0,0 +1,19 @@ +/* + Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"). + You may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +import { createFailoverTests } from "./failover_tests"; + +describe("aurora failover2", createFailoverTests({ plugins: "failover2" })); diff --git a/tests/integration/container/tests/failover/failover_tests.ts b/tests/integration/container/tests/failover/failover_tests.ts new file mode 100644 index 000000000..ce430464a --- /dev/null +++ b/tests/integration/container/tests/failover/failover_tests.ts @@ -0,0 +1,271 @@ +/* + Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"). + You may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +import { TestEnvironment } from "../utils/test_environment"; +import { DriverHelper } from "../utils/driver_helper"; +import { AuroraTestUtility } from "../utils/aurora_test_utility"; +import { FailoverSuccessError, PluginManager, TransactionIsolationLevel, TransactionResolutionUnknownError } from "../../../../../index"; +import { DatabaseEngine } from "../utils/database_engine"; +import { QueryResult } from "pg"; +import { ProxyHelper } from "../utils/proxy_helper"; +import { logger } from "../../../../../common/logutils"; +import { features, instanceCount } from "../config"; +import { TestEnvironmentFeatures } from "../utils/test_environment_features"; +import { RdsUtils } from "../../../../../common/lib/utils/rds_utils"; + +export interface FailoverTestOptions { + plugins: string; + getExtraConfig?: () => Record; +} + +export function createFailoverTests(options: FailoverTestOptions) { + const itIf = + features.includes(TestEnvironmentFeatures.FAILOVER_SUPPORTED) && + !features.includes(TestEnvironmentFeatures.PERFORMANCE) && + !features.includes(TestEnvironmentFeatures.RUN_AUTOSCALING_TESTS_ONLY) && + instanceCount >= 2 + ? it + : it.skip; + const itIfTwoInstance = instanceCount == 2 ? itIf : it.skip; + + return () => { + let env: TestEnvironment; + let driver: any; + let client: any; + let secondaryClient: any; + let initClientFunc: (props: any) => any; + let auroraTestUtility: AuroraTestUtility; + + async function initDefaultConfig(host: string, port: number, connectToProxy: boolean): Promise { + let config: any = { + user: env.databaseInfo.username, + host: host, + database: env.databaseInfo.defaultDbName, + password: env.databaseInfo.password, + port: port, + plugins: options.plugins, + failoverTimeoutMs: 250000, + enableTelemetry: true, + telemetryTracesBackend: "OTLP", + telemetryMetricsBackend: "OTLP", + ...options.getExtraConfig?.() + }; + if (connectToProxy) { + config["clusterInstanceHostPattern"] = "?." + env.proxyDatabaseInfo.instanceEndpointSuffix; + } + config = DriverHelper.addDriverSpecificConfiguration(config, env.engine); + return config; + } + + beforeEach(async () => { + logger.info(`Test started: ${expect.getState().currentTestName}`); + env = await TestEnvironment.getCurrent(); + auroraTestUtility = new AuroraTestUtility(env.region); + driver = DriverHelper.getDriverForDatabaseEngine(env.engine); + initClientFunc = DriverHelper.getClient(driver); + await ProxyHelper.enableAllConnectivity(); + await TestEnvironment.verifyClusterStatus(); + client = null; + secondaryClient = null; + }, 1320000); + + afterEach(async () => { + await ProxyHelper.enableAllConnectivity(); + + if (client !== null) { + try { + await client.end(); + } catch (error) { + // pass + } + } + if (secondaryClient !== null) { + try { + await secondaryClient.end(); + } catch (error) { + // pass + } + } + await PluginManager.releaseResources(); + logger.info(`Test finished: ${expect.getState().currentTestName}`); + }, 1320000); + + itIf( + "fails from writer to new writer on connection invocation", + async () => { + const config = await initDefaultConfig(env.databaseInfo.writerInstanceEndpoint, env.databaseInfo.instanceEndpointPort, false); + client = initClientFunc(config); + + await client.connect(); + + const initialWriterId = await auroraTestUtility.queryInstanceId(client); + expect(await auroraTestUtility.isDbInstanceWriter(initialWriterId)).toStrictEqual(true); + + // Crash instance 1 and nominate a new writer + await auroraTestUtility.failoverClusterAndWaitUntilWriterChanged(); + + await expect(async () => { + await auroraTestUtility.queryInstanceId(client); + }).rejects.toThrow(FailoverSuccessError); + + // Assert that we are connected to the new writer after failover happens + const currentConnectionId = await auroraTestUtility.queryInstanceId(client); + expect(await auroraTestUtility.isDbInstanceWriter(currentConnectionId)).toBe(true); + expect(currentConnectionId).not.toBe(initialWriterId); + }, + 1320000 + ); + + itIf( + "writer fails within transaction", + async () => { + const config = await initDefaultConfig(env.databaseInfo.writerInstanceEndpoint, env.databaseInfo.instanceEndpointPort, false); + client = initClientFunc(config); + + await client.connect(); + const initialWriterId = await auroraTestUtility.queryInstanceId(client); + expect(await auroraTestUtility.isDbInstanceWriter(initialWriterId)).toStrictEqual(true); + + await DriverHelper.executeQuery(env.engine, client, "DROP TABLE IF EXISTS test3_3"); + await DriverHelper.executeQuery( + env.engine, + client, + "CREATE TABLE test3_3 (id int not null primary key, test3_3_field varchar(255) not null)" + ); + + await DriverHelper.executeQuery(env.engine, client, "START TRANSACTION"); + await DriverHelper.executeQuery(env.engine, client, "INSERT INTO test3_3 VALUES (1, 'test field string 1')"); + + // Crash instance 1 and nominate a new writer + await auroraTestUtility.failoverClusterAndWaitUntilWriterChanged(); + + await expect(async () => { + await DriverHelper.executeQuery(env.engine, client, "INSERT INTO test3_3 VALUES (2, 'test field string 2')"); + }).rejects.toThrow(TransactionResolutionUnknownError); + + // Attempt to query the instance id + const currentConnectionId = await auroraTestUtility.queryInstanceId(client); + + // Assert that we are connected to the new writer after failover happens + expect(await auroraTestUtility.isDbInstanceWriter(currentConnectionId)).toBe(true); + + const nextClusterWriterId = await auroraTestUtility.getClusterWriterInstanceId(); + expect(currentConnectionId).toBe(nextClusterWriterId); + expect(initialWriterId).not.toBe(nextClusterWriterId); + + // Assert that NO row has been inserted to the table + const result = await DriverHelper.executeQuery(env.engine, client, "SELECT count(*) from test3_3"); + if (env.engine === DatabaseEngine.PG) { + expect((result as QueryResult).rows[0]["count"]).toBe("0"); + } else if (env.engine === DatabaseEngine.MYSQL) { + expect(JSON.parse(JSON.stringify(result))[0][0]["count(*)"]).toBe(0); + } + + await DriverHelper.executeQuery(env.engine, client, "DROP TABLE IF EXISTS test3_3"); + }, + 2000000 + ); + + itIf( + "fails from writer and transfers session state", + async () => { + const config = await initDefaultConfig(env.databaseInfo.writerInstanceEndpoint, env.databaseInfo.instanceEndpointPort, false); + client = initClientFunc(config); + + await client.connect(); + const initialWriterId = await auroraTestUtility.queryInstanceId(client); + expect(await auroraTestUtility.isDbInstanceWriter(initialWriterId)).toBe(true); + + await client.setReadOnly(true); + await client.setTransactionIsolation(TransactionIsolationLevel.TRANSACTION_SERIALIZABLE); + + if (driver === DatabaseEngine.PG) { + await client.setSchema(env.databaseInfo.defaultDbName); + } else if (driver === DatabaseEngine.MYSQL) { + await client.setAutoCommit(false); + await client.setCatalog(env.databaseInfo.defaultDbName); + } + + // Failover cluster and nominate a new writer + await auroraTestUtility.failoverClusterAndWaitUntilWriterChanged(); + + await expect(async () => { + await auroraTestUtility.queryInstanceId(client); + }).rejects.toThrow(FailoverSuccessError); + + // Assert that we are connected to the new writer after failover happens + const currentConnectionId = await auroraTestUtility.queryInstanceId(client); + expect(await auroraTestUtility.isDbInstanceWriter(currentConnectionId)).toBe(true); + expect(currentConnectionId).not.toBe(initialWriterId); + expect(client.isReadOnly()).toBe(true); + expect(client.getTransactionIsolation()).toBe(TransactionIsolationLevel.TRANSACTION_SERIALIZABLE); + if (driver === DatabaseEngine.PG) { + expect(client.getSchema()).toBe(env.databaseInfo.defaultDbName); + } else if (driver === DatabaseEngine.MYSQL) { + expect(client.getAutoCommit()).toBe(false); + expect(client.getCatalog()).toBe(env.databaseInfo.defaultDbName); + } + }, + 1320000 + ); + + itIfTwoInstance( + "fails from reader to writer", + async () => { + // Connect to writer instance + const writerConfig = await initDefaultConfig(env.proxyDatabaseInfo.writerInstanceEndpoint, env.proxyDatabaseInfo.instanceEndpointPort, true); + client = initClientFunc(writerConfig); + await client.connect(); + const initialWriterId = await auroraTestUtility.queryInstanceId(client); + expect(await auroraTestUtility.isDbInstanceWriter(initialWriterId)).toStrictEqual(true); + + // Get a reader instance + let readerInstanceHost; + for (const host of env.proxyDatabaseInfo.instances) { + if (host.instanceId && host.instanceId !== initialWriterId) { + readerInstanceHost = host.host; + } + } + if (!readerInstanceHost) { + throw new Error("Could not find a reader instance"); + } + const readerConfig = await initDefaultConfig(readerInstanceHost, env.proxyDatabaseInfo.instanceEndpointPort, true); + + secondaryClient = initClientFunc(readerConfig); + await secondaryClient.connect(); + + // Crash the reader instance + const rdsUtils = new RdsUtils(); + const readerInstanceId = rdsUtils.getRdsInstanceId(readerInstanceHost); + if (readerInstanceId) { + await ProxyHelper.disableConnectivity(env.engine, readerInstanceId); + + await expect(async () => { + await auroraTestUtility.queryInstanceId(secondaryClient); + }).rejects.toThrow(FailoverSuccessError); + + await ProxyHelper.enableConnectivity(readerInstanceId); + + // Assert that we are currently connected to the writer instance + const currentConnectionId = await auroraTestUtility.queryInstanceId(secondaryClient); + expect(await auroraTestUtility.isDbInstanceWriter(currentConnectionId)).toBe(true); + expect(currentConnectionId).toBe(initialWriterId); + } + }, + 1320000 + ); + }; +} diff --git a/tests/integration/container/tests/failover/gdb_failover.test.ts b/tests/integration/container/tests/failover/gdb_failover.test.ts new file mode 100644 index 000000000..261d98c0b --- /dev/null +++ b/tests/integration/container/tests/failover/gdb_failover.test.ts @@ -0,0 +1,184 @@ +/* + Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"). + You may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +import { TestEnvironment } from "../utils/test_environment"; +import { DriverHelper } from "../utils/driver_helper"; +import { AuroraTestUtility } from "../utils/aurora_test_utility"; +import { FailoverSuccessError, PluginManager } from "../../../../../index"; +import { ProxyHelper } from "../utils/proxy_helper"; +import { logger } from "../../../../../common/logutils"; +import { features, instanceCount } from "../config"; +import { TestEnvironmentFeatures } from "../utils/test_environment_features"; +import { createFailoverTests } from "./failover_tests"; + +const itIf = + features.includes(TestEnvironmentFeatures.FAILOVER_SUPPORTED) && + !features.includes(TestEnvironmentFeatures.PERFORMANCE) && + !features.includes(TestEnvironmentFeatures.RUN_AUTOSCALING_TESTS_ONLY) && + instanceCount >= 2 + ? it + : it.skip; +const itIfNetworkOutages = features.includes(TestEnvironmentFeatures.NETWORK_OUTAGES_ENABLED) && instanceCount >= 2 ? itIf : it.skip; + +let env: TestEnvironment; +let driver: any; +let client: any; +let initClientFunc: (props: any) => any; + +let auroraTestUtility: AuroraTestUtility; + +async function initDefaultConfig(host: string, port: number, connectToProxy: boolean): Promise { + let config: any = { + user: env.databaseInfo.username, + host: host, + database: env.databaseInfo.defaultDbName, + password: env.databaseInfo.password, + port: port, + plugins: "gdbFailover", + failoverTimeoutMs: 250000, + activeHomeFailoverMode: "strict-writer", + inactiveHomeFailoverMode: "strict-writer", + enableTelemetry: true, + telemetryTracesBackend: "OTLP", + telemetryMetricsBackend: "OTLP" + }; + if (connectToProxy) { + config["clusterInstanceHostPattern"] = "?." + env.proxyDatabaseInfo.instanceEndpointSuffix; + } + config = DriverHelper.addDriverSpecificConfiguration(config, env.engine); + return config; +} + +describe("gdb failover", () => { + // Inherit shared failover tests with GDB-specific configuration + // This mirrors the Java pattern where GdbFailoverTest extends FailoverTest + describe( + "failover tests", + createFailoverTests({ + plugins: "gdbFailover", + getExtraConfig: () => ({ + // These settings mimic failover/failover2 plugin logic when connecting to non-GDB Aurora or RDS DB clusters. + activeHomeFailoverMode: "strict-writer", + inactiveHomeFailoverMode: "strict-writer" + }) + }) + ); + + // GDB-specific tests (overrides from Java GdbFailoverTest) + describe("gdb-specific tests", () => { + beforeEach(async () => { + logger.info(`Test started: ${expect.getState().currentTestName}`); + env = await TestEnvironment.getCurrent(); + + auroraTestUtility = new AuroraTestUtility(env.region); + driver = DriverHelper.getDriverForDatabaseEngine(env.engine); + initClientFunc = DriverHelper.getClient(driver); + await ProxyHelper.enableAllConnectivity(); + await TestEnvironment.verifyClusterStatus(); + + client = null; + await PluginManager.releaseResources(); + }, 1320000); + + afterEach(async () => { + await ProxyHelper.enableAllConnectivity(); + + if (client !== null) { + try { + await client.end(); + } catch (error) { + // pass + } + } + await PluginManager.releaseResources(); + logger.info(`Test finished: ${expect.getState().currentTestName}`); + }, 1320000); + + itIfNetworkOutages( + "reader failover with home-reader-or-writer mode", + async () => { + const initialWriterId = env.proxyDatabaseInfo.writerInstanceId; + const initialWriterHost = env.proxyDatabaseInfo.writerInstanceEndpoint; + const initialWriterPort = env.proxyDatabaseInfo.instanceEndpointPort; + + const config = await initDefaultConfig(initialWriterHost, initialWriterPort, true); + config["activeHomeFailoverMode"] = "home-reader-or-writer"; + config["inactiveHomeFailoverMode"] = "home-reader-or-writer"; + + client = initClientFunc(config); + await client.connect(); + + await ProxyHelper.disableConnectivity(env.engine, initialWriterId!); + + await expect(async () => { + await auroraTestUtility.queryInstanceId(client); + }).rejects.toThrow(FailoverSuccessError); + }, + 1320000 + ); + + itIfNetworkOutages( + "reader failover with strict-home-reader mode", + async () => { + const initialWriterId = env.proxyDatabaseInfo.writerInstanceId; + const initialWriterHost = env.proxyDatabaseInfo.writerInstanceEndpoint; + const initialWriterPort = env.proxyDatabaseInfo.instanceEndpointPort; + + const config = await initDefaultConfig(initialWriterHost, initialWriterPort, true); + config["activeHomeFailoverMode"] = "strict-home-reader"; + config["inactiveHomeFailoverMode"] = "strict-home-reader"; + + client = initClientFunc(config); + await client.connect(); + + await ProxyHelper.disableConnectivity(env.engine, initialWriterId!); + + await expect(async () => { + await auroraTestUtility.queryInstanceId(client); + }).rejects.toThrow(FailoverSuccessError); + + const currentConnectionId = await auroraTestUtility.queryInstanceId(client); + expect(await auroraTestUtility.isDbInstanceWriter(currentConnectionId)).toBe(false); + }, + 1320000 + ); + + itIfNetworkOutages( + "writer reelected with home-reader-or-writer mode", + async () => { + const initialWriterId = env.proxyDatabaseInfo.writerInstanceId; + const initialWriterHost = env.proxyDatabaseInfo.writerInstanceEndpoint; + const initialWriterPort = env.proxyDatabaseInfo.instanceEndpointPort; + + const config = await initDefaultConfig(initialWriterHost, initialWriterPort, true); + config["activeHomeFailoverMode"] = "home-reader-or-writer"; + config["inactiveHomeFailoverMode"] = "home-reader-or-writer"; + + client = initClientFunc(config); + await client.connect(); + + // Failover usually changes the writer instance, but we want to test re-election of the same writer, so we will + // simulate this by temporarily disabling connectivity to the writer. + await auroraTestUtility.simulateTemporaryFailure(initialWriterId!); + + await expect(async () => { + await auroraTestUtility.queryInstanceId(client); + }).rejects.toThrow(FailoverSuccessError); + }, + 1320000 + ); + }); +}); diff --git a/tests/integration/container/tests/fastest_response_strategy.test.ts b/tests/integration/container/tests/fastest_response_strategy.test.ts index ff44b9222..9e7cdb98c 100644 --- a/tests/integration/container/tests/fastest_response_strategy.test.ts +++ b/tests/integration/container/tests/fastest_response_strategy.test.ts @@ -89,8 +89,7 @@ describe("aurora fastest response strategy", () => { await TestEnvironment.verifyAllInstancesHasRightState("available"); await TestEnvironment.verifyAllInstancesUp(); - RdsHostListProvider.clearAll(); - PluginServiceImpl.clearHostAvailabilityCache(); + await PluginManager.releaseResources(); }, 1320000); afterEach(async () => { diff --git a/tests/integration/container/tests/initial_connection_strategy.test.ts b/tests/integration/container/tests/initial_connection_strategy.test.ts index c540c72ab..98ef8551f 100644 --- a/tests/integration/container/tests/initial_connection_strategy.test.ts +++ b/tests/integration/container/tests/initial_connection_strategy.test.ts @@ -73,8 +73,7 @@ describe("aurora initial connection strategy", () => { await TestEnvironment.verifyAllInstancesHasRightState("available"); await TestEnvironment.verifyAllInstancesUp(); - RdsHostListProvider.clearAll(); - PluginServiceImpl.clearHostAvailabilityCache(); + await PluginManager.releaseResources(); numReaders = env.databaseInfo.instances.length - 1; }, 1320000); diff --git a/tests/integration/container/tests/mysql_pool.test.ts b/tests/integration/container/tests/mysql_pool.test.ts index aa3108508..576c0e55e 100644 --- a/tests/integration/container/tests/mysql_pool.test.ts +++ b/tests/integration/container/tests/mysql_pool.test.ts @@ -1,12 +1,12 @@ /* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - + Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance with the License. You may obtain a copy of the License at - + http://www.apache.org/licenses/LICENSE-2.0 - + Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/tests/integration/container/tests/parameterized_queries.test.ts b/tests/integration/container/tests/parameterized_queries.test.ts index 3af976426..a8ceac145 100644 --- a/tests/integration/container/tests/parameterized_queries.test.ts +++ b/tests/integration/container/tests/parameterized_queries.test.ts @@ -1,12 +1,12 @@ /* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - + Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance with the License. You may obtain a copy of the License at - + http://www.apache.org/licenses/LICENSE-2.0 - + Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/tests/integration/container/tests/read_write_splitting.test.ts b/tests/integration/container/tests/read_write_splitting.test.ts index bb838105d..f3af23785 100644 --- a/tests/integration/container/tests/read_write_splitting.test.ts +++ b/tests/integration/container/tests/read_write_splitting.test.ts @@ -34,8 +34,6 @@ import { ProxyHelper } from "./utils/proxy_helper"; import { logger } from "../../../../common/logutils"; import { TestEnvironmentFeatures } from "./utils/test_environment_features"; import { features, instanceCount } from "./config"; -import { RdsHostListProvider } from "../../../../common/lib/host_list_provider/rds_host_list_provider"; -import { PluginServiceImpl } from "../../../../common/lib/plugin_service"; const itIf = !features.includes(TestEnvironmentFeatures.PERFORMANCE) && @@ -64,6 +62,8 @@ async function initConfig(host: string, port: number, connectToProxy: boolean, p port: port, plugins: plugins, enableTelemetry: true, + wrapperQueryTimeout: 10000, + wrapperConnectTimeout: 3000, telemetryTracesBackend: "OTLP", telemetryMetricsBackend: "OTLP" }; @@ -107,8 +107,7 @@ describe("aurora read write splitting", () => { await TestEnvironment.verifyAllInstancesHasRightState("available"); await TestEnvironment.verifyAllInstancesUp(); - RdsHostListProvider.clearAll(); - PluginServiceImpl.clearHostAvailabilityCache(); + await PluginManager.releaseResources(); }, 1320000); afterEach(async () => { @@ -609,9 +608,9 @@ describe("aurora read write splitting", () => { connectionsSet.add(client); } } finally { - for (const connection of connectionsSet) { + connectionsSet.forEach(async (connection) => { await connection.end(); - } + }); } }, 1000000 @@ -663,9 +662,9 @@ describe("aurora read write splitting", () => { connectionsSet.add(client); } } finally { - for (const connection of connectionsSet) { + connectionsSet.forEach(async (connection) => { await connection.end(); - } + }); } }, 1000000 diff --git a/tests/integration/container/tests/read_write_splitting_performance.test.ts b/tests/integration/container/tests/read_write_splitting_performance.test.ts index 5cf6c95ce..8ce96bfb0 100644 --- a/tests/integration/container/tests/read_write_splitting_performance.test.ts +++ b/tests/integration/container/tests/read_write_splitting_performance.test.ts @@ -21,12 +21,7 @@ import { TestEnvironmentFeatures } from "./utils/test_environment_features"; import { features, instanceCount } from "./config"; import { PerfStat } from "./utils/perf_stat"; import { PerfTestUtility } from "./utils/perf_util"; -import { - ConnectTimePlugin, - ExecuteTimePlugin, - InternalPooledConnectionProvider, - PluginManager -} from "../../../../index"; +import { ConnectTimePlugin, ExecuteTimePlugin, InternalPooledConnectionProvider, PluginManager } from "../../../../index"; import { TestDriver } from "./utils/test_driver"; const itIf = diff --git a/tests/integration/container/tests/session_state.test.ts b/tests/integration/container/tests/session_state.test.ts index 5f06d420a..015e8de9a 100644 --- a/tests/integration/container/tests/session_state.test.ts +++ b/tests/integration/container/tests/session_state.test.ts @@ -119,7 +119,7 @@ describe("session state", () => { expect(autoCommit[0][0].autocommit).toEqual(1); expect(transactionIsolation[0][0].level).toEqual("REPEATABLE-READ"); - await client.getPluginService().setCurrentClient(newClient.targetClient); + await client.pluginService.setCurrentClient(newClient.targetClient); expect(client.targetClient).not.toEqual(targetClient); expect(client.targetClient).toEqual(newTargetClient); @@ -154,7 +154,7 @@ describe("session state", () => { expect(schema.rows[0]["search_path"]).not.toEqual("testSessionState"); expect(transactionIsolation.rows[0]["transaction_isolation"]).toEqual("read committed"); - await client.getPluginService().setCurrentClient(newClient.targetClient); + await client.pluginService.setCurrentClient(newClient.targetClient); expect(client.targetClient).not.toEqual(targetClient); expect(client.targetClient).toEqual(newTargetClient); diff --git a/tests/integration/container/tests/utils/aurora_test_utility.ts b/tests/integration/container/tests/utils/aurora_test_utility.ts index b0e0bad85..bdfb9c71e 100644 --- a/tests/integration/container/tests/utils/aurora_test_utility.ts +++ b/tests/integration/container/tests/utils/aurora_test_utility.ts @@ -42,6 +42,7 @@ import { TestInstanceInfo } from "./test_instance_info"; import { TestEnvironmentInfo } from "./test_environment_info"; import { DatabaseEngine } from "./database_engine"; import { DatabaseEngineDeployment } from "./database_engine_deployment"; +import { ProxyHelper } from "./proxy_helper"; const instanceClass: string = "db.r5.large"; @@ -492,4 +493,38 @@ export class AuroraTestUtility { logger.debug("switchoverBlueGreenDeployment request is sent."); } } + + async simulateTemporaryFailure(instanceName: string, delayMs: number = 0, failureDurationMs: number = 5000): Promise { + const env = await TestEnvironment.getCurrent(); + const deployment = env.deployment; + const clusterEndpoint = env.proxyDatabaseInfo.clusterEndpoint; + const clusterReadOnlyEndpoint = env.proxyDatabaseInfo.clusterReadOnlyEndpoint; + + (async () => { + try { + if (delayMs > 0) { + await sleep(delayMs); + } + + await ProxyHelper.disableConnectivity(env.engine, instanceName); + + if (deployment === DatabaseEngineDeployment.RDS_MULTI_AZ_CLUSTER) { + await ProxyHelper.disableConnectivity(env.engine, clusterEndpoint); + await ProxyHelper.disableConnectivity(env.engine, clusterReadOnlyEndpoint); + } + + await sleep(failureDurationMs); + + await ProxyHelper.enableConnectivity(instanceName); + if (deployment === DatabaseEngineDeployment.RDS_MULTI_AZ_CLUSTER) { + await ProxyHelper.enableConnectivity(clusterEndpoint); + await ProxyHelper.enableConnectivity(clusterReadOnlyEndpoint); + } + } catch (e: any) { + logger.error(`Error during simulateTemporaryFailure: ${e.message}`); + } + })(); + + await sleep(500); + } } diff --git a/tests/integration/container/tests/utils/test_environment.ts b/tests/integration/container/tests/utils/test_environment.ts index cc2d52e96..459229fa0 100644 --- a/tests/integration/container/tests/utils/test_environment.ts +++ b/tests/integration/container/tests/utils/test_environment.ts @@ -37,6 +37,7 @@ import { ATTR_SERVICE_NAME } from "@opentelemetry/semantic-conventions"; import { PeriodicExportingMetricReader } from "@opentelemetry/sdk-metrics"; import { OTLPMetricExporter } from "@opentelemetry/exporter-metrics-otlp-grpc"; import { logger } from "../../../../../common/logutils"; +import { RdsUtils } from "../../../../../common/lib/utils/rds_utils"; import pkgPg from "pg"; import { ConnectionOptions, createConnection } from "mysql2/promise"; import { readFileSync } from "fs"; @@ -238,6 +239,14 @@ export class TestEnvironment { await TestEnvironment.initProxies(env); } + // Helps to eliminate problem with proxied endpoints. + RdsUtils.setPrepareHostFunc((host: string) => { + if (host.endsWith(".proxied")) { + return host.substring(0, host.length - ".proxied".length); + } + return host; + }); + const contextManager = new AsyncHooksContextManager(); contextManager.enable(); context.setGlobalContextManager(contextManager); diff --git a/tests/plugin_benchmarks.ts b/tests/plugin_benchmarks.ts index d28f031a0..edcd56e76 100644 --- a/tests/plugin_benchmarks.ts +++ b/tests/plugin_benchmarks.ts @@ -25,8 +25,8 @@ import { SimpleHostAvailabilityStrategy } from "../common/lib/host_availability/ import { HostInfoBuilder } from "../common/lib"; import { PgClientWrapper } from "../common/lib/pg_client_wrapper"; import { NullTelemetryFactory } from "../common/lib/utils/telemetry/null_telemetry_factory"; -import { PluginServiceManagerContainer } from "../common/lib/plugin_service_manager_container"; import { ConnectionProviderManager } from "../common/lib/connection_provider_manager"; +import { FullServicesContainerImpl } from "../common/lib/utils/full_services_container"; const mockConnectionProvider = mock(); const mockPluginService = mock(PluginServiceImpl); @@ -47,8 +47,8 @@ when(mockPluginService.getCurrentClient()).thenReturn(mockClientWrapper.client); when(mockPluginService.getDriverDialect()).thenReturn(mockDialect); const connectionString = "my.domain.com"; -const pluginServiceManagerContainer = new PluginServiceManagerContainer(); -pluginServiceManagerContainer.pluginService = instance(mockPluginService); +const servicesContainer = mock(FullServicesContainerImpl); +when(servicesContainer.pluginService).thenReturn(instance(mockPluginService)); function getProps(plugins: string) { const props = new Map(); @@ -59,7 +59,7 @@ function getProps(plugins: string) { function getPluginManager(props: Map) { return new PluginManager( - pluginServiceManagerContainer, + servicesContainer, props, new ConnectionProviderManager(instance(mockConnectionProvider), null), new NullTelemetryFactory() diff --git a/tests/plugin_manager_benchmarks.ts b/tests/plugin_manager_benchmarks.ts index 1b84b5bcb..801a4f889 100644 --- a/tests/plugin_manager_benchmarks.ts +++ b/tests/plugin_manager_benchmarks.ts @@ -16,7 +16,6 @@ import { add, complete, configure, cycle, save, suite } from "benny"; import { ConnectionPlugin, ConnectionProvider, HostInfoBuilder, PluginManager } from "../common/lib"; -import { PluginServiceManagerContainer } from "../common/lib/plugin_service_manager_container"; import { instance, mock, when } from "ts-mockito"; import { SimpleHostAvailabilityStrategy } from "../common/lib/host_availability/simple_host_availability_strategy"; import { PluginServiceImpl } from "../common/lib/plugin_service"; @@ -32,6 +31,7 @@ import { ConnectionPluginFactory } from "../common/lib/plugin_factory"; import { DefaultPlugin } from "../common/lib/plugins/default_plugin"; import { AwsPGClient } from "../pg/lib"; import { ConfigurationProfileBuilder } from "../common/lib/profile/configuration_profile_builder"; +import { FullServicesContainerImpl } from "../common/lib/utils/full_services_container"; const mockConnectionProvider = mock(); const mockHostListProviderService = mock(); @@ -43,8 +43,8 @@ when(mockPluginService.getDialect()).thenReturn(new PgDatabaseDialect()); when(mockPluginService.getDriverDialect()).thenReturn(new NodePostgresDriverDialect()); when(mockPluginService.getCurrentClient()).thenReturn(mockClient); -const pluginServiceManagerContainer = new PluginServiceManagerContainer(); -pluginServiceManagerContainer.pluginService = instance(mockPluginService); +const servicesContainer = mock(FullServicesContainerImpl); +when(servicesContainer.pluginService).thenReturn(instance(mockPluginService)); const propsWithNoPlugins = new Map(); const propsWithPlugins = new Map(); @@ -53,7 +53,7 @@ WrapperProperties.PLUGINS.set(propsWithNoPlugins, ""); function getPluginManagerWithPlugins() { return new PluginManager( - pluginServiceManagerContainer, + servicesContainer, propsWithPlugins, new ConnectionProviderManager(instance(mockConnectionProvider), null), new NullTelemetryFactory() @@ -62,7 +62,7 @@ function getPluginManagerWithPlugins() { function getPluginManagerWithNoPlugins() { return new PluginManager( - pluginServiceManagerContainer, + servicesContainer, propsWithNoPlugins, new ConnectionProviderManager(instance(mockConnectionProvider), null), new NullTelemetryFactory() diff --git a/tests/plugin_manager_telemetry_benchmarks.ts b/tests/plugin_manager_telemetry_benchmarks.ts index 8420c8eab..3a0f19bba 100644 --- a/tests/plugin_manager_telemetry_benchmarks.ts +++ b/tests/plugin_manager_telemetry_benchmarks.ts @@ -16,7 +16,6 @@ import { add, complete, configure, cycle, save, suite } from "benny"; import { ConnectionPlugin, ConnectionProvider, HostInfoBuilder, PluginManager } from "../common/lib"; -import { PluginServiceManagerContainer } from "../common/lib/plugin_service_manager_container"; import { instance, mock, when } from "ts-mockito"; import { SimpleHostAvailabilityStrategy } from "../common/lib/host_availability/simple_host_availability_strategy"; import { PluginService, PluginServiceImpl } from "../common/lib/plugin_service"; @@ -44,6 +43,7 @@ import { ConnectionPluginFactory } from "../common/lib/plugin_factory"; import { ConfigurationProfileBuilder } from "../common/lib/profile/configuration_profile_builder"; import { AwsPGClient } from "../pg/lib"; import { resourceFromAttributes } from "@opentelemetry/resources"; +import { FullServicesContainerImpl } from "../common/lib/utils/full_services_container"; const mockConnectionProvider = mock(); const mockHostListProviderService = mock(); @@ -55,8 +55,8 @@ when(mockPluginService.getDialect()).thenReturn(new PgDatabaseDialect()); when(mockPluginService.getDriverDialect()).thenReturn(new NodePostgresDriverDialect()); when(mockPluginService.getCurrentClient()).thenReturn(mockClient); -const pluginServiceManagerContainer = new PluginServiceManagerContainer(); -pluginServiceManagerContainer.pluginService = instance(mockPluginService); +const servicesContainer = mock(FullServicesContainerImpl); +when(servicesContainer.pluginService).thenReturn(instance(mockPluginService)); const propsWithNoPlugins = new Map(); const propsWithPlugins = new Map(); @@ -80,7 +80,7 @@ async function createPlugins(numPlugins: number, pluginService: PluginService, c function getPluginManagerWithPlugins() { return new PluginManager( - pluginServiceManagerContainer, + servicesContainer, propsWithPlugins, new ConnectionProviderManager(instance(mockConnectionProvider), null), telemetryFactory @@ -89,7 +89,7 @@ function getPluginManagerWithPlugins() { function getPluginManagerWithNoPlugins() { return new PluginManager( - pluginServiceManagerContainer, + servicesContainer, propsWithNoPlugins, new ConnectionProviderManager(instance(mockConnectionProvider), null), telemetryFactory diff --git a/tests/plugin_telemetry_benchmarks.ts b/tests/plugin_telemetry_benchmarks.ts index 04d791e52..b6c391a58 100644 --- a/tests/plugin_telemetry_benchmarks.ts +++ b/tests/plugin_telemetry_benchmarks.ts @@ -17,7 +17,6 @@ import { anything, instance, mock, when } from "ts-mockito"; import { ConnectionProvider, HostInfoBuilder, PluginManager } from "../common/lib"; import { PluginServiceImpl } from "../common/lib/plugin_service"; -import { PluginServiceManagerContainer } from "../common/lib/plugin_service_manager_container"; import { WrapperProperties } from "../common/lib/wrapper_property"; import { add, complete, configure, cycle, save, suite } from "benny"; import { TestConnectionWrapper } from "./testplugin/test_connection_wrapper"; @@ -41,6 +40,7 @@ import { PgClientWrapper } from "../common/lib/pg_client_wrapper"; import { DriverDialect } from "../common/lib/driver_dialect/driver_dialect"; import { NodePostgresDriverDialect } from "../pg/lib/dialect/node_postgres_driver_dialect"; import { resourceFromAttributes } from "@opentelemetry/resources"; +import { FullServicesContainerImpl } from "../common/lib/utils/full_services_container"; const mockConnectionProvider = mock(); const mockPluginService = mock(PluginServiceImpl); @@ -60,8 +60,8 @@ when(mockPluginService.getCurrentClient()).thenReturn(mockClientWrapper.client); when(mockPluginService.getDriverDialect()).thenReturn(mockDialect); const connectionString = "my.domain.com"; -const pluginServiceManagerContainer = new PluginServiceManagerContainer(); -pluginServiceManagerContainer.pluginService = instance(mockPluginService); +const servicesContainer = mock(FullServicesContainerImpl); +when(servicesContainer.pluginService).thenReturn(instance(mockPluginService)); const propsExecute = new Map(); const propsReadWrite = new Map(); @@ -84,19 +84,19 @@ WrapperProperties.TELEMETRY_TRACES_BACKEND.set(propsReadWrite, "OTLP"); WrapperProperties.TELEMETRY_TRACES_BACKEND.set(props, "OTLP"); const pluginManagerExecute = new PluginManager( - pluginServiceManagerContainer, + servicesContainer, propsExecute, new ConnectionProviderManager(instance(mockConnectionProvider), null), telemetryFactory ); const pluginManagerReadWrite = new PluginManager( - pluginServiceManagerContainer, + servicesContainer, propsReadWrite, new ConnectionProviderManager(instance(mockConnectionProvider), null), telemetryFactory ); const pluginManager = new PluginManager( - pluginServiceManagerContainer, + servicesContainer, props, new ConnectionProviderManager(instance(mockConnectionProvider), null), new NullTelemetryFactory() diff --git a/tests/testplugin/benchmark_plugin_factory.ts b/tests/testplugin/benchmark_plugin_factory.ts index f1cd65e63..d4101cbaa 100644 --- a/tests/testplugin/benchmark_plugin_factory.ts +++ b/tests/testplugin/benchmark_plugin_factory.ts @@ -14,13 +14,14 @@ limitations under the License. */ -import { ConnectionPlugin, PluginService, AwsWrapperError } from "../../index"; +import { ConnectionPlugin, AwsWrapperError } from "../../index"; import { Messages } from "../../common/lib/utils/messages"; import { BenchmarkPlugin } from "./benchmark_plugin"; import { ConnectionPluginFactory } from "../../common/lib/plugin_factory"; +import { FullServicesContainer } from "../../common/lib/utils/full_services_container"; export class BenchmarkPluginFactory extends ConnectionPluginFactory { - async getInstance(pluginService: PluginService, properties: object): Promise { + async getInstance(servicesContainer: FullServicesContainer, properties: object): Promise { try { return new BenchmarkPlugin(); } catch (error: any) { diff --git a/tests/unit/aurora_connection_tracker.test.ts b/tests/unit/aurora_connection_tracker.test.ts index dfd7f3d87..d6aad0791 100644 --- a/tests/unit/aurora_connection_tracker.test.ts +++ b/tests/unit/aurora_connection_tracker.test.ts @@ -19,9 +19,7 @@ import { HostInfoBuilder } from "../../common/lib/host_info_builder"; import { SimpleHostAvailabilityStrategy } from "../../common/lib/host_availability/simple_host_availability_strategy"; import { PluginServiceImpl } from "../../common/lib/plugin_service"; import { anything, instance, mock, reset, verify, when } from "ts-mockito"; -import { - AuroraConnectionTrackerPlugin -} from "../../common/lib/plugins/connection_tracker/aurora_connection_tracker_plugin"; +import { AuroraConnectionTrackerPlugin } from "../../common/lib/plugins/connection_tracker/aurora_connection_tracker_plugin"; import { OpenedConnectionTracker } from "../../common/lib/plugins/connection_tracker/opened_connection_tracker"; import { RdsUtils } from "../../common/lib/utils/rds_utils"; import { RdsUrlType } from "../../common/lib/utils/rds_url_type"; diff --git a/tests/unit/aurora_initial_connection_strategy_plugin.test.ts b/tests/unit/aurora_initial_connection_strategy_plugin.test.ts index 47757b5e6..920057ec5 100644 --- a/tests/unit/aurora_initial_connection_strategy_plugin.test.ts +++ b/tests/unit/aurora_initial_connection_strategy_plugin.test.ts @@ -14,9 +14,7 @@ limitations under the License. */ -import { - AuroraInitialConnectionStrategyPlugin -} from "../../common/lib/plugins/aurora_initial_connection_strategy_plugin"; +import { AuroraInitialConnectionStrategyPlugin } from "../../common/lib/plugins/aurora_initial_connection_strategy_plugin"; import { PluginServiceImpl } from "../../common/lib/plugin_service"; import { anything, instance, mock, reset, spy, verify, when } from "ts-mockito"; import { WrapperProperties } from "../../common/lib/wrapper_property"; @@ -111,7 +109,7 @@ describe("Aurora initial connection strategy plugin", () => { when(mockPluginService.connect(anything(), anything(), anything())).thenResolve(writerClient); expect(await plugin.connect(hostInfo, props, true, mockFunc)).toBe(writerClient); - verify(mockPluginService.forceRefreshHostList(writerClient)).never(); + verify(mockPluginService.forceRefreshHostList()).never(); }); it("test reader - not found", async () => { @@ -130,7 +128,7 @@ describe("Aurora initial connection strategy plugin", () => { when(mockPluginService.getHostInfoByStrategy(anything(), anything())).thenReturn(instance(mockReaderHostInfo)); expect(await plugin.connect(hostInfo, props, true, mockFunc)).toBe(readerClient); - verify(mockPluginService.forceRefreshHostList(readerClient)).never(); + verify(mockPluginService.forceRefreshHostList()).never(); }); it("test reader - resolves to writer", async () => { diff --git a/tests/unit/aws_secrets_manager_plugin.test.ts b/tests/unit/aws_secrets_manager_plugin.test.ts index dbae23c28..4aa1271e1 100644 --- a/tests/unit/aws_secrets_manager_plugin.test.ts +++ b/tests/unit/aws_secrets_manager_plugin.test.ts @@ -15,11 +15,7 @@ */ import { SecretsManagerClient, SecretsManagerServiceException } from "@aws-sdk/client-secrets-manager"; -import { - AwsSecretsManagerPlugin, - Secret, - SecretCacheKey -} from "../../common/lib/authentication/aws_secrets_manager_plugin"; +import { AwsSecretsManagerPlugin, Secret, SecretCacheKey } from "../../common/lib/authentication/aws_secrets_manager_plugin"; import { AwsClient } from "../../common/lib/aws_client"; import { SimpleHostAvailabilityStrategy } from "../../common/lib/host_availability/simple_host_availability_strategy"; import { AwsWrapperError, HostInfo } from "../../common/lib"; diff --git a/tests/unit/batching_event_publisher.test.ts b/tests/unit/batching_event_publisher.test.ts new file mode 100644 index 000000000..bb079c5c4 --- /dev/null +++ b/tests/unit/batching_event_publisher.test.ts @@ -0,0 +1,141 @@ +/* + Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"). + You may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +import { BatchingEventPublisher } from "../../common/lib/utils/events/batching_event_publisher"; +import { DataAccessEvent } from "../../common/lib/utils/events/data_access_event"; +import { Event, EventSubscriber } from "../../common/lib/utils/events/event"; + +class TestableEventPublisher extends BatchingEventPublisher { + constructor() { + super(0); // Pass 0 to avoid starting the interval + } + + protected initPublishingInterval(_messageIntervalMs: number): void { + // Do nothing. + } + + get subscriberCount(): number { + return this.subscribersMap.size; + } + + get pendingEventCount(): number { + return this.pendingEvents.size; + } + + async triggerSendMessages(): Promise { + await this.sendMessages(); + } +} + +// A simple class to use as the dataClass in DataAccessEvent +class TestDataClass {} + +describe("BatchingEventPublisher", () => { + let publisher: TestableEventPublisher; + let mockSubscriber: EventSubscriber; + let processEventCalls: Event[]; + + beforeEach(() => { + publisher = new TestableEventPublisher(); + processEventCalls = []; + mockSubscriber = { + processEvent: async (event: Event) => { + processEventCalls.push(event); + } + }; + }); + + afterEach(() => { + publisher.releaseResources(); + }); + + it("should publish events to subscribers and deduplicate", async () => { + const eventSubscriptions = new Set([DataAccessEvent]); + + publisher.subscribe(mockSubscriber, eventSubscriptions); + publisher.subscribe(mockSubscriber, eventSubscriptions); + expect(publisher.subscriberCount).toBe(1); + + const event = new DataAccessEvent(TestDataClass, "key"); + publisher.publish(event); + publisher.publish(event); + + await publisher.triggerSendMessages(); + + expect(publisher.pendingEventCount).toBe(0); + + expect(processEventCalls.length).toBe(1); + expect(processEventCalls[0]).toBe(event); + + publisher.unsubscribe(mockSubscriber, eventSubscriptions); + publisher.publish(event); + await publisher.triggerSendMessages(); + + expect(publisher.pendingEventCount).toBe(0); + + expect(processEventCalls.length).toBe(1); + }); + + it("should deliver immediate events synchronously", () => { + const immediateEvent: Event = { + isImmediateDelivery: true + }; + + const eventSubscriptions = new Set([immediateEvent.constructor as new (...args: any[]) => Event]); + publisher.subscribe(mockSubscriber, eventSubscriptions); + + publisher.publish(immediateEvent); + + expect(processEventCalls.length).toBe(1); + expect(processEventCalls[0]).toBe(immediateEvent); + + expect(publisher.pendingEventCount).toBe(0); + }); + + it("should not deliver events to unsubscribed subscribers", async () => { + const eventSubscriptions = new Set([DataAccessEvent]); + + publisher.subscribe(mockSubscriber, eventSubscriptions); + publisher.unsubscribe(mockSubscriber, eventSubscriptions); + + const event = new DataAccessEvent(TestDataClass, "key"); + publisher.publish(event); + await publisher.triggerSendMessages(); + + expect(processEventCalls.length).toBe(0); + }); + + it("should handle multiple subscribers", async () => { + const processEventCalls2: Event[] = []; + const mockSubscriber2: EventSubscriber = { + processEvent: async (event: Event) => { + processEventCalls2.push(event); + } + }; + + const eventSubscriptions = new Set([DataAccessEvent]); + + publisher.subscribe(mockSubscriber, eventSubscriptions); + publisher.subscribe(mockSubscriber2, eventSubscriptions); + + const event = new DataAccessEvent(TestDataClass, "key"); + publisher.publish(event); + await publisher.triggerSendMessages(); + + expect(processEventCalls.length).toBe(1); + expect(processEventCalls2.length).toBe(1); + }); +}); diff --git a/tests/unit/connection_plugin_chain_builder.test.ts b/tests/unit/connection_plugin_chain_builder.test.ts index ee152a198..b9cb9c780 100644 --- a/tests/unit/connection_plugin_chain_builder.test.ts +++ b/tests/unit/connection_plugin_chain_builder.test.ts @@ -30,23 +30,33 @@ import { ConnectionProviderManager } from "../../common/lib/connection_provider_ import { NullTelemetryFactory } from "../../common/lib/utils/telemetry/null_telemetry_factory"; import { AbstractConnectionPlugin } from "../../common/lib/abstract_connection_plugin"; import { ConnectionPluginFactory } from "../../common/lib/plugin_factory"; +import { FullServicesContainer } from "../../common/lib/utils/full_services_container"; const mockPluginService: PluginServiceImpl = mock(PluginServiceImpl); const mockPluginServiceInstance: PluginService = instance(mockPluginService); const mockDefaultConnProvider: ConnectionProvider = mock(DriverConnectionProvider); const mockEffectiveConnProvider: ConnectionProvider = mock(DriverConnectionProvider); +const mockServicesContainer: FullServicesContainer = { + pluginService: mockPluginServiceInstance, + telemetryFactory: new NullTelemetryFactory() +} as unknown as FullServicesContainer; + describe("testConnectionPluginChainBuilder", () => { beforeAll(() => { when(mockPluginService.getTelemetryFactory()).thenReturn(new NullTelemetryFactory()); }); + afterEach(async () => { + await PluginManager.releaseResources(); + }); + it.each([["iam,staleDns,failover"], ["iam, staleDns, failover"]])("sort plugins", async (plugins) => { const props = new Map(); props.set(WrapperProperties.PLUGINS.name, plugins); const result = await ConnectionPluginChainBuilder.getPlugins( - mockPluginServiceInstance, + mockServicesContainer, props, new ConnectionProviderManager(mockDefaultConnProvider, mockEffectiveConnProvider), null @@ -65,7 +75,7 @@ describe("testConnectionPluginChainBuilder", () => { props.set(WrapperProperties.AUTO_SORT_PLUGIN_ORDER.name, false); const result = await ConnectionPluginChainBuilder.getPlugins( - mockPluginServiceInstance, + mockServicesContainer, props, new ConnectionProviderManager(mockDefaultConnProvider, mockEffectiveConnProvider), null @@ -84,7 +94,7 @@ describe("testConnectionPluginChainBuilder", () => { props.set(WrapperProperties.PLUGINS.name, "executeTime,connectTime,iam"); let result = await ConnectionPluginChainBuilder.getPlugins( - mockPluginServiceInstance, + mockServicesContainer, props, new ConnectionProviderManager(mockDefaultConnProvider, mockEffectiveConnProvider), null @@ -99,7 +109,7 @@ describe("testConnectionPluginChainBuilder", () => { props.set(WrapperProperties.PLUGINS.name, "iam,executeTime,connectTime,failover"); result = await ConnectionPluginChainBuilder.getPlugins( - mockPluginServiceInstance, + mockServicesContainer, props, new ConnectionProviderManager(mockDefaultConnProvider, mockEffectiveConnProvider), null @@ -120,7 +130,7 @@ describe("testConnectionPluginChainBuilder", () => { props.set(WrapperProperties.PLUGINS.name, "test"); const result = await ConnectionPluginChainBuilder.getPlugins( - mockPluginServiceInstance, + mockServicesContainer, props, new ConnectionProviderManager(mockDefaultConnProvider, mockEffectiveConnProvider), null @@ -133,8 +143,8 @@ describe("testConnectionPluginChainBuilder", () => { }); class TestPluginFactory extends ConnectionPluginFactory { - async getInstance(pluginService: PluginService, properties: Map): Promise { - return new TestPlugin(pluginService, properties); + async getInstance(servicesContainer: FullServicesContainer, properties: Map): Promise { + return new TestPlugin(servicesContainer.pluginService, properties); } } diff --git a/tests/unit/database_dialect.test.ts b/tests/unit/database_dialect.test.ts index 36b55a71f..b8ff34720 100644 --- a/tests/unit/database_dialect.test.ts +++ b/tests/unit/database_dialect.test.ts @@ -23,7 +23,6 @@ import { RdsPgDatabaseDialect } from "../../pg/lib/dialect/rds_pg_database_diale import { DatabaseDialect, DatabaseType } from "../../common/lib/database_dialect/database_dialect"; import { DatabaseDialectCodes } from "../../common/lib/database_dialect/database_dialect_codes"; import { PluginServiceImpl } from "../../common/lib/plugin_service"; -import { PluginServiceManagerContainer } from "../../common/lib/plugin_service_manager_container"; import { AwsPGClient } from "../../pg/lib"; import { WrapperProperties } from "../../common/lib/wrapper_property"; import { HostInfoBuilder } from "../../common/lib/host_info_builder"; @@ -31,10 +30,18 @@ import { SimpleHostAvailabilityStrategy } from "../../common/lib/host_availabili import { ClientWrapper } from "../../common/lib/client_wrapper"; import { RdsMultiAZClusterMySQLDatabaseDialect } from "../../mysql/lib/dialect/rds_multi_az_mysql_database_dialect"; import { RdsMultiAZClusterPgDatabaseDialect } from "../../pg/lib/dialect/rds_multi_az_pg_database_dialect"; +import { GlobalAuroraMySQLDatabaseDialect } from "../../mysql/lib/dialect/global_aurora_mysql_database_dialect"; +import { GlobalAuroraPgDatabaseDialect } from "../../pg/lib/dialect/global_aurora_pg_database_dialect"; import { DatabaseDialectManager } from "../../common/lib/database_dialect/database_dialect_manager"; import { NodePostgresDriverDialect } from "../../pg/lib/dialect/node_postgres_driver_dialect"; import { mock } from "ts-mockito"; import { PgClientWrapper } from "../../common/lib/pg_client_wrapper"; +import { StorageService } from "../../common/lib/utils/storage/storage_service"; +import { ConnectionProvider } from "../../common/lib"; +import { TelemetryFactory } from "../../common/lib/utils/telemetry/telemetry_factory"; +import { MonitorService } from "../../common/lib/utils/monitoring/monitor_service"; +import { FullServicesContainerImpl } from "../../common/lib/utils/full_services_container"; +import { EventPublisher } from "../../common/lib/utils/events/event"; const LOCALHOST = "localhost"; const RDS_DATABASE = "database-1.xyz.us-east-2.rds.amazonaws.com"; @@ -44,14 +51,16 @@ const mysqlDialects: Map = new Map([ [DatabaseDialectCodes.MYSQL, new MySQLDatabaseDialect()], [DatabaseDialectCodes.RDS_MYSQL, new RdsMySQLDatabaseDialect()], [DatabaseDialectCodes.AURORA_MYSQL, new AuroraMySQLDatabaseDialect()], - [DatabaseDialectCodes.RDS_MULTI_AZ_MYSQL, new RdsMultiAZClusterMySQLDatabaseDialect()] + [DatabaseDialectCodes.RDS_MULTI_AZ_MYSQL, new RdsMultiAZClusterMySQLDatabaseDialect()], + [DatabaseDialectCodes.GLOBAL_AURORA_MYSQL, new GlobalAuroraMySQLDatabaseDialect()] ]); const pgDialects: Map = new Map([ [DatabaseDialectCodes.PG, new PgDatabaseDialect()], [DatabaseDialectCodes.RDS_PG, new RdsPgDatabaseDialect()], [DatabaseDialectCodes.AURORA_PG, new AuroraPgDatabaseDialect()], - [DatabaseDialectCodes.RDS_MULTI_AZ_PG, new RdsMultiAZClusterPgDatabaseDialect()] + [DatabaseDialectCodes.RDS_MULTI_AZ_PG, new RdsMultiAZClusterPgDatabaseDialect()], + [DatabaseDialectCodes.GLOBAL_AURORA_PG, new GlobalAuroraPgDatabaseDialect()] ]); const MYSQL_QUERY = "SHOW VARIABLES LIKE 'version_comment'"; @@ -180,7 +189,13 @@ const expectedDialectMapping: Map = ne ] ]); -const pluginServiceManagerContainer = new PluginServiceManagerContainer(); +const fullServicesContainer = new FullServicesContainerImpl( + mock(), + mock(), + mock(), + mock(), + mock() +); const mockClient = new AwsPGClient({}); const mockDriverDialect = mock(NodePostgresDriverDialect); @@ -275,14 +290,9 @@ describe("test database dialects", () => { }).build(); const mockClientWrapper: ClientWrapper = new PgClientWrapper(mockTargetClient, currentHostInfo, new Map()); - const pluginService = new PluginServiceImpl( - pluginServiceManagerContainer, - mockClient, - databaseType, - expectedDialect!.dialects, - props, - mockDriverDialect - ); + const pluginService = new PluginServiceImpl(fullServicesContainer, mockClient, databaseType, expectedDialect!.dialects, props, mockDriverDialect); + fullServicesContainer.hostListProviderService = pluginService; + fullServicesContainer.pluginService = pluginService; await pluginService.updateDialect(mockClientWrapper); expect(pluginService.getDialect()).toBe(expectedDialectClass); }); diff --git a/tests/unit/exponential_backoff_host_availability_strategy.test.ts b/tests/unit/exponential_backoff_host_availability_strategy.test.ts index e10cda452..f0720d1f0 100644 --- a/tests/unit/exponential_backoff_host_availability_strategy.test.ts +++ b/tests/unit/exponential_backoff_host_availability_strategy.test.ts @@ -14,9 +14,7 @@ limitations under the License. */ -import { - ExponentialBackoffHostAvailabilityStrategy -} from "../../common/lib/host_availability/exponential_backoff_host_availability_strategy"; +import { ExponentialBackoffHostAvailabilityStrategy } from "../../common/lib/host_availability/exponential_backoff_host_availability_strategy"; import { HostAvailability, IllegalArgumentError } from "../../common/lib"; import { sleep } from "../../common/lib/utils/utils"; import { WrapperProperties } from "../../common/lib/wrapper_property"; diff --git a/tests/unit/failover2_plugin.test.ts b/tests/unit/failover2_plugin.test.ts index b36a2baf3..8db14ad22 100644 --- a/tests/unit/failover2_plugin.test.ts +++ b/tests/unit/failover2_plugin.test.ts @@ -40,6 +40,7 @@ import { MySQLClientWrapper } from "../../common/lib/mysql_client_wrapper"; import { MySQL2DriverDialect } from "../../mysql/lib/dialect/mysql2_driver_dialect"; import { DriverDialect } from "../../common/lib/driver_dialect/driver_dialect"; import { Failover2Plugin } from "../../common/lib/plugins/failover2/failover2_plugin"; +import { FullServicesContainer } from "../../common/lib/utils/full_services_container"; const builder = new HostInfoBuilder({ hostAvailabilityStrategy: new SimpleHostAvailabilityStrategy() }); @@ -57,7 +58,10 @@ const properties: Map = new Map(); let plugin: Failover2Plugin; function initializePlugin(mockPluginServiceInstance: PluginService): void { - plugin = new Failover2Plugin(mockPluginServiceInstance, properties, new RdsUtils()); + const mockContainer = { + pluginService: mockPluginServiceInstance + } as unknown as FullServicesContainer; + plugin = new Failover2Plugin(mockContainer, properties, new RdsUtils()); } describe("reader failover handler", () => { diff --git a/tests/unit/failover_plugin.test.ts b/tests/unit/failover_plugin.test.ts index 0cbbe34d1..627500902 100644 --- a/tests/unit/failover_plugin.test.ts +++ b/tests/unit/failover_plugin.test.ts @@ -17,7 +17,6 @@ import { AwsClient } from "../../common/lib/aws_client"; import { SimpleHostAvailabilityStrategy } from "../../common/lib/host_availability/simple_host_availability_strategy"; import { HostInfoBuilder } from "../../common/lib/host_info_builder"; -import { RdsHostListProvider } from "../../common/lib/host_list_provider/rds_host_list_provider"; import { PluginService, PluginServiceImpl } from "../../common/lib/plugin_service"; import { FailoverMode } from "../../common/lib/plugins/failover/failover_mode"; import { FailoverPlugin } from "../../common/lib/plugins/failover/failover_plugin"; @@ -44,6 +43,7 @@ import { MySQLClientWrapper } from "../../common/lib/mysql_client_wrapper"; import { NullTelemetryFactory } from "../../common/lib/utils/telemetry/null_telemetry_factory"; import { HostChangeOptions } from "../../common/lib/host_change_options"; import { Messages } from "../../common/lib/utils/messages"; +import { RdsHostListProvider } from "../../common/lib/host_list_provider/rds_host_list_provider"; const builder = new HostInfoBuilder({ hostAvailabilityStrategy: new SimpleHostAvailabilityStrategy() }); diff --git a/tests/unit/iam_authentication_plugin.test.ts b/tests/unit/iam_authentication_plugin.test.ts index f2ee620fb..58d7ad7ba 100644 --- a/tests/unit/iam_authentication_plugin.test.ts +++ b/tests/unit/iam_authentication_plugin.test.ts @@ -110,6 +110,7 @@ describe("testIamAuth", () => { afterEach(() => { reset(spyIamAuthUtils); + PluginManager.releaseResources(); }); it("testPostgresConnectValidTokenInCache", async () => { diff --git a/tests/unit/internal_pool_connection_provider.test.ts b/tests/unit/internal_pool_connection_provider.test.ts index 5e274f82d..dc14622a9 100644 --- a/tests/unit/internal_pool_connection_provider.test.ts +++ b/tests/unit/internal_pool_connection_provider.test.ts @@ -15,14 +15,7 @@ */ import { AwsClient } from "../../common/lib/aws_client"; -import { - AwsPoolConfig, - HostInfo, - HostInfoBuilder, - HostRole, - InternalPooledConnectionProvider, - InternalPoolMapping -} from "../../common/lib"; +import { AwsPoolConfig, HostInfo, HostInfoBuilder, HostRole, InternalPooledConnectionProvider, InternalPoolMapping } from "../../common/lib"; import { PluginServiceImpl } from "../../common/lib/plugin_service"; import { anything, instance, mock, reset, spy, when } from "ts-mockito"; import { HostListProviderService } from "../../common/lib/host_list_provider_service"; @@ -36,9 +29,7 @@ import { AwsMySQLClient } from "../../mysql/lib"; import { MySQLDatabaseDialect } from "../../mysql/lib/dialect/mysql_database_dialect"; import { MySQL2DriverDialect } from "../../mysql/lib/dialect/mysql2_driver_dialect"; import { PoolClientWrapper } from "../../common/lib/pool_client_wrapper"; -import { - SlidingExpirationCacheWithCleanupTask -} from "../../common/lib/utils/sliding_expiration_cache_with_cleanup_task"; +import { SlidingExpirationCacheWithCleanupTask } from "../../common/lib/utils/sliding_expiration_cache_with_cleanup_task"; const user1 = "user1"; const user2 = "user2"; diff --git a/tests/unit/notification_pipeline.test.ts b/tests/unit/notification_pipeline.test.ts index cb1d315e1..0d4d95c5e 100644 --- a/tests/unit/notification_pipeline.test.ts +++ b/tests/unit/notification_pipeline.test.ts @@ -16,7 +16,6 @@ import { HostChangeOptions } from "../../common/lib/host_change_options"; import { OldConnectionSuggestionAction } from "../../common/lib/old_connection_suggestion_action"; -import { PluginServiceManagerContainer } from "../../common/lib/plugin_service_manager_container"; import { DefaultPlugin } from "../../common/lib/plugins/default_plugin"; import { instance, mock } from "ts-mockito"; import { PluginServiceImpl } from "../../common/lib/plugin_service"; @@ -24,6 +23,7 @@ import { DriverConnectionProvider } from "../../common/lib/driver_connection_pro import { ConnectionProviderManager } from "../../common/lib/connection_provider_manager"; import { NullTelemetryFactory } from "../../common/lib/utils/telemetry/null_telemetry_factory"; import { PluginManager } from "../../common/lib"; +import { FullServicesContainer, FullServicesContainerImpl } from "../../common/lib/utils/full_services_container"; class TestPlugin extends DefaultPlugin { counter: number = 0; @@ -43,7 +43,7 @@ class TestPlugin extends DefaultPlugin { } } -const container: PluginServiceManagerContainer = new PluginServiceManagerContainer(); +const container: FullServicesContainer = mock(FullServicesContainerImpl); const props: Map = mock(Map); const hostListChanges: Map> = mock(Map>); const connectionChanges: Set = mock(Set); @@ -64,6 +64,10 @@ describe("notificationPipelineTest", () => { pluginManager["_plugins"] = [plugin]; }); + afterEach(async () => { + await PluginManager.releaseResources(); + }); + it("test_notifyConnectionChanged", async () => { const result: Set = await pluginManager.notifyConnectionChanged(connectionChanges, null); expect(plugin.counter).toBe(1); diff --git a/tests/unit/plugin_service.test.ts b/tests/unit/plugin_service.test.ts index 2d6b145bf..2b14f31cc 100644 --- a/tests/unit/plugin_service.test.ts +++ b/tests/unit/plugin_service.test.ts @@ -26,7 +26,7 @@ import { HostInfoBuilder } from "../../common/lib"; import { PluginServiceImpl } from "../../common/lib/plugin_service"; import { DatabaseDialectCodes } from "../../common/lib/database_dialect/database_dialect_codes"; import { AwsClient } from "../../common/lib/aws_client"; -import { PluginServiceManagerContainer } from "../../common/lib/plugin_service_manager_container"; +import { FullServicesContainer, FullServicesContainerImpl } from "../../common/lib/utils/full_services_container"; import { AllowedAndBlockedHosts } from "../../common/lib/allowed_and_blocked_hosts"; import { DatabaseType } from "../../common/lib/database_dialect/database_dialect"; @@ -64,7 +64,7 @@ let pluginService: TestPluginService; describe("testCustomEndpoint", () => { beforeEach(() => { pluginService = new TestPluginService( - new PluginServiceManagerContainer(), + mock(FullServicesContainerImpl), mockAwsClient, DatabaseType.MYSQL, knownDialectsByCode, diff --git a/tests/unit/rds_host_list_provider.test.ts b/tests/unit/rds_host_list_provider.test.ts index 99a91d3df..6f675a8c8 100644 --- a/tests/unit/rds_host_list_provider.test.ts +++ b/tests/unit/rds_host_list_provider.test.ts @@ -14,11 +14,10 @@ limitations under the License. */ -import { RdsHostListProvider } from "../../common/lib/host_list_provider/rds_host_list_provider"; import { anything, instance, mock, reset, spy, verify, when } from "ts-mockito"; import { PluginServiceImpl } from "../../common/lib/plugin_service"; import { AwsClient } from "../../common/lib/aws_client"; -import { AwsWrapperError, HostInfo, HostInfoBuilder } from "../../common/lib"; +import { AwsWrapperError, HostInfo, HostInfoBuilder, PluginManager } from "../../common/lib"; import { SimpleHostAvailabilityStrategy } from "../../common/lib/host_availability/simple_host_availability_strategy"; import { ConnectionUrlParser } from "../../common/lib/utils/connection_url_parser"; import { AwsPGClient } from "../../pg/lib"; @@ -30,13 +29,17 @@ import { PgClientWrapper } from "../../common/lib/pg_client_wrapper"; import { CoreServicesContainer } from "../../common/lib/utils/core_services_container"; import { StorageService } from "../../common/lib/utils/storage/storage_service"; import { Topology } from "../../common/lib/host_list_provider/topology"; +import { RdsHostListProvider } from "../../common/lib/host_list_provider/rds_host_list_provider"; import { TopologyQueryResult, TopologyUtils } from "../../common/lib/host_list_provider/topology_utils"; +import { FullServicesContainerImpl } from "../../common/lib/utils/full_services_container"; const mockClient: AwsClient = mock(AwsPGClient); const mockDialect: AuroraPgDatabaseDialect = mock(AuroraPgDatabaseDialect); +const mockServiceContainer: FullServicesContainerImpl = mock(FullServicesContainerImpl); const mockPluginService: PluginServiceImpl = mock(PluginServiceImpl); const connectionUrlParser: ConnectionUrlParser = new PgConnectionUrlParser(); const mockTopologyUtils: TopologyUtils = mock(TopologyUtils); +const storageService: StorageService = CoreServicesContainer.getInstance().storageService; const hosts: HostInfo[] = [ createHost({ @@ -56,20 +59,15 @@ const currentHostInfo = createHost({ }); const clientWrapper: ClientWrapper = new PgClientWrapper(undefined, currentHostInfo, new Map()); - const mockClientWrapper: ClientWrapper = mock(clientWrapper); -const storageService: StorageService = CoreServicesContainer.getInstance().getStorageService(); - -const defaultRefreshRateNano: number = 5 * 1_000_000_000; - function createHost(config: any): HostInfo { const info = new HostInfoBuilder(config); return info.build(); } function getRdsHostListProvider(originalHost: string): RdsHostListProvider { - const provider = new RdsHostListProvider(new Map(), originalHost, instance(mockTopologyUtils), instance(mockPluginService)); + const provider = new RdsHostListProvider(new Map(), originalHost, instance(mockTopologyUtils), instance(mockServiceContainer)); provider.init(); return provider; } @@ -84,11 +82,13 @@ describe("testRdsHostListProvider", () => { when(mockPluginService.getCurrentClient()).thenReturn(instance(mockClient)); when(mockClient.targetClient).thenReturn(mockClientWrapper); when(mockPluginService.getHostInfoBuilder()).thenReturn(new HostInfoBuilder({ hostAvailabilityStrategy: new SimpleHostAvailabilityStrategy() })); + when(mockServiceContainer.hostListProviderService).thenReturn(instance(mockPluginService)); + when(mockServiceContainer.pluginService).thenReturn(instance(mockPluginService)); + when(mockServiceContainer.storageService).thenReturn(storageService); }); - afterEach(() => { - RdsHostListProvider.clearAll(); - CoreServicesContainer.getInstance().getStorageService().clearAll(); + afterEach(async () => { + await PluginManager.releaseResources(); reset(mockDialect); reset(mockClientWrapper); @@ -103,7 +103,7 @@ describe("testRdsHostListProvider", () => { const expected: HostInfo[] = hosts; storageService.set(rdsHostListProvider.clusterId, new Topology(expected)); - const result = await rdsHostListProvider.getTopology(mockClientWrapper, false); + const result = await rdsHostListProvider.getTopology(); expect(result.hosts.length).toEqual(2); expect(result.hosts).toEqual(expected); @@ -117,7 +117,6 @@ describe("testRdsHostListProvider", () => { when(mockPluginService.isClientValid(anything())).thenResolve(true); - storageService.set(rdsHostListProvider.clusterId, new Topology(hosts)); const newHosts: HostInfo[] = [ createHost({ hostAvailabilityStrategy: new SimpleHostAvailabilityStrategy(), @@ -126,13 +125,12 @@ describe("testRdsHostListProvider", () => { ]; when(mockClient.isValid()).thenResolve(true); - when(spiedProvider.getCurrentTopology(mockClientWrapper, anything())).thenReturn(Promise.resolve(newHosts)); + when(mockPluginService.isDialectConfirmed()).thenReturn(true); + when((spiedProvider as any).forceRefreshMonitor(anything(), anything())).thenReturn(Promise.resolve(newHosts)); - const result = await rdsHostListProvider.getTopology(mockClientWrapper, true); + const result = await rdsHostListProvider.getTopology(); expect(result.hosts.length).toEqual(1); expect(result.hosts).toEqual(newHosts); - - verify(spiedProvider.getCurrentTopology(anything(), anything())).atMost(1); }); it("testGetTopology_noForceUpdate_queryReturnsEmptyHostList", async () => { @@ -145,7 +143,7 @@ describe("testRdsHostListProvider", () => { storageService.set(rdsHostListProvider.clusterId, new Topology(expected)); when(spiedProvider.getCurrentTopology(mockClientWrapper, anything())).thenReturn(Promise.resolve([])); - const result = await rdsHostListProvider.getTopology(mockClientWrapper, false); + const result = await rdsHostListProvider.getTopology(); expect(result.hosts.length).toEqual(2); expect(result.hosts).toEqual(expected); verify(spiedProvider.getCurrentTopology(anything(), anything())).atMost(1); @@ -165,7 +163,7 @@ describe("testRdsHostListProvider", () => { when(spiedProvider.getCurrentTopology(mockClientWrapper, anything())).thenReturn(Promise.resolve([])); - const result = await rdsHostListProvider.getTopology(mockClientWrapper, true); + const result = await rdsHostListProvider.getTopology(); expect(result.hosts).toBeTruthy(); for (let i = 0; i < result.hosts.length; i++) { expect(result.hosts[i].equals(initialHosts[i])).toBeTruthy(); diff --git a/tests/unit/read_write_splitting.test.ts b/tests/unit/read_write_splitting.test.ts index bb291e50a..612435e19 100644 --- a/tests/unit/read_write_splitting.test.ts +++ b/tests/unit/read_write_splitting.test.ts @@ -252,7 +252,7 @@ describe("reader write splitting test", () => { when(mockPluginService.getCurrentHostInfo()).thenReturn(readerHost1); when(mockPluginService.acceptsStrategy(anything(), anything())).thenReturn(true); - when(mockHostListProviderService.isStaticHostListProvider()).thenReturn(false); + when(mockHostListProviderService.isDynamicHostListProvider()).thenReturn(true); when(mockHostListProviderService.getHostListProvider()).thenReturn(mockHostListProviderInstance); const target = new TestReadWriteSplitting( @@ -405,7 +405,7 @@ describe("reader write splitting test", () => { when(mockPluginService.getCurrentHostInfo()).thenReturn(writerHostUnknownRole); when(mockPluginService.acceptsStrategy(anything(), anything())).thenReturn(true); - when(mockHostListProviderService.isStaticHostListProvider()).thenReturn(false); + when(mockHostListProviderService.isDynamicHostListProvider()).thenReturn(true); const target = new TestReadWriteSplitting( mockPluginServiceInstance, diff --git a/tests/unit/sliding_expiration_cache.test.ts b/tests/unit/sliding_expiration_cache.test.ts index 181924f94..d9f2808b2 100644 --- a/tests/unit/sliding_expiration_cache.test.ts +++ b/tests/unit/sliding_expiration_cache.test.ts @@ -15,8 +15,8 @@ */ import { SlidingExpirationCache } from "../../common/lib/utils/sliding_expiration_cache"; -import { convertMsToNanos, convertNanosToMs, sleep } from "../../common/lib/utils/utils"; import { SlidingExpirationCacheWithCleanupTask } from "../../common/lib/utils/sliding_expiration_cache_with_cleanup_task"; +import { convertMsToNanos, convertNanosToMs, sleep } from "../../common/lib/utils/utils"; class DisposableItem { shouldDispose: boolean; @@ -127,7 +127,7 @@ describe("test_sliding_expiration_cache", () => { expect(item2.disposed).toEqual(true); }); - it("test async cleanup thread", async () => { + it("test async cleanup task", async () => { const cleanupIntervalNanos = BigInt(300_000_000); // .3 seconds const disposeMs = 1000; const target = new SlidingExpirationCacheWithCleanupTask( diff --git a/tests/unit/stale_dns_helper.test.ts b/tests/unit/stale_dns_helper.test.ts index 086b75b8c..017b4c651 100644 --- a/tests/unit/stale_dns_helper.test.ts +++ b/tests/unit/stale_dns_helper.test.ts @@ -34,14 +34,6 @@ const mockHostListProviderService = mock(); const props: Map = new Map(); const writerInstance = new HostInfo("writer-host.XYZ.us-west-2.rds.amazonaws.com", 1234, HostRole.WRITER); -const writerCluster = new HostInfo("my-cluster.cluster-XYZ.us-west-2.rds.amazonaws.com", 1234, HostRole.WRITER); -const writerClusterInvalidClusterInetAddress = new HostInfo("my-cluster.cluster-invalid.us-west-2.rds.amazonaws.com", 1234, HostRole.WRITER); -const readerA = new HostInfo("reader-a-host.XYZ.us-west-2.rds.amazonaws.com", 1234, HostRole.READER, HostAvailability.AVAILABLE); -const readerB = new HostInfo("reader-b-host.XYZ.us-west-2.rds.amazonaws.com", 1234, HostRole.READER, HostAvailability.AVAILABLE); - -const clusterHostList = [writerCluster, readerA, readerB]; -const readerHostList = [readerA, readerB]; -const instanceHostList = [writerInstance, readerA, readerB]; const mockInitialConn = mock(AwsClient); const mockHostInfo = mock(HostInfo); @@ -87,208 +79,6 @@ describe("test_stale_dns_helper", () => { expect(returnConn).toBe(mockInitialClientWrapper); }); - it("test_get_verified_connection_cluster_inet_address_none", async () => { - const target: StaleDnsHelper = spy(new StaleDnsHelper(instance(mockPluginService))); - const targetInstance = instance(target); - - const mockHostListProviderServiceInstance = instance(mockHostListProviderService); - - when(target.lookupResult(anything())).thenReturn(); - - const returnConn = await targetInstance.getVerifiedConnection( - writerClusterInvalidClusterInetAddress.host, - true, - mockHostListProviderServiceInstance, - props, - mockConnectFunc - ); - - expect(mockInitialClientWrapper).toBe(returnConn); - expect(mockConnectFunc).toHaveBeenCalled(); - }); - - it("test_get_verified_connection__no_writer_hostinfo", async () => { - const target: StaleDnsHelper = spy(new StaleDnsHelper(instance(mockPluginService))); - const targetInstance = instance(target); - - const mockHostListProviderServiceInstance = instance(mockHostListProviderService); - when(mockPluginService.getHosts()).thenReturn(readerHostList); - when(mockPluginService.getAllHosts()).thenReturn(readerHostList); - - when(mockPluginService.getCurrentHostInfo()).thenReturn(readerA); - - const lookupAddress = { address: "2.2.2.2", family: 0 }; - when(target.lookupResult(anything())).thenResolve(lookupAddress); - - const returnConn = await targetInstance.getVerifiedConnection( - writerCluster.host, - true, - mockHostListProviderServiceInstance, - props, - mockConnectFunc - ); - - expect(mockConnectFunc).toHaveBeenCalled(); - expect(readerA.role).toBe(HostRole.READER); - verify(mockPluginService.forceRefreshHostList(anything())).once(); - expect(mockInitialClientWrapper).toBe(returnConn); - }); - - it("test_get_verified_connection__writer_rds_cluster_dns_true", async () => { - const target: StaleDnsHelper = spy(new StaleDnsHelper(instance(mockPluginService))); - const targetInstance = instance(target); - - const mockHostListProviderServiceInstance = instance(mockHostListProviderService); - - when(mockPluginService.getHosts()).thenReturn(clusterHostList); - when(mockPluginService.getAllHosts()).thenReturn(clusterHostList); - - const lookupAddress = { address: "5.5.5.5", family: 0 }; - when(target.lookupResult(anything())).thenResolve(lookupAddress); - - const returnConn = await targetInstance.getVerifiedConnection( - writerCluster.host, - true, - mockHostListProviderServiceInstance, - props, - mockConnectFunc - ); - - expect(mockConnectFunc).toHaveBeenCalled(); - verify(mockPluginService.refreshHostList(anything())).once(); - expect(mockInitialClientWrapper).toBe(returnConn); - }); - - it("test_get_verified_connection__writer_host_address_none", async () => { - const target: StaleDnsHelper = spy(new StaleDnsHelper(instance(mockPluginService))); - const targetInstance = instance(target); - when(mockPluginService.getHosts()).thenReturn(instanceHostList); - when(mockPluginService.getAllHosts()).thenReturn(instanceHostList); - - const mockHostListProviderServiceInstance = instance(mockHostListProviderService); - - const firstCall = { address: "5.5.5.5", family: 0 }; - const secondCall = { address: "", family: 0 }; - - when(target.lookupResult(anything())).thenResolve(firstCall, secondCall); - - const returnConn = await targetInstance.getVerifiedConnection( - writerCluster.host, - true, - mockHostListProviderServiceInstance, - props, - mockConnectFunc - ); - - expect(mockConnectFunc).toHaveBeenCalled(); - expect(mockInitialClientWrapper).toBe(returnConn); - }); - - it("test_get_verified_connection__writer_host_info_none", async () => { - const target: StaleDnsHelper = spy(new StaleDnsHelper(instance(mockPluginService))); - const targetInstance = instance(target); - when(mockPluginService.getHosts()).thenReturn(readerHostList); - when(mockPluginService.getAllHosts()).thenReturn(readerHostList); - const mockHostListProviderServiceInstance = instance(mockHostListProviderService); - - const firstCall = { address: "5.5.5.5", family: 0 }; - const secondCall = { address: "", family: 0 }; - - when(target.lookupResult(anything())).thenResolve(firstCall, secondCall); - - const returnConn = await targetInstance.getVerifiedConnection( - writerCluster.host, - true, - mockHostListProviderServiceInstance, - props, - mockConnectFunc - ); - - expect(mockConnectFunc).toHaveBeenCalled(); - expect(mockInitialClientWrapper).toBe(returnConn); - verify(mockPluginService.connect(anything(), anything())).never(); - }); - - it("test_get_verified_connection__writer_host_address_equals_cluster_inet_address", async () => { - const target: StaleDnsHelper = spy(new StaleDnsHelper(instance(mockPluginService))); - const targetInstance = instance(target); - when(mockPluginService.getHosts()).thenReturn(instanceHostList); - when(mockPluginService.getAllHosts()).thenReturn(instanceHostList); - const mockHostListProviderServiceInstance = instance(mockHostListProviderService); - - const firstCall = { address: "5.5.5.5", family: 0 }; - const secondCall = { address: "5.5.5.5", family: 0 }; - - when(target.lookupResult(anything())).thenResolve(firstCall, secondCall); - - const returnConn = await targetInstance.getVerifiedConnection( - writerCluster.host, - true, - mockHostListProviderServiceInstance, - props, - mockConnectFunc - ); - - expect(mockConnectFunc).toHaveBeenCalled(); - expect(mockInitialClientWrapper).toBe(returnConn); - verify(mockPluginService.connect(anything(), anything())).never(); - }); - - it("test_get_verified_connection__writer_host_address_not_equals_cluster_inet_address", async () => { - const target: StaleDnsHelper = spy(new StaleDnsHelper(instance(mockPluginService))); - const targetInstance = instance(target); - - when(mockPluginService.getHosts()).thenReturn(clusterHostList); - when(mockPluginService.getAllHosts()).thenReturn(clusterHostList); - const mockHostListProviderServiceInstance = instance(mockHostListProviderService); - targetInstance["writerHostInfo"] = writerCluster; - - const firstCall = { address: "5.5.5.5", family: 0 }; - const secondCall = { address: "8.8.8.8", family: 0 }; - - when(target.lookupResult(anything())).thenResolve(firstCall, secondCall); - - const returnConn = await targetInstance.getVerifiedConnection( - writerCluster.host, - false, - mockHostListProviderServiceInstance, - props, - mockConnectFunc - ); - - expect(mockInitialConn.targetClient).not.toBe(returnConn); - expect(mockConnectFunc).toHaveBeenCalled(); - verify(mockPluginService.connect(anything(), anything())).once(); - }); - - it("test_get_verified_connection__initial_connection_writer_host_address_not_equals_cluster_inet_address", async () => { - const target: StaleDnsHelper = spy(new StaleDnsHelper(instance(mockPluginService))); - const targetInstance = instance(target); - - when(mockPluginService.getHosts()).thenReturn(clusterHostList); - when(mockPluginService.getAllHosts()).thenReturn(clusterHostList); - const mockHostListProviderServiceInstance = instance(mockHostListProviderService); - targetInstance["writerHostInfo"] = writerCluster; - when(mockHostListProviderService.getInitialConnectionHostInfo()).thenReturn(writerCluster); - - const firstCall = { address: "5.5.5.5", family: 0 }; - const secondCall = { address: "8.8.8.8", family: 0 }; - - when(target.lookupResult(anything())).thenResolve(firstCall, secondCall); - - const returnConn = await targetInstance.getVerifiedConnection( - writerCluster.host, - true, - mockHostListProviderServiceInstance, - props, - mockConnectFunc - ); - - verify(mockPluginService.connect(anything(), anything())).once(); - expect(targetInstance["writerHostInfo"]).toBe(mockHostListProviderServiceInstance.getInitialConnectionHostInfo()); - expect(mockInitialConn.targetClient).not.toBe(returnConn); - }); - it("test_notify_host_list_changed", () => { const target: StaleDnsHelper = spy(new StaleDnsHelper(instance(mockPluginService))); const targetInstance = instance(target); @@ -301,6 +91,5 @@ describe("test_stale_dns_helper", () => { targetInstance.notifyHostListChanged(changes); expect(targetInstance["writerHostInfo"]).toBeNull(); - expect(targetInstance["writerHostAddress"]).toBe(""); }); }); diff --git a/tests/unit/storage_service.test.ts b/tests/unit/storage_service.test.ts index 25d32e487..a52d739c9 100644 --- a/tests/unit/storage_service.test.ts +++ b/tests/unit/storage_service.test.ts @@ -1,18 +1,18 @@ /* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ + Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"). + You may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ import { StorageService } from "../../common/lib/utils/storage/storage_service"; import { Topology } from "../../common/lib/host_list_provider/topology"; @@ -24,7 +24,7 @@ describe("test_storage_service", () => { let storageService: StorageService; beforeEach(() => { - storageService = CoreServicesContainer.getInstance().getStorageService(); + storageService = CoreServicesContainer.getInstance().storageService; }); afterEach(() => { diff --git a/tests/unit/topology_utils.test.ts b/tests/unit/topology_utils.test.ts index a757c8a6c..b3e927c0a 100644 --- a/tests/unit/topology_utils.test.ts +++ b/tests/unit/topology_utils.test.ts @@ -14,9 +14,10 @@ limitations under the License. */ -import { TopologyQueryResult, TopologyUtils } from "../../common/lib/host_list_provider/topology_utils"; +import { TopologyQueryResult } from "../../common/lib/host_list_provider/topology_utils"; +import { AuroraTopologyUtils } from "../../common/lib/host_list_provider/aurora_topology_utils"; import { anything, instance, mock, reset, when } from "ts-mockito"; -import { HostInfo, HostInfoBuilder } from "../../common/lib"; +import { HostInfo, HostInfoBuilder, PluginManager } from "../../common/lib"; import { SimpleHostAvailabilityStrategy } from "../../common/lib/host_availability/simple_host_availability_strategy"; import { AuroraPgDatabaseDialect } from "../../pg/lib/dialect/aurora_pg_database_dialect"; import { ClientWrapper } from "../../common/lib/client_wrapper"; @@ -43,8 +44,8 @@ function createHost(config: any): HostInfo { return info.build(); } -function getTopologyUtils(): TopologyUtils { - return new TopologyUtils(instance(mockDialect), hostInfoBuilder); +function getTopologyUtils(): AuroraTopologyUtils { + return new AuroraTopologyUtils(instance(mockDialect), hostInfoBuilder); } describe("testTopologyUtils", () => { @@ -54,9 +55,13 @@ describe("testTopologyUtils", () => { reset(mockNonTopologyDialect); }); + afterEach(async () => { + await PluginManager.releaseResources(); + }); + it("testQueryForTopology_withNonTopologyAwareDialect_throwsError", async () => { const hostInfoBuilder = new HostInfoBuilder({ hostAvailabilityStrategy: new SimpleHostAvailabilityStrategy() }); - const topologyUtils = new TopologyUtils(instance(mockNonTopologyDialect) as any, hostInfoBuilder); + const topologyUtils = new AuroraTopologyUtils(instance(mockNonTopologyDialect) as any, hostInfoBuilder); const initialHost = createHost({ host: "initial-host", diff --git a/tests/unit/writer_failover_handler.test.ts b/tests/unit/writer_failover_handler.test.ts index 2d4dcf468..bb07d9d92 100644 --- a/tests/unit/writer_failover_handler.test.ts +++ b/tests/unit/writer_failover_handler.test.ts @@ -211,7 +211,7 @@ describe("writer failover handler", () => { expect(result.topology.length).toBe(4); expect(result.topology[0].host).toBe("new-writer-host"); - verify(mockPluginService.forceRefreshHostList(anything())).atLeast(1); + verify(mockPluginService.forceRefreshHostList()).atLeast(1); verify(mockPluginService.setAvailability(newWriterHost.allAliases, HostAvailability.AVAILABLE)).once(); clearTimeout(timeoutId); }, 10000);