diff --git a/drivers/mysql/internal/backfill.go b/drivers/mysql/internal/backfill.go index 366a94b74..19573230f 100644 --- a/drivers/mysql/internal/backfill.go +++ b/drivers/mysql/internal/backfill.go @@ -79,9 +79,14 @@ func (m *MySQL) backfill(pool *protocol.WriterPool, stream protocol.Stream) erro // Begin transaction with repeatable read isolation return jdbc.WithIsolation(backfillCtx, m.client, func(tx *sql.Tx) error { // Build query for the chunk - stmt := jdbc.MysqlChunkScanQuery(stream, pkColumn, chunk) + stmt := jdbc.MySQLChunkScanQuery(stream, pkColumn, chunk) setter := jdbc.NewReader(backfillCtx, stmt, 0, func(ctx context.Context, query string, args ...any) (*sql.Rows, error) { - return tx.QueryContext(ctx, query, args...) + if chunk.Min != nil && chunk.Max != nil { + return tx.QueryContext(ctx, query, chunk.Min, chunk.Max) + } else if chunk.Min != nil { + return tx.QueryContext(ctx, query, chunk.Min) + } + return tx.QueryContext(ctx, query, chunk.Max) }) // Capture and process rows return setter.Capture(func(rows *sql.Rows) error { @@ -134,7 +139,7 @@ func (m *MySQL) splitChunks(stream protocol.Stream, chunks *types.Set[types.Chun } // Generate chunks based on range - query := jdbc.NextChunkEndQuery(stream, pkColumn, chunkSize) + query := jdbc.NextChunkEndQuery(stream, pkColumn) currentVal := minVal for { diff --git a/drivers/postgres/internal/backfill.go b/drivers/postgres/internal/backfill.go index a407e1cff..3dd56854b 100644 --- a/drivers/postgres/internal/backfill.go +++ b/drivers/postgres/internal/backfill.go @@ -19,8 +19,8 @@ import ( func (p *Postgres) backfill(pool *protocol.WriterPool, stream protocol.Stream) error { backfillCtx := context.TODO() var approxRowCount int64 - approxRowCountQuery := jdbc.PostgresRowCountQuery(stream) - err := p.client.QueryRow(approxRowCountQuery).Scan(&approxRowCount) + approxRowCountQuery := jdbc.PostgresRowCountQuery() + err := p.client.QueryRow(approxRowCountQuery, stream.Name(), stream.Namespace()).Scan(&approxRowCount) if err != nil { return fmt.Errorf("failed to get approx row count: %s", err) } @@ -55,7 +55,12 @@ func (p *Postgres) backfill(pool *protocol.WriterPool, stream protocol.Stream) e stmt := jdbc.PostgresChunkScanQuery(stream, splitColumn, chunk) setter := jdbc.NewReader(backfillCtx, stmt, p.config.BatchSize, func(ctx context.Context, query string, args ...any) (*sql.Rows, error) { - return tx.Query(query, args...) + if chunk.Min != nil && chunk.Max != nil { + return tx.Query(query, chunk.Min, chunk.Max) + } else if chunk.Min != nil { + return tx.Query(query, chunk.Min) + } + return tx.Query(query, chunk.Max) }) batchStartTime := time.Now() waitChannel := make(chan error, 1) @@ -102,8 +107,8 @@ func (p *Postgres) backfill(pool *protocol.WriterPool, stream protocol.Stream) e func (p *Postgres) splitTableIntoChunks(stream protocol.Stream) ([]types.Chunk, error) { generateCTIDRanges := func(stream protocol.Stream) ([]types.Chunk, error) { var relPages uint32 - relPagesQuery := jdbc.PostgresRelPageCount(stream) - err := p.client.QueryRow(relPagesQuery).Scan(&relPages) + relPagesQuery := jdbc.PostgresRelPageCount() + err := p.client.QueryRow(relPagesQuery, stream.Name(), stream.Namespace()).Scan(&relPages) if err != nil { return nil, fmt.Errorf("failed to get relPages: %s", err) } @@ -193,8 +198,8 @@ func (p *Postgres) splitTableIntoChunks(stream protocol.Stream) ([]types.Chunk, func (p *Postgres) nextChunkEnd(stream protocol.Stream, previousChunkEnd interface{}, splitColumn string) (interface{}, error) { var chunkEnd interface{} - nextChunkEnd := jdbc.PostgresNextChunkEndQuery(stream, splitColumn, previousChunkEnd, p.config.BatchSize) - err := p.client.QueryRow(nextChunkEnd).Scan(&chunkEnd) + nextChunkEnd := jdbc.PostgresNextChunkEndQuery(stream, splitColumn) + err := p.client.QueryRow(nextChunkEnd, previousChunkEnd, p.config.BatchSize).Scan(&chunkEnd) if err != nil { return nil, fmt.Errorf("failed to query[%s] next chunk end: %s", nextChunkEnd, err) } diff --git a/pkg/jdbc/jdbc.go b/pkg/jdbc/jdbc.go index 959d4460e..37aa5ab91 100644 --- a/pkg/jdbc/jdbc.go +++ b/pkg/jdbc/jdbc.go @@ -15,18 +15,29 @@ func MinMaxQuery(stream protocol.Stream, column string) string { } // NextChunkEndQuery returns the query to calculate the next chunk boundary -func NextChunkEndQuery(stream protocol.Stream, column string, chunkSize int) string { - return fmt.Sprintf(`SELECT MAX(%[1]s) FROM (SELECT %[1]s FROM %[2]s.%[3]s WHERE %[1]s > ? ORDER BY %[1]s LIMIT %[4]d) AS subquery`, column, stream.Namespace(), stream.Name(), chunkSize) +// ?: is the filter value, ?: is the batch size +func NextChunkEndQuery(stream protocol.Stream, column string) string { + return fmt.Sprintf(`SELECT MAX(%[1]s) FROM (SELECT %[1]s FROM %[2]s.%[3]s WHERE %[1]s > ? ORDER BY %[1]s LIMIT ?) AS subquery`, column, stream.Namespace(), stream.Name()) } -// buildChunkCondition builds the condition for a chunk -func buildChunkCondition(filterColumn string, chunk types.Chunk) string { +// PostgresBuildChunkCondition builds the condition for a chunk +func PostgresBuildChunkCondition(filterColumn string, chunk types.Chunk) string { if chunk.Min != nil && chunk.Max != nil { - return fmt.Sprintf("%s >= %v AND %s <= %v", filterColumn, chunk.Min, filterColumn, chunk.Max) + return fmt.Sprintf("%s >= $1 AND %s <= $2", filterColumn, filterColumn) } else if chunk.Min != nil { - return fmt.Sprintf("%s >= %v", filterColumn, chunk.Min) + return fmt.Sprintf("%s >= $1", filterColumn) } - return fmt.Sprintf("%s <= %v", filterColumn, chunk.Max) + return fmt.Sprintf("%s <= $1", filterColumn) +} + +// MySQLBuildChunkCondition builds the condition for a chunk +func MySQLBuildChunkCondition(filterColumn string, chunk types.Chunk) string { + if chunk.Min != nil && chunk.Max != nil { + return fmt.Sprintf("%s >= ? AND %s <= ?", filterColumn, filterColumn) + } else if chunk.Min != nil { + return fmt.Sprintf("%s >= ?", filterColumn) + } + return fmt.Sprintf("%s <= ?", filterColumn) } // PostgreSQL-Specific Queries @@ -43,13 +54,19 @@ func PostgresWithState(stream protocol.Stream) string { } // PostgresRowCountQuery returns the query to fetch the estimated row count in PostgreSQL -func PostgresRowCountQuery(stream protocol.Stream) string { - return fmt.Sprintf(`SELECT reltuples::bigint AS approx_row_count FROM pg_class c JOIN pg_namespace n ON n.oid = c.relnamespace WHERE c.relname = '%s' AND n.nspname = '%s';`, stream.Name(), stream.Namespace()) +// args to be passed: +// $1: stream name, +// $2: stream namespace +func PostgresRowCountQuery() string { + return `SELECT reltuples::bigint AS approx_row_count FROM pg_class c JOIN pg_namespace n ON n.oid = c.relnamespace WHERE c.relname = $1 AND n.nspname = $2;` } // PostgresRelPageCount returns the query to fetch relation page count in PostgreSQL -func PostgresRelPageCount(stream protocol.Stream) string { - return fmt.Sprintf(`SELECT relpages FROM pg_class WHERE relname = '%s' AND relnamespace = (SELECT oid FROM pg_namespace WHERE nspname = '%s')`, stream.Name(), stream.Namespace()) +// args to be passed: +// $1: stream name, +// $2: stream namespace +func PostgresRelPageCount() string { + return `SELECT relpages FROM pg_class WHERE relname = $1 AND relnamespace = (SELECT oid FROM pg_namespace WHERE nspname = $2);` } // PostgresWalLSNQuery returns the query to fetch the current WAL LSN in PostgreSQL @@ -58,30 +75,41 @@ func PostgresWalLSNQuery() string { } // PostgresNextChunkEndQuery generates a SQL query to fetch the maximum value of a specified column -func PostgresNextChunkEndQuery(stream protocol.Stream, filterColumn string, filterValue interface{}, batchSize int) string { - return fmt.Sprintf(`SELECT MAX(%s) FROM (SELECT %s FROM "%s"."%s" WHERE %s > %v ORDER BY %s ASC LIMIT %d) AS T`, filterColumn, filterColumn, stream.Namespace(), stream.Name(), filterColumn, filterValue, filterColumn, batchSize) +// args to be passed: +// $1: filter value, +// $2: batch size +func PostgresNextChunkEndQuery(stream protocol.Stream, filterColumn string) string { + return fmt.Sprintf(`SELECT MAX(%s) FROM (SELECT %s FROM "%s"."%s" WHERE %s > $1 ORDER BY %s ASC LIMIT $2) AS T`, filterColumn, filterColumn, stream.Namespace(), stream.Name(), filterColumn, filterColumn) } // PostgresMinQuery returns the query to fetch the minimum value of a column in PostgreSQL -func PostgresMinQuery(stream protocol.Stream, filterColumn string, filterValue interface{}) string { - return fmt.Sprintf(`SELECT MIN(%s) FROM "%s"."%s" WHERE %s > %v`, filterColumn, stream.Namespace(), stream.Name(), filterColumn, filterValue) +// args to be passed: +// $1: filter value, +func PostgresMinQuery(stream protocol.Stream, filterColumn string) string { + return fmt.Sprintf(`SELECT MIN(%s) FROM "%s"."%s" WHERE %s > $1`, filterColumn, stream.Namespace(), stream.Name(), filterColumn) } -// PostgresBuildSplitScanQuery builds a chunk scan query for PostgreSQL +// PostgresChunkScanQuery builds a chunk scan query for PostgreSQL +// args to be passed: +// Chunk.Min/Chunk.Max: filter value, func PostgresChunkScanQuery(stream protocol.Stream, filterColumn string, chunk types.Chunk) string { - condition := buildChunkCondition(filterColumn, chunk) + condition := PostgresBuildChunkCondition(filterColumn, chunk) return fmt.Sprintf(`SELECT * FROM "%s"."%s" WHERE %s`, stream.Namespace(), stream.Name(), condition) } // MySQL-Specific Queries -// MySQLWithoutState builds a chunk scan query for MySql -func MysqlChunkScanQuery(stream protocol.Stream, filterColumn string, chunk types.Chunk) string { - condition := buildChunkCondition(filterColumn, chunk) +// MySQLChunkScanQuery builds a chunk scan query for MySql +// args to be passed: +// Chunk.Min/Chunk.Max: filter value, +func MySQLChunkScanQuery(stream protocol.Stream, filterColumn string, chunk types.Chunk) string { + condition := MySQLBuildChunkCondition(filterColumn, chunk) return fmt.Sprintf("SELECT * FROM `%s`.`%s` WHERE %s", stream.Namespace(), stream.Name(), condition) } // MySQLDiscoverTablesQuery returns the query to discover tables in a MySQL database +// Args: +// ?: schema name (string) func MySQLDiscoverTablesQuery() string { return ` SELECT @@ -96,6 +124,9 @@ func MySQLDiscoverTablesQuery() string { } // MySQLTableSchemaQuery returns the query to fetch schema information for a table in MySQL +// Args: +// ?: schema name (string) +// ?: table name (string) func MySQLTableSchemaQuery() string { return ` SELECT @@ -114,6 +145,8 @@ func MySQLTableSchemaQuery() string { } // MySQLPrimaryKeyQuery returns the query to fetch the primary key column of a table in MySQL +// Args: +// ?: table name (string) func MySQLPrimaryKeyQuery() string { return ` SELECT COLUMN_NAME @@ -126,6 +159,8 @@ func MySQLPrimaryKeyQuery() string { } // MySQLTableRowsQuery returns the query to fetch the estimated row count of a table in MySQL +// Args: +// ?: table name (string) func MySQLTableRowsQuery() string { return ` SELECT TABLE_ROWS @@ -141,6 +176,9 @@ func MySQLMasterStatusQuery() string { } // MySQLTableColumnsQuery returns the query to fetch column names of a table in MySQL +// Args: +// ?: schema name (string) +// ?: table name (string) func MySQLTableColumnsQuery() string { return ` SELECT COLUMN_NAME @@ -149,6 +187,7 @@ func MySQLTableColumnsQuery() string { ORDER BY ORDINAL_POSITION ` } + func WithIsolation(ctx context.Context, client *sql.DB, fn func(tx *sql.Tx) error) error { tx, err := client.BeginTx(ctx, &sql.TxOptions{ Isolation: sql.LevelRepeatableRead,