diff --git a/.github/codecov.yml b/.github/codecov.yml index 6f56dbbb6d..cf315e6cf9 100644 --- a/.github/codecov.yml +++ b/.github/codecov.yml @@ -16,13 +16,8 @@ component_management: - component_id: spanner-templates name: spanner-templates paths: - - "v1/src/main/java/com/google/cloud/teleport/spanner/**" - - "v1/src/main/java/com/google/cloud/teleport/templates/SpannerToText.java" - - "v1/src/main/java/com/google/cloud/teleport/templates/common/SpannerConverters.java" - "v2/datastream-to-spanner/**" - "v2/spanner-common/**" - - "v2/spanner-change-streams-to-sharded-file-sink/**" - - "v2/gcs-to-sourcedb/**" - "v2/spanner-migrations-sdk/**" - "v2/spanner-custom-shard/**" - "v2/sourcedb-to-spanner/**" @@ -30,10 +25,9 @@ component_management: - "v2/gcs-spanner-dv/**" statuses: - type: project - informational: true + target: 80% - type: patch target: 80% - informational: true - component_id: spanner-import-export name: spanner-import-export paths: @@ -64,3 +58,14 @@ component_management: paths: - "v2/gcs-spanner-dv/**" - "v2/spanner-common/**" + +# Flags are used to identify reports from different workflows. +# In a monorepo with path-based triggers (like this one), not all workflows run on every commit. +# Enabling 'carryforward' allows Codecov to use the coverage report from a previous commit +# for workflows that were skipped in the current commit, preventing status checks from hanging +# and providing a complete picture of the codebase coverage. +flags: + spanner: + carryforward: true + java: + carryforward: true diff --git a/.github/workflows/bigtable-pr.yml b/.github/workflows/bigtable-pr.yml index 599a54654f..3cd4b695aa 100644 --- a/.github/workflows/bigtable-pr.yml +++ b/.github/workflows/bigtable-pr.yml @@ -122,6 +122,7 @@ jobs: token: ${{ secrets.CODECOV_TOKEN }} slug: GoogleCloudPlatform/DataflowTemplates files: 'target/site/jacoco-aggregate/jacoco.xml' + flags: java - name: Cleanup Java Environment uses: ./.github/actions/cleanup-java-env java_integration_smoke_tests_templates: diff --git a/.github/workflows/datastream-pr.yml b/.github/workflows/datastream-pr.yml index 3cea3f84a7..0a33156f15 100644 --- a/.github/workflows/datastream-pr.yml +++ b/.github/workflows/datastream-pr.yml @@ -125,6 +125,7 @@ jobs: token: ${{ secrets.CODECOV_TOKEN }} slug: GoogleCloudPlatform/DataflowTemplates files: 'target/site/jacoco-aggregate/jacoco.xml' + flags: java - name: Cleanup Java Environment uses: ./.github/actions/cleanup-java-env java_integration_smoke_tests_templates: diff --git a/.github/workflows/java-pr.yml b/.github/workflows/java-pr.yml index 963e0876d8..351a690139 100644 --- a/.github/workflows/java-pr.yml +++ b/.github/workflows/java-pr.yml @@ -129,6 +129,7 @@ jobs: token: ${{ secrets.CODECOV_TOKEN }} slug: GoogleCloudPlatform/DataflowTemplates files: 'target/site/jacoco-aggregate/jacoco.xml' + flags: java - name: Cleanup Java Environment uses: ./.github/actions/cleanup-java-env if: always() diff --git a/.github/workflows/kafka-pr.yml b/.github/workflows/kafka-pr.yml index b2562482a1..d6868d124d 100644 --- a/.github/workflows/kafka-pr.yml +++ b/.github/workflows/kafka-pr.yml @@ -125,6 +125,7 @@ jobs: token: ${{ secrets.CODECOV_TOKEN }} slug: GoogleCloudPlatform/DataflowTemplates files: 'target/site/jacoco-aggregate/jacoco.xml' + flags: java - name: Cleanup Java Environment uses: ./.github/actions/cleanup-java-env java_integration_smoke_tests_templates: diff --git a/.github/workflows/spanner-pr.yml b/.github/workflows/spanner-pr.yml index 190e205518..670dd62732 100644 --- a/.github/workflows/spanner-pr.yml +++ b/.github/workflows/spanner-pr.yml @@ -132,12 +132,16 @@ jobs: **/surefire-reports/*.html **/surefire-reports/html/** retention-days: 1 + # The 'spanner' flag identifies reports from this workflow. + # Combined with 'carryforward: true' in codecov.yml, it allows Codecov to reuse + # reports from previous commits when this workflow is skipped by path-based triggers. - name: Upload coverage reports to Codecov uses: codecov/codecov-action@v6.0.1 with: token: ${{ secrets.CODECOV_TOKEN }} slug: GoogleCloudPlatform/DataflowTemplates files: 'target/site/jacoco-aggregate/jacoco.xml' + flags: spanner - name: Cleanup Java Environment uses: ./.github/actions/cleanup-java-env java_integration_smoke_tests_templates: diff --git a/.github/workflows/yaml-pr.yml b/.github/workflows/yaml-pr.yml index 159ad25fc1..750406e47c 100644 --- a/.github/workflows/yaml-pr.yml +++ b/.github/workflows/yaml-pr.yml @@ -151,6 +151,7 @@ jobs: token: ${{ secrets.CODECOV_TOKEN }} slug: GoogleCloudPlatform/DataflowTemplates files: 'target/site/jacoco-aggregate/jacoco.xml' + flags: java - name: Cleanup Java Environment uses: ./.github/actions/cleanup-java-env if: always() diff --git a/v2/spanner-common/src/test/java/com/google/cloud/teleport/v2/spanner/ddl/ColumnTest.java b/v2/spanner-common/src/test/java/com/google/cloud/teleport/v2/spanner/ddl/ColumnTest.java new file mode 100644 index 0000000000..2ad0ddf62c --- /dev/null +++ b/v2/spanner-common/src/test/java/com/google/cloud/teleport/v2/spanner/ddl/ColumnTest.java @@ -0,0 +1,270 @@ +/* + * Copyright (C) 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.teleport.v2.spanner.ddl; + +import static org.junit.Assert.assertEquals; + +import com.google.cloud.spanner.Dialect; +import com.google.cloud.teleport.v2.spanner.type.Type; +import org.junit.Test; + +public class ColumnTest { + + @Test + public void testTypeString_GoogleStandardSQL() { + assertEquals("BOOL", Column.builder().name("col").type(Type.bool()).autoBuild().typeString()); + assertEquals("INT64", Column.builder().name("col").type(Type.int64()).autoBuild().typeString()); + assertEquals( + "FLOAT64", Column.builder().name("col").type(Type.float64()).autoBuild().typeString()); + assertEquals( + "FLOAT32", Column.builder().name("col").type(Type.float32()).autoBuild().typeString()); + assertEquals( + "STRING(MAX)", + Column.builder().name("col").type(Type.string()).size(-1).autoBuild().typeString()); + assertEquals( + "STRING(10)", + Column.builder().name("col").type(Type.string()).size(10).autoBuild().typeString()); + assertEquals( + "BYTES(MAX)", + Column.builder().name("col").type(Type.bytes()).size(-1).autoBuild().typeString()); + assertEquals( + "BYTES(20)", + Column.builder().name("col").type(Type.bytes()).size(20).autoBuild().typeString()); + assertEquals("DATE", Column.builder().name("col").type(Type.date()).autoBuild().typeString()); + assertEquals( + "TIMESTAMP", Column.builder().name("col").type(Type.timestamp()).autoBuild().typeString()); + assertEquals( + "NUMERIC", Column.builder().name("col").type(Type.numeric()).autoBuild().typeString()); + assertEquals("JSON", Column.builder().name("col").type(Type.json()).autoBuild().typeString()); + assertEquals( + "TOKENLIST", Column.builder().name("col").type(Type.tokenlist()).autoBuild().typeString()); + assertEquals( + "ARRAY", + Column.builder().name("col").type(Type.array(Type.int64())).autoBuild().typeString()); + } + + @Test + public void testTypeString_PostgreSQL() { + assertEquals( + "boolean", + Column.builder(Dialect.POSTGRESQL) + .name("col") + .type(Type.pgBool()) + .autoBuild() + .typeString()); + assertEquals( + "bigint", + Column.builder(Dialect.POSTGRESQL) + .name("col") + .type(Type.pgInt8()) + .autoBuild() + .typeString()); + assertEquals( + "double precision", + Column.builder(Dialect.POSTGRESQL) + .name("col") + .type(Type.pgFloat8()) + .autoBuild() + .typeString()); + assertEquals( + "real", + Column.builder(Dialect.POSTGRESQL) + .name("col") + .type(Type.pgFloat4()) + .autoBuild() + .typeString()); + assertEquals( + "character varying", + Column.builder(Dialect.POSTGRESQL) + .name("col") + .type(Type.pgVarchar()) + .size(-1) + .autoBuild() + .typeString()); + assertEquals( + "character varying(10)", + Column.builder(Dialect.POSTGRESQL) + .name("col") + .type(Type.pgVarchar()) + .size(10) + .autoBuild() + .typeString()); + assertEquals( + "text", + Column.builder(Dialect.POSTGRESQL) + .name("col") + .type(Type.pgText()) + .autoBuild() + .typeString()); + assertEquals( + "bytea", + Column.builder(Dialect.POSTGRESQL) + .name("col") + .type(Type.pgBytea()) + .autoBuild() + .typeString()); + assertEquals( + "date", + Column.builder(Dialect.POSTGRESQL) + .name("col") + .type(Type.pgDate()) + .autoBuild() + .typeString()); + assertEquals( + "timestamp with time zone", + Column.builder(Dialect.POSTGRESQL) + .name("col") + .type(Type.pgTimestamptz()) + .autoBuild() + .typeString()); + assertEquals( + "numeric", + Column.builder(Dialect.POSTGRESQL) + .name("col") + .type(Type.pgNumeric()) + .autoBuild() + .typeString()); + assertEquals( + "jsonb", + Column.builder(Dialect.POSTGRESQL) + .name("col") + .type(Type.pgJsonb()) + .autoBuild() + .typeString()); + assertEquals( + "spanner.commit_timestamp", + Column.builder(Dialect.POSTGRESQL) + .name("col") + .type(Type.pgCommitTimestamp()) + .autoBuild() + .typeString()); + assertEquals( + "bigint[]", + Column.builder(Dialect.POSTGRESQL) + .name("col") + .type(Type.pgArray(Type.pgInt8())) + .autoBuild() + .typeString()); + } + + @Test + public void testParseSpannerType_GoogleStandardSQL() { + assertEquals(Type.bool(), Column.builder().name("col").parseType("BOOL").autoBuild().type()); + assertEquals(Type.int64(), Column.builder().name("col").parseType("INT64").autoBuild().type()); + assertEquals( + Type.float64(), Column.builder().name("col").parseType("FLOAT64").autoBuild().type()); + assertEquals( + Type.float32(), Column.builder().name("col").parseType("FLOAT32").autoBuild().type()); + assertEquals( + Type.string(), Column.builder().name("col").parseType("STRING(MAX)").autoBuild().type()); + assertEquals( + Type.bytes(), Column.builder().name("col").parseType("BYTES(MAX)").autoBuild().type()); + assertEquals(Type.date(), Column.builder().name("col").parseType("DATE").autoBuild().type()); + assertEquals( + Type.timestamp(), Column.builder().name("col").parseType("TIMESTAMP").autoBuild().type()); + assertEquals( + Type.numeric(), Column.builder().name("col").parseType("NUMERIC").autoBuild().type()); + assertEquals(Type.json(), Column.builder().name("col").parseType("JSON").autoBuild().type()); + assertEquals( + Type.tokenlist(), Column.builder().name("col").parseType("TOKENLIST").autoBuild().type()); + assertEquals( + Type.array(Type.int64()), + Column.builder().name("col").parseType("ARRAY").autoBuild().type()); + } + + @Test + public void testParseSpannerType_PostgreSQL() { + assertEquals( + Type.pgBool(), + Column.builder(Dialect.POSTGRESQL).name("col").parseType("boolean").autoBuild().type()); + assertEquals( + Type.pgInt8(), + Column.builder(Dialect.POSTGRESQL).name("col").parseType("bigint").autoBuild().type()); + assertEquals( + Type.pgFloat8(), + Column.builder(Dialect.POSTGRESQL) + .name("col") + .parseType("double precision") + .autoBuild() + .type()); + assertEquals( + Type.pgFloat4(), + Column.builder(Dialect.POSTGRESQL).name("col").parseType("real").autoBuild().type()); + assertEquals( + Type.pgText(), + Column.builder(Dialect.POSTGRESQL).name("col").parseType("text").autoBuild().type()); + assertEquals( + Type.pgVarchar(), + Column.builder(Dialect.POSTGRESQL) + .name("col") + .parseType("character varying") + .autoBuild() + .type()); + assertEquals( + Type.pgBytea(), + Column.builder(Dialect.POSTGRESQL).name("col").parseType("bytea").autoBuild().type()); + assertEquals( + Type.pgTimestamptz(), + Column.builder(Dialect.POSTGRESQL) + .name("col") + .parseType("timestamp with time zone") + .autoBuild() + .type()); + assertEquals( + Type.pgNumeric(), + Column.builder(Dialect.POSTGRESQL).name("col").parseType("numeric").autoBuild().type()); + assertEquals( + Type.pgJsonb(), + Column.builder(Dialect.POSTGRESQL).name("col").parseType("jsonb").autoBuild().type()); + assertEquals( + Type.pgDate(), + Column.builder(Dialect.POSTGRESQL).name("col").parseType("date").autoBuild().type()); + assertEquals( + Type.pgCommitTimestamp(), + Column.builder(Dialect.POSTGRESQL) + .name("col") + .parseType("spanner.commit_timestamp") + .autoBuild() + .type()); + assertEquals( + Type.pgArray(Type.pgInt8()), + Column.builder(Dialect.POSTGRESQL).name("col").parseType("bigint[]").autoBuild().type()); + } + + @Test + public void testPrettyPrint() { + Column c = Column.builder().name("col1").type(Type.bool()).notNull(true).autoBuild(); + assertEquals("`col1` BOOL NOT NULL", c.prettyPrint()); + + Column c2 = Column.builder().name("col2").type(Type.string()).size(10).autoBuild(); + assertEquals("`col2` STRING(10)", c2.prettyPrint()); + + Column c3 = + Column.builder().name("col3").type(Type.int64()).generatedAs("1+1").stored().autoBuild(); + assertEquals("`col3` INT64 AS (1+1) STORED", c3.prettyPrint()); + + Column c4 = + Column.builder(Dialect.POSTGRESQL) + .name("col4") + .type(Type.pgInt8()) + .generatedAs("1+1") + .stored() + .autoBuild(); + assertEquals( + "\"col4\" bigint GENERATED ALWAYS AS (1+1) STORED", + c4.prettyPrint()); + } +} diff --git a/v2/spanner-common/src/test/java/com/google/cloud/teleport/v2/spanner/ddl/DdlTest.java b/v2/spanner-common/src/test/java/com/google/cloud/teleport/v2/spanner/ddl/DdlTest.java index 32298f6985..94306ab28d 100644 --- a/v2/spanner-common/src/test/java/com/google/cloud/teleport/v2/spanner/ddl/DdlTest.java +++ b/v2/spanner-common/src/test/java/com/google/cloud/teleport/v2/spanner/ddl/DdlTest.java @@ -927,11 +927,72 @@ public void testDdlEquals() { Ddl ddl1 = Ddl.builder(Dialect.GOOGLE_STANDARD_SQL).build(); Ddl ddl2 = Ddl.builder(Dialect.POSTGRESQL).build(); assertFalse(ddl1.equals(ddl2)); + Ddl.Builder ddl1Builder = Ddl.builder().createTable("Users").column("id").int64().endColumn().endTable(); - ddl1Builder.createTable("Users"); - ddl1 = ddl1Builder.build(); - assertFalse(ddl1.equals(ddl2)); + Ddl ddl3 = ddl1Builder.build(); + + // Test same object + assertTrue(ddl3.equals(ddl3)); + + // Test null + assertFalse(ddl3.equals(null)); + + // Test different class + assertFalse(ddl3.equals("string")); + + // Test equal objects + Ddl ddl4 = + Ddl.builder().createTable("Users").column("id").int64().endColumn().endTable().build(); + assertTrue(ddl3.equals(ddl4)); + + // Test different tables + Ddl ddl5 = + Ddl.builder() + .createTable("DifferentTable") + .column("id") + .int64() + .endColumn() + .endTable() + .build(); + assertFalse(ddl3.equals(ddl5)); + + // Test different parents (interleaving) + Ddl ddl6 = + Ddl.builder() + .createTable("Users") + .column("id") + .int64() + .endColumn() + .primaryKey() + .asc("id") + .end() + .endTable() + .createTable("Account") + .column("id") + .int64() + .endColumn() + .interleavingParent("Users") + .interleaveType("IN PARENT") + .endTable() + .build(); + Ddl ddl7 = + Ddl.builder() + .createTable("Users") + .column("id") + .int64() + .endColumn() + .primaryKey() + .asc("id") + .end() + .endTable() + .createTable("Account") + .column("id") + .int64() + .endColumn() + .endTable() // No interleave! + .build(); + assertFalse(ddl6.equals(ddl7)); } @Rule public ExpectedException thrown = ExpectedException.none(); diff --git a/v2/spanner-common/src/test/java/com/google/cloud/teleport/v2/spanner/ddl/InformationSchemaScannerTest.java b/v2/spanner-common/src/test/java/com/google/cloud/teleport/v2/spanner/ddl/InformationSchemaScannerTest.java index fa92ce41da..5b6aea80fe 100644 --- a/v2/spanner-common/src/test/java/com/google/cloud/teleport/v2/spanner/ddl/InformationSchemaScannerTest.java +++ b/v2/spanner-common/src/test/java/com/google/cloud/teleport/v2/spanner/ddl/InformationSchemaScannerTest.java @@ -16,6 +16,8 @@ package com.google.cloud.teleport.v2.spanner.ddl; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -37,12 +39,12 @@ void mockGSQLColumnOptions(ReadContext context) { + " ORDER BY t.table_name, t.column_name"); ResultSet listColumnOptionsResultSet = mock(ResultSet.class); when(context.executeQuery(listColumnOptions)).thenReturn(listColumnOptionsResultSet); - when(listColumnOptionsResultSet.next()).thenReturn(true, false); - when(listColumnOptionsResultSet.getString(0)).thenReturn("singer"); - when(listColumnOptionsResultSet.getString(1)).thenReturn("singerName"); - when(listColumnOptionsResultSet.getString(2)).thenReturn("option1"); - when(listColumnOptionsResultSet.getString(3)).thenReturn("STRING"); - when(listColumnOptionsResultSet.getString(4)).thenReturn("SomeName"); + when(listColumnOptionsResultSet.next()).thenReturn(true, true, false); + when(listColumnOptionsResultSet.getString(0)).thenReturn("singer", "singer"); + when(listColumnOptionsResultSet.getString(1)).thenReturn("singerName", "singerId"); + when(listColumnOptionsResultSet.getString(2)).thenReturn("option1", "option2"); + when(listColumnOptionsResultSet.getString(3)).thenReturn("STRING", "INT64"); + when(listColumnOptionsResultSet.getString(4)).thenReturn("SomeName", "123"); } void mockGSQLIndex(ReadContext context) { @@ -234,11 +236,12 @@ void mockPgSQLIndexColumns(ReadContext context) { + "ORDER BY t.table_name, t.index_name, t.ordinal_position"); ResultSet listIndexColumnsResultSet = mock(ResultSet.class); when(context.executeQuery(listIndexColumns)).thenReturn(listIndexColumnsResultSet); - when(listIndexColumnsResultSet.next()).thenReturn(true, false); - when(listIndexColumnsResultSet.getString(0)).thenReturn("singer"); - when(listIndexColumnsResultSet.getString(1)).thenReturn("singerName"); - when(listIndexColumnsResultSet.isNull(2)).thenReturn(true); - when(listIndexColumnsResultSet.getString(3)).thenReturn("index1"); + when(listIndexColumnsResultSet.next()).thenReturn(true, true, true, false); + when(listIndexColumnsResultSet.getString(0)).thenReturn("singer", "singer", "singer"); + when(listIndexColumnsResultSet.getString(1)).thenReturn("singerName", "singerId", "age"); + when(listIndexColumnsResultSet.isNull(2)).thenReturn(true, false, false); + when(listIndexColumnsResultSet.getString(2)).thenReturn("ASC", "DESC"); + when(listIndexColumnsResultSet.getString(3)).thenReturn("index1", "index1", "index1"); } void mockPgSQLForeignKey(ReadContext context) { @@ -332,18 +335,27 @@ void mockPgSQLListColumns(ReadContext context) { ResultSet listColumnsResultSet = mock(ResultSet.class); when(context.executeQuery(listColumns)).thenReturn(listColumnsResultSet); - when(listColumnsResultSet.next()).thenReturn(true, true, true, true, true, true, true, false); + when(listColumnsResultSet.next()) + .thenReturn(true, true, true, true, true, true, true, true, false); when(listColumnsResultSet.getString(0)) - .thenReturn("singer", "singer", "album", "album", "album", "album", "album"); + .thenReturn("singer", "singer", "singer", "album", "album", "album", "album", "album"); when(listColumnsResultSet.getString(1)) .thenReturn( - "singerId", "singerName", "singerId", "albumId", "albumName", "rating", "ratings"); + "singerId", + "singerName", + "age", + "singerId", + "albumId", + "albumName", + "rating", + "ratings"); when(listColumnsResultSet.getString(3)) .thenReturn( "bigint", "character varying(50)", "bigint", "bigint", + "bigint", "character varying(50)", "real", "real[]"); @@ -369,7 +381,7 @@ public void testScanGSQLDdl() { Ddl ddl = informationSchemaScanner.scan(); String expectedDdl = "CREATE TABLE `singer` (\n" - + "\t`singerId` INT64 NOT NULL,\n" + + "\t`singerId` INT64 NOT NULL OPTIONS (option2=123),\n" + "\t`singerName` STRING(50) NOT NULL OPTIONS (option1=\"SomeName\"),\n" + ") PRIMARY KEY (`singerId` ASC)\n" + "CREATE INDEX `PRIMARY_KEY` ON `singer`()\n" @@ -413,9 +425,10 @@ public void testScanPgSQLDdl() { "CREATE TABLE \"singer\" (\n" + "\t\"singerId\" bigint NOT NULL,\n" + "\t\"singerName\" character varying(50) NOT NULL OPTIONS (option1='SomeName'),\n" + + "\t\"age\" bigint NOT NULL,\n" + "\tPRIMARY KEY ()\n" + ")\n" - + "CREATE UNIQUE INDEX \"index1\" ON \"singer\"() INCLUDE (\"singerName\")\n" + + "CREATE UNIQUE INDEX \"index1\" ON \"singer\"(\"singerId\" ASC, \"age\" DESC) INCLUDE (\"singerName\")\n" + "\n" + "\n" + "CREATE TABLE \"album\" (\n" @@ -442,4 +455,193 @@ public void testWithInvalidDialect() { new InformationSchemaScanner(context, Dialect.fromName("xyz")); Ddl ddl = informationSchemaScanner.scan(); } + + @Test(expected = IllegalStateException.class) + public void testListTables_InvalidInterleave() { + ReadContext context = mock(ReadContext.class); + Statement listTables = + Statement.of( + "SELECT t.table_name, t.parent_table_name, t.on_delete_action, t.interleave_type" + + " FROM information_schema.tables AS t" + + " WHERE t.table_catalog = '' AND t.table_schema = ''" + + " AND t.table_type='BASE TABLE'"); + ResultSet listTablesResultSet = mock(ResultSet.class); + when(context.executeQuery(listTables)).thenReturn(listTablesResultSet); + when(listTablesResultSet.next()).thenReturn(true, false); + when(listTablesResultSet.getString(0)).thenReturn("album"); + when(listTablesResultSet.getString(1)).thenReturn("singer"); + when(listTablesResultSet.isNull(2)).thenReturn(true); // onDeleteAction is null + when(listTablesResultSet.getString(3)).thenReturn("IN PARENT"); + + InformationSchemaScanner scanner = + new InformationSchemaScanner(context, Dialect.GOOGLE_STANDARD_SQL); + scanner.scan(); + } + + @Test(expected = IllegalStateException.class) + public void testListTables_UnsupportedOnDelete() { + ReadContext context = mock(ReadContext.class); + Statement listTables = + Statement.of( + "SELECT t.table_name, t.parent_table_name, t.on_delete_action, t.interleave_type" + + " FROM information_schema.tables AS t" + + " WHERE t.table_catalog = '' AND t.table_schema = ''" + + " AND t.table_type='BASE TABLE'"); + ResultSet listTablesResultSet = mock(ResultSet.class); + when(context.executeQuery(listTables)).thenReturn(listTablesResultSet); + when(listTablesResultSet.next()).thenReturn(true, false); + when(listTablesResultSet.getString(0)).thenReturn("album"); + when(listTablesResultSet.getString(1)).thenReturn("singer"); + when(listTablesResultSet.getString(2)).thenReturn("RESTRICT"); // Unsupported! + when(listTablesResultSet.getString(3)).thenReturn("IN PARENT"); + + InformationSchemaScanner scanner = + new InformationSchemaScanner(context, Dialect.GOOGLE_STANDARD_SQL); + scanner.scan(); + } + + @Test + public void testSkipNonExistentTable_AllQueries() { + ReadContext context = mock(ReadContext.class); + + // Setup default answer for other queries to avoid NPE + ResultSet emptyResultSet = mock(ResultSet.class); + when(emptyResultSet.next()).thenReturn(false); + when(context.executeQuery(any(Statement.class))).thenReturn(emptyResultSet); + + // Mock listTables to return only "singer" + Statement listTables = + Statement.of( + "SELECT t.table_name, t.parent_table_name, t.on_delete_action, t.interleave_type" + + " FROM information_schema.tables AS t" + + " WHERE t.table_catalog = '' AND t.table_schema = ''" + + " AND t.table_type='BASE TABLE'"); + ResultSet listTablesResultSet = mock(ResultSet.class); + when(context.executeQuery(listTables)).thenReturn(listTablesResultSet); + when(listTablesResultSet.next()).thenReturn(true, false); + when(listTablesResultSet.getString(0)).thenReturn("singer"); + when(listTablesResultSet.isNull(1)).thenReturn(true); + when(listTablesResultSet.isNull(2)).thenReturn(true); + when(listTablesResultSet.isNull(3)).thenReturn(true); + + // Mock listColumns to return column for "missing_table" + Statement listColumns = + Statement.of( + "SELECT c.table_name, c.column_name," + + " c.ordinal_position, c.spanner_type, c.is_nullable," + + " c.is_generated, c.generation_expression, c.is_stored" + + " FROM information_schema.columns as c" + + " WHERE c.table_catalog = '' AND c.table_schema = '' " + + " AND c.spanner_state = 'COMMITTED' " + + " ORDER BY c.table_name, c.ordinal_position"); + ResultSet listColumnsResultSet = mock(ResultSet.class); + when(context.executeQuery(listColumns)).thenReturn(listColumnsResultSet); + when(listColumnsResultSet.next()).thenReturn(true, false); + when(listColumnsResultSet.getString(0)).thenReturn("missing_table"); + when(listColumnsResultSet.getString(1)).thenReturn("col1"); + when(listColumnsResultSet.getString(3)).thenReturn("INT64"); + when(listColumnsResultSet.getString(4)).thenReturn("NO"); + when(listColumnsResultSet.getString(5)).thenReturn("NO"); + when(listColumnsResultSet.isNull(6)).thenReturn(true); + when(listColumnsResultSet.isNull(7)).thenReturn(true); + + // Mock listIndexes to return index for "missing_table" + Statement listIndexes = + Statement.of( + "SELECT t.table_name, t.index_name, t.parent_table_name, t.is_unique," + + " t.is_null_filtered" + + " FROM information_schema.indexes AS t" + + " WHERE t.table_catalog = '' AND t.table_schema = '' AND" + + " t.index_type='INDEX' AND t.spanner_is_managed = FALSE" + + " ORDER BY t.table_name, t.index_name"); + ResultSet listIndexesResultSet = mock(ResultSet.class); + when(context.executeQuery(listIndexes)).thenReturn(listIndexesResultSet); + when(listIndexesResultSet.next()).thenReturn(true, false); + when(listIndexesResultSet.getString(0)).thenReturn("missing_table"); + when(listIndexesResultSet.getString(1)).thenReturn("index1"); + when(listIndexesResultSet.isNull(2)).thenReturn(true); + when(listIndexesResultSet.getBoolean(3)).thenReturn(false); + when(listIndexesResultSet.getBoolean(4)).thenReturn(false); + + // Mock listColumnOptions to return option for "missing_table" + Statement listColumnOptions = + Statement.of( + "SELECT t.table_name, t.column_name, t.option_name, t.option_type," + + " t.option_value" + + " FROM information_schema.column_options AS t" + + " WHERE t.table_catalog = '' AND t.table_schema = ''" + + " ORDER BY t.table_name, t.column_name"); + ResultSet listColumnOptionsResultSet = mock(ResultSet.class); + when(context.executeQuery(listColumnOptions)).thenReturn(listColumnOptionsResultSet); + when(listColumnOptionsResultSet.next()).thenReturn(true, false); + when(listColumnOptionsResultSet.getString(0)).thenReturn("missing_table"); + when(listColumnOptionsResultSet.getString(1)).thenReturn("col1"); + when(listColumnOptionsResultSet.getString(2)).thenReturn("option1"); + when(listColumnOptionsResultSet.getString(3)).thenReturn("STRING"); + when(listColumnOptionsResultSet.getString(4)).thenReturn("value1"); + + // Mock listForeignKeys to return index for "missing_table" + Statement listForeignKeys = + Statement.of( + "SELECT rc.constraint_name," + + " kcu1.table_name," + + " kcu1.column_name," + + " kcu2.table_name," + + " kcu2.column_name" + + " FROM information_schema.referential_constraints as rc" + + " INNER JOIN information_schema.key_column_usage as kcu1" + + " ON kcu1.constraint_catalog = rc.constraint_catalog" + + " AND kcu1.constraint_schema = rc.constraint_schema" + + " AND kcu1.constraint_name = rc.constraint_name" + + " INNER JOIN information_schema.key_column_usage as kcu2" + + " ON kcu2.constraint_catalog = rc.unique_constraint_catalog" + + " AND kcu2.constraint_schema = rc.unique_constraint_schema" + + " AND kcu2.constraint_name = rc.unique_constraint_name" + + " AND kcu2.ordinal_position = kcu1.position_in_unique_constraint" + + " WHERE rc.constraint_catalog = ''" + + " AND rc.constraint_schema = ''" + + " AND kcu1.constraint_catalog = ''" + + " AND kcu1.constraint_schema = ''" + + " AND kcu2.constraint_catalog = ''" + + " AND kcu2.constraint_schema = ''" + + " ORDER BY rc.constraint_name, kcu1.ordinal_position;"); + ResultSet listForeignKeysResultSet = mock(ResultSet.class); + when(context.executeQuery(listForeignKeys)).thenReturn(listForeignKeysResultSet); + when(listForeignKeysResultSet.next()).thenReturn(true, false); + when(listForeignKeysResultSet.getString(0)).thenReturn("fk1"); + when(listForeignKeysResultSet.getString(1)).thenReturn("missing_table"); + when(listForeignKeysResultSet.getString(2)).thenReturn("col1"); + when(listForeignKeysResultSet.getString(3)).thenReturn("singer"); + when(listForeignKeysResultSet.getString(4)).thenReturn("singerId"); + + // Mock listCheckConstraints to return option for "missing_table" + Statement listCheckConstraints = + Statement.of( + "SELECT ctu.TABLE_NAME," + + " cc.CONSTRAINT_NAME," + + " cc.CHECK_CLAUSE" + + " FROM INFORMATION_SCHEMA.CONSTRAINT_TABLE_USAGE as ctu" + + " INNER JOIN @{JOIN_METHOD=HASH_JOIN} INFORMATION_SCHEMA.CHECK_CONSTRAINTS as cc" + + " ON ctu.constraint_catalog = cc.constraint_catalog" + + " AND ctu.constraint_schema = cc.constraint_schema" + + " AND ctu.CONSTRAINT_NAME = cc.CONSTRAINT_NAME" + + " WHERE NOT STARTS_WITH(cc.CONSTRAINT_NAME, 'CK_IS_NOT_NULL_')" + + " AND ctu.table_catalog = ''" + + " AND ctu.table_schema = ''" + + " AND ctu.constraint_catalog = ''" + + " AND ctu.constraint_schema = ''" + + " AND cc.SPANNER_STATE = 'COMMITTED';"); + ResultSet listCheckConstraintsResultSet = mock(ResultSet.class); + when(context.executeQuery(listCheckConstraints)).thenReturn(listCheckConstraintsResultSet); + when(listCheckConstraintsResultSet.next()).thenReturn(true, false); + when(listCheckConstraintsResultSet.getString(0)).thenReturn("missing_table"); + when(listCheckConstraintsResultSet.getString(1)).thenReturn("check1"); + when(listCheckConstraintsResultSet.getString(2)).thenReturn("col1!=NULL"); + + InformationSchemaScanner scanner = + new InformationSchemaScanner(context, Dialect.GOOGLE_STANDARD_SQL); + Ddl ddl = scanner.scan(); + + assertTrue(ddl.table("missing_table") == null); + } } diff --git a/v2/spanner-common/src/test/java/com/google/cloud/teleport/v2/spanner/migrations/convertors/ChangeEventSessionConvertorTest.java b/v2/spanner-common/src/test/java/com/google/cloud/teleport/v2/spanner/migrations/convertors/ChangeEventSessionConvertorTest.java index 4ee8c6c894..79e065970f 100644 --- a/v2/spanner-common/src/test/java/com/google/cloud/teleport/v2/spanner/migrations/convertors/ChangeEventSessionConvertorTest.java +++ b/v2/spanner-common/src/test/java/com/google/cloud/teleport/v2/spanner/migrations/convertors/ChangeEventSessionConvertorTest.java @@ -16,6 +16,7 @@ package com.google.cloud.teleport.v2.spanner.migrations.convertors; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -80,12 +81,15 @@ public void setUp() throws IOException { private void mockDbClient() throws IOException { databaseClient = mock(DatabaseClient.class); queryReadContext = mock(ReadContext.class); - queryResultSet = mock(ResultSet.class); + ResultSet rs1 = mock(ResultSet.class); + ResultSet rs2 = mock(ResultSet.class); when(databaseClient.singleUse()).thenReturn(queryReadContext); - when(queryReadContext.executeQuery(any(Statement.class))).thenReturn(queryResultSet); - when(queryResultSet.next()).thenReturn(true, false); // only return one row - when(queryResultSet.getJson(any(String.class))) + when(queryReadContext.executeQuery(any(Statement.class))).thenReturn(rs1, rs2); + when(rs1.next()).thenReturn(true, false); + when(rs1.getJson(any(String.class))) .thenReturn("{\"a\": 1.3542, \"b\": {\"c\": 48.198136676310106}}"); + when(rs2.next()).thenReturn(true, false); + when(rs2.getJson(any(String.class))).thenReturn("{\"pg\": 1}"); } @Test @@ -147,6 +151,7 @@ public void transformChangeEventDataTest() throws Exception { JSONObject changeEvent = new JSONObject(); changeEvent.put("first_name", "A"); changeEvent.put("last_name", "{\"a\": 1.3542, \"b\": {\"c\": 48.19813667631011}}"); + changeEvent.put("pg_json_col", "{\"pg\": 1.0}"); changeEvent.put(Constants.EVENT_TABLE_NAME_KEY, "Users"); JsonNode ce = parseChangeEvent(changeEvent.toString()); @@ -156,6 +161,7 @@ public void transformChangeEventDataTest() throws Exception { changeEvent = new JSONObject(); changeEvent.put("first_name", "A"); changeEvent.put("last_name", "{\"a\": 1.3542, \"b\": {\"c\": 48.198136676310106}}"); + changeEvent.put("pg_json_col", "{\"pg\": 1}"); changeEvent.put(Constants.EVENT_TABLE_NAME_KEY, "Users"); JsonNode expectedEvent = parseChangeEvent(changeEvent.toString()); @@ -173,6 +179,9 @@ static Ddl getTestDdl() { .column("last_name") .json() .endColumn() + .column("pg_json_col") + .pgJsonb() + .endColumn() .endTable() .build(); return ddl; @@ -713,4 +722,121 @@ public void testPopulateShardIdReturnsEmptyWhenValueAndMetadataMissing() { assertEquals("", actualEvent.get("migration_shard_id").asText()); assertEquals("migration_shard_id", actualEvent.get(Constants.SHARD_ID_COLUMN_NAME).asText()); } + + @Test + public void testGetShardId_MissingSchemaKey_ShardingContext() { + Schema schema = getShardedSchemaObject(); + ShardingContext shardingContext = getShardingContext(); + ChangeEventSessionConvertor changeEventSessionConvertor = + new ChangeEventSessionConvertor( + schema, null, new TransformationContext(), shardingContext, "mysql", false); + + JSONObject changeEvent = new JSONObject(); + changeEvent.put("name", "A"); + changeEvent.put(Constants.EVENT_STREAM_NAME, "stream1"); + changeEvent.put(Constants.EVENT_TABLE_NAME_KEY, "people"); + JsonNode ce = parseChangeEvent(changeEvent.toString()); + + String shardId = changeEventSessionConvertor.getShardId(ce); + assertEquals("", shardId); + } + + @Test + public void testGetShardId_MissingSchemaKey_TransformationContext() { + Schema schema = getShardedSchemaObject(); + TransformationContext transformationContext = getTransformationContext(); + ChangeEventSessionConvertor changeEventSessionConvertor = + new ChangeEventSessionConvertor( + schema, null, transformationContext, new ShardingContext(), "mysql", false); + + JSONObject changeEvent = new JSONObject(); + changeEvent.put("name", "A"); + changeEvent.put(Constants.EVENT_TABLE_NAME_KEY, "people"); + JsonNode ce = parseChangeEvent(changeEvent.toString()); + + String shardId = changeEventSessionConvertor.getShardId(ce); + assertEquals("", shardId); + } + + @Test(expected = Exception.class) + public void testTransformChangeEventData_TableNotFound() throws Exception { + ChangeEventSessionConvertor convertor = + new ChangeEventSessionConvertor(null, null, null, null, "", true); + JSONObject changeEvent = new JSONObject(); + changeEvent.put(Constants.EVENT_TABLE_NAME_KEY, "MissingTable"); + JsonNode ce = parseChangeEvent(changeEvent.toString()); + convertor.transformChangeEventData(ce, databaseClient, Ddl.builder().build()); + } + + @Test + public void testTransformChangeEventData_NullJsonStr() throws Exception { + ChangeEventSessionConvertor convertor = + new ChangeEventSessionConvertor(null, null, null, null, "", true); + JSONObject changeEvent = new JSONObject(); + changeEvent.put(Constants.EVENT_TABLE_NAME_KEY, "Users"); + // Do NOT put "last_name" column! + JsonNode ce = parseChangeEvent(changeEvent.toString()); + + Ddl ddl = getTestDdl(); // Users table has "last_name" as JSON + JsonNode actual = convertor.transformChangeEventData(ce, databaseClient, ddl); + assertEquals(ce, actual); // No change + } + + @Test + public void testTransformViaSessionFile_NullShardingMap() { + Schema schema = getShardedSchemaObject(); + ShardingContext shardingContext = new ShardingContext(null); + ChangeEventSessionConvertor changeEventSessionConvertor = + new ChangeEventSessionConvertor( + schema, null, new TransformationContext(), shardingContext, "mysql", false); + + JSONObject changeEvent = new JSONObject(); + changeEvent.put(Constants.EVENT_TABLE_NAME_KEY, "cart"); + JsonNode ce = parseChangeEvent(changeEvent.toString()); + + JsonNode actual = changeEventSessionConvertor.transformChangeEventViaSessionFile(ce); + assertFalse(actual.has(Constants.SHARD_ID_COLUMN_NAME)); + } + + @Test + public void testPopulateShardId_NullShardIdColumn() { + Schema schema = getShardedSchemaObject(); + ChangeEventSessionConvertor changeEventSessionConvertor = + new ChangeEventSessionConvertor( + schema, null, new TransformationContext(), getShardingContext(), "mysql", false); + + JSONObject changeEvent = new JSONObject(); + changeEvent.put(Constants.EVENT_TABLE_NAME_KEY, "cart"); + JsonNode ce = parseChangeEvent(changeEvent.toString()); + + JsonNode actual = changeEventSessionConvertor.transformChangeEventViaSessionFile(ce); + assertFalse(actual.has(Constants.SHARD_ID_COLUMN_NAME)); + } + + @Test + public void testPopulateShardId_MissingShardIdColDef() { + Schema schema = getShardedSchemaObject(); + SpannerTable table = schema.getSpSchema().get("t2"); + Map colDefs = new HashMap<>(table.getColDefs()); + colDefs.remove("c6"); + SpannerTable malformedTable = + new SpannerTable( + table.getName(), + table.getColIds(), + colDefs, + table.getPrimaryKeys(), + table.getShardIdColumn()); + schema.getSpSchema().put("t2", malformedTable); + + ChangeEventSessionConvertor changeEventSessionConvertor = + new ChangeEventSessionConvertor( + schema, null, new TransformationContext(), getShardingContext(), "mysql", false); + + JSONObject changeEvent = new JSONObject(); + changeEvent.put(Constants.EVENT_TABLE_NAME_KEY, "people"); + JsonNode ce = parseChangeEvent(changeEvent.toString()); + + JsonNode actual = changeEventSessionConvertor.transformChangeEventViaSessionFile(ce); + assertFalse(actual.has(Constants.SHARD_ID_COLUMN_NAME)); + } } diff --git a/v2/spanner-common/src/test/java/com/google/cloud/teleport/v2/spanner/migrations/schema/SourceTableTest.java b/v2/spanner-common/src/test/java/com/google/cloud/teleport/v2/spanner/migrations/schema/SourceTableTest.java new file mode 100644 index 0000000000..93e3cb7194 --- /dev/null +++ b/v2/spanner-common/src/test/java/com/google/cloud/teleport/v2/spanner/migrations/schema/SourceTableTest.java @@ -0,0 +1,116 @@ +/* + * Copyright (C) 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.teleport.v2.spanner.migrations.schema; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; + +import java.util.HashMap; +import java.util.Map; +import java.util.Set; +import org.junit.Test; + +public class SourceTableTest { + + @Test + public void testConstructorAndGetters() { + String name = "test_table"; + String schema = "test_schema"; + String[] colIds = new String[] {"c1", "c2"}; + Map colDefs = new HashMap<>(); + colDefs.put("c1", new SourceColumnDefinition("col1", null)); + colDefs.put("c2", new SourceColumnDefinition("col2", null)); + ColumnPK[] primaryKeys = new ColumnPK[] {new ColumnPK("c1", 1)}; + + SourceTable table = new SourceTable(name, schema, colIds, colDefs, primaryKeys); + + assertEquals(name, table.getName()); + assertEquals(schema, table.getSchema()); + assertEquals(colIds, table.getColIds()); + assertEquals(colDefs, table.getColDefs()); + assertEquals(primaryKeys, table.getPrimaryKeys()); + } + + @Test + public void testConstructor_NullArrays() { + SourceTable table = new SourceTable("name", "schema", null, null, null); + assertNotNull(table.getColIds()); + assertEquals(0, table.getColIds().length); + assertNotNull(table.getColDefs()); + assertTrue(table.getColDefs().isEmpty()); + assertTrue(table.getPrimaryKeys() == null); + } + + @Test + public void testGetPrimaryKeySet() { + String[] colIds = new String[] {"c1", "c2"}; + Map colDefs = new HashMap<>(); + colDefs.put("c1", new SourceColumnDefinition("col1", null)); + colDefs.put("c2", new SourceColumnDefinition("col2", null)); + ColumnPK[] primaryKeys = new ColumnPK[] {new ColumnPK("c1", 1)}; + + SourceTable table = new SourceTable("name", "schema", colIds, colDefs, primaryKeys); + Set pkSet = table.getPrimaryKeySet(); + + assertEquals(1, pkSet.size()); + assertTrue(pkSet.contains("col1")); + } + + @Test + public void testGetPrimaryKeySet_NullPrimaryKeys() { + SourceTable table = new SourceTable("name", "schema", null, null, null); + Set pkSet = table.getPrimaryKeySet(); + assertTrue(pkSet.isEmpty()); + } + + @Test + public void testToString() { + String[] colIds = new String[] {"c1"}; + Map colDefs = new HashMap<>(); + colDefs.put("c1", new SourceColumnDefinition("col1", null)); + ColumnPK[] primaryKeys = new ColumnPK[] {new ColumnPK("c1", 1)}; + + SourceTable table = new SourceTable("name", "schema", colIds, colDefs, primaryKeys); + String str = table.toString(); + + assertTrue(str.contains("name")); + assertTrue(str.contains("schema")); + assertTrue(str.contains("c1")); + } + + @Test + public void testEqualsAndHashCode() { + String[] colIds = new String[] {"c1"}; + Map colDefs = new HashMap<>(); + colDefs.put("c1", new SourceColumnDefinition("col1", null)); + ColumnPK[] primaryKeys = new ColumnPK[] {new ColumnPK("c1", 1)}; + + SourceTable table1 = new SourceTable("name", "schema", colIds, colDefs, primaryKeys); + SourceTable table2 = new SourceTable("name", "schema", colIds, colDefs, primaryKeys); + SourceTable table3 = new SourceTable("different", "schema", colIds, colDefs, primaryKeys); + + assertTrue(table1.equals(table1)); + assertTrue(table1.equals(table2)); + assertFalse(table1.equals(table3)); + assertFalse(table1.equals(null)); + assertFalse(table1.equals("string")); + + assertEquals(table1.hashCode(), table2.hashCode()); + assertTrue(table1.hashCode() != table3.hashCode()); + } +} diff --git a/v2/spanner-common/src/test/java/com/google/cloud/teleport/v2/spanner/migrations/schema/SpannerTableTest.java b/v2/spanner-common/src/test/java/com/google/cloud/teleport/v2/spanner/migrations/schema/SpannerTableTest.java new file mode 100644 index 0000000000..d5849acc45 --- /dev/null +++ b/v2/spanner-common/src/test/java/com/google/cloud/teleport/v2/spanner/migrations/schema/SpannerTableTest.java @@ -0,0 +1,120 @@ +/* + * Copyright (C) 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.teleport.v2.spanner.migrations.schema; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; + +import java.util.HashMap; +import java.util.Map; +import java.util.Set; +import org.junit.Test; + +public class SpannerTableTest { + + @Test + public void testConstructorAndGetters() { + String name = "test_table"; + String[] colIds = new String[] {"c1", "c2"}; + Map colDefs = new HashMap<>(); + colDefs.put("c1", new SpannerColumnDefinition("col1", new SpannerColumnType("STRING", false))); + colDefs.put("c2", new SpannerColumnDefinition("col2", new SpannerColumnType("INT64", false))); + ColumnPK[] primaryKeys = new ColumnPK[] {new ColumnPK("c1", 1)}; + String shardIdColumn = "c2"; + + SpannerTable table = new SpannerTable(name, colIds, colDefs, primaryKeys, shardIdColumn); + + assertEquals(name, table.getName()); + assertEquals(colIds, table.getColIds()); + assertEquals(colDefs, table.getColDefs()); + assertEquals(primaryKeys, table.getPrimaryKeys()); + assertEquals(shardIdColumn, table.getShardIdColumn()); + } + + @Test + public void testConstructor_NullArrays() { + SpannerTable table = new SpannerTable("name", null, null, null, null); + assertNotNull(table.getColIds()); + assertEquals(0, table.getColIds().length); + assertNotNull(table.getColDefs()); + assertTrue(table.getColDefs().isEmpty()); + assertNotNull(table.getPrimaryKeys()); + assertEquals(0, table.getPrimaryKeys().length); + assertTrue(table.getShardIdColumn() == null); + } + + @Test + public void testGetPrimaryKeySet() { + String[] colIds = new String[] {"c1", "c2"}; + Map colDefs = new HashMap<>(); + colDefs.put("c1", new SpannerColumnDefinition("col1", new SpannerColumnType("STRING", false))); + colDefs.put("c2", new SpannerColumnDefinition("col2", new SpannerColumnType("INT64", false))); + ColumnPK[] primaryKeys = new ColumnPK[] {new ColumnPK("c1", 1)}; + + SpannerTable table = new SpannerTable("name", colIds, colDefs, primaryKeys, null); + Set pkSet = table.getPrimaryKeySet(); + + assertEquals(1, pkSet.size()); + assertTrue(pkSet.contains("col1")); + } + + @Test + public void testGetPrimaryKeySet_NullPrimaryKeys() { + SpannerTable table = new SpannerTable("name", null, null, null, null); + Set pkSet = table.getPrimaryKeySet(); + assertTrue(pkSet.isEmpty()); + } + + @Test + public void testToString() { + String[] colIds = new String[] {"c1"}; + Map colDefs = new HashMap<>(); + colDefs.put("c1", new SpannerColumnDefinition("col1", new SpannerColumnType("STRING", false))); + ColumnPK[] primaryKeys = new ColumnPK[] {new ColumnPK("c1", 1)}; + + SpannerTable table = new SpannerTable("name", colIds, colDefs, primaryKeys, "c1"); + String str = table.toString(); + + assertTrue(str.contains("name")); + assertTrue(str.contains("c1")); + assertTrue(str.contains("shardIdColumn")); + } + + @Test + public void testEqualsAndHashCode() { + String[] colIds = new String[] {"c1"}; + Map colDefs = new HashMap<>(); + colDefs.put("c1", new SpannerColumnDefinition("col1", new SpannerColumnType("STRING", false))); + ColumnPK[] primaryKeys = new ColumnPK[] {new ColumnPK("c1", 1)}; + + SpannerTable table1 = new SpannerTable("name", colIds, colDefs, primaryKeys, "c1"); + SpannerTable table2 = new SpannerTable("name", colIds, colDefs, primaryKeys, "c1"); + SpannerTable table3 = new SpannerTable("different", colIds, colDefs, primaryKeys, "c1"); + SpannerTable table4 = new SpannerTable("name", colIds, colDefs, primaryKeys, "different"); + + assertTrue(table1.equals(table1)); + assertTrue(table1.equals(table2)); + assertFalse(table1.equals(table3)); + assertFalse(table1.equals(table4)); + assertFalse(table1.equals(null)); + assertFalse(table1.equals("string")); + + assertEquals(table1.hashCode(), table2.hashCode()); + assertTrue(table1.hashCode() != table3.hashCode()); + } +} diff --git a/v2/spanner-common/src/test/java/com/google/cloud/teleport/v2/spanner/testutils/failureinjectiontesting/UserTest.java b/v2/spanner-common/src/test/java/com/google/cloud/teleport/v2/spanner/testutils/failureinjectiontesting/UserTest.java new file mode 100644 index 0000000000..f6bc835899 --- /dev/null +++ b/v2/spanner-common/src/test/java/com/google/cloud/teleport/v2/spanner/testutils/failureinjectiontesting/UserTest.java @@ -0,0 +1,209 @@ +/* + * Copyright (C) 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.teleport.v2.spanner.testutils.failureinjectiontesting; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import com.google.cloud.spanner.DatabaseClient; +import com.google.cloud.spanner.Struct; +import com.google.common.collect.ImmutableList; +import java.sql.Connection; +import java.sql.PreparedStatement; +import java.sql.SQLException; +import java.util.List; +import java.util.Map; +import java.util.Random; +import org.apache.beam.it.gcp.spanner.SpannerResourceManager; +import org.junit.Test; +import org.mockito.ArgumentCaptor; + +public class UserTest { + + @Test + public void testEqualsAndHashCode() { + User u1 = new User(); + u1.id = 1; + u1.firstName = "John"; + u1.lastName = "Doe"; + u1.age = 30; + u1.status = true; + u1.col1 = 100L; + u1.col2 = 200L; + + User u2 = new User(); + u2.id = 1; + u2.firstName = "John"; + u2.lastName = "Doe"; + u2.age = 30; + u2.status = true; + u2.col1 = 100L; + u2.col2 = 200L; + + User u3 = new User(); + u3.id = 2; + + assertTrue(u1.equals(u1)); + assertTrue(u1.equals(u2)); + assertFalse(u1.equals(u3)); + assertFalse(u1.equals(null)); + assertFalse(u1.equals("string")); + + assertEquals(u1.hashCode(), u2.hashCode()); + + User u4 = new User(); + u4.id = 1; + u4.age = 31; // Different age + assertFalse(u1.equals(u4)); + + User u5 = new User(); + u5.id = 1; + u5.status = false; // Different status + assertFalse(u1.equals(u5)); + + User u6 = new User(); + u6.id = 1; + u6.col1 = 101L; // Different col1 + assertFalse(u1.equals(u6)); + + User u7 = new User(); + u7.id = 1; + u7.col2 = 201L; // Different col2 + assertFalse(u1.equals(u7)); + + User u8 = new User(); + u8.id = 1; + u8.firstName = "Jane"; // Different firstName + assertFalse(u1.equals(u8)); + + User u9 = new User(); + u9.id = 1; + u9.lastName = "Smith"; // Different lastName + assertFalse(u1.equals(u9)); + } + + @Test + public void testToString() { + User u = new User(); + u.id = 1; + u.firstName = "John"; + String str = u.toString(); + assertTrue(str.contains("User")); + assertTrue(str.contains("John")); + } + + @Test + public void testUpdate_JDBC() throws SQLException { + User u = new User(); + u.id = 1; + u.firstName = "John"; + u.lastName = "Doe"; + u.age = 30; + u.status = true; + + Connection conn = mock(Connection.class); + PreparedStatement ps = mock(PreparedStatement.class); + when(conn.prepareStatement(anyString())).thenReturn(ps); + + Random random = mock(Random.class); + // Test all columns in switch + for (int i = 0; i < User.UPDATABLE_COLUMNS.size(); i++) { + when(random.nextInt(User.UPDATABLE_COLUMNS.size())).thenReturn(i); + u.update(conn, random); + } + // Verify that executeUpdate was called for each + verify(ps, times(User.UPDATABLE_COLUMNS.size())).executeUpdate(); + } + + @Test + public void testInsert_Spanner() { + User u = new User(); + u.id = 1; + u.firstName = "John"; + + DatabaseClient client = mock(DatabaseClient.class); + u.insert(client); + + ArgumentCaptor captor = ArgumentCaptor.forClass(List.class); + verify(client).write(captor.capture()); + assertEquals(1, captor.getValue().size()); + } + + @Test + public void testFetchAll_Spanner() { + SpannerResourceManager manager = mock(SpannerResourceManager.class); + Struct struct = mock(Struct.class); + when(struct.getLong(User.ID)).thenReturn(1L); + when(struct.isNull(User.FIRST_NAME)).thenReturn(false); + when(struct.getString(User.FIRST_NAME)).thenReturn("John"); + when(struct.isNull(User.LAST_NAME)).thenReturn(true); + when(struct.isNull(User.AGE)).thenReturn(false); + when(struct.getLong(User.AGE)).thenReturn(30L); + when(struct.isNull(User.STATUS)).thenReturn(false); + when(struct.getBoolean(User.STATUS)).thenReturn(true); + when(struct.isNull(User.COL1)).thenReturn(true); + when(struct.isNull(User.COL2)).thenReturn(true); + + when(manager.runQuery(anyString())).thenReturn(ImmutableList.of(struct)); + + Map users = User.fetchAll(manager); + assertEquals(1, users.size()); + User u = users.get(1); + assertNotNull(u); + assertEquals("John", u.firstName); + assertEquals(null, u.lastName); + assertEquals(30, u.age); + assertTrue(u.status); + assertEquals(0L, u.col1); + assertEquals(0L, u.col2); + } + + @Test + public void testFetchAll_Spanner_NullValues() { + SpannerResourceManager manager = mock(SpannerResourceManager.class); + Struct struct = mock(Struct.class); + when(struct.getLong(User.ID)).thenReturn(1L); + when(struct.isNull(User.FIRST_NAME)).thenReturn(true); + when(struct.isNull(User.LAST_NAME)).thenReturn(false); + when(struct.getString(User.LAST_NAME)).thenReturn("Doe"); + when(struct.isNull(User.AGE)).thenReturn(true); + when(struct.isNull(User.STATUS)).thenReturn(true); + when(struct.isNull(User.COL1)).thenReturn(false); + when(struct.getLong(User.COL1)).thenReturn(100L); + when(struct.isNull(User.COL2)).thenReturn(false); + when(struct.getLong(User.COL2)).thenReturn(200L); + + when(manager.runQuery(anyString())).thenReturn(ImmutableList.of(struct)); + + Map users = User.fetchAll(manager); + assertEquals(1, users.size()); + User u = users.get(1); + assertNotNull(u); + assertEquals(null, u.firstName); + assertEquals("Doe", u.lastName); + assertEquals(0, u.age); + assertFalse(u.status); + assertEquals(100L, u.col1); + assertEquals(200L, u.col2); + } +} diff --git a/v2/spanner-common/src/test/java/com/google/cloud/teleport/v2/spanner/type/TypeTest.java b/v2/spanner-common/src/test/java/com/google/cloud/teleport/v2/spanner/type/TypeTest.java index fe49ca5088..5df70b38a7 100644 --- a/v2/spanner-common/src/test/java/com/google/cloud/teleport/v2/spanner/type/TypeTest.java +++ b/v2/spanner-common/src/test/java/com/google/cloud/teleport/v2/spanner/type/TypeTest.java @@ -16,6 +16,8 @@ package com.google.cloud.teleport.v2.spanner.type; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; import org.junit.Test; @@ -62,4 +64,84 @@ public void testToString() { assertEquals("PG_TEXT[]", Type.pgArray(Type.pgText()).toString()); assertEquals("PG_COMMIT_TIMESTAMP", Type.pgCommitTimestamp().toString()); } + + @Test + public void testEqualsAndHashCode() { + Type t1 = Type.int64(); + Type t2 = Type.int64(); + Type t3 = Type.string(); + + assertTrue(t1.equals(t1)); + assertTrue(t1.equals(t2)); + assertFalse(t1.equals(t3)); + assertFalse(t1.equals(null)); + assertFalse(t1.equals("string")); + + assertEquals(t1.hashCode(), t2.hashCode()); + + // Test with array types + Type a1 = Type.array(Type.int64()); + Type a2 = Type.array(Type.int64()); + Type a3 = Type.array(Type.string()); + assertTrue(a1.equals(a2)); + assertFalse(a1.equals(a3)); + + // Test with struct types + Type s1 = Type.struct(Type.StructField.of("f1", Type.int64())); + Type s2 = Type.struct(Type.StructField.of("f1", Type.int64())); + Type s3 = Type.struct(Type.StructField.of("f2", Type.int64())); + assertTrue(s1.equals(s2)); + assertFalse(s1.equals(s3)); + } + + @Test + public void testStructFieldEqualsAndHashCode() { + Type.StructField f1 = Type.StructField.of("f1", Type.int64()); + Type.StructField f2 = Type.StructField.of("f1", Type.int64()); + Type.StructField f3 = Type.StructField.of("f2", Type.int64()); + Type.StructField f4 = Type.StructField.of("f1", Type.string()); + + assertTrue(f1.equals(f1)); + assertTrue(f1.equals(f2)); + assertFalse(f1.equals(f3)); + assertFalse(f1.equals(f4)); + assertFalse(f1.equals(null)); + assertFalse(f1.equals("string")); + + assertEquals(f1.hashCode(), f2.hashCode()); + } + + @Test + public void testGetArrayElementType() { + Type a1 = Type.array(Type.int64()); + assertEquals(Type.int64(), a1.getArrayElementType()); + + Type a2 = Type.pgArray(Type.pgInt8()); + assertEquals(Type.pgInt8(), a2.getArrayElementType()); + } + + @Test(expected = IllegalStateException.class) + public void testGetArrayElementType_Exception() { + Type.int64().getArrayElementType(); + } + + @Test + public void testPgArray_AllTypes() { + assertEquals(Type.Code.PG_ARRAY, Type.pgArray(Type.pgBool()).getCode()); + assertEquals(Type.Code.PG_ARRAY, Type.pgArray(Type.pgInt8()).getCode()); + assertEquals(Type.Code.PG_ARRAY, Type.pgArray(Type.pgFloat4()).getCode()); + assertEquals(Type.Code.PG_ARRAY, Type.pgArray(Type.pgFloat8()).getCode()); + assertEquals(Type.Code.PG_ARRAY, Type.pgArray(Type.pgNumeric()).getCode()); + assertEquals(Type.Code.PG_ARRAY, Type.pgArray(Type.pgJsonb()).getCode()); + assertEquals(Type.Code.PG_ARRAY, Type.pgArray(Type.pgVarchar()).getCode()); + assertEquals(Type.Code.PG_ARRAY, Type.pgArray(Type.pgText()).getCode()); + assertEquals(Type.Code.PG_ARRAY, Type.pgArray(Type.pgBytea()).getCode()); + assertEquals(Type.Code.PG_ARRAY, Type.pgArray(Type.pgTimestamptz()).getCode()); + assertEquals(Type.Code.PG_ARRAY, Type.pgArray(Type.pgDate()).getCode()); + } + + @Test(expected = IllegalArgumentException.class) + public void testPgArray_UnknownType() { + Type.pgArray(Type.int64()); // int64 is not a PG type in this context! + } } diff --git a/v2/spanner-custom-shard/src/test/java/com/custom/CustomTransformationForDLQITTest.java b/v2/spanner-custom-shard/src/test/java/com/custom/CustomTransformationForDLQITTest.java new file mode 100644 index 0000000000..c3d9a95cdf --- /dev/null +++ b/v2/spanner-custom-shard/src/test/java/com/custom/CustomTransformationForDLQITTest.java @@ -0,0 +1,144 @@ +/* + * Copyright (C) 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.custom; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +import com.google.cloud.teleport.v2.spanner.exceptions.InvalidTransformationException; +import com.google.cloud.teleport.v2.spanner.utils.MigrationTransformationRequest; +import com.google.cloud.teleport.v2.spanner.utils.MigrationTransformationResponse; +import java.util.HashMap; +import java.util.Map; +import org.junit.Test; + +public class CustomTransformationForDLQITTest { + + @Test + public void testInit() { + CustomTransformationForDLQIT transformer = new CustomTransformationForDLQIT(); + transformer.init("mode=semi-fixed"); + // Verified via behavior in other tests + } + + @Test + public void testToSpannerRow_AllDataTypes_BadMode_CrashingId() { + CustomTransformationForDLQIT transformer = new CustomTransformationForDLQIT(); + transformer.init("mode=bad"); + + Map requestRow = new HashMap<>(); + requestRow.put("id", 999L); + MigrationTransformationRequest request = + new MigrationTransformationRequest("AllDataTypes", requestRow, "shard1", "INSERT"); + + try { + transformer.toSpannerRow(request); + fail("Expected InvalidTransformationException"); + } catch (InvalidTransformationException e) { + assertEquals("Simulated failure for id 999", e.getMessage()); + } + } + + @Test + public void testToSpannerRow_AllDataTypes_BadMode_NonCrashingId() throws Exception { + CustomTransformationForDLQIT transformer = new CustomTransformationForDLQIT(); + transformer.init("mode=bad"); + + Map requestRow = new HashMap<>(); + requestRow.put("id", 888L); // 888 doesn't crash in toSpannerRow even in bad mode! + MigrationTransformationRequest request = + new MigrationTransformationRequest("AllDataTypes", requestRow, "shard1", "INSERT"); + + MigrationTransformationResponse response = transformer.toSpannerRow(request); + assertNotNull(response); + assertFalse(response.isEventFiltered()); + } + + @Test + public void testToSpannerRow_Orders() throws Exception { + CustomTransformationForDLQIT transformer = new CustomTransformationForDLQIT(); + + Map requestRow = new HashMap<>(); + requestRow.put("OrderSource", "WEB"); + MigrationTransformationRequest request = + new MigrationTransformationRequest("Orders", requestRow, "shard1", "INSERT"); + + MigrationTransformationResponse response = transformer.toSpannerRow(request); + assertNotNull(response); + assertEquals("'WEB_v1'", response.getResponseRow().get("LegacyOrderSystem")); + } + + @Test + public void testToSourceRow_AllDataTypes_BadMode_CrashingId() { + CustomTransformationForDLQIT transformer = new CustomTransformationForDLQIT(); + transformer.init("mode=bad"); + + Map requestRow = new HashMap<>(); + requestRow.put("id", 888L); + MigrationTransformationRequest request = + new MigrationTransformationRequest("AllDataTypes", requestRow, "shard1", "INSERT"); + + try { + transformer.toSourceRow(request); + fail("Expected InvalidTransformationException"); + } catch (InvalidTransformationException e) { + assertEquals("Simulated failure for id 888", e.getMessage()); + } + } + + @Test + public void testToSourceRow_AllDataTypes_SemiFixedMode_CrashingId() { + CustomTransformationForDLQIT transformer = new CustomTransformationForDLQIT(); + transformer.init("mode=semi-fixed"); + + Map requestRow = new HashMap<>(); + requestRow.put("id", 888L); + MigrationTransformationRequest request = + new MigrationTransformationRequest("AllDataTypes", requestRow, "shard1", "INSERT"); + + try { + transformer.toSourceRow(request); + fail("Expected InvalidTransformationException"); + } catch (InvalidTransformationException e) { + assertEquals("Simulated failure for id 888", e.getMessage()); + } + } + + @Test + public void testToSourceRow_AllDataTypes_SemiFixedMode_NonCrashingId() throws Exception { + CustomTransformationForDLQIT transformer = new CustomTransformationForDLQIT(); + transformer.init("mode=semi-fixed"); + + Map requestRow = new HashMap<>(); + requestRow.put("id", 999L); // 999 doesn't crash in semi-fixed mode in toSourceRow! + MigrationTransformationRequest request = + new MigrationTransformationRequest("AllDataTypes", requestRow, "shard1", "INSERT"); + + MigrationTransformationResponse response = transformer.toSourceRow(request); + assertNotNull(response); + } + + @Test + public void testTransformFailedSpannerMutation() throws Exception { + CustomTransformationForDLQIT transformer = new CustomTransformationForDLQIT(); + MigrationTransformationResponse response = transformer.transformFailedSpannerMutation(null); + assertNotNull(response); + assertTrue(response.getResponseRow().isEmpty()); + } +} diff --git a/v2/spanner-custom-shard/src/test/java/com/custom/CustomTransformationWithShardForBulkITTest.java b/v2/spanner-custom-shard/src/test/java/com/custom/CustomTransformationWithShardForBulkITTest.java new file mode 100644 index 0000000000..b924e4097e --- /dev/null +++ b/v2/spanner-custom-shard/src/test/java/com/custom/CustomTransformationWithShardForBulkITTest.java @@ -0,0 +1,219 @@ +/* + * Copyright (C) 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.custom; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; + +import com.google.cloud.teleport.v2.spanner.utils.MigrationTransformationRequest; +import com.google.cloud.teleport.v2.spanner.utils.MigrationTransformationResponse; +import java.util.HashMap; +import java.util.Map; +import org.junit.Test; + +public class CustomTransformationWithShardForBulkITTest { + + @Test + public void testToSpannerRow_Customers() throws Exception { + CustomTransformationWithShardForBulkIT transformer = + new CustomTransformationWithShardForBulkIT(); + + Map requestRow = new HashMap<>(); + requestRow.put("first_name", "John"); + requestRow.put("last_name", "Doe"); + requestRow.put("id", 1L); + MigrationTransformationRequest request = + new MigrationTransformationRequest("Customers", requestRow, "shard1", "INSERT"); + + MigrationTransformationResponse response = transformer.toSpannerRow(request); + assertNotNull(response); + assertEquals("John Doe", response.getResponseRow().get("full_name")); + assertEquals("shard1_1", response.getResponseRow().get("migration_shard_id")); + assertFalse(response.isEventFiltered()); + } + + @Test + public void testToSpannerRow_AllDatatypeTransformation_FilterEvent() throws Exception { + CustomTransformationWithShardForBulkIT transformer = + new CustomTransformationWithShardForBulkIT(); + + Map requestRow = new HashMap<>(); + requestRow.put("varchar_column", "example1"); + MigrationTransformationRequest request = + new MigrationTransformationRequest( + "AllDatatypeTransformation", requestRow, "shard1", "INSERT"); + + MigrationTransformationResponse response = transformer.toSpannerRow(request); + assertNotNull(response); + assertTrue(response.isEventFiltered()); + assertTrue(response.getResponseRow() == null); + } + + @Test + public void testToSpannerRow_AllDatatypeTransformation_UpdateInsert() throws Exception { + CustomTransformationWithShardForBulkIT transformer = + new CustomTransformationWithShardForBulkIT(); + + Map requestRow = new HashMap<>(); + requestRow.put("varchar_column", "example2"); + MigrationTransformationRequest request = + new MigrationTransformationRequest( + "AllDatatypeTransformation", requestRow, "shard1", "UPDATE-INSERT"); + + MigrationTransformationResponse response = transformer.toSpannerRow(request); + assertNotNull(response); + assertFalse(response.isEventFiltered()); + assertTrue(response.getResponseRow() == null); + } + + @Test + public void testToSpannerRow_AllDatatypeTransformation_Normal() throws Exception { + CustomTransformationWithShardForBulkIT transformer = + new CustomTransformationWithShardForBulkIT(); + + Map requestRow = new HashMap<>(); + requestRow.put("varchar_column", "example2"); + requestRow.put("tinyint_column", 1L); + requestRow.put("text_column", "text"); + requestRow.put("int_column", 2L); + requestRow.put("bigint_column", 3L); + requestRow.put("float_column", 4.0); + requestRow.put("double_column", 5.0); + requestRow.put("decimal_column", "6.0"); + requestRow.put("bool_column", 0L); + requestRow.put("enum_column", "0"); + requestRow.put("blob_column", "blob"); + requestRow.put("binary_column", "bin"); + requestRow.put("bit_column", 12L); + requestRow.put("year_column", 2020L); + requestRow.put("date_column", "2020-01-01"); + requestRow.put("datetime_column", "2020-01-01T12:00:00Z"); + requestRow.put("timestamp_column", "2020-01-01T12:00:00Z"); + requestRow.put("time_column", "12:00:00"); + + MigrationTransformationRequest request = + new MigrationTransformationRequest( + "AllDatatypeTransformation", requestRow, "shard1", "INSERT"); + + MigrationTransformationResponse response = transformer.toSpannerRow(request); + assertNotNull(response); + assertEquals(2L, response.getResponseRow().get("tinyint_column")); + assertEquals("text append", response.getResponseRow().get("text_column")); + assertEquals(3L, response.getResponseRow().get("int_column")); + assertEquals(4L, response.getResponseRow().get("bigint_column")); + assertEquals(5.0, response.getResponseRow().get("float_column")); + assertEquals(6.0, response.getResponseRow().get("double_column")); + assertEquals("7.0", response.getResponseRow().get("decimal_column")); + assertEquals(1, response.getResponseRow().get("bool_column")); + assertEquals("1", response.getResponseRow().get("enum_column")); + assertEquals("576f726d64", response.getResponseRow().get("blob_column")); + assertEquals("2020-01-02", response.getResponseRow().get("date_column")); + assertEquals("2020-01-01T11:59:59Z", response.getResponseRow().get("datetime_column")); + assertEquals("2020-01-01T11:59:59Z", response.getResponseRow().get("timestamp_column")); + assertEquals("13:00:00", response.getResponseRow().get("time_column")); + } + + @Test + public void testToSpannerRow_AllDatatypeTransformation_OptionalColumns() throws Exception { + CustomTransformationWithShardForBulkIT transformer = + new CustomTransformationWithShardForBulkIT(); + + Map requestRow = new HashMap<>(); + requestRow.put("varchar_column", "example2"); + requestRow.put("varbinary_column", "val"); + requestRow.put("char_column", "val"); + requestRow.put("longblob_column", "val"); + requestRow.put("longtext_column", "val"); + requestRow.put("mediumblob_column", "val"); + requestRow.put("mediumint_column", 1L); + requestRow.put("mediumtext_column", "val"); + requestRow.put("set_column", "val"); + requestRow.put("smallint_column", 1L); + requestRow.put("tinyblob_column", "val"); + requestRow.put("tinytext_column", "val"); + requestRow.put("json_column", "val"); + + requestRow.put("tinyint_column", 1L); + requestRow.put("text_column", "text"); + requestRow.put("int_column", 2L); + requestRow.put("bigint_column", 3L); + requestRow.put("float_column", 4.0); + requestRow.put("double_column", 5.0); + requestRow.put("decimal_column", "6.0"); + requestRow.put("year_column", 2020L); + requestRow.put("date_column", "2020-01-01"); + requestRow.put("datetime_column", "2020-01-01T12:00:00Z"); + requestRow.put("timestamp_column", "2020-01-01T12:00:00Z"); + requestRow.put("time_column", "12:00:00"); + + MigrationTransformationRequest request = + new MigrationTransformationRequest( + "AllDatatypeTransformation", requestRow, "shard1", "INSERT"); + + MigrationTransformationResponse response = transformer.toSpannerRow(request); + assertNotNull(response); + assertEquals( + "0102030405060708090A0B0C0D0E0F1011121314", + response.getResponseRow().get("varbinary_column")); + assertEquals("newchar", response.getResponseRow().get("char_column")); + assertEquals("576f726d64", response.getResponseRow().get("longblob_column")); + assertEquals("val append", response.getResponseRow().get("longtext_column")); + assertEquals("576f726d64", response.getResponseRow().get("mediumblob_column")); + assertEquals(2L, response.getResponseRow().get("mediumint_column")); + assertEquals("val append", response.getResponseRow().get("mediumtext_column")); + assertEquals("v3", response.getResponseRow().get("set_column")); + assertEquals(2L, response.getResponseRow().get("smallint_column")); + assertEquals("576f726d64", response.getResponseRow().get("tinyblob_column")); + assertEquals("val append", response.getResponseRow().get("tinytext_column")); + assertEquals("{\"k1\": \"v1\", \"k2\": \"v2\"}", response.getResponseRow().get("json_column")); + } + + @Test + public void testToSpannerRow_UnknownTable() throws Exception { + CustomTransformationWithShardForBulkIT transformer = + new CustomTransformationWithShardForBulkIT(); + + MigrationTransformationRequest request = + new MigrationTransformationRequest("Unknown", new HashMap<>(), "shard1", "INSERT"); + + MigrationTransformationResponse response = transformer.toSpannerRow(request); + assertNotNull(response); + assertTrue(response.getResponseRow() == null); + assertFalse(response.isEventFiltered()); + } + + @Test + public void testToSourceRow() throws Exception { + CustomTransformationWithShardForBulkIT transformer = + new CustomTransformationWithShardForBulkIT(); + MigrationTransformationResponse response = transformer.toSourceRow(null); + assertNotNull(response); + assertTrue(response.getResponseRow() == null); + assertFalse(response.isEventFiltered()); + } + + @Test + public void testTransformFailedSpannerMutation() throws Exception { + CustomTransformationWithShardForBulkIT transformer = + new CustomTransformationWithShardForBulkIT(); + MigrationTransformationResponse response = transformer.transformFailedSpannerMutation(null); + assertNotNull(response); + assertTrue(response.getResponseRow() == null); + assertFalse(response.isEventFiltered()); + } +} diff --git a/v2/spanner-to-sourcedb/src/main/java/com/google/cloud/teleport/v2/templates/dbutils/dml/CassandraDMLGenerator.java b/v2/spanner-to-sourcedb/src/main/java/com/google/cloud/teleport/v2/templates/dbutils/dml/CassandraDMLGenerator.java index 7ea205a53d..e8a61c072a 100644 --- a/v2/spanner-to-sourcedb/src/main/java/com/google/cloud/teleport/v2/templates/dbutils/dml/CassandraDMLGenerator.java +++ b/v2/spanner-to-sourcedb/src/main/java/com/google/cloud/teleport/v2/templates/dbutils/dml/CassandraDMLGenerator.java @@ -269,7 +269,7 @@ private static DMLGeneratorResponse getUpsertStatementCQL( *

If no primary key column values are provided, an empty WHERE clause is generated. An * exception may be thrown if any value type does not match the expected type. */ - private static DMLGeneratorResponse getDeleteStatementCQL( + static DMLGeneratorResponse getDeleteStatementCQL( String tableName, java.sql.Timestamp timestamp, Map> allColumnNamesAndValues) { @@ -284,16 +284,19 @@ private static DMLGeneratorResponse getDeleteStatementCQL( List> values = new ArrayList<>(allColumnNamesAndValues.values()); + String preparedStatement; if (timestamp != null) { PreparedStatementValueObject timestampObj = PreparedStatementValueObject.create("USING_TIMESTAMP", timestamp.getTime()); values.add(0, timestampObj); + preparedStatement = + String.format( + "DELETE FROM %s USING TIMESTAMP ? WHERE %s", escapedTableName, deleteConditions); + } else { + preparedStatement = + String.format("DELETE FROM %s WHERE %s", escapedTableName, deleteConditions); } - String preparedStatement = - String.format( - "DELETE FROM %s USING TIMESTAMP ? WHERE %s", escapedTableName, deleteConditions); - return new PreparedStatementGeneratedResponse(preparedStatement, values); } @@ -315,7 +318,7 @@ private static DMLGeneratorResponse getDeleteStatementCQL( * `newValuesJson` and retrieves the appropriate value. 4. Skips columns that do not exist in * any of the JSON objects or are marked as null. */ - private static Map> getColumnValues( + static Map> getColumnValues( ISchemaMapper schemaMapper, Table spannerTable, SourceTable sourceTable, @@ -390,7 +393,7 @@ private static Map> getColumnValues( * `newValuesJson` and retrieves the appropriate value. 4. Returns null if any required * primary key column is missing in the JSON objects. */ - private static Map> getPkColumnValues( + static Map> getPkColumnValues( ISchemaMapper schemaMapper, Table spannerTable, SourceTable sourceTable, diff --git a/v2/spanner-to-sourcedb/src/main/java/com/google/cloud/teleport/v2/templates/dbutils/dml/MySQLDMLGenerator.java b/v2/spanner-to-sourcedb/src/main/java/com/google/cloud/teleport/v2/templates/dbutils/dml/MySQLDMLGenerator.java index d4ee9352b1..cf530ea18e 100644 --- a/v2/spanner-to-sourcedb/src/main/java/com/google/cloud/teleport/v2/templates/dbutils/dml/MySQLDMLGenerator.java +++ b/v2/spanner-to-sourcedb/src/main/java/com/google/cloud/teleport/v2/templates/dbutils/dml/MySQLDMLGenerator.java @@ -21,6 +21,7 @@ import com.google.cloud.teleport.v2.spanner.migrations.schema.ISchemaMapper; import com.google.cloud.teleport.v2.spanner.sourceddl.SourceColumn; import com.google.cloud.teleport.v2.spanner.sourceddl.SourceSchema; +import com.google.cloud.teleport.v2.spanner.sourceddl.SourceTable; import com.google.cloud.teleport.v2.spanner.type.Type; import com.google.cloud.teleport.v2.templates.exceptions.InvalidDMLGenerationException; import com.google.cloud.teleport.v2.templates.models.DMLGeneratorRequest; @@ -69,8 +70,7 @@ public DMLGeneratorResponse getDMLStatement(DMLGeneratorRequest dmlGeneratorRequ throw new InvalidDMLGenerationException( "Could not find source table name for spanner table: " + spannerTableName, e); } - com.google.cloud.teleport.v2.spanner.sourceddl.SourceTable sourceTable = - sourceSchema.table(sourceTableName); + SourceTable sourceTable = sourceSchema.table(sourceTableName); if (sourceTable == null) { throw new InvalidDMLGenerationException( String.format( @@ -179,7 +179,7 @@ private static DMLGeneratorResponse getDeleteStatement( private static DMLGeneratorResponse generateUpsertStatement( Table spannerTable, - com.google.cloud.teleport.v2.spanner.sourceddl.SourceTable sourceTable, + SourceTable sourceTable, DMLGeneratorRequest dmlGeneratorRequest, Map pkcolumnNameValues) { Map columnNameValues = @@ -196,7 +196,8 @@ private static DMLGeneratorResponse generateUpsertStatement( return getUpsertStatement(sourceTable.name(), columnNameValues); } - private static String getMappedColumnValue( + @VisibleForTesting + static String getMappedColumnValue( Column spannerColDef, SourceColumn sourceColDef, JSONObject valuesJson, @@ -258,7 +259,8 @@ protected static String convertBase64ToHex(String base64EncodedString) { return rawHex.isEmpty() ? "x''" : "x'" + rawHex + "'"; } - private static String getColumnValueByType( + @VisibleForTesting + static String getColumnValueByType( String columnType, String colValue, String sourceDbTimezoneOffset, String spannerColType) { String response = ""; String cleanedNullBytes = ""; diff --git a/v2/spanner-to-sourcedb/src/main/java/com/google/cloud/teleport/v2/templates/dbutils/dml/PostgreSQLDMLGenerator.java b/v2/spanner-to-sourcedb/src/main/java/com/google/cloud/teleport/v2/templates/dbutils/dml/PostgreSQLDMLGenerator.java index 08b033b0a9..1e4edf7bc5 100644 --- a/v2/spanner-to-sourcedb/src/main/java/com/google/cloud/teleport/v2/templates/dbutils/dml/PostgreSQLDMLGenerator.java +++ b/v2/spanner-to-sourcedb/src/main/java/com/google/cloud/teleport/v2/templates/dbutils/dml/PostgreSQLDMLGenerator.java @@ -21,6 +21,7 @@ import com.google.cloud.teleport.v2.spanner.migrations.schema.ISchemaMapper; import com.google.cloud.teleport.v2.spanner.sourceddl.SourceColumn; import com.google.cloud.teleport.v2.spanner.sourceddl.SourceSchema; +import com.google.cloud.teleport.v2.spanner.sourceddl.SourceTable; import com.google.cloud.teleport.v2.spanner.type.Type; import com.google.cloud.teleport.v2.templates.exceptions.InvalidDMLGenerationException; import com.google.cloud.teleport.v2.templates.models.DMLGeneratorRequest; @@ -71,8 +72,7 @@ public DMLGeneratorResponse getDMLStatement(DMLGeneratorRequest dmlGeneratorRequ throw new InvalidDMLGenerationException( "Could not find source table name for spanner table: " + spannerTableName, e); } - com.google.cloud.teleport.v2.spanner.sourceddl.SourceTable sourceTable = - sourceSchema.table(sourceTableName); + SourceTable sourceTable = sourceSchema.table(sourceTableName); if (sourceTable == null) { throw new InvalidDMLGenerationException( String.format( @@ -190,7 +190,7 @@ private static DMLGeneratorResponse getDeleteStatement( private static DMLGeneratorResponse generateUpsertStatement( Table spannerTable, - com.google.cloud.teleport.v2.spanner.sourceddl.SourceTable sourceTable, + SourceTable sourceTable, DMLGeneratorRequest dmlGeneratorRequest, Map pkcolumnNameValues) { Map columnNameValues = @@ -311,7 +311,7 @@ private static String escapeString(String input) { return cleanedNullBytes; } - private static String getQuotedEscapedString(String input, String spannerColType) { + static String getQuotedEscapedString(String input, String spannerColType) { if ("BYTES".equals(spannerColType) || "PG_BYTEA".equals(spannerColType)) { return input; } diff --git a/v2/spanner-to-sourcedb/src/test/java/com/google/cloud/teleport/v2/templates/changestream/ChangeStreamErrorRecordTest.java b/v2/spanner-to-sourcedb/src/test/java/com/google/cloud/teleport/v2/templates/changestream/ChangeStreamErrorRecordTest.java new file mode 100644 index 0000000000..544bec9a31 --- /dev/null +++ b/v2/spanner-to-sourcedb/src/test/java/com/google/cloud/teleport/v2/templates/changestream/ChangeStreamErrorRecordTest.java @@ -0,0 +1,39 @@ +/* + * Copyright (C) 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.teleport.v2.templates.changestream; + +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +import org.junit.Test; + +public class ChangeStreamErrorRecordTest { + + @Test + public void testEquals() { + ChangeStreamErrorRecord record1 = new ChangeStreamErrorRecord("rec1", "err1"); + ChangeStreamErrorRecord record2 = new ChangeStreamErrorRecord("rec1", "err1"); + ChangeStreamErrorRecord record3 = new ChangeStreamErrorRecord("rec2", "err1"); + ChangeStreamErrorRecord record4 = new ChangeStreamErrorRecord("rec1", "err2"); + + assertTrue(record1.equals(record1)); // Same instance + assertTrue(record1.equals(record2)); // Equal + assertFalse(record1.equals(record3)); // Different record + assertFalse(record1.equals(record4)); // Different error + assertFalse(record1.equals(null)); // Null + assertFalse(record1.equals("string")); // Different class + } +} diff --git a/v2/spanner-to-sourcedb/src/test/java/com/google/cloud/teleport/v2/templates/changestream/TrimmedShardedDataChangeRecordTest.java b/v2/spanner-to-sourcedb/src/test/java/com/google/cloud/teleport/v2/templates/changestream/TrimmedShardedDataChangeRecordTest.java new file mode 100644 index 0000000000..a0e34b787e --- /dev/null +++ b/v2/spanner-to-sourcedb/src/test/java/com/google/cloud/teleport/v2/templates/changestream/TrimmedShardedDataChangeRecordTest.java @@ -0,0 +1,104 @@ +/* + * Copyright (C) 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.teleport.v2.templates.changestream; + +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +import com.google.cloud.Timestamp; +import org.apache.beam.sdk.io.gcp.spanner.changestreams.model.Mod; +import org.apache.beam.sdk.io.gcp.spanner.changestreams.model.ModType; +import org.junit.Test; + +public class TrimmedShardedDataChangeRecordTest { + + @Test + public void testEquals() { + Timestamp ts1 = Timestamp.ofTimeSecondsAndNanos(1000, 0); + Timestamp ts2 = Timestamp.ofTimeSecondsAndNanos(2000, 0); + Mod mod1 = new Mod("{\"id\": 1}", "{}", "{}"); + Mod mod2 = new Mod("{\"id\": 2}", "{}", "{}"); + + TrimmedShardedDataChangeRecord record1 = + new TrimmedShardedDataChangeRecord( + ts1, "txn1", "seq1", "table1", mod1, ModType.INSERT, 1, "tag1"); + record1.setShard("shard1"); + + TrimmedShardedDataChangeRecord record2 = + new TrimmedShardedDataChangeRecord( + ts1, "txn1", "seq1", "table1", mod1, ModType.INSERT, 1, "tag1"); + record2.setShard("shard1"); + + assertTrue(record1.equals(record1)); // Same instance + assertTrue(record1.equals(record2)); // Equal + + // Test differences in each field + TrimmedShardedDataChangeRecord diffTs = + new TrimmedShardedDataChangeRecord( + ts2, "txn1", "seq1", "table1", mod1, ModType.INSERT, 1, "tag1"); + diffTs.setShard("shard1"); + assertFalse(record1.equals(diffTs)); + + TrimmedShardedDataChangeRecord diffTxn = + new TrimmedShardedDataChangeRecord( + ts1, "txn2", "seq1", "table1", mod1, ModType.INSERT, 1, "tag1"); + diffTxn.setShard("shard1"); + assertFalse(record1.equals(diffTxn)); + + TrimmedShardedDataChangeRecord diffSeq = + new TrimmedShardedDataChangeRecord( + ts1, "txn1", "seq2", "table1", mod1, ModType.INSERT, 1, "tag1"); + diffSeq.setShard("shard1"); + assertFalse(record1.equals(diffSeq)); + + TrimmedShardedDataChangeRecord diffTable = + new TrimmedShardedDataChangeRecord( + ts1, "txn1", "seq1", "table2", mod1, ModType.INSERT, 1, "tag1"); + diffTable.setShard("shard1"); + assertFalse(record1.equals(diffTable)); + + TrimmedShardedDataChangeRecord diffMod = + new TrimmedShardedDataChangeRecord( + ts1, "txn1", "seq1", "table1", mod2, ModType.INSERT, 1, "tag1"); + diffMod.setShard("shard1"); + assertFalse(record1.equals(diffMod)); + + TrimmedShardedDataChangeRecord diffModType = + new TrimmedShardedDataChangeRecord( + ts1, "txn1", "seq1", "table1", mod1, ModType.DELETE, 1, "tag1"); + diffModType.setShard("shard1"); + assertFalse(record1.equals(diffModType)); + + TrimmedShardedDataChangeRecord diffCount = + new TrimmedShardedDataChangeRecord( + ts1, "txn1", "seq1", "table1", mod1, ModType.INSERT, 2, "tag1"); + diffCount.setShard("shard1"); + assertFalse(record1.equals(diffCount)); + + TrimmedShardedDataChangeRecord diffTag = + new TrimmedShardedDataChangeRecord( + ts1, "txn1", "seq1", "table1", mod1, ModType.INSERT, 1, "tag2"); + diffTag.setShard("shard1"); + assertFalse(record1.equals(diffTag)); + + TrimmedShardedDataChangeRecord diffShard = new TrimmedShardedDataChangeRecord(record1); + diffShard.setShard("shard2"); + assertFalse(record1.equals(diffShard)); + + assertFalse(record1.equals(null)); // Null + assertFalse(record1.equals("string")); // Different class + } +} diff --git a/v2/spanner-to-sourcedb/src/test/java/com/google/cloud/teleport/v2/templates/dbutils/connection/CassandraConnectionHelperTest.java b/v2/spanner-to-sourcedb/src/test/java/com/google/cloud/teleport/v2/templates/dbutils/connection/CassandraConnectionHelperTest.java index f088c694c9..28910c906d 100644 --- a/v2/spanner-to-sourcedb/src/test/java/com/google/cloud/teleport/v2/templates/dbutils/connection/CassandraConnectionHelperTest.java +++ b/v2/spanner-to-sourcedb/src/test/java/com/google/cloud/teleport/v2/templates/dbutils/connection/CassandraConnectionHelperTest.java @@ -16,6 +16,7 @@ package com.google.cloud.teleport.v2.templates.dbutils.connection; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; @@ -200,4 +201,10 @@ public void testInit_ShouldNotInitializeConnectionPool() { assertThrows(IllegalArgumentException.class, () -> connectionHelper.init(request)); assertEquals("The options map must contain a profile named default", exception.getMessage()); } + + @Test + public void testIsConnectionPoolInitialized_NullPool() { + connectionHelper.setConnectionPoolMap(null); + assertFalse(connectionHelper.isConnectionPoolInitialized()); + } } diff --git a/v2/spanner-to-sourcedb/src/test/java/com/google/cloud/teleport/v2/templates/dbutils/dml/CassandraDMLGeneratorTest.java b/v2/spanner-to-sourcedb/src/test/java/com/google/cloud/teleport/v2/templates/dbutils/dml/CassandraDMLGeneratorTest.java index 0c3920a302..bcf635fae5 100644 --- a/v2/spanner-to-sourcedb/src/test/java/com/google/cloud/teleport/v2/templates/dbutils/dml/CassandraDMLGeneratorTest.java +++ b/v2/spanner-to-sourcedb/src/test/java/com/google/cloud/teleport/v2/templates/dbutils/dml/CassandraDMLGeneratorTest.java @@ -16,29 +16,36 @@ package com.google.cloud.teleport.v2.templates.dbutils.dml; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; import com.google.cloud.Timestamp; import com.google.cloud.teleport.v2.spanner.ddl.Ddl; +import com.google.cloud.teleport.v2.spanner.ddl.Table; import com.google.cloud.teleport.v2.spanner.migrations.schema.ISchemaMapper; import com.google.cloud.teleport.v2.spanner.migrations.schema.SessionBasedMapper; +import com.google.cloud.teleport.v2.spanner.sourceddl.SourceColumn; import com.google.cloud.teleport.v2.spanner.sourceddl.SourceSchema; +import com.google.cloud.teleport.v2.spanner.sourceddl.SourceTable; import com.google.cloud.teleport.v2.templates.exceptions.InvalidDMLGenerationException; import com.google.cloud.teleport.v2.templates.models.DMLGeneratorRequest; import com.google.cloud.teleport.v2.templates.models.DMLGeneratorResponse; import com.google.cloud.teleport.v2.templates.models.PreparedStatementGeneratedResponse; import com.google.cloud.teleport.v2.templates.models.PreparedStatementValueObject; import com.google.cloud.teleport.v2.templates.utils.SchemaUtils; +import com.google.common.collect.ImmutableList; import java.nio.ByteBuffer; import java.time.Instant; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.NoSuchElementException; import org.json.JSONObject; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; +import org.mockito.Mockito; import org.mockito.junit.MockitoJUnitRunner; @RunWith(MockitoJUnitRunner.class) @@ -872,4 +879,137 @@ public void testSourcePkNull() { .setSourceSchema(sourceSchema) .build())); } + + @Test + public void testGetDeleteStatementCQL_NullTimestamp() { + Map> pkValues = + Map.of("id", PreparedStatementValueObject.create("int", 1)); + DMLGeneratorResponse response = + CassandraDMLGenerator.getDeleteStatementCQL("my_table", null, pkValues); + + assertEquals("DELETE FROM \"my_table\" WHERE \"id\" = ?", response.getDmlStatement()); + } + + @Test + public void testGetColumnValues_MissingSpannerColumnMapping() throws Exception { + ISchemaMapper mockSchemaMapper = Mockito.mock(ISchemaMapper.class); + Table spannerTable = Mockito.mock(Table.class); + SourceTable sourceTable = Mockito.mock(SourceTable.class); + SourceColumn sourceCol = Mockito.mock(SourceColumn.class); + + Mockito.when(sourceTable.primaryKeyColumns()).thenReturn(ImmutableList.of()); + Mockito.when(sourceTable.columns()).thenReturn(ImmutableList.of(sourceCol)); + Mockito.when(sourceCol.name()).thenReturn("col1"); + Mockito.when(mockSchemaMapper.getSpannerColumnName("", "my_table", "col1")) + .thenThrow(new NoSuchElementException()); + Mockito.when(sourceTable.name()).thenReturn("my_table"); + + Map> result = + CassandraDMLGenerator.getColumnValues( + mockSchemaMapper, + spannerTable, + sourceTable, + new JSONObject(), + new JSONObject(), + "+00:00", + null); + + assertTrue(result.isEmpty()); + } + + @Test + public void testGetPkColumnValues_MissingSourceColumnDefinition() { + ISchemaMapper schemaMapper = Mockito.mock(ISchemaMapper.class); + Table spannerTable = Mockito.mock(Table.class); + SourceTable sourceTable = Mockito.mock(SourceTable.class); + + Mockito.when(sourceTable.primaryKeyColumns()).thenReturn(ImmutableList.of("id")); + Mockito.when(sourceTable.column("id")).thenReturn(null); // Missing! + + Map> result = + CassandraDMLGenerator.getPkColumnValues( + schemaMapper, + spannerTable, + sourceTable, + new JSONObject(), + new JSONObject(), + "+00:00", + null); + + assertNull(result); + } + + @Test + public void testGetDMLStatement_NullSpannerDdl() { + CassandraDMLGenerator generator = new CassandraDMLGenerator(); + assertThrows( + InvalidDMLGenerationException.class, + () -> + generator.getDMLStatement( + new DMLGeneratorRequest.Builder( + "INSERT", "tableName", new JSONObject(), new JSONObject(), "+00:00") + .setSchemaMapper(Mockito.mock(ISchemaMapper.class)) + .setSourceSchema(Mockito.mock(SourceSchema.class)) + .setDdl(null) + .build())); + } + + @Test + public void testGetDMLStatement_NullSourceSchema() { + CassandraDMLGenerator generator = new CassandraDMLGenerator(); + assertThrows( + InvalidDMLGenerationException.class, + () -> + generator.getDMLStatement( + new DMLGeneratorRequest.Builder( + "INSERT", "tableName", new JSONObject(), new JSONObject(), "+00:00") + .setSchemaMapper(Mockito.mock(ISchemaMapper.class)) + .setDdl(Mockito.mock(Ddl.class)) + .setSourceSchema(null) + .build())); + } + + @Test + public void testGetDMLStatement_NullSpannerTable() { + CassandraDMLGenerator generator = new CassandraDMLGenerator(); + Ddl ddl = Mockito.mock(Ddl.class); + Mockito.when(ddl.table("tableName")).thenReturn(null); + + assertThrows( + InvalidDMLGenerationException.class, + () -> + generator.getDMLStatement( + new DMLGeneratorRequest.Builder( + "INSERT", "tableName", new JSONObject(), new JSONObject(), "+00:00") + .setSchemaMapper(Mockito.mock(ISchemaMapper.class)) + .setDdl(ddl) + .setSourceSchema(Mockito.mock(SourceSchema.class)) + .build())); + } + + @Test + public void testGetDMLStatement_NullSourceTable() { + CassandraDMLGenerator generator = new CassandraDMLGenerator(); + Ddl ddl = Mockito.mock(Ddl.class); + Table table = Mockito.mock(Table.class); + Mockito.when(ddl.table("tableName")).thenReturn(table); + + SourceSchema sourceSchema = Mockito.mock(SourceSchema.class); + Mockito.when(sourceSchema.table(Mockito.anyString())).thenReturn(null); + + ISchemaMapper schemaMapper = Mockito.mock(ISchemaMapper.class); + Mockito.when(schemaMapper.getSourceTableName(Mockito.anyString(), Mockito.anyString())) + .thenReturn("src_table"); + + assertThrows( + InvalidDMLGenerationException.class, + () -> + generator.getDMLStatement( + new DMLGeneratorRequest.Builder( + "INSERT", "tableName", new JSONObject(), new JSONObject(), "+00:00") + .setSchemaMapper(schemaMapper) + .setDdl(ddl) + .setSourceSchema(sourceSchema) + .build())); + } } diff --git a/v2/spanner-to-sourcedb/src/test/java/com/google/cloud/teleport/v2/templates/dbutils/dml/CassandraTypeHandlerTest.java b/v2/spanner-to-sourcedb/src/test/java/com/google/cloud/teleport/v2/templates/dbutils/dml/CassandraTypeHandlerTest.java index 0c53f61b6a..4eba0bd541 100644 --- a/v2/spanner-to-sourcedb/src/test/java/com/google/cloud/teleport/v2/templates/dbutils/dml/CassandraTypeHandlerTest.java +++ b/v2/spanner-to-sourcedb/src/test/java/com/google/cloud/teleport/v2/templates/dbutils/dml/CassandraTypeHandlerTest.java @@ -23,6 +23,8 @@ import static org.junit.Assert.assertNull; import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; import com.datastax.oss.driver.api.core.data.CqlDuration; import com.google.cloud.spanner.Dialect; @@ -31,6 +33,7 @@ import com.google.cloud.teleport.v2.spanner.ddl.Table; import com.google.cloud.teleport.v2.spanner.sourceddl.SourceColumn; import com.google.cloud.teleport.v2.spanner.sourceddl.SourceDatabaseType; +import com.google.cloud.teleport.v2.spanner.type.Type; import com.google.cloud.teleport.v2.templates.models.PreparedStatementValueObject; import com.google.common.net.InetAddresses; import java.math.BigDecimal; @@ -792,6 +795,48 @@ public void testConvertToCassandraTimestampWithISOInstant() { assertEquals(expectedValue, castResult); } + @Test + public void testCastToExpectedType_SmallInt() { + Object result = CassandraTypeHandler.castToExpectedType("smallint", "123"); + assertEquals((short) 123, result); + } + + @Test + public void testCastToExpectedType_TinyInt() { + Object result = CassandraTypeHandler.castToExpectedType("tinyint", "12"); + assertEquals((byte) 12, result); + } + + @Test + public void testCastToExpectedType_Float() { + Object result = CassandraTypeHandler.castToExpectedType("float", "3.14"); + assertEquals(3.14f, result); + } + + @Test + public void testCastToExpectedType_Boolean_Zero() { + Object result = CassandraTypeHandler.castToExpectedType("boolean", "0"); + assertEquals(false, result); + } + + @Test + public void testCastToExpectedType_EmptyList() { + Object result = CassandraTypeHandler.castToExpectedType("list", "[]"); + assertEquals(java.util.Collections.emptyList(), result); + } + + @Test + public void testCastToExpectedType_EmptySet() { + Object result = CassandraTypeHandler.castToExpectedType("set", "[]"); + assertEquals(java.util.Collections.emptySet(), result); + } + + @Test + public void testCastToExpectedType_EmptyMap() { + Object result = CassandraTypeHandler.castToExpectedType("map", "{}"); + assertEquals(java.util.Collections.emptyMap(), result); + } + @Test public void testConvertToCassandraTimestampWithISODateTime() { String timestamp = "2025-01-15T10:15:30"; @@ -1645,4 +1690,103 @@ public void testCastToExpectedTypeForEmptyMap() { assertTrue(result instanceof Map); assertEquals(Collections.emptyMap(), result); } + + @Test + public void testGetColumnValueByType_NullSpannerColumn() { + SourceColumn sourceCol = mock(SourceColumn.class); + when(sourceCol.type()).thenReturn("text"); + + assertThrows( + IllegalArgumentException.class, + () -> CassandraTypeHandler.getColumnValueByType(null, sourceCol, new JSONObject(), "")); + } + + @Test + public void testGetColumnValueByType_NullSourceColumn() { + Column spannerCol = mock(Column.class); + Type mockType = mock(Type.class); + when(spannerCol.type()).thenReturn(mockType); + + assertThrows( + IllegalArgumentException.class, + () -> CassandraTypeHandler.getColumnValueByType(spannerCol, null, new JSONObject(), "")); + } + + @Test + public void testCastToExpectedType_EmptyUuid() { + assertThrows( + IllegalArgumentException.class, () -> CassandraTypeHandler.castToExpectedType("uuid", "")); + } + + @Test + public void testCastToExpectedType_EmptyTimestamp() { + assertThrows( + IllegalArgumentException.class, + () -> CassandraTypeHandler.castToExpectedType("timestamp", "")); + } + + @Test + public void testParseBlobType_ByteBuffer() { + java.nio.ByteBuffer buffer = java.nio.ByteBuffer.wrap(new byte[] {1, 2, 3}); + Object result = CassandraTypeHandler.castToExpectedType("blob", buffer); + assertEquals(buffer, result); + } + + @Test + public void testParseBlobType_ByteArray() { + byte[] bytes = new byte[] {1, 2, 3}; + Object result = CassandraTypeHandler.castToExpectedType("blob", bytes); + assertTrue(result instanceof ByteBuffer); + assertEquals(ByteBuffer.wrap(bytes), result); + } + + @Test + public void testHandleCassandraTimestampType_Empty() { + assertThrows( + IllegalArgumentException.class, + () -> CassandraTypeHandler.castToExpectedType("timestamp", "")); + } + + @Test + public void testHandleCassandraUuidType_Null() { + Object result = CassandraTypeHandler.castToExpectedType("uuid", null); + assertNull(result); + } + + @Test + public void testConvertToCassandraTimestamp_Spaces() { + assertThrows( + IllegalArgumentException.class, + () -> CassandraTypeHandler.castToExpectedType("timestamp", " ")); + } + + @Test + public void testConvertToCassandraTimestamp_TimeString() { + Object result = CassandraTypeHandler.castToExpectedType("timestamp", "12:30:00"); + assertTrue(result instanceof Instant); + } + + @Test + public void testParseAndCastToCassandraType_List_JSONArray() { + JSONArray array = new JSONArray("[\"apple\"]"); + Object result = CassandraTypeHandler.castToExpectedType("list", array); + assertNotNull(result); + assertTrue(result instanceof List); + } + + @Test + public void testParseAndCastToCassandraType_Set_JSONArray() { + JSONArray array = new JSONArray("[\"apple\"]"); + Object result = CassandraTypeHandler.castToExpectedType("set", array); + assertNotNull(result); + assertTrue(result instanceof Set); + } + + @Test + public void testParseAndCastToCassandraType_Map_JSONObject() { + JSONObject obj = new JSONObject("{\"name\": \"John\"}"); + Object result = CassandraTypeHandler.castToExpectedType("map", obj); + assertNotNull(result); + assertTrue(result instanceof Map); + } } diff --git a/v2/spanner-to-sourcedb/src/test/java/com/google/cloud/teleport/v2/templates/dbutils/dml/DMLGeneratorUtilsTest.java b/v2/spanner-to-sourcedb/src/test/java/com/google/cloud/teleport/v2/templates/dbutils/dml/DMLGeneratorUtilsTest.java new file mode 100644 index 0000000000..bde38bcbe9 --- /dev/null +++ b/v2/spanner-to-sourcedb/src/test/java/com/google/cloud/teleport/v2/templates/dbutils/dml/DMLGeneratorUtilsTest.java @@ -0,0 +1,49 @@ +/* + * Copyright (C) 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.teleport.v2.templates.dbutils.dml; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThrows; + +import org.junit.Test; + +public class DMLGeneratorUtilsTest { + + @Test + public void testConvertBase64ToRawHex_Null() { + assertNull(DMLGeneratorUtils.convertBase64ToRawHex(null)); + } + + @Test + public void testConvertBase64ToRawHex_Empty() { + assertEquals("", DMLGeneratorUtils.convertBase64ToRawHex("")); + } + + @Test + public void testConvertBase64ToRawHex_Invalid() { + assertThrows( + IllegalArgumentException.class, + () -> DMLGeneratorUtils.convertBase64ToRawHex("invalid base64!")); + } + + @Test + public void testConvertBase64ToRawHex_Valid() { + // "Hello" in base64 is "SGVsbG8=" + // "Hello" in hex is "48656c6c6f" + assertEquals("48656c6c6f", DMLGeneratorUtils.convertBase64ToRawHex("SGVsbG8=")); + } +} diff --git a/v2/spanner-to-sourcedb/src/test/java/com/google/cloud/teleport/v2/templates/dbutils/dml/MySQLDMLGeneratorTest.java b/v2/spanner-to-sourcedb/src/test/java/com/google/cloud/teleport/v2/templates/dbutils/dml/MySQLDMLGeneratorTest.java index 117b362eae..2a11b5ab4f 100644 --- a/v2/spanner-to-sourcedb/src/test/java/com/google/cloud/teleport/v2/templates/dbutils/dml/MySQLDMLGeneratorTest.java +++ b/v2/spanner-to-sourcedb/src/test/java/com/google/cloud/teleport/v2/templates/dbutils/dml/MySQLDMLGeneratorTest.java @@ -19,13 +19,19 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import com.google.cloud.teleport.v2.spanner.ddl.Column; import com.google.cloud.teleport.v2.spanner.ddl.Ddl; +import com.google.cloud.teleport.v2.spanner.ddl.Table; import com.google.cloud.teleport.v2.spanner.migrations.schema.ISchemaMapper; import com.google.cloud.teleport.v2.spanner.migrations.schema.SessionBasedMapper; +import com.google.cloud.teleport.v2.spanner.sourceddl.SourceColumn; import com.google.cloud.teleport.v2.spanner.sourceddl.SourceDatabaseType; import com.google.cloud.teleport.v2.spanner.sourceddl.SourceSchema; import com.google.cloud.teleport.v2.spanner.sourceddl.SourceTable; +import com.google.cloud.teleport.v2.spanner.type.Type; import com.google.cloud.teleport.v2.templates.changestream.TrimmedShardedDataChangeRecord; import com.google.cloud.teleport.v2.templates.exceptions.InvalidDMLGenerationException; import com.google.cloud.teleport.v2.templates.models.DMLGeneratorRequest; @@ -47,6 +53,7 @@ import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; +import org.mockito.Mockito; @RunWith(JUnit4.class) public final class MySQLDMLGeneratorTest { @@ -1455,20 +1462,17 @@ public void testGeneratedPrimaryKeyDML() { .primaryKeyColumns(ImmutableList.of("SingerId")) .columns( ImmutableList.of( - com.google.cloud.teleport.v2.spanner.sourceddl.SourceColumn.builder( - SourceDatabaseType.MYSQL) + SourceColumn.builder(SourceDatabaseType.MYSQL) .name("SingerId") .type("bigint") .isGenerated(true) .build(), - com.google.cloud.teleport.v2.spanner.sourceddl.SourceColumn.builder( - SourceDatabaseType.MYSQL) + SourceColumn.builder(SourceDatabaseType.MYSQL) .name("FirstName") .type("varchar") .isPrimaryKey(true) .build(), - com.google.cloud.teleport.v2.spanner.sourceddl.SourceColumn.builder( - SourceDatabaseType.MYSQL) + SourceColumn.builder(SourceDatabaseType.MYSQL) .name("LastName") .type("varchar") .build())) @@ -1479,54 +1483,25 @@ public void testGeneratedPrimaryKeyDML() { .tables(ImmutableMap.of("GeneratedPKTable", sourceTable)) .build(); - ISchemaMapper mockSchemaMapper = org.mockito.Mockito.mock(ISchemaMapper.class); - org.mockito.Mockito.when( - mockSchemaMapper.getSpannerColumns( - org.mockito.ArgumentMatchers.any(), - org.mockito.ArgumentMatchers.eq("GeneratedPKTable"))) + ISchemaMapper mockSchemaMapper = Mockito.mock(ISchemaMapper.class); + Mockito.when(mockSchemaMapper.getSpannerColumns(any(), eq("GeneratedPKTable"))) .thenReturn(ImmutableList.of("SingerId", "FirstName", "LastName")); - org.mockito.Mockito.when( - mockSchemaMapper.getSpannerColumnName( - org.mockito.ArgumentMatchers.any(), - org.mockito.ArgumentMatchers.eq("GeneratedPKTable"), - org.mockito.ArgumentMatchers.eq("LastName"))) + Mockito.when( + mockSchemaMapper.getSpannerColumnName(any(), eq("GeneratedPKTable"), eq("LastName"))) .thenReturn("LastName"); - org.mockito.Mockito.when( - mockSchemaMapper.getSourceTableName( - org.mockito.ArgumentMatchers.any(), - org.mockito.ArgumentMatchers.eq("GeneratedPKTable"))) + Mockito.when(mockSchemaMapper.getSourceTableName(any(), eq("GeneratedPKTable"))) .thenReturn("GeneratedPKTable"); - org.mockito.Mockito.when( - mockSchemaMapper.isGeneratedColumn( - org.mockito.ArgumentMatchers.any(), - org.mockito.ArgumentMatchers.eq("GeneratedPKTable"), - org.mockito.ArgumentMatchers.eq("FirstName"))) + Mockito.when(mockSchemaMapper.isGeneratedColumn(any(), eq("GeneratedPKTable"), eq("FirstName"))) .thenReturn(true); - org.mockito.Mockito.when( - mockSchemaMapper.isGeneratedColumn( - org.mockito.ArgumentMatchers.any(), - org.mockito.ArgumentMatchers.eq("GeneratedPKTable"), - org.mockito.ArgumentMatchers.eq("SingerId"))) + Mockito.when(mockSchemaMapper.isGeneratedColumn(any(), eq("GeneratedPKTable"), eq("SingerId"))) .thenReturn(false); - org.mockito.Mockito.when( - mockSchemaMapper.isGeneratedColumn( - org.mockito.ArgumentMatchers.any(), - org.mockito.ArgumentMatchers.eq("GeneratedPKTable"), - org.mockito.ArgumentMatchers.eq("LastName"))) + Mockito.when(mockSchemaMapper.isGeneratedColumn(any(), eq("GeneratedPKTable"), eq("LastName"))) .thenReturn(false); for (String col : new String[] {"SingerId", "FirstName", "LastName"}) { - org.mockito.Mockito.when( - mockSchemaMapper.colExistsAtSource( - org.mockito.ArgumentMatchers.any(), - org.mockito.ArgumentMatchers.eq("GeneratedPKTable"), - org.mockito.ArgumentMatchers.eq(col))) + Mockito.when(mockSchemaMapper.colExistsAtSource(any(), eq("GeneratedPKTable"), eq(col))) .thenReturn(true); - org.mockito.Mockito.when( - mockSchemaMapper.getSourceColumnName( - org.mockito.ArgumentMatchers.any(), - org.mockito.ArgumentMatchers.eq("GeneratedPKTable"), - org.mockito.ArgumentMatchers.eq(col))) + Mockito.when(mockSchemaMapper.getSourceColumnName(any(), eq("GeneratedPKTable"), eq(col))) .thenReturn(col); } @@ -1557,4 +1532,141 @@ public void testGeneratedPrimaryKeyDML() { assertEquals(0, countInSQL(sql, "FirstName")); assertEquals(0, countInSQL(sql, "SingerId")); } + + @Test + public void testBitTypeDML() throws Exception { + Column spannerCol = Mockito.mock(Column.class); + Mockito.when(spannerCol.name()).thenReturn("bit_column"); + Mockito.when(spannerCol.type()).thenReturn(Type.bytes()); + + SourceColumn sourceCol = Mockito.mock(SourceColumn.class); + Mockito.when(sourceCol.name()).thenReturn("bit_column"); + Mockito.when(sourceCol.type()).thenReturn("bit"); + + JSONObject json = new JSONObject("{\"bit_column\":\"SGVsbG8=\"}"); // "Hello" in base64 + + String res = MySQLDMLGenerator.getMappedColumnValue(spannerCol, sourceCol, json, "+00:00"); + assertEquals("x'48656c6c6f'", res); // "Hello" in hex is 48656c6c6f + } + + @Test + public void testGetDMLStatement_NullRequest() { + MySQLDMLGenerator generator = new MySQLDMLGenerator(); + assertThrows(InvalidDMLGenerationException.class, () -> generator.getDMLStatement(null)); + } + + @Test + public void testGetDMLStatement_NullSchemaMapper() { + MySQLDMLGenerator generator = new MySQLDMLGenerator(); + DMLGeneratorRequest request = Mockito.mock(DMLGeneratorRequest.class); + Mockito.when(request.getSpannerTableName()).thenReturn("Singers"); + Mockito.when(request.getSchemaMapper()).thenReturn(null); + Mockito.when(request.getSpannerDdl()).thenReturn(Mockito.mock(Ddl.class)); + Mockito.when(request.getSourceSchema()).thenReturn(Mockito.mock(SourceSchema.class)); + + assertThrows(InvalidDMLGenerationException.class, () -> generator.getDMLStatement(request)); + } + + @Test + public void testGetDMLStatement_NullSpannerDdl() { + MySQLDMLGenerator generator = new MySQLDMLGenerator(); + DMLGeneratorRequest request = Mockito.mock(DMLGeneratorRequest.class); + Mockito.when(request.getSpannerTableName()).thenReturn("Singers"); + Mockito.when(request.getSchemaMapper()).thenReturn(Mockito.mock(ISchemaMapper.class)); + Mockito.when(request.getSpannerDdl()).thenReturn(null); + Mockito.when(request.getSourceSchema()).thenReturn(Mockito.mock(SourceSchema.class)); + + assertThrows(InvalidDMLGenerationException.class, () -> generator.getDMLStatement(request)); + } + + @Test + public void testGetDMLStatement_NullSourceSchema() { + MySQLDMLGenerator generator = new MySQLDMLGenerator(); + DMLGeneratorRequest request = Mockito.mock(DMLGeneratorRequest.class); + Mockito.when(request.getSpannerTableName()).thenReturn("Singers"); + Mockito.when(request.getSchemaMapper()).thenReturn(Mockito.mock(ISchemaMapper.class)); + Mockito.when(request.getSpannerDdl()).thenReturn(Mockito.mock(Ddl.class)); + Mockito.when(request.getSourceSchema()).thenReturn(null); + + assertThrows(InvalidDMLGenerationException.class, () -> generator.getDMLStatement(request)); + } + + @Test + public void testGetDMLStatement_SourceTableNotFound() { + MySQLDMLGenerator generator = new MySQLDMLGenerator(); + DMLGeneratorRequest request = Mockito.mock(DMLGeneratorRequest.class); + ISchemaMapper schemaMapper = Mockito.mock(ISchemaMapper.class); + SourceSchema sourceSchema = Mockito.mock(SourceSchema.class); + Ddl ddl = Mockito.mock(Ddl.class); + Table spannerTable = Mockito.mock(Table.class); + + Mockito.when(request.getSpannerTableName()).thenReturn("Singers"); + Mockito.when(request.getSchemaMapper()).thenReturn(schemaMapper); + Mockito.when(request.getSourceSchema()).thenReturn(sourceSchema); + Mockito.when(request.getSpannerDdl()).thenReturn(ddl); + Mockito.when(ddl.table("Singers")).thenReturn(spannerTable); + + try { + Mockito.when(schemaMapper.getSourceTableName("", "Singers")).thenReturn("Singers"); + } catch (Exception e) { + // ignore + } + Mockito.when(sourceSchema.table("Singers")).thenReturn(null); // Not found! + + assertThrows(InvalidDMLGenerationException.class, () -> generator.getDMLStatement(request)); + } + + @Test + public void testGetDMLStatement_NoPrimaryKeys() { + MySQLDMLGenerator generator = new MySQLDMLGenerator(); + DMLGeneratorRequest request = Mockito.mock(DMLGeneratorRequest.class); + ISchemaMapper schemaMapper = Mockito.mock(ISchemaMapper.class); + SourceSchema sourceSchema = Mockito.mock(SourceSchema.class); + SourceTable sourceTable = Mockito.mock(SourceTable.class); + Ddl ddl = Mockito.mock(Ddl.class); + Table spannerTable = Mockito.mock(Table.class); + + Mockito.when(request.getSpannerTableName()).thenReturn("Singers"); + Mockito.when(request.getSchemaMapper()).thenReturn(schemaMapper); + Mockito.when(request.getSourceSchema()).thenReturn(sourceSchema); + Mockito.when(request.getSpannerDdl()).thenReturn(ddl); + Mockito.when(ddl.table("Singers")).thenReturn(spannerTable); + + try { + Mockito.when(schemaMapper.getSourceTableName("", "Singers")).thenReturn("Singers"); + } catch (Exception e) { + // ignore + } + Mockito.when(sourceSchema.table("Singers")).thenReturn(sourceTable); + Mockito.when(sourceTable.primaryKeyColumns()).thenReturn(ImmutableList.of()); + + assertThrows(InvalidDMLGenerationException.class, () -> generator.getDMLStatement(request)); + } + + @Test + public void testGetDMLStatement_NullPrimaryKeys() { + MySQLDMLGenerator generator = new MySQLDMLGenerator(); + DMLGeneratorRequest request = Mockito.mock(DMLGeneratorRequest.class); + ISchemaMapper schemaMapper = Mockito.mock(ISchemaMapper.class); + SourceSchema sourceSchema = Mockito.mock(SourceSchema.class); + SourceTable sourceTable = Mockito.mock(SourceTable.class); + Ddl ddl = Mockito.mock(Ddl.class); + Table spannerTable = Mockito.mock(Table.class); + + Mockito.when(request.getSpannerTableName()).thenReturn("Singers"); + Mockito.when(request.getSchemaMapper()).thenReturn(schemaMapper); + Mockito.when(request.getSourceSchema()).thenReturn(sourceSchema); + Mockito.when(request.getSpannerDdl()).thenReturn(ddl); + Mockito.when(ddl.table("Singers")).thenReturn(spannerTable); + + try { + Mockito.when(schemaMapper.getSourceTableName("", "Singers")).thenReturn("Singers"); + } catch (Exception e) { + // ignore + } + Mockito.when(sourceSchema.table("Singers")).thenReturn(sourceTable); + Mockito.when(sourceTable.primaryKeyColumns()).thenReturn(null); // NULL! + + assertThrows(InvalidDMLGenerationException.class, () -> generator.getDMLStatement(request)); + } } diff --git a/v2/spanner-to-sourcedb/src/test/java/com/google/cloud/teleport/v2/templates/dbutils/dml/PostgreSQLDMLGeneratorTest.java b/v2/spanner-to-sourcedb/src/test/java/com/google/cloud/teleport/v2/templates/dbutils/dml/PostgreSQLDMLGeneratorTest.java index 0e9c2b95ac..739f5b136d 100644 --- a/v2/spanner-to-sourcedb/src/test/java/com/google/cloud/teleport/v2/templates/dbutils/dml/PostgreSQLDMLGeneratorTest.java +++ b/v2/spanner-to-sourcedb/src/test/java/com/google/cloud/teleport/v2/templates/dbutils/dml/PostgreSQLDMLGeneratorTest.java @@ -24,15 +24,22 @@ import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.ObjectWriter; +import com.google.cloud.teleport.v2.spanner.ddl.Column; import com.google.cloud.teleport.v2.spanner.ddl.Ddl; +import com.google.cloud.teleport.v2.spanner.ddl.Table; import com.google.cloud.teleport.v2.spanner.migrations.schema.ISchemaMapper; import com.google.cloud.teleport.v2.spanner.migrations.schema.SessionBasedMapper; +import com.google.cloud.teleport.v2.spanner.sourceddl.SourceColumn; +import com.google.cloud.teleport.v2.spanner.sourceddl.SourceDatabaseType; import com.google.cloud.teleport.v2.spanner.sourceddl.SourceSchema; +import com.google.cloud.teleport.v2.spanner.sourceddl.SourceTable; +import com.google.cloud.teleport.v2.spanner.type.Type; import com.google.cloud.teleport.v2.templates.changestream.TrimmedShardedDataChangeRecord; import com.google.cloud.teleport.v2.templates.exceptions.InvalidDMLGenerationException; import com.google.cloud.teleport.v2.templates.models.DMLGeneratorRequest; import com.google.cloud.teleport.v2.templates.models.DMLGeneratorResponse; import com.google.cloud.teleport.v2.templates.utils.SchemaUtils; +import com.google.common.collect.ImmutableList; import com.google.gson.FieldNamingPolicy; import com.google.gson.GsonBuilder; import java.io.InputStream; @@ -327,10 +334,7 @@ public void testNullValiationsRequest() { // Request with null mapper Ddl ddl = Ddl.builder().build(); SourceSchema sourceSchema = - SourceSchema.builder( - com.google.cloud.teleport.v2.spanner.sourceddl.SourceDatabaseType.POSTGRESQL) - .databaseName("test") - .build(); + SourceSchema.builder(SourceDatabaseType.POSTGRESQL).databaseName("test").build(); assertThrows( InvalidDMLGenerationException.class, () -> @@ -413,14 +417,12 @@ public void testConvertBase64ToHexEdgeCases() { @Test public void testNonByteaDecodeBranch() { // Mocking Spanner Column - com.google.cloud.teleport.v2.spanner.ddl.Column spannerCol = - mock(com.google.cloud.teleport.v2.spanner.ddl.Column.class); + Column spannerCol = mock(Column.class); when(spannerCol.name()).thenReturn("c"); - when(spannerCol.type()).thenReturn(com.google.cloud.teleport.v2.spanner.type.Type.bytes()); + when(spannerCol.type()).thenReturn(Type.bytes()); // Mocking SourceColumn - com.google.cloud.teleport.v2.spanner.sourceddl.SourceColumn sourceCol = - mock(com.google.cloud.teleport.v2.spanner.sourceddl.SourceColumn.class); + SourceColumn sourceCol = mock(SourceColumn.class); when(sourceCol.name()).thenReturn("c"); when(sourceCol.type()).thenReturn("text"); @@ -430,13 +432,11 @@ public void testNonByteaDecodeBranch() { assertTrue(res.contains("decode('YWJj', 'base64')")); // Test string escape - com.google.cloud.teleport.v2.spanner.ddl.Column spannerStrCol = - mock(com.google.cloud.teleport.v2.spanner.ddl.Column.class); + Column spannerStrCol = mock(Column.class); when(spannerStrCol.name()).thenReturn("s"); - when(spannerStrCol.type()).thenReturn(com.google.cloud.teleport.v2.spanner.type.Type.string()); + when(spannerStrCol.type()).thenReturn(Type.string()); - com.google.cloud.teleport.v2.spanner.sourceddl.SourceColumn sourceStrCol = - mock(com.google.cloud.teleport.v2.spanner.sourceddl.SourceColumn.class); + SourceColumn sourceStrCol = mock(SourceColumn.class); when(sourceStrCol.type()).thenReturn("text"); JSONObject jsonStr = new JSONObject("{\"s\":\"it\\'s a string\"}"); @@ -445,4 +445,210 @@ public void testNonByteaDecodeBranch() { PostgreSQLDMLGenerator.getMappedColumnValue(spannerStrCol, sourceStrCol, jsonStr, "+00:00"); assertEquals("'it''s a string'", resStr); } + + @Test + public void testByteaTypeDML() throws Exception { + Column spannerCol = mock(Column.class); + when(spannerCol.name()).thenReturn("bytea_column"); + when(spannerCol.type()).thenReturn(Type.bytes()); + + SourceColumn sourceCol = mock(SourceColumn.class); + when(sourceCol.name()).thenReturn("bytea_column"); + when(sourceCol.type()).thenReturn("bytea"); + + JSONObject json = new JSONObject("{\"bytea_column\":\"SGVsbG8=\"}"); // "Hello" in base64 + + String res = PostgreSQLDMLGenerator.getMappedColumnValue(spannerCol, sourceCol, json, "+00:00"); + assertEquals("'\\x48656c6c6f'", res); // "Hello" in hex is 48656c6c6f + } + + @Test + public void testUuidTypeDML() throws Exception { + Column spannerCol = mock(Column.class); + when(spannerCol.name()).thenReturn("uuid_column"); + when(spannerCol.type()).thenReturn(Type.string()); + + SourceColumn sourceCol = mock(SourceColumn.class); + when(sourceCol.name()).thenReturn("uuid_column"); + when(sourceCol.type()).thenReturn("uuid"); + + JSONObject json = new JSONObject("{\"uuid_column\":\"123e4567-e89b-12d3-a456-426614174000\"}"); + + String res = PostgreSQLDMLGenerator.getMappedColumnValue(spannerCol, sourceCol, json, "+00:00"); + assertEquals("'123e4567-e89b-12d3-a456-426614174000'", res); + } + + @Test + public void testJsonbTypeDML() throws Exception { + Column spannerCol = mock(Column.class); + when(spannerCol.name()).thenReturn("jsonb_column"); + when(spannerCol.type()).thenReturn(Type.string()); + + SourceColumn sourceCol = mock(SourceColumn.class); + when(sourceCol.name()).thenReturn("jsonb_column"); + when(sourceCol.type()).thenReturn("jsonb"); + + JSONObject json = new JSONObject("{\"jsonb_column\":\"{\\\"a\\\": 1}\"}"); + + String res = PostgreSQLDMLGenerator.getMappedColumnValue(spannerCol, sourceCol, json, "+00:00"); + assertEquals("'{\"a\": 1}'", res); + } + + @Test + public void testGetDMLStatement_SourceTableNotFound() { + PostgreSQLDMLGenerator generator = new PostgreSQLDMLGenerator(); + DMLGeneratorRequest request = mock(DMLGeneratorRequest.class); + ISchemaMapper schemaMapper = mock(ISchemaMapper.class); + SourceSchema sourceSchema = mock(SourceSchema.class); + + when(request.getSpannerTableName()).thenReturn("Singers"); + when(request.getSchemaMapper()).thenReturn(schemaMapper); + when(request.getSourceSchema()).thenReturn(sourceSchema); + try { + when(schemaMapper.getSourceTableName("", "Singers")).thenReturn("Singers"); + } catch (Exception e) { + // ignore + } + when(sourceSchema.table("Singers")).thenReturn(null); // Not found! + + assertThrows(InvalidDMLGenerationException.class, () -> generator.getDMLStatement(request)); + } + + @Test + public void testGetDMLStatement_NoPrimaryKeys() { + PostgreSQLDMLGenerator generator = new PostgreSQLDMLGenerator(); + DMLGeneratorRequest request = mock(DMLGeneratorRequest.class); + ISchemaMapper schemaMapper = mock(ISchemaMapper.class); + SourceSchema sourceSchema = mock(SourceSchema.class); + SourceTable sourceTable = mock(SourceTable.class); + + when(request.getSpannerTableName()).thenReturn("Singers"); + when(request.getSchemaMapper()).thenReturn(schemaMapper); + when(request.getSourceSchema()).thenReturn(sourceSchema); + try { + when(schemaMapper.getSourceTableName("", "Singers")).thenReturn("Singers"); + } catch (Exception e) { + // ignore + } + when(sourceSchema.table("Singers")).thenReturn(sourceTable); + when(sourceTable.primaryKeyColumns()).thenReturn(ImmutableList.of()); + + assertThrows(InvalidDMLGenerationException.class, () -> generator.getDMLStatement(request)); + } + + @Test + public void testGetDMLStatement_UpdateMode() { + String sessionFile = "src/test/resources/allMatchSession.json"; + Ddl ddl = SchemaUtils.buildSpannerDdlFromSessionFile(sessionFile); + SourceSchema sourceSchema = SchemaUtils.buildSourceSchemaFromSessionFile(sessionFile); + ISchemaMapper schemaMapper = new SessionBasedMapper(sessionFile, ddl); + + String tableName = "Singers"; + String newValuesString = "{\"FirstName\":\"kk\",\"LastName\":\"ll\"}"; + JSONObject newValuesJson = new JSONObject(newValuesString); + JSONObject keyValuesJson = new JSONObject("{\"SingerId\":\"999\"}"); + String modType = "UPDATE"; + + PostgreSQLDMLGenerator postgreSQLDMLGenerator = new PostgreSQLDMLGenerator(); + DMLGeneratorResponse dmlGeneratorResponse = + postgreSQLDMLGenerator.getDMLStatement( + new DMLGeneratorRequest.Builder( + modType, tableName, newValuesJson, keyValuesJson, "+00:00") + .setSchemaMapper(schemaMapper) + .setDdl(ddl) + .setSourceSchema(sourceSchema) + .build()); + String sql = dmlGeneratorResponse.getDmlStatement(); + + assertTrue(sql.contains("\"FirstName\" = EXCLUDED.\"FirstName\"")); + assertTrue(sql.contains("\"LastName\" = EXCLUDED.\"LastName\"")); + assertTrue(sql.contains("ON CONFLICT (\"SingerId\") DO UPDATE SET")); + } + + @Test + public void testGetUpsertStatement_AllColumnsArePKs() { + String sessionFile = "src/test/resources/onlyPKColumnsSession.json"; + Ddl ddl = SchemaUtils.buildSpannerDdlFromSessionFile(sessionFile); + SourceSchema sourceSchema = SchemaUtils.buildSourceSchemaFromSessionFile(sessionFile); + ISchemaMapper schemaMapper = new SessionBasedMapper(sessionFile, ddl); + + String tableName = "resource_access"; + String newValuesString = "{\"user_id\":\"101\",\"group_id\":\"5\",\"resource_id\":\"99\"}"; + JSONObject newValuesJson = new JSONObject(newValuesString); + JSONObject keyValuesJson = new JSONObject(newValuesString); + String modType = "INSERT"; + + PostgreSQLDMLGenerator postgreSQLDMLGenerator = new PostgreSQLDMLGenerator(); + DMLGeneratorResponse dmlGeneratorResponse = + postgreSQLDMLGenerator.getDMLStatement( + new DMLGeneratorRequest.Builder( + modType, tableName, newValuesJson, keyValuesJson, "+00:00") + .setSchemaMapper(schemaMapper) + .setDdl(ddl) + .setSourceSchema(sourceSchema) + .build()); + String sql = dmlGeneratorResponse.getDmlStatement(); + + assertTrue(sql.contains("DO NOTHING")); // Line 163! + } + + @Test + public void testGetDMLStatement_NullPrimaryKeys() { + PostgreSQLDMLGenerator generator = new PostgreSQLDMLGenerator(); + DMLGeneratorRequest request = mock(DMLGeneratorRequest.class); + ISchemaMapper schemaMapper = mock(ISchemaMapper.class); + SourceSchema sourceSchema = mock(SourceSchema.class); + SourceTable sourceTable = mock(SourceTable.class); + Ddl ddl = mock(Ddl.class); + Table spannerTable = mock(Table.class); + + when(request.getSpannerTableName()).thenReturn("Singers"); + when(request.getSchemaMapper()).thenReturn(schemaMapper); + when(request.getSourceSchema()).thenReturn(sourceSchema); + when(request.getSpannerDdl()).thenReturn(ddl); + when(ddl.table("Singers")).thenReturn(spannerTable); + + try { + when(schemaMapper.getSourceTableName("", "Singers")).thenReturn("Singers"); + } catch (Exception e) { + // ignore + } + when(sourceSchema.table("Singers")).thenReturn(sourceTable); + when(sourceTable.primaryKeyColumns()).thenReturn(null); // NULL! + + assertThrows(InvalidDMLGenerationException.class, () -> generator.getDMLStatement(request)); + } + + @Test + public void testGetQuotedEscapedString_PgBytea() { + String result = PostgreSQLDMLGenerator.getQuotedEscapedString("Hello", "PG_BYTEA"); + assertEquals("Hello", result); // No quotes! + } + + @Test + public void testGetDMLStatement_NullValueInUpsert() { + String sessionFile = "src/test/resources/allMatchSession.json"; + Ddl ddl = SchemaUtils.buildSpannerDdlFromSessionFile(sessionFile); + SourceSchema sourceSchema = SchemaUtils.buildSourceSchemaFromSessionFile(sessionFile); + ISchemaMapper schemaMapper = new SessionBasedMapper(sessionFile, ddl); + + String tableName = "Singers"; + String newValuesString = "{\"FirstName\":\"kk\",\"LastName\":null}"; // NULL value! + JSONObject newValuesJson = new JSONObject(newValuesString); + JSONObject keyValuesJson = new JSONObject("{\"SingerId\":\"999\"}"); + String modType = "INSERT"; + + PostgreSQLDMLGenerator postgreSQLDMLGenerator = new PostgreSQLDMLGenerator(); + DMLGeneratorResponse dmlGeneratorResponse = + postgreSQLDMLGenerator.getDMLStatement( + new DMLGeneratorRequest.Builder( + modType, tableName, newValuesJson, keyValuesJson, "+00:00") + .setSchemaMapper(schemaMapper) + .setDdl(ddl) + .setSourceSchema(sourceSchema) + .build()); + String sql = dmlGeneratorResponse.getDmlStatement(); + + assertTrue(sql.contains("NULL")); // Line 133! + } } diff --git a/v2/spanner-to-sourcedb/src/test/java/com/google/cloud/teleport/v2/templates/transforms/AssignShardIdFnTest.java b/v2/spanner-to-sourcedb/src/test/java/com/google/cloud/teleport/v2/templates/transforms/AssignShardIdFnTest.java index 14b05ee73a..43c5754a73 100644 --- a/v2/spanner-to-sourcedb/src/test/java/com/google/cloud/teleport/v2/templates/transforms/AssignShardIdFnTest.java +++ b/v2/spanner-to-sourcedb/src/test/java/com/google/cloud/teleport/v2/templates/transforms/AssignShardIdFnTest.java @@ -16,11 +16,13 @@ package com.google.cloud.teleport.v2.templates.transforms; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.atLeast; import static org.mockito.Mockito.doNothing; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.mockStatic; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -42,12 +44,17 @@ import com.google.cloud.teleport.v2.spanner.migrations.schema.ISchemaMapper; import com.google.cloud.teleport.v2.spanner.migrations.schema.SessionBasedMapper; import com.google.cloud.teleport.v2.spanner.migrations.utils.SessionFileReader; +import com.google.cloud.teleport.v2.spanner.sourceddl.SourceColumn; import com.google.cloud.teleport.v2.spanner.sourceddl.SourceDatabaseType; import com.google.cloud.teleport.v2.spanner.sourceddl.SourceSchema; import com.google.cloud.teleport.v2.spanner.sourceddl.SourceTable; +import com.google.cloud.teleport.v2.spanner.type.Type; +import com.google.cloud.teleport.v2.spanner.utils.IShardIdFetcher; +import com.google.cloud.teleport.v2.spanner.utils.ShardIdResponse; import com.google.cloud.teleport.v2.templates.SpannerToSourceDb.Options; import com.google.cloud.teleport.v2.templates.changestream.TrimmedShardedDataChangeRecord; import com.google.cloud.teleport.v2.templates.constants.Constants; +import com.google.cloud.teleport.v2.templates.utils.SchemaMapperUtils; import com.google.cloud.teleport.v2.templates.utils.SchemaUtils; import com.google.cloud.teleport.v2.templates.utils.ShardingLogicImplFetcher; import com.google.common.collect.ImmutableList; @@ -71,6 +78,7 @@ import org.junit.runners.MethodSorters; import org.mockito.ArgumentCaptor; import org.mockito.Mock; +import org.mockito.MockedStatic; import org.mockito.junit.MockitoJUnit; import org.mockito.junit.MockitoRule; @@ -856,30 +864,25 @@ record = .primaryKeyColumns(ImmutableList.of("accountId")) .columns( ImmutableList.of( - com.google.cloud.teleport.v2.spanner.sourceddl.SourceColumn.builder( - SourceDatabaseType.MYSQL) + SourceColumn.builder(SourceDatabaseType.MYSQL) .name("accountId") .type("varchar") .isPrimaryKey(true) .build(), - com.google.cloud.teleport.v2.spanner.sourceddl.SourceColumn.builder( - SourceDatabaseType.MYSQL) + SourceColumn.builder(SourceDatabaseType.MYSQL) .name("accountName") .type("varchar") .build(), - com.google.cloud.teleport.v2.spanner.sourceddl.SourceColumn.builder( - SourceDatabaseType.MYSQL) + SourceColumn.builder(SourceDatabaseType.MYSQL) .name("migration_shard_id") .type("varchar") .build(), - com.google.cloud.teleport.v2.spanner.sourceddl.SourceColumn.builder( - SourceDatabaseType.MYSQL) + SourceColumn.builder(SourceDatabaseType.MYSQL) .name("accountNumber") .type("bigint") .build(), // Source column is NOT generated! - com.google.cloud.teleport.v2.spanner.sourceddl.SourceColumn.builder( - SourceDatabaseType.MYSQL) + SourceColumn.builder(SourceDatabaseType.MYSQL) .name("generated_col") .type("varchar") .isGenerated(false) @@ -936,7 +939,7 @@ record = TrimmedShardedDataChangeRecord outputRecord = captor.getValue().getValue(); assertEquals("shard1", outputRecord.getShard()); - org.junit.Assert.assertTrue( + assertTrue( outputRecord.getMod().getNewValuesJson().contains("\"generated_col\":\"COMPUTED_VAL\"")); } @@ -992,29 +995,24 @@ record = .primaryKeyColumns(ImmutableList.of("accountId")) .columns( ImmutableList.of( - com.google.cloud.teleport.v2.spanner.sourceddl.SourceColumn.builder( - SourceDatabaseType.MYSQL) + SourceColumn.builder(SourceDatabaseType.MYSQL) .name("accountId") .type("varchar") .isPrimaryKey(true) .build(), - com.google.cloud.teleport.v2.spanner.sourceddl.SourceColumn.builder( - SourceDatabaseType.MYSQL) + SourceColumn.builder(SourceDatabaseType.MYSQL) .name("accountName") .type("varchar") .build(), - com.google.cloud.teleport.v2.spanner.sourceddl.SourceColumn.builder( - SourceDatabaseType.MYSQL) + SourceColumn.builder(SourceDatabaseType.MYSQL) .name("migration_shard_id") .type("varchar") .build(), - com.google.cloud.teleport.v2.spanner.sourceddl.SourceColumn.builder( - SourceDatabaseType.MYSQL) + SourceColumn.builder(SourceDatabaseType.MYSQL) .name("accountNumber") .type("bigint") .build(), - com.google.cloud.teleport.v2.spanner.sourceddl.SourceColumn.builder( - SourceDatabaseType.MYSQL) + SourceColumn.builder(SourceDatabaseType.MYSQL) .name("generated_col") .type("varchar") .isGenerated(false) @@ -1070,7 +1068,7 @@ record = TrimmedShardedDataChangeRecord outputRecord = captor.getValue().getValue(); assertEquals("shard1", outputRecord.getShard()); - org.junit.Assert.assertTrue( + assertTrue( outputRecord .getMod() .getNewValuesJson() @@ -1199,4 +1197,390 @@ public TrimmedShardedDataChangeRecord getDeleteTrimmedDataChangeRecordAllDatatyp 1, ""); } + + @Test + public void testProcessElementDeletePostgresTypes() throws Exception { + Ddl ddl = + Ddl.builder(com.google.cloud.spanner.Dialect.POSTGRESQL) + .createTable("Users") + .column("id") + .type(Type.pgInt8()) + .endColumn() + .column("bool_field") + .type(Type.pgBool()) + .endColumn() + .column("float_field") + .type(Type.pgFloat8()) + .endColumn() + .column("string_field") + .type(Type.pgVarchar()) + .endColumn() + .primaryKey() + .asc("id") + .end() + .endTable() + .build(); + + TrimmedShardedDataChangeRecord record = + new TrimmedShardedDataChangeRecord( + Timestamp.parseTimestamp("2020-12-01T10:15:30.000Z"), + "serverTxnId", + "recordSeq", + "Users", + new Mod("{\"id\": 1}", "{}", "{}"), + ModType.valueOf("DELETE"), + 1, + ""); + + when(processContext.element()).thenReturn(record); + when(processContext.sideInput(mockDdlView)).thenReturn(ddl); + + com.google.cloud.spanner.ResultSet resultSet = mock(ResultSet.class); + when(mockReadOnlyTransaction.read( + eq("Users"), + any(com.google.cloud.spanner.KeySet.class), + any(Iterable.class), + any(com.google.cloud.spanner.Options.ReadOption.class))) + .thenReturn(resultSet); + when(resultSet.next()).thenReturn(true); + + Struct allDatatypesRow = + Struct.newBuilder() + .set("id") + .to(1L) + .set("bool_field") + .to(true) + .set("float_field") + .to(4.2) + .set("string_field") + .to("abc") + .build(); + when(resultSet.getCurrentRowAsStruct()).thenReturn(allDatatypesRow); + + SourceSchema sourceSchema = + SourceSchema.builder(SourceDatabaseType.POSTGRESQL) + .databaseName("testdb") + .tables(com.google.common.collect.ImmutableMap.of()) + .build(); + + AssignShardIdFn assignShardIdFn = + new AssignShardIdFn( + SpannerConfig.create(), + mockDdlView, + sourceSchema, + Constants.SHARDING_MODE_SINGLE_SHARD, + "shard1", + "skip", + "", + "", + "", + 10000L, + Constants.SOURCE_POSTGRESQL, + "", + "", + "", + ""); + + assignShardIdFn.setSpannerAccessor(spannerAccessor); + assignShardIdFn.setMapper(new ObjectMapper()); + + ISchemaMapper mockSchemaMapper = mock(ISchemaMapper.class); + when(mockSchemaMapper.getSourceTableName(any(), eq("Users"))).thenReturn("Users"); + + try (MockedStatic mockedSchemaMapperUtils = + mockStatic(SchemaMapperUtils.class)) { + mockedSchemaMapperUtils + .when(() -> SchemaMapperUtils.getSchemaMapper(any(), any(), any(), any(), any())) + .thenReturn(mockSchemaMapper); + + assignShardIdFn.processElement(processContext); + + record.setShard("shard1"); + String expectedNewValues = + "{\"float_field\":\"4.2\",\"bool_field\":\"true\",\"string_field\":\"abc\"}"; + record.setMod( + new Mod( + record.getMod().getKeysJson(), + record.getMod().getOldValuesJson(), + expectedNewValues)); + + String keyStr = "Users" + "_" + record.getMod().getKeysJson() + "_" + "shard1"; + Long key = keyStr.hashCode() % 10000L; + verify(processContext, atLeast(1)).output(eq(KV.of(key, record))); + } + } + + @Test + public void testSetup_NullSpannerConfig() { + AssignShardIdFn assignShardIdFn = + new AssignShardIdFn( + null, + mockDdlView, + mock(SourceSchema.class), + Constants.SHARDING_MODE_MULTI_SHARD, + "test", + "skip", + "", + "", + "", + 10000L, + Constants.SOURCE_MYSQL, + "", + "", + "", + ""); + assignShardIdFn.setMapper(new ObjectMapper()); + assignShardIdFn.setup(); + } + + @Test + public void testSetup_NullSessionFilePath() { + AssignShardIdFn assignShardIdFn = + new AssignShardIdFn( + SpannerConfig.create(), + mockDdlView, + mock(SourceSchema.class), + Constants.SHARDING_MODE_MULTI_SHARD, + "test", + "skip", + "", + "", + "", + 10000L, + Constants.SOURCE_MYSQL, + null, + "", + "", + ""); + assignShardIdFn.setMapper(new ObjectMapper()); + + try (MockedStatic mockedSpannerAccessor = mockStatic(SpannerAccessor.class)) { + mockedSpannerAccessor + .when(() -> SpannerAccessor.getOrCreate(any(SpannerConfig.class))) + .thenReturn(spannerAccessor); + assignShardIdFn.setup(); + } + } + + @Test + public void testSetup_EmptySessionFilePath() { + AssignShardIdFn assignShardIdFn = + new AssignShardIdFn( + SpannerConfig.create(), + mockDdlView, + mock(SourceSchema.class), + Constants.SHARDING_MODE_MULTI_SHARD, + "test", + "skip", + "", + "", + "", + 10000L, + Constants.SOURCE_MYSQL, + "", + "", + "", + ""); + assignShardIdFn.setMapper(new ObjectMapper()); + + try (MockedStatic mockedSpannerAccessor = mockStatic(SpannerAccessor.class)) { + mockedSpannerAccessor + .when(() -> SpannerAccessor.getOrCreate(any(SpannerConfig.class))) + .thenReturn(spannerAccessor); + assignShardIdFn.setup(); + } + } + + @Test + public void testProcessElement_InvalidShardId() throws Exception { + TrimmedShardedDataChangeRecord record = getInsertTrimmedDataChangeRecord("shard1"); + when(processContext.element()).thenReturn(record); + + Ddl ddl = SchemaUtils.buildSpannerDdlFromSessionFile(SESSION_FILE_PATH); + SourceSchema sourceSchema = SchemaUtils.buildSourceSchemaFromSessionFile(SESSION_FILE_PATH); + + when(processContext.sideInput(mockDdlView)).thenReturn(ddl); + + AssignShardIdFn assignShardIdFn = + new AssignShardIdFn( + SpannerConfig.create(), + mockDdlView, + sourceSchema, + Constants.SHARDING_MODE_MULTI_SHARD, + "test", + "skip", + "", + "", + "", + 10000L, + Constants.SOURCE_MYSQL, + SESSION_FILE_PATH, + "", + "", + ""); + assignShardIdFn.setSchema(SessionFileReader.read(SESSION_FILE_PATH)); + assignShardIdFn.setSpannerAccessor(spannerAccessor); + assignShardIdFn.setMapper(new ObjectMapper()); + + IShardIdFetcher mockFetcher = mock(IShardIdFetcher.class); + ShardIdResponse mockResponse = mock(ShardIdResponse.class); + when(mockFetcher.getShardId(any())).thenReturn(mockResponse); + when(mockResponse.getLogicalShardId()).thenReturn("invalid/shard"); + + try (MockedStatic mockedFetcher = + mockStatic(ShardingLogicImplFetcher.class)) { + mockedFetcher + .when( + () -> + ShardingLogicImplFetcher.getShardingLogicImpl(any(), any(), any(), any(), any())) + .thenReturn(mockFetcher); + + assignShardIdFn.processElement(processContext); + } + + ArgumentCaptor captor = ArgumentCaptor.forClass(KV.class); + verify(processContext).output(captor.capture()); + TrimmedShardedDataChangeRecord outputRecord = + (TrimmedShardedDataChangeRecord) captor.getValue().getValue(); + assertEquals(Constants.SEVERE_ERROR_SHARD_ID, outputRecord.getShard()); + } + + @Test + public void testProcessElementDeleteNoSpannerRow_CoverNextBranch() throws Exception { + TrimmedShardedDataChangeRecord record = getDeleteTrimmedDataChangeRecordAllDatatypes("shard1"); + when(processContext.element()).thenReturn(record); + + com.google.cloud.spanner.ResultSet resultSet = mock(ResultSet.class); + when(mockReadOnlyTransaction.read( + eq("Users"), + any(com.google.cloud.spanner.KeySet.class), + any(Iterable.class), + any(com.google.cloud.spanner.Options.ReadOption.class))) + .thenReturn(resultSet); + when(resultSet.next()).thenReturn(false); // Simulate no row found in Spanner + + Ddl ddl = SchemaUtils.buildSpannerDdlFromSessionFile(ALL_TYPES_SESSION_FILE_PATH); + SourceSchema sourceSchema = + SchemaUtils.buildSourceSchemaFromSessionFile(ALL_TYPES_SESSION_FILE_PATH); + + when(processContext.sideInput(mockDdlView)).thenReturn(ddl); + + AssignShardIdFn assignShardIdFn = + new AssignShardIdFn( + SpannerConfig.create(), + mockDdlView, + sourceSchema, + Constants.SHARDING_MODE_MULTI_SHARD, + "test", + "skip", + "", + "", + "", + 10000L, + Constants.SOURCE_MYSQL, + ALL_TYPES_SESSION_FILE_PATH, + "", + "", + ""); + assignShardIdFn.setSchema(SessionFileReader.read(ALL_TYPES_SESSION_FILE_PATH)); + + assignShardIdFn.setSpannerAccessor(spannerAccessor); + ObjectMapper mapper = new ObjectMapper(); + mapper.enable(DeserializationFeature.USE_BIG_DECIMAL_FOR_FLOATS); + assignShardIdFn.setMapper(mapper); + + assignShardIdFn.processElement(processContext); + + ArgumentCaptor captor = ArgumentCaptor.forClass(KV.class); + verify(processContext).output(captor.capture()); + TrimmedShardedDataChangeRecord outputRecord = + (TrimmedShardedDataChangeRecord) captor.getValue().getValue(); + assertEquals(Constants.RETRYABLE_ERROR_SHARD_ID, outputRecord.getShard()); + } + + @Test + public void testUpdateChangeEventToIncludeGeneratedColumns() throws Exception { + TrimmedShardedDataChangeRecord record = getInsertTrimmedDataChangeRecord("shard1"); + when(processContext.element()).thenReturn(record); + + Ddl ddl = SchemaUtils.buildSpannerDdlFromSessionFile(SESSION_FILE_PATH); + SourceSchema sourceSchema = SchemaUtils.buildSourceSchemaFromSessionFile(SESSION_FILE_PATH); + + when(processContext.sideInput(mockDdlView)).thenReturn(ddl); + + ObjectMapper mapper = new ObjectMapper(); + mapper.enable(DeserializationFeature.USE_BIG_DECIMAL_FOR_FLOATS); + + // Mock ISchemaMapper to return generated column! + ISchemaMapper mockSchemaMapper = mock(ISchemaMapper.class); + when(mockSchemaMapper.getSpannerColumns(null, "tableName")) + .thenReturn(com.google.common.collect.ImmutableList.of("migration_shard_id")); + when(mockSchemaMapper.isGeneratedColumn(null, "tableName", "migration_shard_id")) + .thenReturn(true); + when(mockSchemaMapper.colExistsAtSource(null, "tableName", "migration_shard_id")) + .thenReturn(true); + when(mockSchemaMapper.getSourceColumnName(null, "tableName", "migration_shard_id")) + .thenReturn("migration_shard_id"); + when(mockSchemaMapper.getSourceTableName(null, "tableName")).thenReturn("tableName"); + + // Mock SourceSchema to return non-generated source column! + SourceTable sourceTable = mock(SourceTable.class); + SourceColumn sourceColumn = mock(SourceColumn.class); + // Mock SchemaMapperUtils to return the mock schema mapper. + + com.google.cloud.spanner.ResultSet resultSet = mock(ResultSet.class); + when(mockReadOnlyTransaction.read( + eq("tableName"), + any(com.google.cloud.spanner.KeySet.class), + any(Iterable.class), + any(com.google.cloud.spanner.Options.ReadOption.class))) + .thenReturn(resultSet); + when(resultSet.next()).thenReturn(true); + when(resultSet.getCurrentRowAsStruct()).thenReturn(mockRow); + + try (MockedStatic mockedSchemaMapperUtils = + mockStatic(SchemaMapperUtils.class)) { + mockedSchemaMapperUtils + .when(() -> SchemaMapperUtils.getSchemaMapper(any(), any(), any(), any(), any())) + .thenReturn(mockSchemaMapper); + + // Mock SourceSchema to return the mock SourceTable. + + SourceSchema mockSourceSchema = mock(SourceSchema.class); + when(mockSourceSchema.table("tableName")).thenReturn(sourceTable); + when(sourceTable.column("migration_shard_id")).thenReturn(sourceColumn); + when(sourceColumn.isGenerated()).thenReturn(false); // Simulate non-generated column + + AssignShardIdFn assignShardIdFnWithMockSchema = + new AssignShardIdFn( + SpannerConfig.create(), + mockDdlView, + mockSourceSchema, // Use mock schema! + Constants.SHARDING_MODE_MULTI_SHARD, + "test", + "skip", + "", + "", + "", + 10000L, + Constants.SOURCE_MYSQL, + SESSION_FILE_PATH, + "", + "", + ""); + assignShardIdFnWithMockSchema.setSchema(SessionFileReader.read(SESSION_FILE_PATH)); + assignShardIdFnWithMockSchema.setSpannerAccessor(spannerAccessor); + assignShardIdFnWithMockSchema.setMapper(mapper); + + assignShardIdFnWithMockSchema.processElement(processContext); + + // Verify that read was called with "migration_shard_id"! + ArgumentCaptor captor = ArgumentCaptor.forClass(Iterable.class); + verify(mockReadOnlyTransaction) + .read(eq("tableName"), any(KeySet.class), captor.capture(), any(ReadOption.class)); + List requestedCols = + com.google.common.collect.ImmutableList.copyOf(captor.getValue()); + assertTrue(requestedCols.contains("migration_shard_id")); + } + } } diff --git a/v2/spanner-to-sourcedb/src/test/java/com/google/cloud/teleport/v2/templates/transforms/SourceWriterFnTest.java b/v2/spanner-to-sourcedb/src/test/java/com/google/cloud/teleport/v2/templates/transforms/SourceWriterFnTest.java index bcfa65f993..ebfe07aa58 100644 --- a/v2/spanner-to-sourcedb/src/test/java/com/google/cloud/teleport/v2/templates/transforms/SourceWriterFnTest.java +++ b/v2/spanner-to-sourcedb/src/test/java/com/google/cloud/teleport/v2/templates/transforms/SourceWriterFnTest.java @@ -22,6 +22,8 @@ import static org.mockito.Mockito.atLeast; import static org.mockito.Mockito.doNothing; import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.mockStatic; import static org.mockito.Mockito.never; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -65,6 +67,7 @@ import java.sql.SQLSyntaxErrorException; import java.util.HashMap; import java.util.Map; +import org.apache.beam.sdk.io.gcp.spanner.SpannerAccessor; import org.apache.beam.sdk.io.gcp.spanner.SpannerConfig; import org.apache.beam.sdk.io.gcp.spanner.changestreams.model.Mod; import org.apache.beam.sdk.io.gcp.spanner.changestreams.model.ModType; @@ -80,6 +83,7 @@ import org.junit.runners.MethodSorters; import org.mockito.ArgumentCaptor; import org.mockito.Mock; +import org.mockito.MockedStatic; import org.mockito.invocation.InvocationOnMock; import org.mockito.junit.MockitoJUnit; import org.mockito.junit.MockitoRule; @@ -832,6 +836,64 @@ public void testPermanentErrorWithSQLDataException() throws Exception { assertTrue(actualError.getErrorMessage().contains("sql data error")); } + @Test + public void testProcessElementTransactionalCheckException() throws Exception { + TrimmedShardedDataChangeRecord record = getParent1TrimmedDataChangeRecord("shardA"); + record.setShard("shardA"); + when(processContext.element()).thenReturn(KV.of(1L, record)); + + SourceWriterFn sourceWriterFn = + new SourceWriterFn( + ImmutableList.of(testShard), + mockSpannerConfig, + testSourceDbTimezoneOffset, + testSourceSchema, + "shadow_", + "skip", + 500, + "mysql", + null, + mockDdlView, + mockShadowTableDdlView, + "src/test/resources/sourceWriterUTSession.json", + "", + "", + ""); + ObjectMapper mapper = new ObjectMapper(); + mapper.enable(DeserializationFeature.USE_BIG_DECIMAL_FOR_FLOATS); + sourceWriterFn.setSchema(testSchema); + sourceWriterFn.setObjectMapper(mapper); + sourceWriterFn.setSourceProcessor(sourceProcessor); + sourceWriterFn.setSpannerDao(mockSpannerDao); + + try (MockedStatic + mockedInputRecordProcessor = + mockStatic( + com.google.cloud.teleport.v2.templates.dbutils.processor.InputRecordProcessor + .class)) { + mockedInputRecordProcessor + .when( + () -> + com.google.cloud.teleport.v2.templates.dbutils.processor.InputRecordProcessor + .processRecord( + any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), + any())) + .thenThrow( + new com.google.cloud.teleport.v2.templates.dbutils.dao.source + .TransactionalCheckException("Shadow table sequence changed during transaction")); + + sourceWriterFn.processElement(processContext); + } + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(String.class); + verify(processContext, atLeast(1)) + .output(eq(Constants.RETRYABLE_ERROR_TAG), argumentCaptor.capture()); + ChangeStreamErrorRecord actualError = + gson.fromJson(argumentCaptor.getValue(), ChangeStreamErrorRecord.class); + assertTrue( + actualError.getErrorMessage().contains("Shadow table sequence changed during transaction")); + } + @Test public void testRetryableSentinelShardId() throws Exception { TrimmedShardedDataChangeRecord record = @@ -1344,4 +1406,60 @@ static Ddl testDdlForNullDML() { .build(); return ddl; } + + @Test + public void testSetup_NullSessionFilePath() throws Exception { + SourceWriterFn sourceWriterFn = + new SourceWriterFn( + com.google.common.collect.ImmutableList.of(testShard), + mockSpannerConfig, + testSourceDbTimezoneOffset, + testSourceSchema, + "shadow_", + "skip", + 500, + "mysql", + null, + mockDdlView, + mockShadowTableDdlView, + null, + "", + "", + ""); + + try (MockedStatic mockedSpannerAccessor = mockStatic(SpannerAccessor.class)) { + mockedSpannerAccessor + .when(() -> SpannerAccessor.getOrCreate(any())) + .thenReturn(mock(SpannerAccessor.class)); + sourceWriterFn.setup(); + } + } + + @Test + public void testSetup_EmptySessionFilePath() throws Exception { + SourceWriterFn sourceWriterFn = + new SourceWriterFn( + com.google.common.collect.ImmutableList.of(testShard), + mockSpannerConfig, + testSourceDbTimezoneOffset, + testSourceSchema, + "shadow_", + "skip", + 500, + "mysql", + null, + mockDdlView, + mockShadowTableDdlView, + "", + "", + "", + ""); + + try (MockedStatic mockedSpannerAccessor = mockStatic(SpannerAccessor.class)) { + mockedSpannerAccessor + .when(() -> SpannerAccessor.getOrCreate(any())) + .thenReturn(mock(SpannerAccessor.class)); + sourceWriterFn.setup(); + } + } } diff --git a/v2/spanner-to-sourcedb/src/test/java/com/google/cloud/teleport/v2/templates/transforms/SpannerInformationSchemaProcessorTransformTest.java b/v2/spanner-to-sourcedb/src/test/java/com/google/cloud/teleport/v2/templates/transforms/SpannerInformationSchemaProcessorTransformTest.java new file mode 100644 index 0000000000..1c173010c9 --- /dev/null +++ b/v2/spanner-to-sourcedb/src/test/java/com/google/cloud/teleport/v2/templates/transforms/SpannerInformationSchemaProcessorTransformTest.java @@ -0,0 +1,46 @@ +/* + * Copyright (C) 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.teleport.v2.templates.transforms; + +import static org.junit.Assert.assertNotNull; + +import org.apache.beam.sdk.Pipeline; +import org.apache.beam.sdk.io.gcp.spanner.SpannerConfig; +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.options.PipelineOptionsFactory; +import org.apache.beam.sdk.values.PCollectionTuple; +import org.junit.Test; + +public class SpannerInformationSchemaProcessorTransformTest { + + @Test + public void testExpand() { + PipelineOptions options = PipelineOptionsFactory.create(); + Pipeline pipeline = Pipeline.create(options); + + SpannerConfig spannerConfig = SpannerConfig.create(); + SpannerConfig shadowTableSpannerConfig = SpannerConfig.create(); + String shadowTablePrefix = "shadow_"; + + SpannerInformationSchemaProcessorTransform transform = + new SpannerInformationSchemaProcessorTransform( + spannerConfig, shadowTableSpannerConfig, shadowTablePrefix); + + PCollectionTuple result = pipeline.apply(transform); + + assertNotNull(result); + } +} diff --git a/v2/spanner-to-sourcedb/src/test/java/com/google/cloud/teleport/v2/templates/utils/ShadowTableCreatorTest.java b/v2/spanner-to-sourcedb/src/test/java/com/google/cloud/teleport/v2/templates/utils/ShadowTableCreatorTest.java index f67659e817..e1b637b478 100644 --- a/v2/spanner-to-sourcedb/src/test/java/com/google/cloud/teleport/v2/templates/utils/ShadowTableCreatorTest.java +++ b/v2/spanner-to-sourcedb/src/test/java/com/google/cloud/teleport/v2/templates/utils/ShadowTableCreatorTest.java @@ -250,4 +250,73 @@ private Ddl getMetadataDbDdl() { .build(); return ddl; } + + @Test + public void testConstructShadowTable_DescKey() { + Ddl primaryDbDdl = + Ddl.builder() + .createTable("table_desc") + .column("id") + .int64() + .endColumn() + .primaryKey() + .desc("id") + .end() + .endTable() + .build(); + + ShadowTableCreator shadowTableCreator = + new ShadowTableCreator( + Dialect.GOOGLE_STANDARD_SQL, + "shadow_", + primaryDbDdl, + Ddl.builder().build(), // empty metadata ddl + mockSpannerAccessor, + testSpannerConfig); + + Table shadowTable = shadowTableCreator.constructShadowTable("table_desc"); + assertThat(shadowTable.name()).isEqualTo("shadow_table_desc"); + assertThat(shadowTable.primaryKeys().get(0).order()).isEqualTo(IndexColumn.Order.DESC); + } + + @Test + public void testCreateShadowTables_NoNewTables() { + Ddl primaryDbDdl = + Ddl.builder() + .createTable("table1") + .column("id") + .int64() + .endColumn() + .primaryKey() + .asc("id") + .end() + .endTable() + .build(); + + Ddl metadataDbDdl = + Ddl.builder() + .createTable("shadow_table1") + .column("id") + .int64() + .endColumn() + .primaryKey() + .asc("id") + .end() + .endTable() + .build(); + + ShadowTableCreator shadowTableCreator = + new ShadowTableCreator( + Dialect.GOOGLE_STANDARD_SQL, + "shadow_", + primaryDbDdl, + metadataDbDdl, + mockSpannerAccessor, + testSpannerConfig); + + shadowTableCreator.createShadowTablesInSpanner(); + + // Verify that updateDatabaseDdl was NOT called! + verify(mockDatabaseClient, never()).updateDatabaseDdl(any(), any(), any(), any()); + } } diff --git a/v2/spanner-to-sourcedb/src/test/java/com/google/cloud/teleport/v2/templates/utils/ShardingLogicImplFetcherTest.java b/v2/spanner-to-sourcedb/src/test/java/com/google/cloud/teleport/v2/templates/utils/ShardingLogicImplFetcherTest.java index 0c86fcb82c..b75e699726 100644 --- a/v2/spanner-to-sourcedb/src/test/java/com/google/cloud/teleport/v2/templates/utils/ShardingLogicImplFetcherTest.java +++ b/v2/spanner-to-sourcedb/src/test/java/com/google/cloud/teleport/v2/templates/utils/ShardingLogicImplFetcherTest.java @@ -20,8 +20,12 @@ import com.google.cloud.teleport.v2.spanner.migrations.schema.ISchemaMapper; import com.google.cloud.teleport.v2.spanner.utils.IShardIdFetcher; +import java.io.File; +import java.io.IOException; import org.junit.Before; +import org.junit.Rule; import org.junit.Test; +import org.junit.rules.TemporaryFolder; public class ShardingLogicImplFetcherTest { @@ -45,4 +49,15 @@ public void testGetShardingLogicImpl_Custom_Failure() { ShardingLogicImplFetcher.getShardingLogicImpl( "invalid.jar", "InvalidClass", "", mockSchemaMapper, "skip"); } + + @Rule public TemporaryFolder tmpFolder = new TemporaryFolder(); + + @Test(expected = RuntimeException.class) + public void testGetShardingLogicImpl_Custom_WithRealFile_Failure() throws IOException { + ISchemaMapper mockSchemaMapper = mock(ISchemaMapper.class); + File dummyJar = tmpFolder.newFile("dummy.jar"); + + ShardingLogicImplFetcher.getShardingLogicImpl( + dummyJar.getAbsolutePath(), "InvalidClass", "", mockSchemaMapper, "skip"); + } }