From bb39d84c990426932499e30c83a33f1b00d415c2 Mon Sep 17 00:00:00 2001 From: Bidhan Mondal Date: Mon, 23 Mar 2026 16:22:18 +0530 Subject: [PATCH 01/26] feat: Deletion and renaming of columns with upgrade of delta sharing server --- .../src/assets/adhoc/master_csv_to_gold.py | 8 +- dagster/src/internal/common_assets/staging.py | 23 +- .../src/resources/io_managers/adls_delta.py | 14 + dagster/src/utils/delta.py | 259 ++++++++++++++++-- dagster/src/utils/schema.py | 28 ++ dagster/src/utils/spark.py | 3 + dagster/tests/utils/test_delta_sync_schema.py | 97 +++++++ 7 files changed, 404 insertions(+), 28 deletions(-) create mode 100644 dagster/tests/utils/test_delta_sync_schema.py diff --git a/dagster/src/assets/adhoc/master_csv_to_gold.py b/dagster/src/assets/adhoc/master_csv_to_gold.py index 648fd2a35..da5aeef25 100644 --- a/dagster/src/assets/adhoc/master_csv_to_gold.py +++ b/dagster/src/assets/adhoc/master_csv_to_gold.py @@ -14,6 +14,9 @@ ) from pyspark.sql.types import NullType, StructType from sqlalchemy import select, update + +from azure.core.exceptions import ResourceNotFoundError +from dagster import OpExecutionContext, Output, PythonObjectDagsterType, asset from src.constants import DataTier from src.data_quality_checks.utils import ( aggregate_report_json, @@ -56,9 +59,6 @@ from src.utils.sentry import capture_op_exceptions from src.utils.spark import compute_row_hash, transform_types -from azure.core.exceptions import ResourceNotFoundError -from dagster import OpExecutionContext, Output, PythonObjectDagsterType, asset - @asset(io_manager_key=ResourceKey.ADLS_PASSTHROUGH_IO_MANAGER.value) @capture_op_exceptions @@ -638,6 +638,7 @@ def adhoc__publish_master_to_gold( updated_schema=updated_schema, spark=spark.spark_session, context=context, + schema_name=config.metastore_schema, ) schema_reference = get_schema_columns_datahub( @@ -720,6 +721,7 @@ def adhoc__publish_reference_to_gold( updated_schema=updated_schema, spark=spark.spark_session, context=context, + schema_name=config.metastore_schema, ) schema_reference = get_schema_columns_datahub( diff --git a/dagster/src/internal/common_assets/staging.py b/dagster/src/internal/common_assets/staging.py index 64b502eb5..8553d53c4 100644 --- a/dagster/src/internal/common_assets/staging.py +++ b/dagster/src/internal/common_assets/staging.py @@ -395,7 +395,14 @@ def create_empty_staging_table(self): ) def sync_schema_staging(self): - """Update the schema of existing delta tables based on the reference schema delta tables.""" + """Update the schema of existing delta tables based on the reference schema delta tables. + + Supports adding, renaming, and deleting columns. Renames and deletes + are detected by comparing stable column UUIDs stored in the table + properties against the latest reference schema CSV. + """ + from src.utils.delta import apply_renames_and_deletes, persist_column_id_map + self.context.log.info("Checking for schema update...") updated_schema = StructType(self.schema_columns) updated_columns = sorted(updated_schema.fieldNames()) @@ -403,6 +410,15 @@ def sync_schema_staging(self): existing_df = DeltaTable.forName(self.spark, self.staging_table_name).toDF() existing_columns = sorted(existing_df.schema.fieldNames()) + any_renames_deletes = apply_renames_and_deletes( + self.spark, self.staging_table_name, self.schema_name, self.context + ) + + # Refresh schemas after rename/delete + if any_renames_deletes: + existing_df = DeltaTable.forName(self.spark, self.staging_table_name).toDF() + existing_columns = sorted(existing_df.schema.fieldNames()) + # Sync changes in nullability flags alter_sql = f"ALTER TABLE {self.staging_table_name}" alter_stmts = [] @@ -437,9 +453,12 @@ def sync_schema_staging(self): for stmnt in alter_sql: self.spark.sql(stmnt).show() - if has_schema_changed or has_nullability_changed: + if has_schema_changed or has_nullability_changed or any_renames_deletes: self.reload_schema() + # Persist column-ID mapping + persist_column_id_map(self.spark, self.staging_table_name, self.schema_name) + def reload_schema(self): self.schema_columns = get_schema_columns(self.spark, self.schema_name) diff --git a/dagster/src/resources/io_managers/adls_delta.py b/dagster/src/resources/io_managers/adls_delta.py index fe7f0687a..35134d11b 100644 --- a/dagster/src/resources/io_managers/adls_delta.py +++ b/dagster/src/resources/io_managers/adls_delta.py @@ -191,6 +191,8 @@ def _upsert_data( columns = incoming_schema.fields primary_key = "gigasync_id" else: + from src.utils.delta import apply_renames_and_deletes, persist_column_id_map + columns = get_schema_columns(spark, schema_name) primary_key = get_primary_key(spark, schema_name) @@ -203,6 +205,15 @@ def _upsert_data( context.log.info(f"incoming schema {data.schema}") context.log.info(f"existing schema {existing_df.schema}") + any_renames_deletes = apply_renames_and_deletes( + spark, full_table_name, schema_name, context + ) + + # Refresh after rename/delete + if any_renames_deletes: + existing_df = DeltaTable.forName(spark, full_table_name).toDF() + existing_columns = sorted(existing_df.schema.fieldNames()) + if updated_columns != existing_columns: context.log.info("Updating schema...") @@ -218,6 +229,9 @@ def _upsert_data( .saveAsTable(full_table_name) ) + # Persist column-ID mapping + persist_column_id_map(spark, full_table_name, schema_name) + update_columns = [c.name for c in columns if c.name != primary_key] master = DeltaTable.forName(spark, full_table_name) query = build_deduped_merge_query( diff --git a/dagster/src/utils/delta.py b/dagster/src/utils/delta.py index 31b4d790f..5e3d1687c 100644 --- a/dagster/src/utils/delta.py +++ b/dagster/src/utils/delta.py @@ -291,13 +291,238 @@ def build_nullability_queries( return alter_stmts +def _enable_column_mapping(spark: SparkSession, table_name: str) -> None: + """Enable column mapping mode on an existing Delta table if not already enabled. + + This is a one-time, irreversible protocol upgrade (reader v2 / writer v5). + It must be executed before any ``RENAME COLUMN`` or ``DROP COLUMN`` operations. + """ + spark.sql( + f"ALTER TABLE {table_name} SET TBLPROPERTIES (" + f" 'delta.columnMapping.mode' = 'name'," + f" 'delta.minReaderVersion' = '2'," + f" 'delta.minWriterVersion' = '5'" + f")" + ) + + +def _get_stored_column_id_map(spark: SparkSession, table_name: str) -> dict[str, str]: + """Retrieve the column-name → schema-CSV-ID mapping stored in table properties. + + Returns ``{column_name: csv_id}`` or an empty dict if no mapping has been + stored yet (e.g. tables created before this feature was added). + """ + detail = spark.sql(f"DESCRIBE DETAIL {table_name}").collect()[0] + properties: dict = detail["properties"] if detail["properties"] else {} + result = {} + prefix = "giga.columnId." + for key, value in properties.items(): + if key.startswith(prefix): + col_name = key[len(prefix) :] + result[col_name] = value + return result + + +def _store_column_id_map( + spark: SparkSession, + table_name: str, + column_id_map: dict[str, str], +) -> None: + """Persist the column-name → schema-CSV-ID mapping as Delta table properties.""" + if not column_id_map: + return + props = ", ".join( + f"'giga.columnId.{col_name}' = '{csv_id}'" + for col_name, csv_id in column_id_map.items() + ) + spark.sql(f"ALTER TABLE {table_name} SET TBLPROPERTIES ({props})") + + +def _remove_column_id_props( + spark: SparkSession, + table_name: str, + column_names: list[str], +) -> None: + """Remove column-ID table properties for dropped columns.""" + if not column_names: + return + props = ", ".join(f"'giga.columnId.{name}'" for name in column_names) + spark.sql(f"ALTER TABLE {table_name} UNSET TBLPROPERTIES IF EXISTS ({props})") + + +def _detect_renames_and_deletes( + existing_id_map: dict[str, str], + updated_id_map: dict[str, str], +) -> tuple[dict[str, str], list[str]]: + """Compare old and new column-ID mappings to detect renames and deletes. + + Parameters + ---------- + existing_id_map : dict[str, str] + ``{column_name: csv_id}`` from the current table properties. + updated_id_map : dict[str, str] + ``{column_name: csv_id}`` from the latest schema CSV. + + Returns + ------- + renames : dict[str, str] + ``{old_name: new_name}`` for columns whose ID stayed but name changed. + deletes : list[str] + Column names present in the table but whose ID is no longer in the + updated schema (i.e. the column should be dropped). + """ + # Invert maps: csv_id → column_name + existing_by_id = {v: k for k, v in existing_id_map.items()} + updated_by_id = {v: k for k, v in updated_id_map.items()} + + renames: dict[str, str] = {} + deletes: list[str] = [] + + for csv_id, old_name in existing_by_id.items(): + if csv_id in updated_by_id: + new_name = updated_by_id[csv_id] + if old_name != new_name: + renames[old_name] = new_name + else: + # ID no longer present in the reference schema → column deleted + deletes.append(old_name) + + return renames, deletes + + +def apply_renames_and_deletes( + spark: SparkSession, + table_name: str, + schema_name: str, + context: OpExecutionContext, +) -> bool: + """Detect and apply column renames and deletes to a Delta table based on the reference schema. + + Returns True if any schema change occurred (rename or delete). + """ + from src.utils.schema import get_schema_columns_with_id + + columns_with_id = get_schema_columns_with_id(spark, schema_name) + updated_id_map = {field.name: csv_id for csv_id, field in columns_with_id} + existing_id_map = _get_stored_column_id_map(spark, table_name) + + renames: dict[str, str] = {} + deletes: list[str] = [] + + if existing_id_map: + renames, deletes = _detect_renames_and_deletes(existing_id_map, updated_id_map) + context.log.info(f"Detected renames: {renames}") + context.log.info(f"Detected deletes: {deletes}") + else: + context.log.info( + "No stored column-ID mapping found; initialising mapping from current reference schema." + ) + + if renames or deletes: + context.log.info( + "Enabling column mapping on table for rename/delete support..." + ) + _enable_column_mapping(spark, table_name) + + if renames: + context.log.info(f"Renaming columns: {renames}") + for old_name, new_name in renames.items(): + stmt = ( + f"ALTER TABLE {table_name} RENAME COLUMN `{old_name}` TO `{new_name}`" + ) + context.log.info(f"Executing: {stmt}") + spark.sql(stmt) + _remove_column_id_props(spark, table_name, list(renames.keys())) + + if deletes: + context.log.info(f"Dropping columns: {deletes}") + for col_name in deletes: + stmt = f"ALTER TABLE {table_name} DROP COLUMN `{col_name}`" + context.log.info(f"Executing: {stmt}") + spark.sql(stmt) + _remove_column_id_props(spark, table_name, deletes) + + return bool(renames or deletes) + + +def persist_column_id_map( + spark: SparkSession, table_name: str, schema_name: str +) -> None: + """Read the column ID mapping from the schema CSV and store it as table properties.""" + from src.utils.schema import get_schema_columns_with_id + + columns_with_id = get_schema_columns_with_id(spark, schema_name) + new_id_map = {field.name: csv_id for csv_id, field in columns_with_id} + _store_column_id_map(spark, table_name, new_id_map) + + +def apply_datatype_changes( + spark: SparkSession, + table_name: str, + changed_datatypes: dict, + context: OpExecutionContext, +) -> None: + """Apply datatype changes by casting columns and overwriting the table schema.""" + if not changed_datatypes: + return + + context.log.info("Updating datatype...") + context.log.info(f"Changed datatypes: {changed_datatypes}") + existing_dataframe = spark.table(table_name) + updated_df = existing_dataframe + + for column, datatype in changed_datatypes.items(): + updated_df = updated_df.withColumn( + column, existing_dataframe[column].cast(datatype.typeName()) + ) + + ( + updated_df.write.option("overwriteSchema", "true") + .format("delta") + .mode("overwrite") + .saveAsTable(table_name) + ) + + def sync_schema( table_name: str, existing_schema: StructType, updated_schema: StructType, spark: SparkSession, context: OpExecutionContext, + schema_name: str | None = None, ): + """Synchronise a Delta table's schema with the reference schema. + + Supports: + * Adding columns (existing behaviour via ``mergeSchema``) + * Renaming columns (via ``ALTER TABLE RENAME COLUMN``) + * Dropping columns (via ``ALTER TABLE DROP COLUMN``) + * Changing data types (via overwrite with ``overwriteSchema``) + * Changing nullability constraints + + Column renames and deletes require ``schema_name`` so that the stable + UUID column IDs from the schema CSV can be compared against the IDs + stored in the table properties. + """ + # ------------------------------------------------------------------ + # 1. Detect and apply renames & deletes + # ------------------------------------------------------------------ + any_renames_deletes = False + if schema_name is not None: + any_renames_deletes = apply_renames_and_deletes( + spark, table_name, schema_name, context + ) + + # ------------------------------------------------------------------ + # 2. Refresh schemas after rename/delete to get accurate comparison + # ------------------------------------------------------------------ + if any_renames_deletes: + existing_schema = spark.table(table_name).schema + + # ------------------------------------------------------------------ + # 3. Detect added columns & datatype changes (existing logic) + # ------------------------------------------------------------------ alter_stmts = build_nullability_queries( context=context, existing_schema=existing_schema, @@ -307,36 +532,18 @@ def sync_schema( context.log.info(f"alter_stmts {alter_stmts}") has_nullability_changed = len(alter_stmts) > 0 - existing_columns = {field.name: field.name for field in existing_schema} - updated_columns = {field.name: field.name for field in updated_schema} + existing_columns = {field.name for field in existing_schema} + updated_columns_set = {field.name for field in updated_schema} - added_columns = set(updated_columns.keys()) - set(existing_columns.keys()) - removed_columns = set(existing_columns.keys()) - set(updated_columns.keys()) + added_columns = updated_columns_set - existing_columns + removed_columns = existing_columns - updated_columns_set has_schema_changed = len(added_columns) + len(removed_columns) > 0 changed_datatypes = get_changed_datatypes( context=context, existing_schema=existing_schema, updated_schema=updated_schema ) - has_datatype_changed = len(changed_datatypes) > 0 - - if has_datatype_changed: - context.log.info("Updating datatype...") - context.log.info(f"Changed datatypes: {changed_datatypes}") - existing_dataframe = spark.table(table_name) - updated_df = existing_dataframe - - for column, datatype in changed_datatypes.items(): - updated_df = updated_df.withColumn( - column, existing_dataframe[column].cast(datatype.typeName()) - ) - - ( - updated_df.write.option("overwriteSchema", "true") - .format("delta") - .mode("overwrite") - .saveAsTable(table_name) - ) + apply_datatype_changes(spark, table_name, changed_datatypes, context) if has_schema_changed: context.log.info(f"Adding schema columns {added_columns}") @@ -366,3 +573,9 @@ def sync_schema( continue else: raise + + # ------------------------------------------------------------------ + # 4. Persist column-ID mapping for future rename/delete detection + # ------------------------------------------------------------------ + if schema_name is not None: + persist_column_id_map(spark, table_name, schema_name) diff --git a/dagster/src/utils/schema.py b/dagster/src/utils/schema.py index 6fa2e206a..aaff06033 100644 --- a/dagster/src/utils/schema.py +++ b/dagster/src/utils/schema.py @@ -42,6 +42,34 @@ def get_schema_columns(spark: SparkSession, schema_name: str) -> list[StructFiel ] +def get_schema_columns_with_id( + spark: SparkSession, schema_name: str +) -> list[tuple[str, StructField]]: + """Return schema columns paired with their stable UUID id. + + Each tuple is ``(id, StructField)``. The ``id`` is the fixed UUID + assigned to the column in the schema CSV stored in ADLS. By comparing + IDs between the reference schema and an existing Delta table we can + detect: + + * **Renames** – same ID, different ``StructField.name`` + * **Deletes** – ID present in the table but absent from the reference + * **Adds** – ID present in the reference but absent from the table + """ + df = get_schema_table(spark, schema_name) + return [ + ( + row.id, + StructField( + row.name, + getattr(constants.TYPE_MAPPINGS, row.data_type).pyspark(), + row.is_nullable, + ), + ) + for row in df.collect() + ] + + def get_schema_column_descriptions( spark: SparkSession, schema_name: str ) -> dict[str:str]: diff --git a/dagster/src/utils/spark.py b/dagster/src/utils/spark.py index ac66188ad..8c37f7caa 100644 --- a/dagster/src/utils/spark.py +++ b/dagster/src/utils/spark.py @@ -56,6 +56,9 @@ def _get_host_ip() -> str: "spark.databricks.delta.properties.defaults.appendOnly": "false", "spark.databricks.delta.schema.autoMerge.enabled": "false", "spark.databricks.delta.catalog.update.enabled": "true", + "spark.databricks.delta.properties.defaults.columnMapping.mode": "name", + "spark.databricks.delta.properties.defaults.minReaderVersion": "2", + "spark.databricks.delta.properties.defaults.minWriterVersion": "5", } # Configure Azure Storage authentication diff --git a/dagster/tests/utils/test_delta_sync_schema.py b/dagster/tests/utils/test_delta_sync_schema.py new file mode 100644 index 000000000..404e01361 --- /dev/null +++ b/dagster/tests/utils/test_delta_sync_schema.py @@ -0,0 +1,97 @@ +"""Tests for the Delta Lake column rename/delete detection helpers in src.utils.delta.""" + +from src.utils.delta import _detect_renames_and_deletes + + +class TestDetectRenamesAndDeletes: + """Unit tests for _detect_renames_and_deletes.""" + + def test_no_changes(self): + existing = {"col_a": "id-1", "col_b": "id-2", "col_c": "id-3"} + updated = {"col_a": "id-1", "col_b": "id-2", "col_c": "id-3"} + renames, deletes = _detect_renames_and_deletes(existing, updated) + assert renames == {} + assert deletes == [] + + def test_column_renamed(self): + existing = {"old_name": "id-1", "col_b": "id-2"} + updated = {"new_name": "id-1", "col_b": "id-2"} + renames, deletes = _detect_renames_and_deletes(existing, updated) + assert renames == {"old_name": "new_name"} + assert deletes == [] + + def test_column_deleted(self): + existing = {"col_a": "id-1", "col_b": "id-2", "col_c": "id-3"} + updated = {"col_a": "id-1", "col_b": "id-2"} + renames, deletes = _detect_renames_and_deletes(existing, updated) + assert renames == {} + assert deletes == ["col_c"] + + def test_column_added_only(self): + """Adding a column (ID in updated but not existing) should not trigger renames or deletes.""" + existing = {"col_a": "id-1"} + updated = {"col_a": "id-1", "col_new": "id-new"} + renames, deletes = _detect_renames_and_deletes(existing, updated) + assert renames == {} + assert deletes == [] + + def test_rename_and_delete_combined(self): + existing = { + "old_name": "id-1", + "col_b": "id-2", + "col_to_drop": "id-3", + } + updated = { + "new_name": "id-1", + "col_b": "id-2", + } + renames, deletes = _detect_renames_and_deletes(existing, updated) + assert renames == {"old_name": "new_name"} + assert deletes == ["col_to_drop"] + + def test_rename_delete_and_add(self): + existing = { + "old_name": "id-1", + "col_b": "id-2", + "col_drop": "id-3", + } + updated = { + "new_name": "id-1", + "col_b": "id-2", + "col_new": "id-4", + } + renames, deletes = _detect_renames_and_deletes(existing, updated) + assert renames == {"old_name": "new_name"} + assert deletes == ["col_drop"] + + def test_multiple_renames(self): + existing = {"a": "id-1", "b": "id-2", "c": "id-3"} + updated = {"x": "id-1", "y": "id-2", "c": "id-3"} + renames, deletes = _detect_renames_and_deletes(existing, updated) + assert renames == {"a": "x", "b": "y"} + assert deletes == [] + + def test_multiple_deletes(self): + existing = {"a": "id-1", "b": "id-2", "c": "id-3"} + updated = {"a": "id-1"} + renames, deletes = _detect_renames_and_deletes(existing, updated) + assert renames == {} + assert sorted(deletes) == ["b", "c"] + + def test_empty_existing(self): + """If existing is empty, there should be no changes.""" + renames, deletes = _detect_renames_and_deletes({}, {"a": "id-1"}) + assert renames == {} + assert deletes == [] + + def test_empty_updated_deletes_all(self): + """If updated is empty, all existing columns should be deleted.""" + existing = {"a": "id-1", "b": "id-2"} + renames, deletes = _detect_renames_and_deletes(existing, {}) + assert renames == {} + assert sorted(deletes) == ["a", "b"] + + def test_both_empty(self): + renames, deletes = _detect_renames_and_deletes({}, {}) + assert renames == {} + assert deletes == [] From 3524d75f5bb56cff8e8025dab1c8e15a742a7c82 Mon Sep 17 00:00:00 2001 From: Bidhan Mondal Date: Mon, 23 Mar 2026 16:28:41 +0530 Subject: [PATCH 02/26] feat: Deletion and renaming of columns with upgrade of delta sharing server --- dagster/src/utils/delta.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/dagster/src/utils/delta.py b/dagster/src/utils/delta.py index 5e3d1687c..f0a58986d 100644 --- a/dagster/src/utils/delta.py +++ b/dagster/src/utils/delta.py @@ -292,11 +292,7 @@ def build_nullability_queries( def _enable_column_mapping(spark: SparkSession, table_name: str) -> None: - """Enable column mapping mode on an existing Delta table if not already enabled. - - This is a one-time, irreversible protocol upgrade (reader v2 / writer v5). - It must be executed before any ``RENAME COLUMN`` or ``DROP COLUMN`` operations. - """ + """Enable column mapping mode on an existing Delta table if not already enabled.""" spark.sql( f"ALTER TABLE {table_name} SET TBLPROPERTIES (" f" 'delta.columnMapping.mode' = 'name'," From efcb2f1dd1ace95ca39def9eb1c35fbba283ef07 Mon Sep 17 00:00:00 2001 From: Gaurav Gupta Date: Wed, 1 Apr 2026 14:04:49 +0530 Subject: [PATCH 03/26] chore: pre commit issue fixed --- dagster/src/assets/adhoc/master_csv_to_gold.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/dagster/src/assets/adhoc/master_csv_to_gold.py b/dagster/src/assets/adhoc/master_csv_to_gold.py index da5aeef25..5f07d03d8 100644 --- a/dagster/src/assets/adhoc/master_csv_to_gold.py +++ b/dagster/src/assets/adhoc/master_csv_to_gold.py @@ -14,9 +14,6 @@ ) from pyspark.sql.types import NullType, StructType from sqlalchemy import select, update - -from azure.core.exceptions import ResourceNotFoundError -from dagster import OpExecutionContext, Output, PythonObjectDagsterType, asset from src.constants import DataTier from src.data_quality_checks.utils import ( aggregate_report_json, @@ -59,6 +56,9 @@ from src.utils.sentry import capture_op_exceptions from src.utils.spark import compute_row_hash, transform_types +from azure.core.exceptions import ResourceNotFoundError +from dagster import OpExecutionContext, Output, PythonObjectDagsterType, asset + @asset(io_manager_key=ResourceKey.ADLS_PASSTHROUGH_IO_MANAGER.value) @capture_op_exceptions From 96970d40c0cc16ff92784205aa7625e8988cf2d3 Mon Sep 17 00:00:00 2001 From: Brian Musisi Date: Wed, 1 Apr 2026 15:05:51 +0200 Subject: [PATCH 04/26] fix: improve processing time for large files (#445) --- .../src/assets/school_geolocation/assets.py | 164 ++++++++---------- dagster/src/resources/__init__.py | 8 + .../src/resources/io_managers/adls_spark.py | 73 ++++++++ .../io_managers/adls_spark_single_file.py | 59 +++++++ dagster/src/resources/io_managers/base.py | 2 + dagster/src/sensors/school_geolocation.py | 28 ++- dagster/src/utils/op_config.py | 11 ++ 7 files changed, 236 insertions(+), 109 deletions(-) create mode 100644 dagster/src/resources/io_managers/adls_spark.py create mode 100644 dagster/src/resources/io_managers/adls_spark_single_file.py diff --git a/dagster/src/assets/school_geolocation/assets.py b/dagster/src/assets/school_geolocation/assets.py index fe1532659..a5368818d 100644 --- a/dagster/src/assets/school_geolocation/assets.py +++ b/dagster/src/assets/school_geolocation/assets.py @@ -11,7 +11,7 @@ SparkSession, functions as f, ) -from pyspark.sql.types import StringType, StructType +from pyspark.sql.types import StringType, StructField, StructType from sqlalchemy import select from src.constants import DataTier from src.data_quality_checks.utils import ( @@ -52,7 +52,14 @@ from src.utils.send_email_dq_report import send_email_dq_report_with_config from src.utils.sentry import capture_op_exceptions -from dagster import MetadataValue, OpExecutionContext, Output, asset +from dagster import ( + AssetOut, + MetadataValue, + OpExecutionContext, + Output, + asset, + multi_asset, +) @asset(io_manager_key=ResourceKey.ADLS_PASSTHROUGH_IO_MANAGER.value) @@ -159,14 +166,14 @@ def geolocation_metadata( return Output(None) -@asset(io_manager_key=ResourceKey.ADLS_PANDAS_IO_MANAGER.value) +@asset(io_manager_key=ResourceKey.ADLS_SPARK_IO_MANAGER.value) @capture_op_exceptions def geolocation_bronze( context: OpExecutionContext, geolocation_raw: bytes, config: FileConfig, spark: PySparkResource, -) -> Output[pd.DataFrame]: +) -> Output[sql.DataFrame]: s: SparkSession = spark.spark_session country_code = config.country_code mode = config.metadata["mode"] @@ -227,31 +234,28 @@ def geolocation_bronze( if column in df.columns: df = df.withColumn(column, f.initcap(f.col(column))) - ## at this point it's already gone - context.log.info("BEFORE DF TO PANDAS") - df_pandas = df.toPandas() - context.log.info("AFTER DF TO PANDAS") - context.log.info(df_pandas) + df.cache() + row_count = df.count() return Output( - df_pandas, + df, metadata={ **get_output_metadata(config), - "row_count": len(df_pandas), + "row_count": row_count, "column_mapping": column_mapping_filtered, - "preview": get_table_preview(df_pandas), + "preview": get_table_preview(df), }, ) -@asset(io_manager_key=ResourceKey.ADLS_PANDAS_IO_MANAGER.value) +@asset(io_manager_key=ResourceKey.ADLS_SPARK_IO_MANAGER.value) @capture_op_exceptions def geolocation_data_quality_results( context: OpExecutionContext, config: FileConfig, geolocation_bronze: sql.DataFrame, spark: PySparkResource, -) -> Output[pd.DataFrame]: +) -> Output[sql.DataFrame]: s: SparkSession = spark.spark_session country_code = config.country_code schema_name = config.metastore_schema @@ -315,9 +319,10 @@ def geolocation_data_quality_results( dq_results_schema_name = f"{schema_name}_dq_results" table_name = f"{id}_{country_code}_{current_timestamp}" - schema_columns = dq_results.schema.fields - for col in schema_columns: - col.nullable = True + schema_columns = [ + StructField(field.name, field.dataType, nullable=True) + for field in dq_results.schema.fields + ] dq_results_table_name = construct_full_table_name( dq_results_schema_name, @@ -333,10 +338,9 @@ def geolocation_data_quality_results( context, if_not_exists=True, ) + dq_results.cache() dq_results.write.format("delta").mode("append").saveAsTable(dq_results_table_name) - dq_pandas = dq_results.toPandas() - datahub_emit_metadata_with_exception_catcher( context=context, config=config, @@ -344,22 +348,31 @@ def geolocation_data_quality_results( ) return Output( - dq_pandas, + dq_results.coalesce(1), metadata={ **get_output_metadata(config), - "row_count": len(dq_pandas), - "preview": get_table_preview(dq_pandas), + "row_count": dq_results.count(), + "preview": get_table_preview(dq_results), }, ) -@asset(io_manager_key=ResourceKey.ADLS_PANDAS_IO_MANAGER.value) +@multi_asset( + outs={ + "geolocation_dq_schools_passed_human_readable": AssetOut( + io_manager_key=ResourceKey.ADLS_SPARK_SINGLE_FILE_IO_MANAGER.value, + ), + "geolocation_dq_schools_failed_human_readable": AssetOut( + io_manager_key=ResourceKey.ADLS_SPARK_SINGLE_FILE_IO_MANAGER.value, + ), + }, +) @capture_op_exceptions -async def geolocation_data_quality_results_human_readable( +def geolocation_data_quality_results_human_readable( context: OpExecutionContext, geolocation_data_quality_results: sql.DataFrame, config: FileConfig, -) -> Output[pd.DataFrame]: +): context.log.info("Get the file upload object from the database") with get_db_context() as db: file_upload = db.scalar( @@ -381,9 +394,6 @@ async def geolocation_data_quality_results_human_readable( df, human_readable_mappings = dq_geolocation_extract_relevant_columns( geolocation_data_quality_results, uploaded_columns, mode ) - # human_readable_mappings keys are map keys (without "dq_" prefix). - # Expand each relevant map entry into its own column with Yes/No values, - # then drop the map so the output is flat like before. for map_key, human_name in human_readable_mappings.items(): df = df.withColumn( human_name, @@ -393,64 +403,32 @@ async def geolocation_data_quality_results_human_readable( ) df = df.drop("dq_results") - context.log.info("Convert the dataframe to a pandas object to save it locally") - df_pandas = df.toPandas() + # Cache once — both filters read from the same plan + df.cache() - return Output( - df_pandas, - metadata={ - **get_output_metadata(config), - "row_count": len(df_pandas), - "preview": get_table_preview(df_pandas), - }, + df_passed = df.filter(df.dq_has_critical_error == 0).drop( + "dq_has_critical_error", "failure_reason" ) + df_failed = df.filter(df.dq_has_critical_error == 1).drop("dq_has_critical_error") + output_metadata = get_output_metadata(config) -@asset(io_manager_key=ResourceKey.ADLS_PANDAS_IO_MANAGER.value) -@capture_op_exceptions -async def geolocation_dq_schools_passed_human_readable( - context: OpExecutionContext, - geolocation_data_quality_results_human_readable: sql.DataFrame, - config: FileConfig, -) -> Output[pd.DataFrame]: - context.log.info("Filter and keep schools that do not have a critical error") - df = geolocation_data_quality_results_human_readable.filter( - geolocation_data_quality_results_human_readable.dq_has_critical_error == 0 - ) - df = df.drop("dq_has_critical_error", "failure_reason") - df_pandas = df.toPandas() - - return Output( - df_pandas, + yield Output( + df_passed, + output_name="geolocation_dq_schools_passed_human_readable", metadata={ - **get_output_metadata(config), - "row_count": len(df_pandas), - "preview": get_table_preview(df_pandas), + **output_metadata, + "row_count": df_passed.count(), + "preview": get_table_preview(df_passed), }, ) - - -@asset(io_manager_key=ResourceKey.ADLS_PANDAS_IO_MANAGER.value) -@capture_op_exceptions -async def geolocation_dq_schools_failed_human_readable( - context: OpExecutionContext, - geolocation_data_quality_results_human_readable: sql.DataFrame, - config: FileConfig, -) -> Output[pd.DataFrame]: - context.log.info("Filter and keep schools that have a critical error") - df = geolocation_data_quality_results_human_readable.filter( - geolocation_data_quality_results_human_readable.dq_has_critical_error == 1 - ) - - df = df.drop("dq_has_critical_error") - df_pandas = df.toPandas() - - return Output( - df_pandas, + yield Output( + df_failed, + output_name="geolocation_dq_schools_failed_human_readable", metadata={ - **get_output_metadata(config), - "row_count": len(df_pandas), - "preview": get_table_preview(df_pandas), + **output_metadata, + "row_count": df_failed.count(), + "preview": get_table_preview(df_failed), }, ) @@ -559,14 +537,14 @@ def geolocation_data_quality_report( return Output(dq_report) -@asset(io_manager_key=ResourceKey.ADLS_PANDAS_IO_MANAGER.value) +@asset(io_manager_key=ResourceKey.ADLS_SPARK_IO_MANAGER.value) @capture_op_exceptions def geolocation_dq_passed_rows( context: OpExecutionContext, geolocation_data_quality_results: sql.DataFrame, config: FileConfig, spark: PySparkResource, -) -> Output[pd.DataFrame]: +) -> Output[sql.DataFrame]: df_passed = dq_split_passed_rows( geolocation_data_quality_results, config.dataset_type, @@ -583,25 +561,27 @@ def geolocation_dq_passed_rows( schema_reference=schema_reference, ) - df_pandas = df_passed.toPandas() + df_passed.cache() + row_count = df_passed.count() + return Output( - df_pandas, + df_passed, metadata={ **get_output_metadata(config), - "row_count": len(df_pandas), - "preview": get_table_preview(df_pandas), + "row_count": row_count, + "preview": get_table_preview(df_passed), }, ) -@asset(io_manager_key=ResourceKey.ADLS_PANDAS_IO_MANAGER.value) +@asset(io_manager_key=ResourceKey.ADLS_SPARK_IO_MANAGER.value) @capture_op_exceptions def geolocation_dq_failed_rows( context: OpExecutionContext, geolocation_data_quality_results: sql.DataFrame, config: FileConfig, spark: PySparkResource, -) -> Output[pd.DataFrame]: +) -> Output[sql.DataFrame]: df_failed = dq_split_failed_rows( geolocation_data_quality_results, config.dataset_type, @@ -619,13 +599,15 @@ def geolocation_dq_failed_rows( df_failed=df_failed, ) - df_pandas = df_failed.toPandas() + df_failed.cache() + row_count = df_failed.count() + return Output( - df_pandas, + df_failed, metadata={ **get_output_metadata(config), - "row_count": len(df_pandas), - "preview": get_table_preview(df_pandas), + "row_count": row_count, + "preview": get_table_preview(df_failed), }, ) @@ -639,7 +621,7 @@ def geolocation_staging( spark: PySparkResource, config: FileConfig, ) -> Output[None]: - if geolocation_dq_passed_rows.count() == 0: + if geolocation_dq_passed_rows.isEmpty(): context.log.warning("Skipping staging as there are no rows passing DQ checks") return Output(None) diff --git a/dagster/src/resources/__init__.py b/dagster/src/resources/__init__.py index b75687146..34204c5e0 100644 --- a/dagster/src/resources/__init__.py +++ b/dagster/src/resources/__init__.py @@ -8,6 +8,8 @@ from .io_managers.adls_json import ADLSJSONIOManager from .io_managers.adls_pandas import ADLSPandasIOManager from .io_managers.adls_passthrough import ADLSPassthroughIOManager +from .io_managers.adls_spark import ADLSSparkIOManager +from .io_managers.adls_spark_single_file import ADLSSparkSingleFileIOManager class ResourceKey(Enum): @@ -16,6 +18,8 @@ class ResourceKey(Enum): ADLS_JSON_IO_MANAGER = "adls_json_io_manager" ADLS_PANDAS_IO_MANAGER = "adls_pandas_io_manager" ADLS_PASSTHROUGH_IO_MANAGER = "adls_passthrough_io_manager" + ADLS_SPARK_IO_MANAGER = "adls_spark_io_manager" + ADLS_SPARK_SINGLE_FILE_IO_MANAGER = "adls_spark_single_file_io_manager" ADLS_FILE_CLIENT = "adls_file_client" SPARK = "spark" @@ -26,6 +30,10 @@ class ResourceKey(Enum): ResourceKey.ADLS_JSON_IO_MANAGER.value: ADLSJSONIOManager(), ResourceKey.ADLS_PANDAS_IO_MANAGER.value: ADLSPandasIOManager(pyspark=pyspark), ResourceKey.ADLS_PASSTHROUGH_IO_MANAGER.value: ADLSPassthroughIOManager(), + ResourceKey.ADLS_SPARK_IO_MANAGER.value: ADLSSparkIOManager(pyspark=pyspark), + ResourceKey.ADLS_SPARK_SINGLE_FILE_IO_MANAGER.value: ADLSSparkSingleFileIOManager( + pyspark=pyspark + ), ResourceKey.ADLS_FILE_CLIENT.value: ADLSFileClient(), ResourceKey.SPARK.value: pyspark, } diff --git a/dagster/src/resources/io_managers/adls_spark.py b/dagster/src/resources/io_managers/adls_spark.py new file mode 100644 index 000000000..dc727b352 --- /dev/null +++ b/dagster/src/resources/io_managers/adls_spark.py @@ -0,0 +1,73 @@ +import pandas as pd +from dagster_pyspark import PySparkResource +from pyspark import sql +from pyspark.sql import SparkSession +from pyspark.sql.types import NullType, StringType + +from azure.core.exceptions import ResourceNotFoundError +from dagster import InputContext, OutputContext +from src.settings import settings +from src.utils.adls import ADLSFileClient + +from .base import BaseConfigurableIOManager + +adls_client = ADLSFileClient() + + +class ADLSSparkIOManager(BaseConfigurableIOManager): + """Writes Spark DataFrames natively to ADLS (parquet or csv directory). + + Uses a cache → isEmpty → write → unpersist pattern so the plan executes + once and executor memory is freed immediately after the write. + """ + + pyspark: PySparkResource + + def handle_output(self, context: OutputContext, output: sql.DataFrame): + path = self._get_filepath(context) + adls_path = f"{settings.AZURE_BLOB_CONNECTION_URI}/{path}" + + # Cast NullType columns to StringType — schema-only, no action triggered yet + for field in output.schema.fields: + if isinstance(field.dataType, NullType): + output = output.withColumn( + field.name, output[field.name].cast(StringType()) + ) + + # cache → isEmpty (populates cache) → write (reads from cache) → unpersist + output.cache() + output.isEmpty() + + match path.suffix: + case ".parquet": + output.write.mode("overwrite").parquet(adls_path) + case ".csv": + output.write.mode("overwrite").csv(adls_path, header=True) + case _: + raise OSError(f"Unsupported format for Spark write: {path.suffix}") + + output.unpersist() + context.log.info(f"Uploaded {path.name} to {path.parent} in ADLS.") + + def load_input(self, context: InputContext) -> sql.DataFrame: + spark: SparkSession = self.pyspark.spark_session + path = self._get_filepath(context) + adls_path = f"{settings.AZURE_BLOB_CONNECTION_URI}/{path}" + + try: + match path.suffix: + case ".parquet": + data = spark.read.parquet(adls_path) + case ".csv": + data = adls_client.download_csv_as_spark_dataframe(str(path), spark) + case ".xls" | ".xlsx": + # No native Spark Excel support — bridge via pandas + pdf = pd.read_excel(adls_path) + data = spark.createDataFrame(pdf.astype(str)) + case _: + raise OSError(f"Unsupported format for Spark read: {path.suffix}") + except ResourceNotFoundError as e: + raise e + + context.log.info(f"Downloaded {path.name} from {path.parent} in ADLS.") + return data diff --git a/dagster/src/resources/io_managers/adls_spark_single_file.py b/dagster/src/resources/io_managers/adls_spark_single_file.py new file mode 100644 index 000000000..8ce82d555 --- /dev/null +++ b/dagster/src/resources/io_managers/adls_spark_single_file.py @@ -0,0 +1,59 @@ +import pandas as pd +from dagster_pyspark import PySparkResource +from pyspark import sql +from pyspark.sql import SparkSession + +from azure.core.exceptions import ResourceNotFoundError +from dagster import InputContext, OutputContext +from src.settings import settings +from src.utils.adls import ADLSFileClient + +from .base import BaseConfigurableIOManager + +adls_client = ADLSFileClient() + + +class ADLSSparkSingleFileIOManager(BaseConfigurableIOManager): + """Writes Spark DataFrames as a single file to ADLS via a pandas bridge. + + Use this for human-readable output assets (CSV exports for end users) where + a single flat file is required rather than a partitioned Spark output directory. + The asset returns a sql.DataFrame; conversion to pandas happens here, keeping + asset code Spark-native. + + load_input returns a sql.DataFrame so downstream assets remain Spark-native. + """ + + pyspark: PySparkResource + + def handle_output(self, context: OutputContext, output: sql.DataFrame): + path = self._get_filepath(context) + pdf = output.toPandas() + adls_client.upload_pandas_dataframe_as_file( + context=context, + data=pdf, + filepath=str(path), + ) + context.log.info(f"Uploaded {path.name} to {path.parent} in ADLS.") + + def load_input(self, context: InputContext) -> sql.DataFrame: + spark: SparkSession = self.pyspark.spark_session + path = self._get_filepath(context) + adls_path = f"{settings.AZURE_BLOB_CONNECTION_URI}/{path}" + + try: + match path.suffix: + case ".parquet": + data = spark.read.parquet(adls_path) + case ".csv": + data = adls_client.download_csv_as_spark_dataframe(str(path), spark) + case ".xls" | ".xlsx": + pdf = pd.read_excel(adls_path) + data = spark.createDataFrame(pdf.astype(str)) + case _: + raise OSError(f"Unsupported format for Spark read: {path.suffix}") + except ResourceNotFoundError as e: + raise e + + context.log.info(f"Downloaded {path.name} from {path.parent} in ADLS.") + return data diff --git a/dagster/src/resources/io_managers/base.py b/dagster/src/resources/io_managers/base.py index f402bc05b..d36b82f8f 100644 --- a/dagster/src/resources/io_managers/base.py +++ b/dagster/src/resources/io_managers/base.py @@ -37,6 +37,8 @@ def _get_filepath(context: InputContext | OutputContext) -> Path: return config.destination_filepath_object config = FileConfig(**context.step_context.op_config) + if config.output_filepaths and context.name in config.output_filepaths: + return Path(config.output_filepaths[context.name]) return config.destination_filepath_object @staticmethod diff --git a/dagster/src/sensors/school_geolocation.py b/dagster/src/sensors/school_geolocation.py index 7d4992921..51dfe784e 100644 --- a/dagster/src/sensors/school_geolocation.py +++ b/dagster/src/sensors/school_geolocation.py @@ -80,19 +80,11 @@ def school_master_geolocation__raw_file_uploads_sensor( ), "geolocation_data_quality_results_human_readable": OpDestinationMapping( source_filepath=f"{constants.dq_results_folder}/{DOMAIN_DATASET_TYPE}/dq-overall/{country_code}/{stem}.parquet", - destination_filepath=f"{constants.dq_results_folder}/{DOMAIN_DATASET_TYPE}/dq-human-readable/{country_code}/{stem}.parquet", - metastore_schema=METASTORE_SCHEMA, - tier=DataTier.DATA_QUALITY_CHECKS, - ), - "geolocation_dq_schools_passed_human_readable": OpDestinationMapping( - source_filepath=f"{constants.dq_results_folder}/{DOMAIN_DATASET_TYPE}/dq-human-readable/{country_code}/{stem}.parquet", - destination_filepath=f"{constants.dq_results_folder}/{DOMAIN_DATASET_TYPE}/dq-passed-rows-human-readable/{country_code}/{stem}.csv", - metastore_schema=METASTORE_SCHEMA, - tier=DataTier.DATA_QUALITY_CHECKS, - ), - "geolocation_dq_schools_failed_human_readable": OpDestinationMapping( - source_filepath=f"{constants.dq_results_folder}/{DOMAIN_DATASET_TYPE}/dq-human-readable/{country_code}/{stem}.parquet", - destination_filepath=f"{constants.dq_results_folder}/{DOMAIN_DATASET_TYPE}/dq-failed-rows-human-readable/{country_code}/{stem}.csv", + destination_filepath="", + output_filepaths={ + "geolocation_dq_schools_passed_human_readable": f"{constants.dq_results_folder}/{DOMAIN_DATASET_TYPE}/dq-passed-rows-human-readable/{country_code}/{stem}.csv", + "geolocation_dq_schools_failed_human_readable": f"{constants.dq_results_folder}/{DOMAIN_DATASET_TYPE}/dq-failed-rows-human-readable/{country_code}/{stem}.csv", + }, metastore_schema=METASTORE_SCHEMA, tier=DataTier.DATA_QUALITY_CHECKS, ), @@ -109,19 +101,19 @@ def school_master_geolocation__raw_file_uploads_sensor( tier=DataTier.DATA_QUALITY_CHECKS, ), "geolocation_dq_passed_rows": OpDestinationMapping( - source_filepath=f"{constants.dq_results_folder}/{DOMAIN_DATASET_TYPE}/dq-overall/{country_code}/{stem}.csv", - destination_filepath=f"{constants.dq_results_folder}/{DOMAIN_DATASET_TYPE}/dq-passed-rows/{country_code}/{stem}.csv", + source_filepath=f"{constants.dq_results_folder}/{DOMAIN_DATASET_TYPE}/dq-overall/{country_code}/{stem}.parquet", + destination_filepath=f"{constants.dq_results_folder}/{DOMAIN_DATASET_TYPE}/dq-passed-rows/{country_code}/{stem}.parquet", metastore_schema=METASTORE_SCHEMA, tier=DataTier.DATA_QUALITY_CHECKS, ), "geolocation_dq_failed_rows": OpDestinationMapping( - source_filepath=f"{constants.dq_results_folder}/{DOMAIN_DATASET_TYPE}/dq-overall/{country_code}/{stem}.csv", - destination_filepath=f"{constants.dq_results_folder}/{DOMAIN_DATASET_TYPE}/dq-failed-rows/{country_code}/{stem}.csv", + source_filepath=f"{constants.dq_results_folder}/{DOMAIN_DATASET_TYPE}/dq-overall/{country_code}/{stem}.parquet", + destination_filepath=f"{constants.dq_results_folder}/{DOMAIN_DATASET_TYPE}/dq-failed-rows/{country_code}/{stem}.parquet", metastore_schema=METASTORE_SCHEMA, tier=DataTier.DATA_QUALITY_CHECKS, ), "geolocation_staging": OpDestinationMapping( - source_filepath=f"{constants.dq_results_folder}/{DOMAIN_DATASET_TYPE}/dq-passed-rows/{country_code}/{stem}.csv", + source_filepath=f"{constants.dq_results_folder}/{DOMAIN_DATASET_TYPE}/dq-passed-rows/{country_code}/{stem}.parquet", destination_filepath=f"{settings.SPARK_WAREHOUSE_PATH}/school_geolocation_staging.db/{country_code.lower()}", metastore_schema=METASTORE_SCHEMA, tier=DataTier.STAGING, diff --git a/dagster/src/utils/op_config.py b/dagster/src/utils/op_config.py index 7ed3dce04..dc287d4a4 100644 --- a/dagster/src/utils/op_config.py +++ b/dagster/src/utils/op_config.py @@ -51,6 +51,15 @@ class FileConfig(Config): For regular assets, simply pass in the destination path as a string. """, ) + output_filepaths: dict[str, str] = Field( + default_factory=dict, + description=""" + Per-output destination paths for multi-output assets (@multi_asset). + Keys are output names; values are ADLS-relative destination paths. + When non-empty, the IO manager uses the output name to select the path + instead of destination_filepath. + """, + ) dq_target_filepath: str = Field( description=""" The path of the file inside the ADLS container where we run data quality checks on. @@ -137,6 +146,7 @@ class OpDestinationMapping(BaseModel): metastore_schema: str tier: DataTier table_name: Optional[str] = None + output_filepaths: dict[str, str] = Field(default_factory=dict) def generate_run_ops( @@ -155,6 +165,7 @@ def generate_run_ops( file_config = FileConfig( filepath=op_mapping.source_filepath, destination_filepath=op_mapping.destination_filepath, + output_filepaths=op_mapping.output_filepaths, metastore_schema=op_mapping.metastore_schema, tier=op_mapping.tier, table_name=op_mapping.table_name, From cb122b5df3969902f229b892e538b55ae6550e8f Mon Sep 17 00:00:00 2001 From: Brian Musisi Date: Tue, 31 Mar 2026 12:21:33 +0200 Subject: [PATCH 05/26] feat: staging table restructure (#443) * feat: optimize memory footprint and tables * feat: update requests flow and master updates to match staging changes * feat: make post approval changes * fix: ensure approvals are ingested into silver and master * fix: coalesce admin correctly * fix: update connectivity_govt checks and make them safe when empty * fix: improve safety around connectivity columns * fix: remove reference to download_speed_govt * fix: cast the schema * fix: ensure the table is added * ifx: simplify logic when staging table exists * feat: add approval_request_log_id link * fix: ensure admin column are not erroneously updated --- dagster/src/assets/common/assets.py | 447 +++++++++--- .../src/assets/school_geolocation/assets.py | 125 ++-- .../data_quality_checks/column_relation.py | 15 +- dagster/src/data_quality_checks/duplicates.py | 4 +- dagster/src/data_quality_checks/precision.py | 10 +- dagster/src/data_quality_checks/standard.py | 57 +- dagster/src/data_quality_checks/utils.py | 189 +++-- dagster/src/internal/common_assets/staging.py | 682 +++++++++--------- dagster/src/sensors/school_geolocation.py | 4 +- dagster/src/spark/transform_functions.py | 129 +++- dagster/src/spark/user_defined_functions.py | 2 +- dagster/src/utils/delta.py | 3 + 12 files changed, 1013 insertions(+), 654 deletions(-) diff --git a/dagster/src/assets/common/assets.py b/dagster/src/assets/common/assets.py index d8b441ad0..214245a96 100644 --- a/dagster/src/assets/common/assets.py +++ b/dagster/src/assets/common/assets.py @@ -21,17 +21,15 @@ from src.constants import DataTier from src.internal.common_assets.master_release_notes import send_master_release_notes from src.internal.merge import ( + core_merge_logic, full_in_cluster_merge, - manual_review_dedupe_strat, - partial_cdf_in_cluster_merge, + partial_in_cluster_merge, ) from src.resources import ResourceKey from src.spark.transform_functions import ( add_missing_columns, ) -from src.utils.adls import ( - ADLSFileClient, -) +from src.utils.adls import ADLSFileClient from src.utils.datahub.emit_dataset_metadata import ( datahub_emit_metadata_with_exception_catcher, ) @@ -50,9 +48,51 @@ from src.utils.sentry import capture_op_exceptions from src.utils.spark import compute_row_hash, transform_types -from azure.core.exceptions import ResourceNotFoundError from dagster import OpExecutionContext, Output, asset +# Pending-changes status constants (mirror staging.py) +_STATUS_PENDING = "PENDING" +_STATUS_APPROVED = "APPROVED" +_STATUS_REJECTED = "REJECTED" +_STATUS_PROCESSED = "PROCESSED" + +# Pending-changes change_type constants (mirror staging.py) +_CHANGE_INSERT = "INSERT" +_CHANGE_UPDATE = "UPDATE" +_CHANGE_DELETE = "DELETE" +_CHANGE_UNCHANGED = "UNCHANGED" + +# Approval file shorthands +_APPROVE_ALL = "__all__" + + +def _parse_approval_file(approval_data: dict) -> tuple[str, list, list, str | None]: + return ( + approval_data.get("upload_id", ""), + approval_data.get("approved_change_ids", []), + approval_data.get("rejected_change_ids", []), + approval_data.get("approval_request_log_id"), + ) + + +def _resolve_change_ids( + upload_rows: sql.DataFrame, + change_ids: list, + spark_context, +) -> sql.DataFrame: + """Filter upload_rows by change_ids. + + ["__all__"] → all rows + [] → no rows + [id, ...] → filter to matching change_ids + """ + if change_ids == [_APPROVE_ALL]: + return upload_rows + if not change_ids: + return upload_rows.limit(0) + bc = spark_context.broadcast(change_ids) + return upload_rows.filter(f.col("change_id").isin(bc.value)) + @asset(io_manager_key=ResourceKey.ADLS_PASSTHROUGH_IO_MANAGER.value, deps=["silver"]) @capture_op_exceptions @@ -79,18 +119,18 @@ def manual_review_passed_rows( @capture_op_exceptions def manual_review_failed_rows( context: OpExecutionContext, - adls_file_client: ADLSFileClient, spark: PySparkResource, config: FileConfig, + adls_file_client: ADLSFileClient, ) -> Output[sql.DataFrame]: s: SparkSession = spark.spark_session - passing_rows_change_ids = adls_file_client.download_json(config.filepath) schema_name = config.metastore_schema country_code = config.country_code schema_columns = get_schema_columns(s, schema_name) column_names = [c.name for c in schema_columns] primary_key = get_primary_key(s, schema_name) + staging_tier_schema_name = construct_schema_name_for_tier( schema_name, DataTier.STAGING ) @@ -104,44 +144,58 @@ def manual_review_failed_rows( rejected_tier_schema_name, country_code ) - staging_cdf = ( - s.read.format("delta") - .option("readChangeFeed", "true") - .option("startingVersion", 0) - .table(staging_table_name) + # Download and parse the approval file written by the portal + approval_data = adls_file_client.download_json(config.filepath) + upload_id, _, rejected_change_ids, _ = _parse_approval_file(approval_data) + context.log.info( + f"[manual_review_failed_rows] approval file: upload_id={upload_id!r}, " + f"rejected_change_ids={rejected_change_ids!r}" ) - staging_cdf = staging_cdf.withColumn( - "change_id", - f.concat_ws( - "|", - f.col("school_id_giga"), - f.col("_change_type"), - f.col("_commit_version").cast(StringType()), - ), + + # Read PENDING non-UNCHANGED rows for this upload from pending_changes + s.catalog.refreshTable(staging_table_name) + staging_df = DeltaTable.forName(s, staging_table_name).toDF() + upload_rows = staging_df.filter( + (f.col("upload_id") == upload_id) + & (f.col("change_type") != _CHANGE_UNCHANGED) + & (f.col("status") == _STATUS_PENDING) ) - if len(passing_rows_change_ids) == 0: - df_failed = staging_cdf - elif len(passing_rows_change_ids) == 1 and passing_rows_change_ids[0] == "__all__": - df_failed = staging_cdf.limit(0) + rejected_rows = _resolve_change_ids( + upload_rows, rejected_change_ids, s.sparkContext + ).select(*column_names) + + # Mark rejected rows as REJECTED in pending_changes + if rejected_change_ids == [_APPROVE_ALL]: + reject_condition = ( + (f.col("upload_id") == upload_id) + & (f.col("change_type") != _CHANGE_UNCHANGED) + & (f.col("status") == _STATUS_PENDING) + ) + elif not rejected_change_ids: + reject_condition = f.lit(False) else: - bc_passing_rows_change_ids = s.sparkContext.broadcast(passing_rows_change_ids) - df_failed = staging_cdf.filter( - ~f.col("change_id").isin(bc_passing_rows_change_ids.value) + reject_condition = ( + (f.col("upload_id") == upload_id) + & f.col("change_id").isin(rejected_change_ids) + & (f.col("status") == _STATUS_PENDING) ) + DeltaTable.forName(s, staging_table_name).update( + condition=reject_condition, + set={"status": f.lit(_STATUS_REJECTED)}, + ) + context.log.info("Marked rejected staging rows as REJECTED.") - # Check if a rejects table already exists - # If yes, do an in-cluster merge - # Else, the current df_failed is the initial rejects table + create_schema(s, rejected_tier_schema_name) if check_table_exists(s, schema_name, country_code, DataTier.MANUAL_REJECTED): s.catalog.refreshTable(rejected_table_name) - rejected = DeltaTable.forName(s, rejected_table_name).toDF() - rejected = add_missing_columns(rejected, schema_columns) - new_rejected = partial_cdf_in_cluster_merge( - rejected, df_failed, column_names, primary_key, context + current_rejected = DeltaTable.forName(s, rejected_table_name).toDF() + current_rejected = add_missing_columns(current_rejected, schema_columns) + new_rejected = partial_in_cluster_merge( + current_rejected, rejected_rows, primary_key, column_names ) else: - new_rejected = df_failed.select(*column_names) + new_rejected = rejected_rows schema_reference = get_schema_columns_datahub(s, schema_name) datahub_emit_metadata_with_exception_catcher( @@ -161,23 +215,79 @@ def manual_review_failed_rows( ) +def _log_staging_diagnostics( + context: OpExecutionContext, + staging_df: sql.DataFrame, + upload_id: str, +) -> None: + total_staging = staging_df.count() + context.log.info(f"[silver] total rows in staging table: {total_staging}") + + if total_staging > 0: + sample_upload_ids = [ + r.upload_id + for r in staging_df.select("upload_id").distinct().limit(5).collect() + ] + sample_statuses = [ + r.status for r in staging_df.select("status").distinct().limit(5).collect() + ] + sample_change_types = [ + r.change_type + for r in staging_df.select("change_type").distinct().limit(5).collect() + ] + context.log.info(f"[silver] distinct upload_ids (up to 5): {sample_upload_ids}") + context.log.info(f"[silver] distinct statuses: {sample_statuses}") + context.log.info(f"[silver] distinct change_types: {sample_change_types}") + + matched_upload = staging_df.filter(f.col("upload_id") == upload_id).count() + matched_non_unchanged = staging_df.filter( + (f.col("upload_id") == upload_id) & (f.col("change_type") != _CHANGE_UNCHANGED) + ).count() + matched_all = staging_df.filter( + (f.col("upload_id") == upload_id) + & (f.col("change_type") != _CHANGE_UNCHANGED) + & (f.col("status") == _STATUS_PENDING) + ).count() + context.log.info( + f"[silver] filter breakdown: " + f"upload_id match={matched_upload}, " + f"+non-unchanged={matched_non_unchanged}, " + f"+status=PENDING={matched_all}" + ) + + +def _build_processed_condition(approved_change_ids: list, upload_id: str): + if approved_change_ids == [_APPROVE_ALL]: + return ( + (f.col("upload_id") == upload_id) + & (f.col("change_type") != _CHANGE_UNCHANGED) + & (f.col("status") == _STATUS_PENDING) + ) + if not approved_change_ids: + return f.lit(False) + return ( + (f.col("upload_id") == upload_id) + & f.col("change_id").isin(approved_change_ids) + & (f.col("status") == _STATUS_PENDING) + ) + + @asset(io_manager_key=ResourceKey.ADLS_DELTA_IO_MANAGER.value) @capture_op_exceptions def silver( context: OpExecutionContext, - adls_file_client: ADLSFileClient, spark: PySparkResource, config: FileConfig, + adls_file_client: ADLSFileClient, ) -> Output[sql.DataFrame]: s: SparkSession = spark.spark_session - passing_rows_change_ids = adls_file_client.download_json(config.filepath) - context.log.info(f"{len(passing_rows_change_ids)=}") schema_name = config.metastore_schema country_code = config.country_code schema_columns = get_schema_columns(s, schema_name) column_names = [c.name for c in schema_columns] primary_key = get_primary_key(s, schema_name) + staging_tier_schema_name = construct_schema_name_for_tier( schema_name, DataTier.STAGING ) @@ -189,80 +299,170 @@ def silver( ) silver_table_name = construct_full_table_name(silver_tier_schema_name, country_code) - staging_cdf = ( - s.read.format("delta") - .option("readChangeFeed", "true") - .option("startingVersion", 0) - .table(staging_table_name) - ) - staging_cdf = staging_cdf.withColumn( - "change_id", - f.concat_ws( - "|", - f.col("school_id_giga"), - f.col("_change_type"), - f.col("_commit_version").cast(StringType()), - ), - ) - - if len(passing_rows_change_ids) == 0: - df_passed = staging_cdf.limit(0) - elif len(passing_rows_change_ids) == 1 and passing_rows_change_ids[0] == "__all__": - df_passed = staging_cdf - else: - bc_passing_rows_change_ids = s.sparkContext.broadcast(passing_rows_change_ids) - df_passed = staging_cdf.filter( - f.col("change_id").isin(bc_passing_rows_change_ids.value) + # Download and parse the approval file written by the portal + approval_data = adls_file_client.download_json(config.filepath) + upload_id, approved_change_ids, _, approval_request_log_id = _parse_approval_file( + approval_data + ) + context.log.info( + f"[silver] approval file: upload_id={upload_id!r}, " + f"approved_change_ids={approved_change_ids!r}, " + f"approval_request_log_id={approval_request_log_id!r}" + ) + + # Read PENDING non-UNCHANGED rows for this upload from pending_changes + s.catalog.refreshTable(staging_table_name) + staging_df = DeltaTable.forName(s, staging_table_name).toDF() + + _log_staging_diagnostics(context, staging_df, upload_id) + + upload_rows = staging_df.filter( + (f.col("upload_id") == upload_id) + & (f.col("change_type") != _CHANGE_UNCHANGED) + & (f.col("status") == _STATUS_PENDING) + ) + + approved = _resolve_change_ids(upload_rows, approved_change_ids, s.sparkContext) + + # Break the lazy lineage back to the staging Delta table immediately. + # Delta Lake invalidates Spark's DataFrame cache when the table is written to + # (e.g. our PENDING→PROCESSED update), so .cache() is insufficient. + # localCheckpoint() materialises the rows into executor memory and severs the + # dependency on staging, ensuring the IO manager's isEmpty() call cannot + # accidentally re-read the now-PROCESSED rows and find 0 results. + approved = approved.localCheckpoint() + + if approved.isEmpty(): + context.log.info( + "No approved rows for this upload. Returning current silver unchanged." + ) + if check_table_exists(s, schema_name, country_code, DataTier.SILVER): + s.catalog.refreshTable(silver_table_name) + current_silver = DeltaTable.forName(s, silver_table_name).toDF() + return Output( + current_silver, + metadata={ + **get_output_metadata(config), + "preview": get_table_preview(current_silver), + "row_count": current_silver.count(), + }, + ) + empty = s.createDataFrame([], StructType(schema_columns)) + return Output( + empty, + metadata={ + **get_output_metadata(config), + "preview": get_table_preview(empty), + "row_count": 0, + }, ) - context.log.info(f"{df_passed.count()=}") - df_passed = manual_review_dedupe_strat(df_passed) + inserts = approved.filter(f.col("change_type") == _CHANGE_INSERT).select( + *column_names + ) + updates = approved.filter(f.col("change_type") == _CHANGE_UPDATE).select( + *column_names + ) + deletes = approved.filter(f.col("change_type") == _CHANGE_DELETE).select( + *column_names + ) - # Extract delete IDs before merge for post-merge verification - deletes_df = df_passed.filter(f.col("_change_type") == "delete") - delete_ids = [row[primary_key] for row in deletes_df.select(primary_key).collect()] - has_deletes = len(delete_ids) > 0 - context.log.info(f"Approved deletes count: {len(delete_ids)}") + insert_count = inserts.count() + update_count = updates.count() + delete_count = deletes.count() + context.log.info( + f"Approved rows: {insert_count} inserts, {update_count} updates, " + f"{delete_count} deletes" + ) + + delete_ids = [row[primary_key] for row in deletes.select(primary_key).collect()] if check_table_exists(s, schema_name, country_code, DataTier.SILVER): s.catalog.refreshTable(silver_table_name) current_silver = DeltaTable.forName(s, silver_table_name).toDF() current_silver = add_missing_columns(current_silver, schema_columns) - new_silver = partial_cdf_in_cluster_merge( - current_silver, df_passed, column_names, primary_key, context + new_silver = core_merge_logic( + current_silver, + inserts, + updates, + deletes, + primary_key, + column_names, + update_join_type="left", ) - # Post-merge verification: Verify approved deletes were actually removed - if has_deletes: - remaining_deletes = new_silver.filter( - new_silver[primary_key].isin(delete_ids) - ) - remaining_count = remaining_deletes.count() - + # Verify deletes were applied + if delete_ids: + remaining = new_silver.filter(new_silver[primary_key].isin(delete_ids)) + remaining_count = remaining.count() if remaining_count > 0: context.log.error( - f"Delete verification failed: {remaining_count} out of {len(delete_ids)} " - f"approved delete rows still exist in silver table. " - f"Sample IDs: {[row[primary_key] for row in remaining_deletes.limit(5).collect()]}" + f"Delete verification failed: {remaining_count} of {len(delete_ids)} " + f"approved deletes still in silver." ) from dagster import DagsterExecutionInterruptError raise DagsterExecutionInterruptError( - f"Deletes not applied: {remaining_count} rows still in silver table" - ) - else: - context.log.info( - f"Delete verification passed: All {len(delete_ids)} approved deletes " - f"successfully removed from silver table." + f"Deletes not applied: {remaining_count} rows still in silver" ) + context.log.info( + f"Delete verification passed: all {len(delete_ids)} deletes removed." + ) else: - new_silver = df_passed + new_silver = inserts if new_silver.isEmpty(): - context.log.info( - "Silver table is empty after merge, returning empty DataFrame." + context.log.info("Silver is empty after merge.") + new_silver = s.createDataFrame([], StructType(schema_columns)) + + new_silver = compute_row_hash(new_silver) + + # Mark approved rows as PROCESSED in the staging table + processed_condition = _build_processed_condition(approved_change_ids, upload_id) + DeltaTable.forName(s, staging_table_name).update( + condition=processed_condition, + set={ + "status": f.lit(_STATUS_PROCESSED), + "processed_at": f.current_timestamp(), + "approval_request_log_id": f.lit(approval_request_log_id), + }, + ) + context.log.info("Marked approved staging rows as PROCESSED.") + + # Check if any PENDING non-UNCHANGED rows remain across all uploads + remaining_pending = ( + DeltaTable.forName(s, staging_table_name) + .toDF() + .filter( + (f.col("status") == _STATUS_PENDING) + & (f.col("change_type") != _CHANGE_UNCHANGED) ) - new_silver = s.createDataFrame([], schema=StructType(schema_columns)) + .count() + ) + + # Reset ApprovalRequest: always clear is_merge_processing; disable only when + # no other upload is waiting for approval + formatted_dataset = f"School {config.dataset_type.capitalize()}" + update_values = {ApprovalRequest.is_merge_processing: False} + if remaining_pending == 0: + update_values[ApprovalRequest.enabled] = False + + with get_db_context() as db: + try: + with db.begin(): + db.execute( + update(ApprovalRequest) + .where( + (ApprovalRequest.country == country_code) + & (ApprovalRequest.dataset == formatted_dataset) + ) + .values(update_values) + ) + except Exception as e: + context.log.error( + f"Failed to reset ApprovalRequest for {country_code} - " + f"{formatted_dataset}: {e}" + ) schema_reference = get_schema_columns_datahub(s, schema_name) datahub_emit_metadata_with_exception_catcher( @@ -278,6 +478,9 @@ def silver( **get_output_metadata(config), "preview": get_table_preview(new_silver), "row_count": new_silver.count(), + "insert_count": insert_count, + "update_count": update_count, + "delete_count": delete_count, }, ) @@ -288,8 +491,23 @@ def reset_staging_table( context: OpExecutionContext, spark: PySparkResource, config: FileConfig, - adls_file_client: ADLSFileClient, ) -> None: + """ + No-op for the geolocation pipeline: the pending_changes staging table is a + persistent history log and does not need to be reset between merge cycles. + + For other pipelines (coverage) this asset still performs the silver-clone reset. + """ + if config.dataset_type == "geolocation": + context.log.info( + "Geolocation uses the pending_changes staging design; skipping reset." + ) + return + + from src.utils.adls import ADLSFileClient + + from azure.core.exceptions import ResourceNotFoundError + s: SparkSession = spark.spark_session country_code = config.country_code staging_tier_schema_name = construct_schema_name_for_tier( @@ -304,7 +522,6 @@ def reset_staging_table( ) silver_table_name = construct_full_table_name(silver_tier_schema_name, country_code) - # State check: Verify enabled=False and is_merge_processing=False before reset formatted_dataset = f"School {config.dataset_type.capitalize()}" with get_db_context() as db: current_request = db.scalar( @@ -316,23 +533,25 @@ def reset_staging_table( if current_request is None: context.log.warning( - f"No ApprovalRequest found for {country_code} - {formatted_dataset}. Proceeding with reset." + f"No ApprovalRequest found for {country_code} - {formatted_dataset}. " + "Proceeding with reset." ) elif current_request.enabled: context.log.warning( - f"Reset blocked: ApprovalRequest is enabled (enabled=True) for {country_code} - {formatted_dataset}. " - f"Expected enabled=False before reset." + f"Reset blocked: ApprovalRequest is enabled for {country_code} - " + f"{formatted_dataset}." ) return elif current_request.is_merge_processing: context.log.warning( - f"Reset blocked: Merge is still processing (is_merge_processing=True) for {country_code} - {formatted_dataset}. " - f"Expected is_merge_processing=False before reset." + f"Reset blocked: merge still processing for {country_code} - " + f"{formatted_dataset}." ) return + # Lazily import ADLSFileClient to avoid initialising it for the geolocation no-op path + adls_file_client = ADLSFileClient() s.sql(f"DROP TABLE IF EXISTS {staging_table_name}") - try: adls_file_client.delete(staging_table_path, is_directory=True) except ResourceNotFoundError as e: @@ -340,7 +559,7 @@ def reset_staging_table( schema_columns = get_schema_columns(s, config.metastore_schema) s.catalog.refreshTable(silver_table_name) - silver = DeltaTable.forName(s, silver_table_name).alias("silver").toDF() + silver_df = DeltaTable.forName(s, silver_table_name).alias("silver").toDF() create_schema(s, staging_tier_schema_name) create_delta_table( s, @@ -350,13 +569,11 @@ def reset_staging_table( context, if_not_exists=True, ) - silver.write.format("delta").mode("append").saveAsTable(staging_table_name) + silver_df.write.format("delta").mode("append").saveAsTable(staging_table_name) - formatted_dataset = f"School {config.dataset_type.capitalize()}" with get_db_context() as db: try: with db.begin(): - # Ensure enabled=False and is_merge_processing=False after reset result = db.execute( update(ApprovalRequest) .where( @@ -372,23 +589,24 @@ def reset_staging_table( ) if result.rowcount == 0: context.log.warning( - f"No ApprovalRequest found for {country_code} - {formatted_dataset}." + f"No ApprovalRequest found for {country_code} - " + f"{formatted_dataset}." ) except Exception as e: context.log.error( - f"Failed to update ApprovalRequest for {country_code} - {formatted_dataset}: {e}" + f"Failed to update ApprovalRequest for {country_code} - " + f"{formatted_dataset}: {e}" ) raise -def _handle_null_columns(schema_columns, primary_key, silver_columns): +def _handle_null_columns(schema_columns, primary_key): """Handle null columns by providing default values based on data type. If the column value is NULL, add a placeholder value if the following conditions are met: - The column is not nullable - The column is not the primary key - - The column is not in the silver table Default values by type: - String: "Unknown" @@ -398,11 +616,7 @@ def _handle_null_columns(schema_columns, primary_key, silver_columns): """ column_actions = {} for col in schema_columns: - if ( - not col.nullable - and col.name != primary_key - and col.name not in [c.name for c in silver_columns] - ): + if not col.nullable and col.name != primary_key: if col.dataType == StringType(): column_actions[col.name] = f.coalesce(f.col(col.name), f.lit("Unknown")) elif isinstance( @@ -435,8 +649,6 @@ def master( s.catalog.refreshTable(silver_table_name) silver = DeltaTable.forName(s, silver_table_name).alias("silver").toDF() - silver_columns = get_schema_columns(s, f"school_{config.dataset_type}") - schema_columns = get_schema_columns(s, schema_name) column_names = [c.name for c in schema_columns] primary_key = get_primary_key(s, schema_name) @@ -457,7 +669,7 @@ def master( else: new_master = silver - column_actions = _handle_null_columns(schema_columns, primary_key, silver_columns) + column_actions = _handle_null_columns(schema_columns, primary_key) new_master = new_master.withColumns(column_actions) new_master = compute_row_hash(new_master) @@ -499,7 +711,6 @@ def reference( schema_columns = get_schema_columns(s, schema_name) column_names = [c.name for c in schema_columns] primary_key = get_primary_key(s, schema_name) - silver_columns = silver.schema.fields silver = add_missing_columns(silver, schema_columns) silver = transform_types(silver, schema_name, context) @@ -516,7 +727,7 @@ def reference( else: new_reference = silver - column_actions = _handle_null_columns(schema_columns, primary_key, silver_columns) + column_actions = _handle_null_columns(schema_columns, primary_key) new_reference = new_reference.withColumns(column_actions) new_reference = compute_row_hash(new_reference) diff --git a/dagster/src/assets/school_geolocation/assets.py b/dagster/src/assets/school_geolocation/assets.py index 055d93a7a..fe1532659 100644 --- a/dagster/src/assets/school_geolocation/assets.py +++ b/dagster/src/assets/school_geolocation/assets.py @@ -26,15 +26,10 @@ from src.internal.common_assets.staging import StagingChangeTypeEnum, StagingStep from src.resources import ResourceKey from src.schemas.file_upload import FileUploadConfig -from src.settings import DeploymentEnvironment, settings from src.spark.config_expectations import config as config_expectations from src.spark.transform_functions import ( add_missing_columns, - column_mapping_rename, - create_bronze_layer_columns, - get_country_rt_schools, - merge_connectivity_to_master as merge_connectivity_to_df, - standardize_connectivity_type, + create_bronze_layer_columns_updated, ) from src.utils.adls import ( ADLSFileClient, @@ -174,7 +169,6 @@ def geolocation_bronze( ) -> Output[pd.DataFrame]: s: SparkSession = spark.spark_session country_code = config.country_code - schema_name = config.metastore_schema mode = config.metadata["mode"] with get_db_context() as db: @@ -206,58 +200,22 @@ def geolocation_bronze( ).map(str) pdf.rename(lambda name: name.strip(), axis="columns", inplace=True) + column_mapping_filtered = { + k.strip(): v + for k, v in column_to_schema_mapping.items() + if (k is not None) and (v is not None) + } + pdf = pdf[column_to_schema_mapping.keys()] + pdf.rename(column_mapping_filtered, axis="columns", inplace=True) df = s.createDataFrame(pdf) - df, column_mapping = column_mapping_rename(df, column_to_schema_mapping) - context.log.info("COLUMN MAPPING") - context.log.info(column_mapping) - context.log.info("COLUMN MAPPING DATAFRAME") - context.log.info(df) uploaded_columns = df.columns - columns = get_schema_columns(s, schema_name) - context.log.info("schema columns") - context.log.info(columns) - - schema = StructType(columns) - - # Create empty base schema DataFrame - geolocation_base = s.createDataFrame(s.sparkContext.emptyRDD(), schema=schema) + df = df.withColumn("school_id_govt", f.col("school_id_govt").cast(StringType())) - casted_geolocation_base = geolocation_base.withColumn( - "school_id_govt", f.col("school_id_govt").cast(StringType()) + df = create_bronze_layer_columns_updated( + df, mode, uploaded_columns, country_code, s ) - context.log.info("Casted Geolocation") - context.log.info(casted_geolocation_base) - - casted_bronze = df.withColumn( - "school_id_govt", f.col("school_id_govt").cast(StringType()) - ) - - context.log.info("Casted Bronze") - context.log.info(casted_bronze) - - df = create_bronze_layer_columns( - casted_bronze, casted_geolocation_base, country_code, mode, uploaded_columns - ) - context.log.info("DF from create_bronze_layer_columns") - context.log.info(df) - - config.metadata.update({"column_mapping": column_mapping}) - context.log.info("After config metadata update") - - if settings.DEPLOY_ENV != DeploymentEnvironment.LOCAL: - # RT Columns - connectivity = get_country_rt_schools(s, country_code) - df = merge_connectivity_to_df(df, connectivity, uploaded_columns, mode) - else: - # On local, we can't retrieve the connectivity data - df = df.withColumn("connectivity", f.lit("Unknown")) - df = df.withColumn("connectivity_RT", f.lit("Unknown")) - - # standardize the connectivity type - df = standardize_connectivity_type(df, mode, uploaded_columns) - datahub_emit_metadata_with_exception_catcher( context=context, config=config, @@ -280,7 +238,7 @@ def geolocation_bronze( metadata={ **get_output_metadata(config), "row_count": len(df_pandas), - "column_mapping": column_mapping, + "column_mapping": column_mapping_filtered, "preview": get_table_preview(df_pandas), }, ) @@ -337,6 +295,23 @@ def geolocation_data_quality_results( dq_results = dq_results.withColumnRenamed("dq_signature", "signature") + # Collapse all individual dq_ check columns into a single map column. + # This reduces ~120+ columns to one, cutting the DataFrame width by ~60 %. + # dq_has_critical_error and failure_reason remain as top-level columns because + # they are used for row-level filtering throughout the pipeline. + # In Trino the map is queryable as: dq_results['is_null_optional-latitude'] + dq_flag_cols = [ + c + for c in dq_results.columns + if c.startswith("dq_") and c != "dq_has_critical_error" + ] + map_args = [] + for col_name in dq_flag_cols: + map_args.extend([f.lit(col_name[len("dq_") :]), f.col(col_name).cast("int")]) + dq_results = dq_results.withColumn("dq_results", f.create_map(*map_args)).drop( + *dq_flag_cols + ) + dq_results_schema_name = f"{schema_name}_dq_results" table_name = f"{id}_{country_code}_{current_timestamp}" @@ -406,21 +381,18 @@ async def geolocation_data_quality_results_human_readable( df, human_readable_mappings = dq_geolocation_extract_relevant_columns( geolocation_data_quality_results, uploaded_columns, mode ) - # replace the dq_column column binary values with Yes/No depending on if they passed or failed the check - dq_column_names = [ - col - for col in df.columns - if (col.startswith("dq_") and col != "dq_has_critical_error") - ] - for column in dq_column_names: + # human_readable_mappings keys are map keys (without "dq_" prefix). + # Expand each relevant map entry into its own column with Yes/No values, + # then drop the map so the output is flat like before. + for map_key, human_name in human_readable_mappings.items(): df = df.withColumn( - column, - f.when(f.col(column) == 1, "No").otherwise( - f.when(f.col(column) == 0, "Yes") + human_name, + f.when(f.element_at(f.col("dq_results"), map_key) == 1, "No").otherwise( + f.when(f.element_at(f.col("dq_results"), map_key) == 0, "Yes") ), ) + df = df.drop("dq_results") - df = df.withColumnsRenamed(human_readable_mappings) context.log.info("Convert the dataframe to a pandas object to save it locally") df_pandas = df.toPandas() @@ -688,15 +660,30 @@ def geolocation_staging( spark.spark_session, StagingChangeTypeEnum.UPDATE, ) - staging = staging_step(geolocation_dq_passed_rows) - row_count = 0 if staging is None else staging.count() + pending = staging_step(geolocation_dq_passed_rows) + + if pending is None: + return Output( + None, + metadata={ + **get_output_metadata(config), + "insert_count": MetadataValue.int(0), + "update_count": MetadataValue.int(0), + "unchanged_count": MetadataValue.int(0), + "delete_count": MetadataValue.int(0), + }, + ) + counts = pending.groupBy("change_type").count().collect() + count_map = {row["change_type"]: row["count"] for row in counts} return Output( None, metadata={ **get_output_metadata(config), - "row_count": MetadataValue.int(row_count), - "preview": get_table_preview(staging), + "insert_count": MetadataValue.int(count_map.get("INSERT", 0)), + "update_count": MetadataValue.int(count_map.get("UPDATE", 0)), + "unchanged_count": MetadataValue.int(count_map.get("UNCHANGED", 0)), + "delete_count": MetadataValue.int(count_map.get("DELETE", 0)), }, ) diff --git a/dagster/src/data_quality_checks/column_relation.py b/dagster/src/data_quality_checks/column_relation.py index 5f2f80758..07fc2b341 100644 --- a/dagster/src/data_quality_checks/column_relation.py +++ b/dagster/src/data_quality_checks/column_relation.py @@ -5,6 +5,13 @@ from src.utils.logger import get_context_with_fallback_logger +def _col_if_exists(df: sql.DataFrame, col_name: str): + """Return f.col(col_name) if it exists in df, else f.lit(None).""" + if col_name in df.columns: + return f.col(col_name) + return f.lit(None) + + def column_relation_checks( df: sql.DataFrame, dataset_type: str, @@ -127,8 +134,8 @@ def column_relation_checks( transforms[ "dq_column_relation_checks-connectivity_govt_download_speed_contracted" ] = f.when( - (f.col("download_speed_contracted").isNotNull()) - & (f.col("connectivity_govt").isNull()), + (_col_if_exists(df, "download_speed_contracted").isNotNull()) + & (_col_if_exists(df, "connectivity_govt").isNull()), 1, ).otherwise(0) @@ -136,8 +143,8 @@ def column_relation_checks( transforms[ "dq_column_relation_checks-electricity_availability_electricity_type" ] = f.when( - (f.lower(f.col("electricity_availability")) == "yes") - & (f.col("electricity_type").isNull()), + (f.lower(_col_if_exists(df, "electricity_availability")) == "yes") + & (_col_if_exists(df, "electricity_type").isNull()), 1, ).otherwise(0) diff --git a/dagster/src/data_quality_checks/duplicates.py b/dagster/src/data_quality_checks/duplicates.py index c65f4f966..b0c76687e 100644 --- a/dagster/src/data_quality_checks/duplicates.py +++ b/dagster/src/data_quality_checks/duplicates.py @@ -55,10 +55,12 @@ def duplicate_all_except_checks( logger = get_context_with_fallback_logger(context) logger.info("Running duplicate all except checks...") + existing_columns = [col for col in config_column_list if col in df.columns] + df = df.withColumn( "dq_duplicate_all_except_school_code", f.when( - f.count("*").over(Window.partitionBy(config_column_list)) > 1, + f.count("*").over(Window.partitionBy(existing_columns)) > 1, 1, ).otherwise(0), ) diff --git a/dagster/src/data_quality_checks/precision.py b/dagster/src/data_quality_checks/precision.py index e31b9d44c..f38e755ea 100644 --- a/dagster/src/data_quality_checks/precision.py +++ b/dagster/src/data_quality_checks/precision.py @@ -4,7 +4,7 @@ from pyspark.sql import functions as f from dagster import OpExecutionContext -from src.spark.user_defined_functions import get_decimal_places_updated +from src.spark.user_defined_functions import get_decimal_places_udf_factory from src.utils.logger import get_context_with_fallback_logger @@ -18,10 +18,8 @@ def precision_check( column_actions = {} for column in config_column_list: - # precision = config_column_list[column]["min"] - # get_decimal_places = get_decimal_places_udf_factory(precision) - column_actions[f"dq_precision-{column}"] = get_decimal_places_updated( - f.col(column) - ) + precision = config_column_list[column]["min"] + get_decimal_places = get_decimal_places_udf_factory(precision) + column_actions[f"dq_precision-{column}"] = get_decimal_places(f.col(column)) return df.withColumns(column_actions) diff --git a/dagster/src/data_quality_checks/standard.py b/dagster/src/data_quality_checks/standard.py index 8b8ae5cb6..fa1424e8a 100644 --- a/dagster/src/data_quality_checks/standard.py +++ b/dagster/src/data_quality_checks/standard.py @@ -130,31 +130,46 @@ def format_validation_checks(df, context: OpExecutionContext = None): column_actions = {} for column, dtype in config.DATA_TYPES: if column in df.columns and dtype == "STRING": - column_actions[f"dq_is_not_alphanumeric-{column}"] = f.when( - f.regexp_extract(f.col(column), ".*[A-Za-z0-9].*", 0) != "", - 0, - ).otherwise(1) + # Null values are already captured by completeness checks; skip them here + # to avoid double-flagging a column as both null and not-alphanumeric. + column_actions[f"dq_is_not_alphanumeric-{column}"] = ( + f.when(f.col(column).isNull(), f.lit(0)) + .when( + f.regexp_extract(f.col(column), ".*[A-Za-z0-9].*", 0) != "", + f.lit(0), + ) + .otherwise(f.lit(1)) + ) if column in df.columns and dtype in [ "INT", "DOUBLE", "LONG", "TIMESTAMP", ]: # included timestamp based on luke's code - column_actions[f"dq_is_not_numeric-{column}"] = f.when( - f.regexp_extract(f.col(column), r"^-?\d+(\.\d+)?$", 0) != "", - 0, - ).otherwise(1) + # Null values are already captured by completeness checks; skip them here. + column_actions[f"dq_is_not_numeric-{column}"] = ( + f.when(f.col(column).isNull(), f.lit(0)) + .when( + f.regexp_extract(f.col(column), r"^-?\d+(\.\d+)?$", 0) != "", + f.lit(0), + ) + .otherwise(f.lit(1)) + ) # special format validation for school_id_giga if column in df.columns and column == "school_id_giga": - column_actions[f"dq_is_not_36_character_hash-{column}"] = f.when( - f.regexp_extract( - f.col(column), - r"\b([a-fA-F0-9]{8}-[a-fA-F0-9]{4}-[a-fA-F0-9]{4}-[a-fA-F0-9]{4}-[a-fA-F0-9]{12})\b", - 0, + column_actions[f"dq_is_not_36_character_hash-{column}"] = ( + f.when(f.col(column).isNull(), f.lit(0)) + .when( + f.regexp_extract( + f.col(column), + r"\b([a-fA-F0-9]{8}-[a-fA-F0-9]{4}-[a-fA-F0-9]{4}-[a-fA-F0-9]{4}-[a-fA-F0-9]{12})\b", + 0, + ) + != "", + f.lit(0), ) - != "", - 0, - ).otherwise(1) + .otherwise(f.lit(1)) + ) return df.withColumns(column_actions) @@ -166,9 +181,17 @@ def is_string_more_than_255_characters_check( logger = get_context_with_fallback_logger(context) logger.info("Running character length checks...") + # Skip columns whose expected type is non-string — their string representation + # can never exceed 255 characters, so the check always returns 0 for them. + non_string_cols = {col for col, dtype in config.DATA_TYPES if dtype != "STRING"} + column_actions = {} for column in df.columns: - if column != "school_name" and not column.startswith("dq"): + if ( + column != "school_name" + and not column.startswith("dq") + and column not in non_string_cols + ): column_actions[f"dq_is_string_more_than_255_characters-{column}"] = f.when( f.length(column) > 255, 1, diff --git a/dagster/src/data_quality_checks/utils.py b/dagster/src/data_quality_checks/utils.py index 5a08deace..0522eae17 100644 --- a/dagster/src/data_quality_checks/utils.py +++ b/dagster/src/data_quality_checks/utils.py @@ -8,6 +8,7 @@ functions as f, window as w, ) +from pyspark.sql.types import IntegerType, MapType, StringType from dagster import OpExecutionContext from src.constants import UploadMode @@ -49,21 +50,33 @@ def aggregate_report_spark_df( spark: SparkSession, df: sql.DataFrame, ): # input df == row level checks results - dq_columns = [col for col in df.columns if col.startswith("dq_")] - - df = df.select(*dq_columns) - - for column_name in df.columns: - df = df.withColumn(column_name, f.col(column_name).cast("int")) - - # Unpivot Row Level Checks - stack_expr = ", ".join([f"'{col.split('_', 1)[1]}', `{col}`" for col in dq_columns]) - unpivoted_df = df.selectExpr( - f"stack({len(dq_columns)}, {stack_expr}) as (assertion, value)", - ) - # unpivoted_df.show() + # Geolocation DQ results store individual checks in a single dq_results map column. + # Other pipelines (master, coverage, connectivity) still use flat dq_* columns. + # Detect which format is present and unify into (check_key, value) rows. + if "dq_results" in df.columns: + # Map path: explode map + merge dq_has_critical_error + all_checks_map = f.map_concat( + f.col("dq_results"), + f.create_map( + f.lit("has_critical_error"), + f.col("dq_has_critical_error").cast("int"), + ), + ) + unpivoted_df = df.select(f.explode(all_checks_map).alias("check_key", "value")) + else: + # Flat path: stack all dq_* columns into (check_key, value) rows + dq_columns = [col for col in df.columns if col.startswith("dq_")] + df_flat = df.select(*dq_columns) + for col_name in dq_columns: + df_flat = df_flat.withColumn(col_name, f.col(col_name).cast("int")) + stack_expr = ", ".join( + [f"'{col.split('_', 1)[1]}', `{col}`" for col in dq_columns] + ) + unpivoted_df = df_flat.selectExpr( + f"stack({len(dq_columns)}, {stack_expr}) as (check_key, value)", + ) - agg_df = unpivoted_df.groupBy("assertion").agg( + agg_df = unpivoted_df.groupBy("check_key").agg( f.expr("count(CASE WHEN value = 1 THEN value END) as count_failed"), f.expr("count(CASE WHEN value = 0 THEN value END) as count_passed"), f.expr("count(value) as count_overall"), @@ -83,12 +96,9 @@ def aggregate_report_spark_df( ) # Processing for Human Readable Report - agg_df = agg_df.withColumn("dq_column", f.col("assertion")) - agg_df = agg_df.withColumn("column", (f.split(f.col("assertion"), "-").getItem(1))) - agg_df = agg_df.withColumn( - "assertion", - (f.split(f.col("assertion"), "-").getItem(0)), - ) + agg_df = agg_df.withColumn("dq_column", f.col("check_key")) + agg_df = agg_df.withColumn("column", f.split(f.col("check_key"), "-").getItem(1)) + agg_df = agg_df.withColumn("assertion", f.split(f.col("check_key"), "-").getItem(0)) # get data descriptions from nocodb dq_column_name_table_id = get_nocodb_table_id_from_name( @@ -195,12 +205,18 @@ def aggregate_report_json( def aggregate_report_statistics(df: sql.DataFrame, upload_details: dict): - # add necessary columns + df = _normalize_dq_results_map(df) + + # dq_results is now a map; access specific keys via element_at(). + def _check(key): + return f.element_at(f.col("dq_results"), key) + + # Computed composite columns derived from map values df = df.withColumn( "dq_missing_location", f.when( - (f.col("dq_is_null_optional-latitude") == 1) - | (f.col("dq_is_null_optional-latitude") == 1), + (_check("is_null_optional-latitude") == 1) + | (_check("is_null_optional-longitude") == 1), 1, ).otherwise(0), ) @@ -208,42 +224,49 @@ def aggregate_report_statistics(df: sql.DataFrame, upload_details: dict): df = df.withColumn( "dq_is_null_connectivity_type_when_connectivity_govt", f.when( - (f.col("dq_is_null_optional-connectivity_type_govt") == 1) - & (f.col("dq_is_null_optional-connectivity_govt") == 0), + (_check("is_null_optional-connectivity_type_govt") == 1) + & (_check("is_null_optional-connectivity_govt") == 0), 1, ).otherwise(0), ) count_schools_raw_file = df.count() - dq_report_columns = [ - "dq_duplicate-school_id_govt", - "dq_duplicate_all_except_school_code", - "dq_duplicate_name_level_within_110m_radius", - "dq_duplicate_set-location_id", - "dq_duplicate_set-education_level_location_id", - "dq_duplicate_set-school_id_govt_school_name_education_level_location_id", - "dq_duplicate_set-school_name_education_level_location_id", - "dq_duplicate_similar_name_same_level_within_110m_radius", - "dq_has_critical_error", - "dq_is_not_within_country", - "dq_is_not_alphanumeric-school_name", - "dq_is_null_connectivity_type_when_connectivity_govt", - "dq_is_null_mandatory-school_id_govt", - "dq_is_null_optional-computer_availability", - "dq_is_null_optional-connectivity_govt", - "dq_is_null_optional-education_level_govt", - "dq_is_null_optional-school_name", - "dq_is_school_density_greater_than_5", - "dq_missing_location", - "dq_precision-latitude", - "dq_precision-longitude", + # Keys in the dq_results map needed for the statistics report + static_check_keys = [ + "duplicate-school_id_govt", + "duplicate_all_except_school_code", + "duplicate_name_level_within_110m_radius", + "duplicate_set-location_id", + "duplicate_set-education_level_location_id", + "duplicate_set-school_id_govt_school_name_education_level_location_id", + "duplicate_set-school_name_education_level_location_id", + "duplicate_similar_name_same_level_within_110m_radius", + "is_not_within_country", + "is_not_alphanumeric-school_name", + "is_null_mandatory-school_id_govt", + "is_null_optional-computer_availability", + "is_null_optional-connectivity_govt", + "is_null_optional-education_level_govt", + "is_null_optional-school_name", + "is_school_density_greater_than_5", + "precision-latitude", + "precision-longitude", ] - df_report = df.select(*dq_report_columns) + # Flatten only the needed map keys into individual columns for the stack/agg step + df_report = df.select( + *[_check(key).cast("int").alias(f"dq_{key}") for key in static_check_keys], + f.col("dq_has_critical_error").cast("int"), + f.col("dq_missing_location").cast("int"), + f.col("dq_is_null_connectivity_type_when_connectivity_govt").cast("int"), + ) - for column_name in df_report.columns: - df_report = df_report.withColumn(column_name, f.col(column_name).cast("int")) + dq_report_columns = [f"dq_{key}" for key in static_check_keys] + [ + "dq_has_critical_error", + "dq_missing_location", + "dq_is_null_connectivity_type_when_connectivity_govt", + ] dq_duplicate_columns = [ col for col in dq_report_columns if col.startswith("dq_duplicate") @@ -493,9 +516,33 @@ def dq_split_failed_rows(df: sql.DataFrame, dataset_type: str): return df +def _normalize_dq_results_map(df: sql.DataFrame) -> sql.DataFrame: + """Ensure dq_results is MapType(string, int). + + ADLSPandasIOManager round-trips through Pandas/PyArrow, which writes the map as a + Parquet STRUCT (fixed keys, more compact). Spark reads it back as StructType, which + breaks map_filter / element_at. Converting via to_json → from_json restores the + proper MapType regardless of the original column type. + """ + if "dq_results" not in df.columns: + return df + from pyspark.sql.types import StructType as _StructType + + if isinstance(df.schema["dq_results"].dataType, _StructType): + df = df.withColumn( + "dq_results", + f.from_json( + f.to_json(f.col("dq_results")), MapType(StringType(), IntegerType()) + ), + ) + return df + + def dq_geolocation_extract_relevant_columns( df: sql.DataFrame, uploaded_columns: list[str], mode: str ): + df = _normalize_dq_results_map(df) + dq_column_name_table_id = get_nocodb_table_id_from_name( table_name="SchoolGeolocationMasterDQChecks" ) @@ -543,16 +590,42 @@ def dq_geolocation_extract_relevant_columns( dq_columns_list = dq_table_all.sort_values("Related Check ID")[ "DQ Table Column Name" ].tolist() - dq_columns_list = ["dq_has_critical_error", "failure_reason", *dq_columns_list] - admin_columns = ["admin1", "admin2", "admin3", "admin4"] - columns_to_keep = [*uploaded_columns, *admin_columns, *dq_columns_list] - columns_to_keep = [col for col in columns_to_keep if col in df.columns] + + # The dq_results map keys are the check names without the "dq_" prefix. + relevant_map_keys = [col.replace("dq_", "", 1) for col in dq_columns_list] + + columns_to_keep = [ + col + for col in [ + *uploaded_columns, + "dq_has_critical_error", + "failure_reason", + "dq_results", + ] + if col in df.columns + ] df = df.select(*columns_to_keep) - # column name mappings - human_readable_mappings = dq_table_all.set_index("DQ Table Column Name")[ - "Human Readable Name" - ].to_dict() + # Narrow the dq_results map to only checks that are relevant for the uploaded + # columns and mode. This avoids exposing checks that do not apply. + if relevant_map_keys: + relevant_keys_set = set(relevant_map_keys) + df = df.withColumn( + "dq_results", + f.map_filter( + f.col("dq_results"), lambda k, _v: k.isin(list(relevant_keys_set)) + ), + ) + + # human_readable_mappings: map key (without "dq_" prefix) -> human-readable label + human_readable_mappings = { + col.replace("dq_", "", 1): name + for col, name in dq_table_all.set_index("DQ Table Column Name")[ + "Human Readable Name" + ] + .to_dict() + .items() + } return df, human_readable_mappings diff --git a/dagster/src/internal/common_assets/staging.py b/dagster/src/internal/common_assets/staging.py index 8553d53c4..130b76df3 100644 --- a/dagster/src/internal/common_assets/staging.py +++ b/dagster/src/internal/common_assets/staging.py @@ -4,25 +4,24 @@ from models.approval_requests import ApprovalRequest from models.file_upload import FileUpload from pyspark import sql -from pyspark.sql import SparkSession -from pyspark.sql.types import StructType -from sqlalchemy import select, text, update +from pyspark.sql import ( + SparkSession, + functions as f, +) +from pyspark.sql.types import ArrayType, StringType, StructField, TimestampType +from sqlalchemy import select, update from dagster import OpExecutionContext from src.constants import DataTier -from src.internal.merge import partial_in_cluster_merge +from src.schemas.file_upload import FileUploadConfig from src.spark.transform_functions import add_missing_columns from src.utils.adls import ADLSFileClient from src.utils.datahub.emit_lineage import emit_lineage_base from src.utils.db.primary import get_db_context -from src.utils.db.trino import get_db_context as get_trino_context from src.utils.delta import ( - build_deduped_delete_query, - build_deduped_merge_query, check_table_exists, create_delta_table, create_schema, - execute_query_with_error_handler, ) from src.utils.op_config import FileConfig from src.utils.schema import ( @@ -56,6 +55,18 @@ class StagingChangeTypeEnum(enum.Enum): DELETE = "DELETE" +# change_type values written to the pending_changes table +_CHANGE_INSERT = "INSERT" +_CHANGE_UPDATE = "UPDATE" +_CHANGE_UNCHANGED = "UNCHANGED" +_CHANGE_DELETE = "DELETE" + +# status values written alongside each pending_changes row +_STATUS_PENDING = "PENDING" +_STATUS_APPROVED = "APPROVED" +_STATUS_REJECTED = "REJECTED" + + class StagingStep: def __init__( self, @@ -92,103 +103,251 @@ def __init__( ) def __call__(self, upstream_df: sql.DataFrame | list[str]) -> sql.DataFrame | None: - staging = self._process_staging_changes(upstream_df) - if staging is None: + if self.change_type == StagingChangeTypeEnum.DELETE: + pending = self._build_delete_records(upstream_df) + else: + pending = self._build_upsert_records(upstream_df) + + if pending is None or pending.isEmpty(): return None - self._update_approval_request_status(staging) + self._write_pending_records(pending) + self._update_approval_request_status() self._emit_lineage() - return staging + return pending + + def _build_upsert_records(self, df: sql.DataFrame) -> sql.DataFrame | None: + """Build pending_changes rows for an upsert (INSERT/UPDATE/UNCHANGED).""" + uploaded_columns = self._get_uploaded_columns() + df = self._prepare_df(df) + schema_col_names = [c.name for c in self.schema_columns] + upload_id = self.config.filename_components.id + + if not self.silver_table_exists: + # No silver table yet — every row is an INSERT + df = compute_row_hash(df) + df = self._select_schema_cols(df) + df = ( + df.withColumn("change_type", f.lit(_CHANGE_INSERT)) + .withColumn("upload_id", f.lit(upload_id)) + .withColumn( + "uploaded_columns", + f.array(*[f.lit(c) for c in uploaded_columns]), + ) + .withColumn("status", f.lit(_STATUS_PENDING)) + .withColumn( + "change_id", + f.concat_ws( + "|", + f.col(self.primary_key), + f.lit(upload_id), + f.col("change_type"), + ), + ) + .withColumn("created_at", f.current_timestamp()) + .withColumn("processed_at", f.lit(None).cast(TimestampType())) + ) + self.context.log.info(f"No silver table; all {df.count()} rows are INSERT") + return df - def _process_staging_changes( - self, upstream_df: sql.DataFrame | list[str] - ) -> sql.DataFrame | None: - """Process staging changes based on silver table existence.""" - if self.silver_table_exists: - return self._process_with_silver_table(upstream_df) - else: - return self._process_without_silver_table(upstream_df) + # Silver exists: left-join and fill non-uploaded cols from silver + silver_df = DeltaTable.forName(self.spark, self.silver_table_name).toDF() - def _process_with_silver_table( - self, upstream_df: sql.DataFrame | list[str] - ) -> sql.DataFrame: - """Process staging changes when silver table exists. + # Prefix all silver columns to avoid name conflicts + silver_prefixed = silver_df.select( + *[f.col(c).alias(f"_s_{c}") for c in silver_df.columns] + ) + joined = df.join( + silver_prefixed, + df[self.primary_key] == silver_prefixed[f"_s_{self.primary_key}"], + "left", + ) - The staging table contains a working copy of the entire silver table plus - any pending changes (upserts/deletes) that are awaiting approval. This allows - reviewers to see the complete dataset with pending changes applied, not just - the diff of pending changes. + # For columns not in the upload file: use silver's value for existing rows + row_in_silver = f.col(f"_s_{self.primary_key}").isNotNull() + for col_name in schema_col_names: + s_col = f"_s_{col_name}" + if col_name not in uploaded_columns and s_col in joined.columns: + joined = joined.withColumn( + col_name, + f.when(row_in_silver, f.col(s_col)).otherwise(f.col(col_name)), + ) - When changes are approved, the staging table is merged back to silver. - """ - if not self.staging_table_exists: - # Clone entire silver table to staging to create a working copy - self.create_staging_table_from_silver() + # Admin columns may be null per-row when lat/lon is absent for that row, + # even though lat/lon is present in the file (and admin columns therefore + # appear in uploaded_columns). For existing rows, fall back to silver so + # we preserve a previously-computed admin value and avoid spurious UPDATEs. + _admin_cols = ( + "admin1", + "admin2", + "admin3", + "admin4", + "admin1_id_giga", + "admin2_id_giga", + "admin3_id_giga", + "admin4_id_giga", + ) + for col_name in _admin_cols: + s_col = f"_s_{col_name}" + if col_name in uploaded_columns and s_col in joined.columns: + joined = joined.withColumn( + col_name, + f.when( + row_in_silver & f.col(col_name).isNull(), f.col(s_col) + ).otherwise(f.col(col_name)), + ) - self.sync_schema_staging() + # Drop all silver-prefixed columns (including _s_signature) + s_cols_to_drop = [c for c in joined.columns if c.startswith("_s_")] + joined = joined.drop(*s_cols_to_drop) - # If silver table exists and staging table exists, merge files for review to existing staging table - if self.change_type != StagingChangeTypeEnum.DELETE: - staging = self.standard_transforms(upstream_df) - staging = self.upsert_rows(staging) - else: - staging = self.delete_rows(upstream_df) + # Compute hash over the fully-merged row (same column set as when silver was written) + joined = compute_row_hash(joined) - return staging + # Join with silver signatures to determine INSERT / UPDATE / UNCHANGED + # (must happen before _select_schema_cols so that 'signature' is still present) + silver_sigs = silver_df.select( + f.col(self.primary_key).alias("_sig_pk"), + f.col("signature").alias("_silver_sig"), + ) + joined = joined.join( + silver_sigs, + joined[self.primary_key] == f.col("_sig_pk"), + "left", + ) + joined = joined.withColumn( + "change_type", + f.when(f.col("_sig_pk").isNull(), f.lit(_CHANGE_INSERT)) + .when( + f.col("signature") == f.col("_silver_sig"), + f.lit(_CHANGE_UNCHANGED), + ) + .otherwise(f.lit(_CHANGE_UPDATE)), + ) + joined = joined.drop("_sig_pk", "_silver_sig") + + # Trim to schema columns + change_type before persisting. + # Inline the select instead of calling _select_schema_cols so that + # change_type (not a schema column) is preserved alongside them. + available = [c for c in schema_col_names if c in joined.columns] + joined = joined.select(*available, "change_type") + + joined = ( + joined.withColumn("upload_id", f.lit(upload_id)) + .withColumn( + "uploaded_columns", + f.array(*[f.lit(c) for c in uploaded_columns]), + ) + .withColumn("status", f.lit(_STATUS_PENDING)) + .withColumn( + "change_id", + f.concat_ws( + "|", f.col(self.primary_key), f.lit(upload_id), f.col("change_type") + ), + ) + .withColumn("created_at", f.current_timestamp()) + .withColumn("processed_at", f.lit(None).cast(TimestampType())) + ) + return joined - def _process_without_silver_table( - self, upstream_df: sql.DataFrame | list[str] - ) -> sql.DataFrame | None: - """Process staging changes when silver table does not exist.""" - if self.change_type != StagingChangeTypeEnum.UPDATE: - return ( - None # Cannot delete rows when silver and staging tables do not exist + def _build_delete_records(self, delete_ids: list[str]) -> sql.DataFrame | None: + """Build pending_changes rows for a DELETE operation.""" + if not self.silver_table_exists: + self.context.log.warning( + "Silver table does not exist; cannot stage DELETE records." ) + return None - staging = self.standard_transforms(upstream_df) + silver_df = DeltaTable.forName(self.spark, self.silver_table_name).toDF() + rows = silver_df.filter(f.col(self.primary_key).isin(delete_ids)) - if self.staging_table_exists: - self.sync_schema_staging() - staging = self.upsert_rows(staging) - else: - self.create_empty_staging_table() - ( - staging.write.option("mergeSchema", "true") - .format("delta") - .mode("append") - .saveAsTable(self.staging_table_name) + if rows.isEmpty(): + self.context.log.warning( + f"None of {len(delete_ids)} delete IDs found in silver. Skipping." ) + return None - self.context.log.info(f"Full {staging.count()=}") - return staging + upload_id = self.config.filename_components.id + rows = ( + rows.withColumn("change_type", f.lit(_CHANGE_DELETE)) + .withColumn("upload_id", f.lit(upload_id)) + .withColumn("uploaded_columns", f.array(f.lit(self.primary_key))) + .withColumn("status", f.lit(_STATUS_PENDING)) + .withColumn( + "change_id", + f.concat_ws( + "|", f.col(self.primary_key), f.lit(upload_id), f.col("change_type") + ), + ) + .withColumn("created_at", f.current_timestamp()) + .withColumn("processed_at", f.lit(None).cast(TimestampType())) + ) + return rows - def _get_pre_update_row_count(self) -> int | None: - """Get row count from staging table before update.""" - if not self.staging_table_exists: - return None + def _write_pending_records(self, pending: sql.DataFrame) -> None: + """Append pending_changes rows to the staging Delta table.""" + create_schema(self.spark, self.staging_tier_schema_name) - try: - with get_trino_context() as trino_db: - # Table name is constructed from trusted config sources, not user input - result = trino_db.execute( - text(f"SELECT COUNT(*) as count FROM {self.staging_table_name}") # nosec B608 - ) - row_count = result.scalar() - self.context.log.info(f"Pre-update staging row count: {row_count}") - return row_count - except Exception as e: - self.context.log.warning( - f"Failed to query staging row count: {e}. Proceeding with DB update." + pending_extra_fields = [ + StructField("upload_id", StringType(), nullable=False), + StructField("change_type", StringType(), nullable=False), + StructField("uploaded_columns", ArrayType(StringType()), nullable=False), + StructField("status", StringType(), nullable=False), + StructField("change_id", StringType(), nullable=False), + StructField("created_at", TimestampType(), nullable=True), + StructField("processed_at", TimestampType(), nullable=True), + StructField("approval_request_log_id", StringType(), nullable=True), + ] + pending_schema = list(self.schema_columns) + pending_extra_fields + + if not self.spark.catalog.tableExists(self.staging_table_name): + create_delta_table( + self.spark, + self.staging_tier_schema_name, + self.country_code, + pending_schema, + self.context, + replace=True, + partition_by=["upload_id"], ) - return None - def _update_approval_request_status(self, staging: sql.DataFrame) -> None: - """Update ApprovalRequest status if conditions are met.""" - pre_update_row_count = self._get_pre_update_row_count() - formatted_dataset = f"School {self.config.dataset_type.capitalize()}" - result = None + # Cast pending columns to the expected schema types before writing + schema_type_map = {field.name: field.dataType for field in pending_schema} + pending = pending.withColumns( + { + col: f.col(col).cast(dtype) + for col, dtype in schema_type_map.items() + if col in pending.columns + } + ) + + ( + pending.write.format("delta") + .mode("append") + .option("mergeSchema", "true") + .saveAsTable(self.staging_table_name) + ) + + def _update_approval_request_status(self) -> None: + """Enable the ApprovalRequest if any actionable (non-UNCHANGED) rows exist.""" + actionable = ( + DeltaTable.forName(self.spark, self.staging_table_name) + .toDF() + .filter( + (f.col("status") == _STATUS_PENDING) + & (f.col("change_type") != _CHANGE_UNCHANGED) + ) + .count() + ) + if actionable == 0: + self.context.log.info( + "No actionable changes in pending_changes (only UNCHANGED). " + "Skipping ApprovalRequest update." + ) + return + formatted_dataset = f"School {self.config.dataset_type.capitalize()}" with get_db_context() as db: try: with db.begin(): @@ -197,13 +356,16 @@ def _update_approval_request_status(self, staging: sql.DataFrame) -> None: ) if current_request is None: self.context.log.warning( - f"No ApprovalRequest found for {self.country_code} - {formatted_dataset}" + f"No ApprovalRequest found for " + f"{self.country_code} - {formatted_dataset}" ) return - if not self._should_update_enabled( - current_request, pre_update_row_count, formatted_dataset - ): + if current_request.enabled: + self.context.log.info( + f"ApprovalRequest already enabled for " + f"{self.country_code} - {formatted_dataset}. Skipping." + ) return result = db.execute( @@ -211,7 +373,7 @@ def _update_approval_request_status(self, staging: sql.DataFrame) -> None: .where( (ApprovalRequest.country == self.country_code) & (ApprovalRequest.dataset == formatted_dataset) - & (~ApprovalRequest.enabled) # Only update if False + & (~ApprovalRequest.enabled) ) .values( { @@ -222,28 +384,20 @@ def _update_approval_request_status(self, staging: sql.DataFrame) -> None: ) if result.rowcount > 0: self.context.log.info( - f"Successfully set enabled=True for {self.country_code} - {formatted_dataset}" + f"Successfully set enabled=True for " + f"{self.country_code} - {formatted_dataset}" ) else: self.context.log.info( - "No rows updated (already enabled or state changed). Skipping commit." + "No rows updated (already enabled or state changed)." ) - - # Post-delete validation: Verify delete CDF entries were created - if ( - result is not None - and self.change_type == StagingChangeTypeEnum.DELETE - and result.rowcount > 0 - ): - self._validate_delete_cdf() - except Exception as e: self.context.log.error( - f"Failed to update ApprovalRequest for {self.country_code} - {formatted_dataset}: {e}" + f"Failed to update ApprovalRequest for " + f"{self.country_code} - {formatted_dataset}: {e}" ) def _get_current_approval_request(self, db, formatted_dataset: str): - """Get current ApprovalRequest from database.""" return db.scalar( select(ApprovalRequest).where( (ApprovalRequest.country == self.country_code) @@ -251,84 +405,94 @@ def _get_current_approval_request(self, db, formatted_dataset: str): ) ) - def _should_update_enabled( - self, current_request, pre_update_row_count: int | None, formatted_dataset: str - ) -> bool: - """Check if enabled flag should be updated.""" - if current_request.enabled: - self.context.log.info( - f"ApprovalRequest already enabled for {self.country_code} - {formatted_dataset}. Skipping update." - ) - return False - - if ( - pre_update_row_count is not None - and pre_update_row_count == 0 - and self.change_type - in [StagingChangeTypeEnum.UPDATE, StagingChangeTypeEnum.DELETE] - ): - change_type_label = ( - "changes" - if self.change_type == StagingChangeTypeEnum.UPDATE - else "rows to delete" - ) - self.context.log.info( - f"No {change_type_label} detected (row count: {pre_update_row_count}). Skipping enabled=True update." - ) - return False + def _get_uploaded_columns(self) -> list[str]: + """Return the list of schema column names present in the upload file. - return True + Expands the raw upload columns to include columns that are derived + from uploaded columns and computed in bronze (admin, connectivity_type, + connectivity_govt_ingestion_timestamp). Without this expansion the + staging step would copy the old silver value for these derived columns + instead of using the freshly-computed bronze values. - def _validate_delete_cdf(self) -> None: - """Validate that delete CDF entries were created after delete operation.""" + Falls back to all schema column names if no FileUpload record is found, + which preserves backward-compatible behaviour for non-geolocation pipelines. + """ try: - with get_trino_context() as trino_db: - # Query CDF for delete entries - # Table name is constructed from trusted config sources, not user input - result = trino_db.execute( - text( - f"SELECT COUNT(*) as count FROM {self.staging_table_name} " # nosec B608 - f"FOR VERSION AS OF (SELECT MAX(version) FROM " - f'{self.staging_table_name}."$history") ' - f"WHERE _change_type = 'delete'" + with get_db_context() as db: + file_upload = db.scalar( + select(FileUpload).where( + FileUpload.id == self.config.filename_components.id ) ) - delete_count = result.scalar() - - if delete_count == 0: - self.context.log.error( - f"Delete CDF validation failed: No delete entries found in CDF " - f"for {self.country_code}. This indicates the delete operation " - f"may not have been successful despite enabled=True being set." - ) - # Rollback enabled flag - formatted_dataset = ( - f"School {self.config.dataset_type.capitalize()}" - ) - with get_db_context() as db: - with db.begin(): - db.execute( - update(ApprovalRequest) - .where( - (ApprovalRequest.country == self.country_code) - & (ApprovalRequest.dataset == formatted_dataset) - ) - .values({ApprovalRequest.enabled: False}) - ) - raise RuntimeError("Delete CDF empty after delete operation") - else: - self.context.log.info( - f"Delete CDF validation passed: Found {delete_count} delete entries." + if file_upload is None: + raise FileNotFoundError( + f"FileUpload with id `{self.config.filename_components.id}` not found" ) + file_upload = FileUploadConfig.from_orm(file_upload) + columns = set(file_upload.column_to_schema_mapping.values()) + + # Admin columns are derived from lat/lon in geolocation_bronze. + # Treat them as uploaded so staging uses the freshly-computed values. + if {"latitude", "longitude"}.issubset(columns): + columns.update( + { + "admin1", + "admin1_id_giga", + "admin2", + "admin2_id_giga", + "admin3", + "admin3_id_giga", + "admin4", + "admin4_id_giga", + "disputed_region", + } + ) + + # connectivity_type/root are derived from connectivity_type_govt. + if "connectivity_type_govt" in columns: + columns.update({"connectivity_type", "connectivity_type_root"}) + + # connectivity_govt_ingestion_timestamp is derived from connectivity_govt. + if "connectivity_govt" in columns: + columns.add("connectivity_govt_ingestion_timestamp") + + return list(columns) except Exception as e: - self.context.log.error( - f"Failed to validate delete CDF: {e}. Manual verification recommended." + self.context.log.warning( + f"Could not retrieve uploaded_columns from FileUpload: {e}. " + "Falling back to treating all schema columns as uploaded." ) - raise + return [c.name for c in self.schema_columns] + + def _prepare_df(self, df: sql.DataFrame) -> sql.DataFrame: + """Add missing columns and cast types — does NOT compute row hash.""" + df = add_missing_columns(df, self.schema_columns) + df = transform_types(df, self.schema_name, self.context) + # Fill nulls in NOT NULL STRING schema columns with "Unknown" so that + # Delta NOT NULL constraints are never violated on write. + unknown_fills = { + col.name: f.coalesce(f.col(col.name), f.lit("Unknown")) + for col in self.schema_columns + if not col.nullable + and isinstance(col.dataType, StringType) + and col.name in df.columns + } + if unknown_fills: + df = df.withColumns(unknown_fills) + return df + + def _select_schema_cols(self, df: sql.DataFrame) -> sql.DataFrame: + """Select only schema columns from df, dropping any extra bronze columns.""" + schema_col_names = [c.name for c in self.schema_columns] + available = [c for c in schema_col_names if c in df.columns] + return df.select(*available) + + def standard_transforms(self, df: sql.DataFrame) -> sql.DataFrame: + """Backward-compatible wrapper used by other pipelines.""" + df = self._prepare_df(df) + return compute_row_hash(df) def _emit_lineage(self) -> None: - """Emit lineage information.""" - # Get files for review in a separate DB context with get_db_context() as db: files_for_review = db.scalars( select(FileUpload).where( @@ -336,7 +500,7 @@ def _emit_lineage(self) -> None: & (FileUpload.dataset == self.config.dataset_type) ) ) - upstream_filepaths = [f.upload_path for f in files_for_review] + upstream_filepaths = [fu.upload_path for fu in files_for_review] emit_lineage_base( upstream_datasets=upstream_filepaths, @@ -346,185 +510,17 @@ def _emit_lineage(self) -> None: @property def silver_table_exists(self) -> bool: - # Metastore entry must be present AND ADLS path must be a valid Delta Table return check_table_exists( self.spark, self.schema_name, self.country_code, DataTier.SILVER ) @property - def staging_table_exists(self) -> bool: - # Metastore entry must be present AND ADLS path must be a valid Delta Table + def pending_changes_table_exists(self) -> bool: return check_table_exists( self.spark, self.schema_name, self.country_code, DataTier.STAGING ) - def create_staging_table_from_silver(self): - """Create staging table as a complete clone of the silver table. - - This creates a working copy containing all rows from silver. Subsequent changes - will be applied to this staging table via upserts/deletes. The staging table - represents what the silver table will look like after approval. - """ - self.context.log.info("Creating staging from silver if not exists...") - silver = ( - DeltaTable.forName(self.spark, self.silver_table_name) - .alias("silver") - .toDF() - ) - create_schema(self.spark, self.staging_tier_schema_name) - create_delta_table( - self.spark, - self.staging_tier_schema_name, - self.country_code, - self.schema_columns, - self.context, - if_not_exists=True, - ) - silver.write.format("delta").mode("append").saveAsTable(self.staging_table_name) - - def create_empty_staging_table(self): - self.context.log.info("Creating empty staging table...") - create_schema(self.spark, self.staging_tier_schema_name) - create_delta_table( - self.spark, - self.staging_tier_schema_name, - self.country_code, - self.schema_columns, - self.context, - if_not_exists=True, - ) - - def sync_schema_staging(self): - """Update the schema of existing delta tables based on the reference schema delta tables. - - Supports adding, renaming, and deleting columns. Renames and deletes - are detected by comparing stable column UUIDs stored in the table - properties against the latest reference schema CSV. - """ - from src.utils.delta import apply_renames_and_deletes, persist_column_id_map - - self.context.log.info("Checking for schema update...") - updated_schema = StructType(self.schema_columns) - updated_columns = sorted(updated_schema.fieldNames()) - - existing_df = DeltaTable.forName(self.spark, self.staging_table_name).toDF() - existing_columns = sorted(existing_df.schema.fieldNames()) - - any_renames_deletes = apply_renames_and_deletes( - self.spark, self.staging_table_name, self.schema_name, self.context - ) - - # Refresh schemas after rename/delete - if any_renames_deletes: - existing_df = DeltaTable.forName(self.spark, self.staging_table_name).toDF() - existing_columns = sorted(existing_df.schema.fieldNames()) - - # Sync changes in nullability flags - alter_sql = f"ALTER TABLE {self.staging_table_name}" - alter_stmts = [] - for column in existing_df.schema: - if ( - match_ := next( - (c for c in updated_schema if c.name == column.name), None - ) - ) is not None: - if match_.nullable != column.nullable: - if match_.nullable: - alter_stmts.append(f"ALTER COLUMN {column.name} DROP NOT NULL") - else: - alter_stmts.append(f"ALTER COLUMN {column.name} SET NOT NULL") - - has_nullability_changed = len(alter_stmts) > 0 - has_schema_changed = updated_columns != existing_columns - - # Sync changes in columns & data types - if has_schema_changed: - self.context.log.info("Updating schema...") - updated_schema_df = self.spark.createDataFrame([], schema=updated_schema) - ( - updated_schema_df.write.option("mergeSchema", "true") - .format("delta") - .mode("append") - .saveAsTable(self.staging_table_name) - ) - - if has_nullability_changed: - alter_sql = [f"{alter_sql} {alter_stmt}" for alter_stmt in alter_stmts] - for stmnt in alter_sql: - self.spark.sql(stmnt).show() - - if has_schema_changed or has_nullability_changed or any_renames_deletes: - self.reload_schema() - - # Persist column-ID mapping - persist_column_id_map(self.spark, self.staging_table_name, self.schema_name) - - def reload_schema(self): - self.schema_columns = get_schema_columns(self.spark, self.schema_name) - - def standard_transforms(self, df: sql.DataFrame): - self.context.log.info("Performing standard transforms...") - df = add_missing_columns(df, self.schema_columns) - df = transform_types(df, self.schema_name, self.context) - return compute_row_hash(df) - - def upsert_rows(self, df: sql.DataFrame): - self.context.log.info("Performing upsert...") - staging_dt = DeltaTable.forName(self.spark, self.staging_table_name) - update_columns = [ - c.name for c in self.schema_columns if c.name != self.primary_key - ] - df = partial_in_cluster_merge( - staging_dt.toDF(), - df, - self.primary_key, - column_names=[c.name for c in self.schema_columns], - ) - query = build_deduped_merge_query( - staging_dt, - df, - self.primary_key, - update_columns, - ) - - if query is not None: - execute_query_with_error_handler( - self.spark, - query, - self.staging_tier_schema_name, - self.country_code, - self.context, - ) - return staging_dt.toDF() - - def delete_rows(self, df: list[str]): - self.context.log.info("Performing delete...") - staging_dt = DeltaTable.forName(self.spark, self.staging_table_name) - - # Pre-delete validation: Check if target rows exist in staging table - staging_df = staging_dt.toDF() - existing_rows = staging_df.filter(staging_df[self.primary_key].isin(df)) - existing_count = existing_rows.count() - - if existing_count == 0: - self.context.log.warning( - f"No target rows found in staging table for deletion. " - f"Skipping delete operation for {len(df)} row IDs." - ) - return staging_dt.toDF() - - self.context.log.info( - f"Found {existing_count} out of {len(df)} row IDs in staging table for deletion." - ) - - query = build_deduped_delete_query(staging_dt, df, self.primary_key) - - if query is not None: - execute_query_with_error_handler( - self.spark, - query, - self.staging_tier_schema_name, - self.country_code, - self.context, - ) - return staging_dt.toDF() + # Keep old property name as alias for callers that used staging_table_exists + @property + def staging_table_exists(self) -> bool: + return self.pending_changes_table_exists diff --git a/dagster/src/sensors/school_geolocation.py b/dagster/src/sensors/school_geolocation.py index c19c3b13c..7d4992921 100644 --- a/dagster/src/sensors/school_geolocation.py +++ b/dagster/src/sensors/school_geolocation.py @@ -180,7 +180,7 @@ def school_master_geolocation__post_manual_checks_sensor( continue else: country_code = filename_components.country_code - metadata = adls_file_client.fetch_metadata_for_blob(adls_filepath) + metadata = adls_file_client.fetch_metadata_for_blob(adls_filepath) or {} ops_destination_mapping = { "manual_review_passed_rows": OpDestinationMapping( @@ -276,7 +276,7 @@ def school_master_geolocation__admin_delete_rows_sensor( continue else: country_code = filename_components.country_code - metadata = adls_file_client.fetch_metadata_for_blob(adls_filepath) + metadata = adls_file_client.fetch_metadata_for_blob(adls_filepath) or {} ops_destination_mapping = { "geolocation_delete_staging": OpDestinationMapping( diff --git a/dagster/src/spark/transform_functions.py b/dagster/src/spark/transform_functions.py index 5aa6e091c..c21750439 100644 --- a/dagster/src/spark/transform_functions.py +++ b/dagster/src/spark/transform_functions.py @@ -24,7 +24,7 @@ from dagster import OpExecutionContext from src.constants import UploadMode from src.internal.connectivity_queries import get_qos_tables -from src.settings import settings +from src.settings import DeploymentEnvironment, settings from src.spark.udf_dependencies import get_point from src.utils.adls import get_blob_service_client from src.utils.logger import get_context_with_fallback_logger @@ -50,9 +50,9 @@ def generate_uuid(identifier_concat: str) -> str: def create_school_id_giga(df: sql.DataFrame) -> sql.DataFrame: # Create school_id_giga column if not exists, otherwise use provided values - df = df.withColumn( - "school_id_giga", f.coalesce(f.col("school_id_giga"), f.lit(None)) - ) + # df = df.withColumn( + # "school_id_giga", f.coalesce(f.col("school_id_giga"), f.lit(None)) + # ) school_id_giga_prereqs = [ "school_id_govt", @@ -76,6 +76,11 @@ def create_school_id_giga(df: sql.DataFrame) -> sql.DataFrame: ), ) + # If school_id_giga was not part of the upload, add it as null so the coalesce + # below can reference it safely regardless of whether the column already exists. + if "school_id_giga" not in df.columns: + df = df.withColumn("school_id_giga", f.lit(None).cast(StringType())) + # Use existing school_id_giga if provided, otherwise use system-generated value df = df.withColumn( "school_id_giga", @@ -482,6 +487,62 @@ def create_bronze_layer_columns( return df +def create_bronze_layer_columns_updated( + df: sql.DataFrame, + mode: str, + uploaded_columns: list[str], + country_code_iso3: str, + spark: SparkSession = None, +): + # standardize education level + if mode == UploadMode.CREATE.value or "education_level_govt" in uploaded_columns: + df = create_education_level(df, mode, uploaded_columns) + + # Generate school_id_giga for new schools using the dedicated function + if mode == UploadMode.CREATE.value: + df = create_school_id_giga(df) + + # Admin columns: re-compute whenever lat/lon are part of the upload + if "latitude" in uploaded_columns and "longitude" in uploaded_columns: + for admin_level in ("admin1", "admin2", "admin3", "admin4"): + df = add_admin_columns(df, country_code_iso3, admin_level) + df = add_disputed_region_column(df) + + missing_location_condition = ( + f.col("latitude").isNull() + | f.isnan(f.col("latitude")) + | f.col("longitude").isNull() + | f.isnan(f.col("longitude")) + ) + for column in ("admin1", "admin1_id_giga", "admin2", "admin2_id_giga"): + df = df.withColumn( + column, + f.when( + missing_location_condition, f.lit(None).cast(StringType()) + ).otherwise(f.col(column)), + ) + + # Connectivity type: re-compute whenever connectivity_type_govt is part of the upload + if mode == UploadMode.CREATE.value or "connectivity_type_govt" in uploaded_columns: + df = standardize_connectivity_type(df, mode, uploaded_columns) + + # Connectivity govt ingestion timestamp: set when connectivity_govt is uploaded. + if "connectivity_govt" in df.columns: + df = df.withColumn( + "connectivity_govt_ingestion_timestamp", + f.when( + f.col("connectivity_govt").isNotNull(), f.current_timestamp() + ).otherwise(f.lit(None).cast(TimestampType())), + ) + + # RT connectivity columns: merge from the realtime schools table + if settings.DEPLOY_ENV != DeploymentEnvironment.LOCAL: + connectivity = get_country_rt_schools(spark, country_code_iso3) + df = merge_connectivity_to_master(df, connectivity, uploaded_columns, mode) + + return df + + def get_admin_boundaries( country_code_iso3: str, admin_level: str, @@ -560,14 +621,17 @@ def get_admin_id_giga(latitude, longitude) -> str | None: ), } ) + coalesce_args = [ + f.col(f"{admin_level}_en"), + f.col(f"{admin_level}_native"), + ] + if admin_level in df.columns: + coalesce_args.append(f.col(admin_level)) + coalesce_args.append(f.lit("Unknown")) + return df.withColumn( admin_level, - f.coalesce( - f.col(f"{admin_level}_en"), - f.col(f"{admin_level}_native"), - f.col(admin_level), - f.lit("Unknown"), - ), + f.coalesce(*coalesce_args), ).drop(f"{admin_level}_en", f"{admin_level}_native") @@ -775,13 +839,16 @@ def merge_connectivity_to_master( "connectivity_RT", f.coalesce(f.col("connectivity_RT"), f.lit("No")) ) - # make sure connectivity_govt is standardized - master = master.withColumn( - "connectivity_govt", - f.when( - f.isnan(f.col("connectivity_govt")), f.lit(None).cast(StringType()) - ).otherwise(f.initcap(f.trim(f.col("connectivity_govt")))), - ) + # standardize connectivity_govt only when it was uploaded; for CREATE mode ensure it exists + if "connectivity_govt" in uploaded_columns: + master = master.withColumn( + "connectivity_govt", + f.when( + f.isnan(f.col("connectivity_govt")), f.lit(None).cast(StringType()) + ).otherwise(f.initcap(f.trim(f.col("connectivity_govt")))), + ) + elif mode == UploadMode.CREATE.value and "connectivity_govt" not in master.columns: + master = master.withColumn("connectivity_govt", f.lit(None).cast(StringType())) # determine the value of connectivity if mode == UploadMode.CREATE.value or { @@ -795,20 +862,11 @@ def merge_connectivity_to_master( "connectivity", f.when( (f.lower(f.col("connectivity_RT")) == "yes") - | ( - (f.lower(f.col("connectivity_govt")) == "yes") - & ( - (f.col("download_speed_govt") != 0) - | f.col("download_speed_govt").isNull() - | f.isnan(f.col("download_speed_govt")) - ) - ) - | (f.col("download_speed_govt") > 0), + | (f.lower(f.col("connectivity_govt")) == "yes"), "Yes", ) .when( - (f.lower("connectivity_govt") == "no") - | (f.col("download_speed_govt") == 0), + f.lower(f.col("connectivity_govt")) == "no", "No", ) .otherwise( @@ -835,13 +893,14 @@ def merge_connectivity_to_master( .otherwise(f.lit(None).cast(StringType())), ) - # add the time connectivity_govt was ingested - master = master.withColumn( - "connectivity_govt_ingestion_timestamp", - f.when(f.col("connectivity_govt").isNotNull(), f.current_timestamp()).otherwise( - f.lit(None).cast(TimestampType()) - ), - ) + # add the time connectivity_govt was ingested (only when it was part of the upload) + if "connectivity_govt" in uploaded_columns or mode == UploadMode.CREATE.value: + master = master.withColumn( + "connectivity_govt_ingestion_timestamp", + f.when( + f.col("connectivity_govt").isNotNull(), f.current_timestamp() + ).otherwise(f.lit(None).cast(TimestampType())), + ) master_cols_to_drop = [ col diff --git a/dagster/src/spark/user_defined_functions.py b/dagster/src/spark/user_defined_functions.py index 81c62bcb1..699c88d95 100644 --- a/dagster/src/spark/user_defined_functions.py +++ b/dagster/src/spark/user_defined_functions.py @@ -23,7 +23,7 @@ def get_decimal_places(value) -> int | None: return None try: decimal_places = -Decimal(str(value)).as_tuple().exponent - except TypeError: + except (TypeError, InvalidOperation): return None return int(decimal_places < precision) diff --git a/dagster/src/utils/delta.py b/dagster/src/utils/delta.py index f0a58986d..50d28ec33 100644 --- a/dagster/src/utils/delta.py +++ b/dagster/src/utils/delta.py @@ -51,6 +51,7 @@ def create_delta_table( *, if_not_exists: bool = False, replace: bool = False, + partition_by: list[str] | None = None, ) -> None: if if_not_exists and replace: raise MutexException( @@ -71,6 +72,8 @@ def create_delta_table( .addColumns(columns) .property("delta.enableChangeDataFeed", "true") ) + if partition_by: + query = query.partitionedBy(*partition_by) execute_query_with_error_handler(spark, query, schema_name, table_name, context) From 81d581c44fd9078845ea88e2e6df484417e03e3e Mon Sep 17 00:00:00 2001 From: Brian Musisi Date: Wed, 1 Apr 2026 15:05:51 +0200 Subject: [PATCH 06/26] fix: improve processing time for large files (#445) --- .../src/assets/school_geolocation/assets.py | 164 ++++++++---------- dagster/src/resources/__init__.py | 8 + .../src/resources/io_managers/adls_spark.py | 73 ++++++++ .../io_managers/adls_spark_single_file.py | 59 +++++++ dagster/src/resources/io_managers/base.py | 2 + dagster/src/sensors/school_geolocation.py | 28 ++- dagster/src/utils/op_config.py | 11 ++ 7 files changed, 236 insertions(+), 109 deletions(-) create mode 100644 dagster/src/resources/io_managers/adls_spark.py create mode 100644 dagster/src/resources/io_managers/adls_spark_single_file.py diff --git a/dagster/src/assets/school_geolocation/assets.py b/dagster/src/assets/school_geolocation/assets.py index fe1532659..a5368818d 100644 --- a/dagster/src/assets/school_geolocation/assets.py +++ b/dagster/src/assets/school_geolocation/assets.py @@ -11,7 +11,7 @@ SparkSession, functions as f, ) -from pyspark.sql.types import StringType, StructType +from pyspark.sql.types import StringType, StructField, StructType from sqlalchemy import select from src.constants import DataTier from src.data_quality_checks.utils import ( @@ -52,7 +52,14 @@ from src.utils.send_email_dq_report import send_email_dq_report_with_config from src.utils.sentry import capture_op_exceptions -from dagster import MetadataValue, OpExecutionContext, Output, asset +from dagster import ( + AssetOut, + MetadataValue, + OpExecutionContext, + Output, + asset, + multi_asset, +) @asset(io_manager_key=ResourceKey.ADLS_PASSTHROUGH_IO_MANAGER.value) @@ -159,14 +166,14 @@ def geolocation_metadata( return Output(None) -@asset(io_manager_key=ResourceKey.ADLS_PANDAS_IO_MANAGER.value) +@asset(io_manager_key=ResourceKey.ADLS_SPARK_IO_MANAGER.value) @capture_op_exceptions def geolocation_bronze( context: OpExecutionContext, geolocation_raw: bytes, config: FileConfig, spark: PySparkResource, -) -> Output[pd.DataFrame]: +) -> Output[sql.DataFrame]: s: SparkSession = spark.spark_session country_code = config.country_code mode = config.metadata["mode"] @@ -227,31 +234,28 @@ def geolocation_bronze( if column in df.columns: df = df.withColumn(column, f.initcap(f.col(column))) - ## at this point it's already gone - context.log.info("BEFORE DF TO PANDAS") - df_pandas = df.toPandas() - context.log.info("AFTER DF TO PANDAS") - context.log.info(df_pandas) + df.cache() + row_count = df.count() return Output( - df_pandas, + df, metadata={ **get_output_metadata(config), - "row_count": len(df_pandas), + "row_count": row_count, "column_mapping": column_mapping_filtered, - "preview": get_table_preview(df_pandas), + "preview": get_table_preview(df), }, ) -@asset(io_manager_key=ResourceKey.ADLS_PANDAS_IO_MANAGER.value) +@asset(io_manager_key=ResourceKey.ADLS_SPARK_IO_MANAGER.value) @capture_op_exceptions def geolocation_data_quality_results( context: OpExecutionContext, config: FileConfig, geolocation_bronze: sql.DataFrame, spark: PySparkResource, -) -> Output[pd.DataFrame]: +) -> Output[sql.DataFrame]: s: SparkSession = spark.spark_session country_code = config.country_code schema_name = config.metastore_schema @@ -315,9 +319,10 @@ def geolocation_data_quality_results( dq_results_schema_name = f"{schema_name}_dq_results" table_name = f"{id}_{country_code}_{current_timestamp}" - schema_columns = dq_results.schema.fields - for col in schema_columns: - col.nullable = True + schema_columns = [ + StructField(field.name, field.dataType, nullable=True) + for field in dq_results.schema.fields + ] dq_results_table_name = construct_full_table_name( dq_results_schema_name, @@ -333,10 +338,9 @@ def geolocation_data_quality_results( context, if_not_exists=True, ) + dq_results.cache() dq_results.write.format("delta").mode("append").saveAsTable(dq_results_table_name) - dq_pandas = dq_results.toPandas() - datahub_emit_metadata_with_exception_catcher( context=context, config=config, @@ -344,22 +348,31 @@ def geolocation_data_quality_results( ) return Output( - dq_pandas, + dq_results.coalesce(1), metadata={ **get_output_metadata(config), - "row_count": len(dq_pandas), - "preview": get_table_preview(dq_pandas), + "row_count": dq_results.count(), + "preview": get_table_preview(dq_results), }, ) -@asset(io_manager_key=ResourceKey.ADLS_PANDAS_IO_MANAGER.value) +@multi_asset( + outs={ + "geolocation_dq_schools_passed_human_readable": AssetOut( + io_manager_key=ResourceKey.ADLS_SPARK_SINGLE_FILE_IO_MANAGER.value, + ), + "geolocation_dq_schools_failed_human_readable": AssetOut( + io_manager_key=ResourceKey.ADLS_SPARK_SINGLE_FILE_IO_MANAGER.value, + ), + }, +) @capture_op_exceptions -async def geolocation_data_quality_results_human_readable( +def geolocation_data_quality_results_human_readable( context: OpExecutionContext, geolocation_data_quality_results: sql.DataFrame, config: FileConfig, -) -> Output[pd.DataFrame]: +): context.log.info("Get the file upload object from the database") with get_db_context() as db: file_upload = db.scalar( @@ -381,9 +394,6 @@ async def geolocation_data_quality_results_human_readable( df, human_readable_mappings = dq_geolocation_extract_relevant_columns( geolocation_data_quality_results, uploaded_columns, mode ) - # human_readable_mappings keys are map keys (without "dq_" prefix). - # Expand each relevant map entry into its own column with Yes/No values, - # then drop the map so the output is flat like before. for map_key, human_name in human_readable_mappings.items(): df = df.withColumn( human_name, @@ -393,64 +403,32 @@ async def geolocation_data_quality_results_human_readable( ) df = df.drop("dq_results") - context.log.info("Convert the dataframe to a pandas object to save it locally") - df_pandas = df.toPandas() + # Cache once — both filters read from the same plan + df.cache() - return Output( - df_pandas, - metadata={ - **get_output_metadata(config), - "row_count": len(df_pandas), - "preview": get_table_preview(df_pandas), - }, + df_passed = df.filter(df.dq_has_critical_error == 0).drop( + "dq_has_critical_error", "failure_reason" ) + df_failed = df.filter(df.dq_has_critical_error == 1).drop("dq_has_critical_error") + output_metadata = get_output_metadata(config) -@asset(io_manager_key=ResourceKey.ADLS_PANDAS_IO_MANAGER.value) -@capture_op_exceptions -async def geolocation_dq_schools_passed_human_readable( - context: OpExecutionContext, - geolocation_data_quality_results_human_readable: sql.DataFrame, - config: FileConfig, -) -> Output[pd.DataFrame]: - context.log.info("Filter and keep schools that do not have a critical error") - df = geolocation_data_quality_results_human_readable.filter( - geolocation_data_quality_results_human_readable.dq_has_critical_error == 0 - ) - df = df.drop("dq_has_critical_error", "failure_reason") - df_pandas = df.toPandas() - - return Output( - df_pandas, + yield Output( + df_passed, + output_name="geolocation_dq_schools_passed_human_readable", metadata={ - **get_output_metadata(config), - "row_count": len(df_pandas), - "preview": get_table_preview(df_pandas), + **output_metadata, + "row_count": df_passed.count(), + "preview": get_table_preview(df_passed), }, ) - - -@asset(io_manager_key=ResourceKey.ADLS_PANDAS_IO_MANAGER.value) -@capture_op_exceptions -async def geolocation_dq_schools_failed_human_readable( - context: OpExecutionContext, - geolocation_data_quality_results_human_readable: sql.DataFrame, - config: FileConfig, -) -> Output[pd.DataFrame]: - context.log.info("Filter and keep schools that have a critical error") - df = geolocation_data_quality_results_human_readable.filter( - geolocation_data_quality_results_human_readable.dq_has_critical_error == 1 - ) - - df = df.drop("dq_has_critical_error") - df_pandas = df.toPandas() - - return Output( - df_pandas, + yield Output( + df_failed, + output_name="geolocation_dq_schools_failed_human_readable", metadata={ - **get_output_metadata(config), - "row_count": len(df_pandas), - "preview": get_table_preview(df_pandas), + **output_metadata, + "row_count": df_failed.count(), + "preview": get_table_preview(df_failed), }, ) @@ -559,14 +537,14 @@ def geolocation_data_quality_report( return Output(dq_report) -@asset(io_manager_key=ResourceKey.ADLS_PANDAS_IO_MANAGER.value) +@asset(io_manager_key=ResourceKey.ADLS_SPARK_IO_MANAGER.value) @capture_op_exceptions def geolocation_dq_passed_rows( context: OpExecutionContext, geolocation_data_quality_results: sql.DataFrame, config: FileConfig, spark: PySparkResource, -) -> Output[pd.DataFrame]: +) -> Output[sql.DataFrame]: df_passed = dq_split_passed_rows( geolocation_data_quality_results, config.dataset_type, @@ -583,25 +561,27 @@ def geolocation_dq_passed_rows( schema_reference=schema_reference, ) - df_pandas = df_passed.toPandas() + df_passed.cache() + row_count = df_passed.count() + return Output( - df_pandas, + df_passed, metadata={ **get_output_metadata(config), - "row_count": len(df_pandas), - "preview": get_table_preview(df_pandas), + "row_count": row_count, + "preview": get_table_preview(df_passed), }, ) -@asset(io_manager_key=ResourceKey.ADLS_PANDAS_IO_MANAGER.value) +@asset(io_manager_key=ResourceKey.ADLS_SPARK_IO_MANAGER.value) @capture_op_exceptions def geolocation_dq_failed_rows( context: OpExecutionContext, geolocation_data_quality_results: sql.DataFrame, config: FileConfig, spark: PySparkResource, -) -> Output[pd.DataFrame]: +) -> Output[sql.DataFrame]: df_failed = dq_split_failed_rows( geolocation_data_quality_results, config.dataset_type, @@ -619,13 +599,15 @@ def geolocation_dq_failed_rows( df_failed=df_failed, ) - df_pandas = df_failed.toPandas() + df_failed.cache() + row_count = df_failed.count() + return Output( - df_pandas, + df_failed, metadata={ **get_output_metadata(config), - "row_count": len(df_pandas), - "preview": get_table_preview(df_pandas), + "row_count": row_count, + "preview": get_table_preview(df_failed), }, ) @@ -639,7 +621,7 @@ def geolocation_staging( spark: PySparkResource, config: FileConfig, ) -> Output[None]: - if geolocation_dq_passed_rows.count() == 0: + if geolocation_dq_passed_rows.isEmpty(): context.log.warning("Skipping staging as there are no rows passing DQ checks") return Output(None) diff --git a/dagster/src/resources/__init__.py b/dagster/src/resources/__init__.py index b75687146..34204c5e0 100644 --- a/dagster/src/resources/__init__.py +++ b/dagster/src/resources/__init__.py @@ -8,6 +8,8 @@ from .io_managers.adls_json import ADLSJSONIOManager from .io_managers.adls_pandas import ADLSPandasIOManager from .io_managers.adls_passthrough import ADLSPassthroughIOManager +from .io_managers.adls_spark import ADLSSparkIOManager +from .io_managers.adls_spark_single_file import ADLSSparkSingleFileIOManager class ResourceKey(Enum): @@ -16,6 +18,8 @@ class ResourceKey(Enum): ADLS_JSON_IO_MANAGER = "adls_json_io_manager" ADLS_PANDAS_IO_MANAGER = "adls_pandas_io_manager" ADLS_PASSTHROUGH_IO_MANAGER = "adls_passthrough_io_manager" + ADLS_SPARK_IO_MANAGER = "adls_spark_io_manager" + ADLS_SPARK_SINGLE_FILE_IO_MANAGER = "adls_spark_single_file_io_manager" ADLS_FILE_CLIENT = "adls_file_client" SPARK = "spark" @@ -26,6 +30,10 @@ class ResourceKey(Enum): ResourceKey.ADLS_JSON_IO_MANAGER.value: ADLSJSONIOManager(), ResourceKey.ADLS_PANDAS_IO_MANAGER.value: ADLSPandasIOManager(pyspark=pyspark), ResourceKey.ADLS_PASSTHROUGH_IO_MANAGER.value: ADLSPassthroughIOManager(), + ResourceKey.ADLS_SPARK_IO_MANAGER.value: ADLSSparkIOManager(pyspark=pyspark), + ResourceKey.ADLS_SPARK_SINGLE_FILE_IO_MANAGER.value: ADLSSparkSingleFileIOManager( + pyspark=pyspark + ), ResourceKey.ADLS_FILE_CLIENT.value: ADLSFileClient(), ResourceKey.SPARK.value: pyspark, } diff --git a/dagster/src/resources/io_managers/adls_spark.py b/dagster/src/resources/io_managers/adls_spark.py new file mode 100644 index 000000000..dc727b352 --- /dev/null +++ b/dagster/src/resources/io_managers/adls_spark.py @@ -0,0 +1,73 @@ +import pandas as pd +from dagster_pyspark import PySparkResource +from pyspark import sql +from pyspark.sql import SparkSession +from pyspark.sql.types import NullType, StringType + +from azure.core.exceptions import ResourceNotFoundError +from dagster import InputContext, OutputContext +from src.settings import settings +from src.utils.adls import ADLSFileClient + +from .base import BaseConfigurableIOManager + +adls_client = ADLSFileClient() + + +class ADLSSparkIOManager(BaseConfigurableIOManager): + """Writes Spark DataFrames natively to ADLS (parquet or csv directory). + + Uses a cache → isEmpty → write → unpersist pattern so the plan executes + once and executor memory is freed immediately after the write. + """ + + pyspark: PySparkResource + + def handle_output(self, context: OutputContext, output: sql.DataFrame): + path = self._get_filepath(context) + adls_path = f"{settings.AZURE_BLOB_CONNECTION_URI}/{path}" + + # Cast NullType columns to StringType — schema-only, no action triggered yet + for field in output.schema.fields: + if isinstance(field.dataType, NullType): + output = output.withColumn( + field.name, output[field.name].cast(StringType()) + ) + + # cache → isEmpty (populates cache) → write (reads from cache) → unpersist + output.cache() + output.isEmpty() + + match path.suffix: + case ".parquet": + output.write.mode("overwrite").parquet(adls_path) + case ".csv": + output.write.mode("overwrite").csv(adls_path, header=True) + case _: + raise OSError(f"Unsupported format for Spark write: {path.suffix}") + + output.unpersist() + context.log.info(f"Uploaded {path.name} to {path.parent} in ADLS.") + + def load_input(self, context: InputContext) -> sql.DataFrame: + spark: SparkSession = self.pyspark.spark_session + path = self._get_filepath(context) + adls_path = f"{settings.AZURE_BLOB_CONNECTION_URI}/{path}" + + try: + match path.suffix: + case ".parquet": + data = spark.read.parquet(adls_path) + case ".csv": + data = adls_client.download_csv_as_spark_dataframe(str(path), spark) + case ".xls" | ".xlsx": + # No native Spark Excel support — bridge via pandas + pdf = pd.read_excel(adls_path) + data = spark.createDataFrame(pdf.astype(str)) + case _: + raise OSError(f"Unsupported format for Spark read: {path.suffix}") + except ResourceNotFoundError as e: + raise e + + context.log.info(f"Downloaded {path.name} from {path.parent} in ADLS.") + return data diff --git a/dagster/src/resources/io_managers/adls_spark_single_file.py b/dagster/src/resources/io_managers/adls_spark_single_file.py new file mode 100644 index 000000000..8ce82d555 --- /dev/null +++ b/dagster/src/resources/io_managers/adls_spark_single_file.py @@ -0,0 +1,59 @@ +import pandas as pd +from dagster_pyspark import PySparkResource +from pyspark import sql +from pyspark.sql import SparkSession + +from azure.core.exceptions import ResourceNotFoundError +from dagster import InputContext, OutputContext +from src.settings import settings +from src.utils.adls import ADLSFileClient + +from .base import BaseConfigurableIOManager + +adls_client = ADLSFileClient() + + +class ADLSSparkSingleFileIOManager(BaseConfigurableIOManager): + """Writes Spark DataFrames as a single file to ADLS via a pandas bridge. + + Use this for human-readable output assets (CSV exports for end users) where + a single flat file is required rather than a partitioned Spark output directory. + The asset returns a sql.DataFrame; conversion to pandas happens here, keeping + asset code Spark-native. + + load_input returns a sql.DataFrame so downstream assets remain Spark-native. + """ + + pyspark: PySparkResource + + def handle_output(self, context: OutputContext, output: sql.DataFrame): + path = self._get_filepath(context) + pdf = output.toPandas() + adls_client.upload_pandas_dataframe_as_file( + context=context, + data=pdf, + filepath=str(path), + ) + context.log.info(f"Uploaded {path.name} to {path.parent} in ADLS.") + + def load_input(self, context: InputContext) -> sql.DataFrame: + spark: SparkSession = self.pyspark.spark_session + path = self._get_filepath(context) + adls_path = f"{settings.AZURE_BLOB_CONNECTION_URI}/{path}" + + try: + match path.suffix: + case ".parquet": + data = spark.read.parquet(adls_path) + case ".csv": + data = adls_client.download_csv_as_spark_dataframe(str(path), spark) + case ".xls" | ".xlsx": + pdf = pd.read_excel(adls_path) + data = spark.createDataFrame(pdf.astype(str)) + case _: + raise OSError(f"Unsupported format for Spark read: {path.suffix}") + except ResourceNotFoundError as e: + raise e + + context.log.info(f"Downloaded {path.name} from {path.parent} in ADLS.") + return data diff --git a/dagster/src/resources/io_managers/base.py b/dagster/src/resources/io_managers/base.py index f402bc05b..d36b82f8f 100644 --- a/dagster/src/resources/io_managers/base.py +++ b/dagster/src/resources/io_managers/base.py @@ -37,6 +37,8 @@ def _get_filepath(context: InputContext | OutputContext) -> Path: return config.destination_filepath_object config = FileConfig(**context.step_context.op_config) + if config.output_filepaths and context.name in config.output_filepaths: + return Path(config.output_filepaths[context.name]) return config.destination_filepath_object @staticmethod diff --git a/dagster/src/sensors/school_geolocation.py b/dagster/src/sensors/school_geolocation.py index 7d4992921..51dfe784e 100644 --- a/dagster/src/sensors/school_geolocation.py +++ b/dagster/src/sensors/school_geolocation.py @@ -80,19 +80,11 @@ def school_master_geolocation__raw_file_uploads_sensor( ), "geolocation_data_quality_results_human_readable": OpDestinationMapping( source_filepath=f"{constants.dq_results_folder}/{DOMAIN_DATASET_TYPE}/dq-overall/{country_code}/{stem}.parquet", - destination_filepath=f"{constants.dq_results_folder}/{DOMAIN_DATASET_TYPE}/dq-human-readable/{country_code}/{stem}.parquet", - metastore_schema=METASTORE_SCHEMA, - tier=DataTier.DATA_QUALITY_CHECKS, - ), - "geolocation_dq_schools_passed_human_readable": OpDestinationMapping( - source_filepath=f"{constants.dq_results_folder}/{DOMAIN_DATASET_TYPE}/dq-human-readable/{country_code}/{stem}.parquet", - destination_filepath=f"{constants.dq_results_folder}/{DOMAIN_DATASET_TYPE}/dq-passed-rows-human-readable/{country_code}/{stem}.csv", - metastore_schema=METASTORE_SCHEMA, - tier=DataTier.DATA_QUALITY_CHECKS, - ), - "geolocation_dq_schools_failed_human_readable": OpDestinationMapping( - source_filepath=f"{constants.dq_results_folder}/{DOMAIN_DATASET_TYPE}/dq-human-readable/{country_code}/{stem}.parquet", - destination_filepath=f"{constants.dq_results_folder}/{DOMAIN_DATASET_TYPE}/dq-failed-rows-human-readable/{country_code}/{stem}.csv", + destination_filepath="", + output_filepaths={ + "geolocation_dq_schools_passed_human_readable": f"{constants.dq_results_folder}/{DOMAIN_DATASET_TYPE}/dq-passed-rows-human-readable/{country_code}/{stem}.csv", + "geolocation_dq_schools_failed_human_readable": f"{constants.dq_results_folder}/{DOMAIN_DATASET_TYPE}/dq-failed-rows-human-readable/{country_code}/{stem}.csv", + }, metastore_schema=METASTORE_SCHEMA, tier=DataTier.DATA_QUALITY_CHECKS, ), @@ -109,19 +101,19 @@ def school_master_geolocation__raw_file_uploads_sensor( tier=DataTier.DATA_QUALITY_CHECKS, ), "geolocation_dq_passed_rows": OpDestinationMapping( - source_filepath=f"{constants.dq_results_folder}/{DOMAIN_DATASET_TYPE}/dq-overall/{country_code}/{stem}.csv", - destination_filepath=f"{constants.dq_results_folder}/{DOMAIN_DATASET_TYPE}/dq-passed-rows/{country_code}/{stem}.csv", + source_filepath=f"{constants.dq_results_folder}/{DOMAIN_DATASET_TYPE}/dq-overall/{country_code}/{stem}.parquet", + destination_filepath=f"{constants.dq_results_folder}/{DOMAIN_DATASET_TYPE}/dq-passed-rows/{country_code}/{stem}.parquet", metastore_schema=METASTORE_SCHEMA, tier=DataTier.DATA_QUALITY_CHECKS, ), "geolocation_dq_failed_rows": OpDestinationMapping( - source_filepath=f"{constants.dq_results_folder}/{DOMAIN_DATASET_TYPE}/dq-overall/{country_code}/{stem}.csv", - destination_filepath=f"{constants.dq_results_folder}/{DOMAIN_DATASET_TYPE}/dq-failed-rows/{country_code}/{stem}.csv", + source_filepath=f"{constants.dq_results_folder}/{DOMAIN_DATASET_TYPE}/dq-overall/{country_code}/{stem}.parquet", + destination_filepath=f"{constants.dq_results_folder}/{DOMAIN_DATASET_TYPE}/dq-failed-rows/{country_code}/{stem}.parquet", metastore_schema=METASTORE_SCHEMA, tier=DataTier.DATA_QUALITY_CHECKS, ), "geolocation_staging": OpDestinationMapping( - source_filepath=f"{constants.dq_results_folder}/{DOMAIN_DATASET_TYPE}/dq-passed-rows/{country_code}/{stem}.csv", + source_filepath=f"{constants.dq_results_folder}/{DOMAIN_DATASET_TYPE}/dq-passed-rows/{country_code}/{stem}.parquet", destination_filepath=f"{settings.SPARK_WAREHOUSE_PATH}/school_geolocation_staging.db/{country_code.lower()}", metastore_schema=METASTORE_SCHEMA, tier=DataTier.STAGING, diff --git a/dagster/src/utils/op_config.py b/dagster/src/utils/op_config.py index 7ed3dce04..dc287d4a4 100644 --- a/dagster/src/utils/op_config.py +++ b/dagster/src/utils/op_config.py @@ -51,6 +51,15 @@ class FileConfig(Config): For regular assets, simply pass in the destination path as a string. """, ) + output_filepaths: dict[str, str] = Field( + default_factory=dict, + description=""" + Per-output destination paths for multi-output assets (@multi_asset). + Keys are output names; values are ADLS-relative destination paths. + When non-empty, the IO manager uses the output name to select the path + instead of destination_filepath. + """, + ) dq_target_filepath: str = Field( description=""" The path of the file inside the ADLS container where we run data quality checks on. @@ -137,6 +146,7 @@ class OpDestinationMapping(BaseModel): metastore_schema: str tier: DataTier table_name: Optional[str] = None + output_filepaths: dict[str, str] = Field(default_factory=dict) def generate_run_ops( @@ -155,6 +165,7 @@ def generate_run_ops( file_config = FileConfig( filepath=op_mapping.source_filepath, destination_filepath=op_mapping.destination_filepath, + output_filepaths=op_mapping.output_filepaths, metastore_schema=op_mapping.metastore_schema, tier=op_mapping.tier, table_name=op_mapping.table_name, From 7869c37a7c2927b60ff332b8761003f1434416da Mon Sep 17 00:00:00 2001 From: Bidhan Mondal Date: Mon, 23 Mar 2026 16:22:18 +0530 Subject: [PATCH 07/26] feat: Deletion and renaming of columns with upgrade of delta sharing server --- dagster/src/assets/adhoc/master_csv_to_gold.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/dagster/src/assets/adhoc/master_csv_to_gold.py b/dagster/src/assets/adhoc/master_csv_to_gold.py index 5f07d03d8..da5aeef25 100644 --- a/dagster/src/assets/adhoc/master_csv_to_gold.py +++ b/dagster/src/assets/adhoc/master_csv_to_gold.py @@ -14,6 +14,9 @@ ) from pyspark.sql.types import NullType, StructType from sqlalchemy import select, update + +from azure.core.exceptions import ResourceNotFoundError +from dagster import OpExecutionContext, Output, PythonObjectDagsterType, asset from src.constants import DataTier from src.data_quality_checks.utils import ( aggregate_report_json, @@ -56,9 +59,6 @@ from src.utils.sentry import capture_op_exceptions from src.utils.spark import compute_row_hash, transform_types -from azure.core.exceptions import ResourceNotFoundError -from dagster import OpExecutionContext, Output, PythonObjectDagsterType, asset - @asset(io_manager_key=ResourceKey.ADLS_PASSTHROUGH_IO_MANAGER.value) @capture_op_exceptions From c90589cde0ae8abed298a2d38bc4bfb38bd6713a Mon Sep 17 00:00:00 2001 From: sharky93 Date: Mon, 6 Apr 2026 12:29:43 +0530 Subject: [PATCH 08/26] update the setup to upgrade the Hive metastore --- hive/hms-entrypoint.sh | 4 +++- hive/prod.Dockerfile | 5 +++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/hive/hms-entrypoint.sh b/hive/hms-entrypoint.sh index 0b9bb7c7d..d4f6fe157 100755 --- a/hive/hms-entrypoint.sh +++ b/hive/hms-entrypoint.sh @@ -36,7 +36,9 @@ export METASTORE_PORT=${METASTORE_PORT:-9083} if $HIVE_HOME/bin/schematool --verbose -dbType postgres -validate | grep 'Done with metastore validation' | grep '[SUCCESS]'; then echo 'Database OK' else - $HIVE_HOME/bin/schematool --verbose -dbType postgres -initSchema + # $HIVE_HOME/bin/schematool --verbose -dbType postgres -initSchema + $HIVE_HOME/bin/schematool --verbose -dbType postgres -upgradeSchema || true + $HIVE_HOME/bin/schematool --verbose -dbType postgres -validate fi exec $HIVE_HOME/bin/hive --skiphadoopversion --skiphbasecp --service $SERVICE_NAME diff --git a/hive/prod.Dockerfile b/hive/prod.Dockerfile index bada84afb..9e6ffac16 100644 --- a/hive/prod.Dockerfile +++ b/hive/prod.Dockerfile @@ -1,4 +1,4 @@ -FROM apache/hive:3.1.3 +FROM apache/hive:4.0.0 USER root @@ -10,7 +10,8 @@ WORKDIR /opt/hive/lib RUN wget https://repo1.maven.org/maven2/org/apache/hadoop/hadoop-azure/3.3.4/hadoop-azure-3.3.4.jar \ https://repo1.maven.org/maven2/com/microsoft/azure/azure-storage/8.6.6/azure-storage-8.6.6.jar \ - https://repo1.maven.org/maven2/com/azure/azure-storage-blob/12.24.0/azure-storage-blob-12.24.0.jar + https://repo1.maven.org/maven2/com/azure/azure-storage-blob/12.24.0/azure-storage-blob-12.24.0.jar \ + https://repo1.maven.org/maven2/org/postgresql/postgresql/42.7.3/postgresql-42.7.3.jar COPY ./hms-entrypoint.sh /opt/hive/bin/hms-entrypoint.sh COPY ./metastore-site.template.xml /opt/hive/tpl/metastore-site.template.xml From adebbfec81c06614c490f5f7d01c94ea854ca953 Mon Sep 17 00:00:00 2001 From: Rishabh Raj Date: Mon, 6 Apr 2026 18:08:57 +0530 Subject: [PATCH 09/26] Update metastore-site.template.xml --- hive/metastore-site.template.xml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/hive/metastore-site.template.xml b/hive/metastore-site.template.xml index b6d0054be..24a7fbe50 100644 --- a/hive/metastore-site.template.xml +++ b/hive/metastore-site.template.xml @@ -6,6 +6,10 @@ Thrift URI for the remote metastore. Used by metastore client to connect to remote metastore. + + fs.azure.account.key.{ENV:STORAGE_ACCOUNT_NAME}.dfs.core.windows.net + {ENV:AZURE_STORAGE_ACCOUNT_KEY} + metastore.warehouse.dir {ENV:METASTORE_WAREHOUSE_DIR} From 398a987e35d03dca82173fba409b09a9f7fc2a77 Mon Sep 17 00:00:00 2001 From: Brian Musisi Date: Wed, 8 Apr 2026 09:49:50 +0200 Subject: [PATCH 10/26] feat: migrate from wasbs to abfss and vectorize admin data functions (#446) * fix: improve processing time for large files (#445) * feat: migrate from wasbs to abfss * fix: include wasbs spark config for reading old tables * chore: set up FixedSAS class locally * fix: removing leading ? for token * chore: add timing logs * chore/add logging within bronze asset * fix: vectorize the admin column addition step * fix: change the asset key source for multi assets --- dagster/Dockerfile | 16 ++- dagster/java/FixedSASTokenProvider.java | 34 +++++++ dagster/prod.Dockerfile | 16 ++- .../migrate_hms_table_locations_to_abfss.py | 99 +++++++++++++++++++ dagster/src/jobs/adhoc.py | 6 ++ dagster/src/settings.py | 4 +- dagster/src/spark/transform_functions.py | 79 +++++++-------- dagster/src/utils/sentry.py | 2 +- dagster/src/utils/spark.py | 19 +++- spark/Dockerfile | 18 +++- spark/java/FixedSASTokenProvider.java | 34 +++++++ spark/prod.Dockerfile | 18 +++- 12 files changed, 286 insertions(+), 59 deletions(-) create mode 100644 dagster/java/FixedSASTokenProvider.java create mode 100644 dagster/src/assets/adhoc/migrate_hms_table_locations_to_abfss.py create mode 100644 spark/java/FixedSASTokenProvider.java diff --git a/dagster/Dockerfile b/dagster/Dockerfile index d95a5ab75..ce995776e 100644 --- a/dagster/Dockerfile +++ b/dagster/Dockerfile @@ -23,13 +23,25 @@ RUN apt-get update && \ # Move to Spark JARs directory and download additional dependencies WORKDIR /opt/spark/jars -RUN wget https://repo1.maven.org/maven2/org/apache/hadoop/hadoop-azure/3.3.4/hadoop-azure-3.3.4.jar \ +RUN wget --tries=3 \ + https://repo1.maven.org/maven2/org/apache/hadoop/hadoop-azure/3.3.4/hadoop-azure-3.3.4.jar \ https://repo1.maven.org/maven2/com/microsoft/azure/azure-storage/8.6.6/azure-storage-8.6.6.jar \ https://repo1.maven.org/maven2/com/azure/azure-storage-blob/12.24.0/azure-storage-blob-12.24.0.jar \ https://repo1.maven.org/maven2/org/eclipse/jetty/jetty-util/9.4.51.v20230217/jetty-util-9.4.51.v20230217.jar \ https://repo1.maven.org/maven2/org/eclipse/jetty/jetty-util-ajax/11.0.14/jetty-util-ajax-11.0.14.jar \ https://repo1.maven.org/maven2/io/delta/delta-spark_2.12/3.0.0/delta-spark_2.12-3.0.0.jar \ - https://repo1.maven.org/maven2/io/delta/delta-storage/3.0.0/delta-storage-3.0.0.jar + https://repo1.maven.org/maven2/io/delta/delta-storage/3.0.0/delta-storage-3.0.0.jar \ + && unzip -t hadoop-azure-3.3.4.jar > /dev/null + +# Compile FixedSASTokenProvider (not included in the standard hadoop-azure JAR) +COPY java/FixedSASTokenProvider.java /tmp/sas/FixedSASTokenProvider.java +RUN mkdir -p /tmp/sas/src/org/apache/hadoop/fs/azurebfs/sas /tmp/sas/classes && \ + cp /tmp/sas/FixedSASTokenProvider.java /tmp/sas/src/org/apache/hadoop/fs/azurebfs/sas/ && \ + javac -cp "/opt/spark/jars/*" \ + -d /tmp/sas/classes \ + /tmp/sas/src/org/apache/hadoop/fs/azurebfs/sas/FixedSASTokenProvider.java && \ + jar cf /opt/spark/jars/fixed-sas-token-provider.jar -C /tmp/sas/classes . && \ + rm -rf /tmp/sas # Install poetry and setuptools (required for building some packages) # Set virtualenv to be in-project for easier volume mounting diff --git a/dagster/java/FixedSASTokenProvider.java b/dagster/java/FixedSASTokenProvider.java new file mode 100644 index 000000000..9c4598b11 --- /dev/null +++ b/dagster/java/FixedSASTokenProvider.java @@ -0,0 +1,34 @@ +package org.apache.hadoop.fs.azurebfs.sas; + +import java.io.IOException; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.azurebfs.extensions.SASTokenProvider; + +/** + * A SASTokenProvider implementation that returns a fixed SAS token read from + * the Hadoop configuration key: + * fs.azure.sas.fixed.token. + * + * This class is not included in the standard Apache hadoop-azure JAR and must + * be compiled and packaged separately. See docs/abfss-migration.md. + */ +public class FixedSASTokenProvider implements SASTokenProvider { + + private String sasToken; + + @Override + public void initialize(Configuration conf, String accountName) throws IOException { + sasToken = conf.get("fs.azure.sas.fixed.token." + accountName); + if (sasToken == null || sasToken.isEmpty()) { + throw new IOException( + "No SAS token configured for account: " + accountName + + ". Set fs.azure.sas.fixed.token." + accountName + ); + } + } + + @Override + public String getSASToken(String account, String fileSystem, String path, String operation) { + return sasToken; + } +} diff --git a/dagster/prod.Dockerfile b/dagster/prod.Dockerfile index 4e73e60ac..763aebd8f 100644 --- a/dagster/prod.Dockerfile +++ b/dagster/prod.Dockerfile @@ -23,13 +23,25 @@ RUN apt-get update && \ # Move to Spark JARs directory and download additional dependencies WORKDIR /opt/spark/jars -RUN wget https://repo1.maven.org/maven2/org/apache/hadoop/hadoop-azure/3.3.4/hadoop-azure-3.3.4.jar \ +RUN wget --tries=3 \ + https://repo1.maven.org/maven2/org/apache/hadoop/hadoop-azure/3.3.4/hadoop-azure-3.3.4.jar \ https://repo1.maven.org/maven2/com/microsoft/azure/azure-storage/8.6.6/azure-storage-8.6.6.jar \ https://repo1.maven.org/maven2/com/azure/azure-storage-blob/12.24.0/azure-storage-blob-12.24.0.jar \ https://repo1.maven.org/maven2/org/eclipse/jetty/jetty-util/9.4.51.v20230217/jetty-util-9.4.51.v20230217.jar \ https://repo1.maven.org/maven2/org/eclipse/jetty/jetty-util-ajax/11.0.14/jetty-util-ajax-11.0.14.jar \ https://repo1.maven.org/maven2/io/delta/delta-spark_2.12/3.0.0/delta-spark_2.12-3.0.0.jar \ - https://repo1.maven.org/maven2/io/delta/delta-storage/3.0.0/delta-storage-3.0.0.jar + https://repo1.maven.org/maven2/io/delta/delta-storage/3.0.0/delta-storage-3.0.0.jar \ + && unzip -t hadoop-azure-3.3.4.jar > /dev/null + +# Compile FixedSASTokenProvider (not included in the standard hadoop-azure JAR) +COPY java/FixedSASTokenProvider.java /tmp/sas/FixedSASTokenProvider.java +RUN mkdir -p /tmp/sas/src/org/apache/hadoop/fs/azurebfs/sas /tmp/sas/classes && \ + cp /tmp/sas/FixedSASTokenProvider.java /tmp/sas/src/org/apache/hadoop/fs/azurebfs/sas/ && \ + javac -cp "/opt/spark/jars/*" \ + -d /tmp/sas/classes \ + /tmp/sas/src/org/apache/hadoop/fs/azurebfs/sas/FixedSASTokenProvider.java && \ + jar cf /opt/spark/jars/fixed-sas-token-provider.jar -C /tmp/sas/classes . && \ + rm -rf /tmp/sas FROM base AS deps diff --git a/dagster/src/assets/adhoc/migrate_hms_table_locations_to_abfss.py b/dagster/src/assets/adhoc/migrate_hms_table_locations_to_abfss.py new file mode 100644 index 000000000..286506c93 --- /dev/null +++ b/dagster/src/assets/adhoc/migrate_hms_table_locations_to_abfss.py @@ -0,0 +1,99 @@ +from dagster_pyspark import PySparkResource +from pyspark.sql import SparkSession +from src.settings import settings +from src.utils.sentry import capture_op_exceptions + +from dagster import OpExecutionContext, Output, asset + + +@asset +@capture_op_exceptions +def adhoc__migrate_hms_table_locations_to_abfss( + context: OpExecutionContext, + spark: PySparkResource, +) -> Output[None]: + """ + One-shot migration asset that updates all Hive Metastore table locations + from wasbs:// (Azure Blob Storage) to abfss:// (Azure Data Lake Storage Gen2). + + Run this once after deploying the abfss:// storage migration to update + existing Delta table locations registered in the metastore. Once all + locations are migrated, the WASBS SAS credential in spark.py can be removed. + """ + s: SparkSession = spark.spark_session + + old_prefix = ( + f"wasbs://{settings.AZURE_BLOB_CONTAINER_NAME}@{settings.AZURE_BLOB_SAS_HOST}" + ) + new_prefix = ( + f"abfss://{settings.AZURE_BLOB_CONTAINER_NAME}@{settings.AZURE_DFS_SAS_HOST}" + ) + + context.log.info(f"Migrating HMS locations: {old_prefix} -> {new_prefix}") + + databases = [row.namespace for row in s.sql("SHOW DATABASES").collect()] + context.log.info(f"Found {len(databases)} databases: {databases}") + + migrated = [] + skipped = [] + errors = [] + + for db in databases: + tables = s.sql(f"SHOW TABLES IN `{db}`").collect() + context.log.info(f"Database `{db}`: {len(tables)} tables") + + for table_row in tables: + table_name = table_row.tableName + full_name = f"`{db}`.`{table_name}`" + + try: + detail = s.sql(f"DESCRIBE DETAIL {full_name}").collect() + location = detail[0]["location"] if detail else None + except Exception as e: + context.log.warning(f"Could not DESCRIBE DETAIL {full_name}: {e}") + errors.append({"table": full_name, "error": str(e)}) + continue + + if not location: + context.log.info(f" {full_name}: no location, skipping") + skipped.append(full_name) + continue + + if not location.startswith(old_prefix): + context.log.info( + f" {full_name}: location already uses correct scheme, skipping ({location[:60]}...)" + ) + skipped.append(full_name) + continue + + new_location = location.replace(old_prefix, new_prefix, 1) + context.log.info(f" {full_name}: updating location") + context.log.info(f" old: {location}") + context.log.info(f" new: {new_location}") + + try: + s.sql(f"ALTER TABLE {full_name} SET LOCATION '{new_location}'") + migrated.append( + {"table": full_name, "old": location, "new": new_location} + ) + except Exception as e: + context.log.error(f" {full_name}: ALTER TABLE failed: {e}") + errors.append({"table": full_name, "error": str(e)}) + + context.log.info( + f"Migration complete: {len(migrated)} migrated, {len(skipped)} skipped, {len(errors)} errors" + ) + + if errors: + context.log.error(f"Tables with errors: {[e['table'] for e in errors]}") + + return Output( + None, + metadata={ + "migrated_count": len(migrated), + "skipped_count": len(skipped), + "error_count": len(errors), + "migrated_tables": [m["table"] for m in migrated], + "error_tables": [e["table"] for e in errors], + }, + ) diff --git a/dagster/src/jobs/adhoc.py b/dagster/src/jobs/adhoc.py index 77eabda70..aecf7b7c8 100644 --- a/dagster/src/jobs/adhoc.py +++ b/dagster/src/jobs/adhoc.py @@ -61,3 +61,9 @@ ], tags={"dagster/max_runtime": settings.DEFAULT_MAX_RUNTIME}, ) + +hms_migrate_table_locations_to_abfss_job = define_asset_job( + name="hms_migrate_table_locations_to_abfss_job", + selection="adhoc__migrate_hms_table_locations_to_abfss", + tags={"dagster/max_runtime": settings.DEFAULT_MAX_RUNTIME}, +) diff --git a/dagster/src/settings.py b/dagster/src/settings.py index d9309a678..184210b96 100644 --- a/dagster/src/settings.py +++ b/dagster/src/settings.py @@ -132,9 +132,9 @@ def AZURE_DFS_SAS_HOST(self) -> str: @property def AZURE_BLOB_CONNECTION_URI(self) -> str: if self.USE_AZURITE: - # Use wasb:// (HTTP) with Azurite - wasbs:// (HTTPS) doesn't work + # Use wasb:// (HTTP) with Azurite - abfss:// is not supported by Azurite return f"wasb://{self.AZURE_BLOB_CONTAINER_NAME}@{self.AZURE_BLOB_SAS_HOST}" - return f"wasbs://{self.AZURE_BLOB_CONTAINER_NAME}@{self.AZURE_BLOB_SAS_HOST}" + return f"abfss://{self.AZURE_BLOB_CONTAINER_NAME}@{self.AZURE_DFS_SAS_HOST}" @property def AZURE_STORAGE_CONNECTION_STRING(self) -> str: diff --git a/dagster/src/spark/transform_functions.py b/dagster/src/spark/transform_functions.py index c21750439..8d3ebe410 100644 --- a/dagster/src/spark/transform_functions.py +++ b/dagster/src/spark/transform_functions.py @@ -13,6 +13,7 @@ SparkSession, functions as f, ) +from pyspark.sql.functions import pandas_udf from pyspark.sql.types import ( FloatType, StringType, @@ -562,7 +563,7 @@ def get_admin_boundaries( return None -def add_admin_columns( # noqa: C901 +def add_admin_columns( df: sql.DataFrame, country_code_iso3: str, admin_level: str, @@ -583,56 +584,50 @@ def add_admin_columns( # noqa: C901 spark = df.sparkSession broadcasted_admin_boundaries = spark.sparkContext.broadcast(admin_boundaries) - def get_admin_en(latitude, longitude) -> str | None: - point = get_point(longitude=longitude, latitude=latitude) - for _, row in broadcasted_admin_boundaries.value.iterrows(): - if row.geometry.contains(point): - return row.get("name_en") - return None - - get_admin_en_udf = f.udf(get_admin_en, StringType()) - - def get_admin_native(latitude, longitude) -> str | None: - point = get_point(longitude=longitude, latitude=latitude) - for _, row in broadcasted_admin_boundaries.value.iterrows(): - if row.geometry.contains(point): - return row.get("name") - return None - - get_admin_native_udf = f.udf(get_admin_native, StringType()) + result_schema = StructType( + [ + StructField("name_en", StringType()), + StructField("name_native", StringType()), + StructField("id_giga", StringType()), + ] + ) - def get_admin_id_giga(latitude, longitude) -> str | None: - point = get_point(longitude=longitude, latitude=latitude) - for _, row in broadcasted_admin_boundaries.value.iterrows(): - if row.geometry.contains(point): - return row.get(f"{admin_level}_id_giga") - return None + @pandas_udf(result_schema) + def get_admin_info(latitude: pd.Series, longitude: pd.Series) -> pd.DataFrame: + boundaries = broadcasted_admin_boundaries.value + gdf_points = gpd.GeoDataFrame( + {"orig_index": range(len(latitude))}, + geometry=gpd.points_from_xy(longitude, latitude), + crs="epsg:4326", + ) + joined = gpd.sjoin(gdf_points, boundaries, how="left", predicate="within") + # sjoin may produce duplicates for points on shared boundaries — keep first match + joined = joined[~joined.index.duplicated(keep="first")] + return pd.DataFrame( + { + "name_en": joined.get("name_en"), + "name_native": joined.get("name"), + "id_giga": joined.get(f"{admin_level}_id_giga"), + } + ) - get_admin_id_giga_udf = f.udf(get_admin_id_giga, StringType()) + result_col = f"_admin_info_{admin_level}" + df = df.withColumn(result_col, get_admin_info(df["latitude"], df["longitude"])) - df = df.withColumns( - { - f"{admin_level}_en": get_admin_en_udf(df["latitude"], df["longitude"]), - f"{admin_level}_native": get_admin_native_udf( - df["latitude"], df["longitude"] - ), - f"{admin_level}_id_giga": get_admin_id_giga_udf( - df["latitude"], df["longitude"] - ), - } - ) coalesce_args = [ - f.col(f"{admin_level}_en"), - f.col(f"{admin_level}_native"), + f.col(f"{result_col}.name_en"), + f.col(f"{result_col}.name_native"), ] if admin_level in df.columns: coalesce_args.append(f.col(admin_level)) coalesce_args.append(f.lit("Unknown")) - return df.withColumn( - admin_level, - f.coalesce(*coalesce_args), - ).drop(f"{admin_level}_en", f"{admin_level}_native") + return df.withColumns( + { + admin_level: f.coalesce(*coalesce_args), + f"{admin_level}_id_giga": f.col(f"{result_col}.id_giga"), + } + ).drop(result_col) def add_disputed_region_column(df: sql.DataFrame) -> sql.DataFrame: diff --git a/dagster/src/utils/sentry.py b/dagster/src/utils/sentry.py index 7d5aa85b4..d3402c9f6 100644 --- a/dagster/src/utils/sentry.py +++ b/dagster/src/utils/sentry.py @@ -51,7 +51,7 @@ def log_op_context(context: OpExecutionContext) -> None: "run_id": context.run_id, "run_tags": context.run_tags, "retry_number": context.retry_number, - "asset_key": context.asset_key, + "asset_key": context.asset_keys_for_node, }, ) diff --git a/dagster/src/utils/spark.py b/dagster/src/utils/spark.py index 8c37f7caa..16788ec97 100644 --- a/dagster/src/utils/spark.py +++ b/dagster/src/utils/spark.py @@ -73,10 +73,21 @@ def _get_host_ip() -> str: settings.AZURE_STORAGE_ACCOUNT_NAME ) else: - # SAS token authentication for Azure cloud - spark_common_config[ - f"spark.hadoop.fs.azure.sas.{settings.AZURE_BLOB_CONTAINER_NAME}.{settings.AZURE_STORAGE_ACCOUNT_NAME}.blob.core.windows.net" - ] = settings.AZURE_SAS_TOKEN + spark_common_config.update( + { + # ABFS driver credentials (abfss://) — used for all new reads/writes + f"spark.hadoop.fs.azure.account.auth.type.{settings.AZURE_STORAGE_ACCOUNT_NAME}.dfs.core.windows.net": "SAS", + f"spark.hadoop.fs.azure.sas.token.provider.type.{settings.AZURE_STORAGE_ACCOUNT_NAME}.dfs.core.windows.net": "org.apache.hadoop.fs.azurebfs.sas.FixedSASTokenProvider", + f"spark.hadoop.fs.azure.sas.fixed.token.{settings.AZURE_STORAGE_ACCOUNT_NAME}.dfs.core.windows.net": settings.AZURE_SAS_TOKEN.lstrip( + "?" + ), + # WASBS driver credentials (wasbs://) — retained so that existing Delta + # tables whose locations were registered in the Hive Metastore as + # wasbs:// paths remain accessible. To be removed once all table locations + # in the metastore have been migrated to abfss://. + f"spark.hadoop.fs.azure.sas.{settings.AZURE_BLOB_CONTAINER_NAME}.{settings.AZURE_STORAGE_ACCOUNT_NAME}.blob.core.windows.net": settings.AZURE_SAS_TOKEN, + } + ) if settings.IN_PRODUCTION: spark_common_config.update( diff --git a/spark/Dockerfile b/spark/Dockerfile index f2f7865d0..0cb84dc1c 100644 --- a/spark/Dockerfile +++ b/spark/Dockerfile @@ -17,7 +17,7 @@ USER root WORKDIR /tmp RUN apt-get update && \ - apt-get install -y curl wget gdal-bin libgdal-dev libgeos-dev g++ && \ + apt-get install -y curl wget gdal-bin libgdal-dev libgeos-dev g++ unzip default-jdk && \ apt-get clean COPY --from=deps /tmp/requirements.txt /tmp/requirements.txt @@ -26,13 +26,25 @@ RUN pip install --default-timeout=1000 --no-cache-dir -r requirements.txt WORKDIR /opt/bitnami/spark/jars -RUN wget https://repo1.maven.org/maven2/org/apache/hadoop/hadoop-azure/3.3.4/hadoop-azure-3.3.4.jar \ +RUN wget --tries=3 \ + https://repo1.maven.org/maven2/org/apache/hadoop/hadoop-azure/3.3.4/hadoop-azure-3.3.4.jar \ https://repo1.maven.org/maven2/com/microsoft/azure/azure-storage/8.6.6/azure-storage-8.6.6.jar \ https://repo1.maven.org/maven2/com/azure/azure-storage-blob/12.24.0/azure-storage-blob-12.24.0.jar \ https://repo1.maven.org/maven2/org/eclipse/jetty/jetty-util/9.4.51.v20230217/jetty-util-9.4.51.v20230217.jar \ https://repo1.maven.org/maven2/org/eclipse/jetty/jetty-util-ajax/11.0.14/jetty-util-ajax-11.0.14.jar \ https://repo1.maven.org/maven2/io/delta/delta-spark_2.12/3.0.0/delta-spark_2.12-3.0.0.jar \ - https://repo1.maven.org/maven2/io/delta/delta-storage/3.0.0/delta-storage-3.0.0.jar + https://repo1.maven.org/maven2/io/delta/delta-storage/3.0.0/delta-storage-3.0.0.jar \ + && unzip -t hadoop-azure-3.3.4.jar > /dev/null + +# Compile FixedSASTokenProvider (not included in the standard hadoop-azure JAR) +COPY spark/java/FixedSASTokenProvider.java /tmp/sas/FixedSASTokenProvider.java +RUN mkdir -p /tmp/sas/src/org/apache/hadoop/fs/azurebfs/sas /tmp/sas/classes && \ + cp /tmp/sas/FixedSASTokenProvider.java /tmp/sas/src/org/apache/hadoop/fs/azurebfs/sas/ && \ + javac -cp "/opt/bitnami/spark/jars/*" \ + -d /tmp/sas/classes \ + /tmp/sas/src/org/apache/hadoop/fs/azurebfs/sas/FixedSASTokenProvider.java && \ + jar cf /opt/bitnami/spark/jars/fixed-sas-token-provider.jar -C /tmp/sas/classes . && \ + rm -rf /tmp/sas USER 1001 diff --git a/spark/java/FixedSASTokenProvider.java b/spark/java/FixedSASTokenProvider.java new file mode 100644 index 000000000..9c4598b11 --- /dev/null +++ b/spark/java/FixedSASTokenProvider.java @@ -0,0 +1,34 @@ +package org.apache.hadoop.fs.azurebfs.sas; + +import java.io.IOException; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.azurebfs.extensions.SASTokenProvider; + +/** + * A SASTokenProvider implementation that returns a fixed SAS token read from + * the Hadoop configuration key: + * fs.azure.sas.fixed.token. + * + * This class is not included in the standard Apache hadoop-azure JAR and must + * be compiled and packaged separately. See docs/abfss-migration.md. + */ +public class FixedSASTokenProvider implements SASTokenProvider { + + private String sasToken; + + @Override + public void initialize(Configuration conf, String accountName) throws IOException { + sasToken = conf.get("fs.azure.sas.fixed.token." + accountName); + if (sasToken == null || sasToken.isEmpty()) { + throw new IOException( + "No SAS token configured for account: " + accountName + + ". Set fs.azure.sas.fixed.token." + accountName + ); + } + } + + @Override + public String getSASToken(String account, String fileSystem, String path, String operation) { + return sasToken; + } +} diff --git a/spark/prod.Dockerfile b/spark/prod.Dockerfile index 36eade700..4bc8cecf3 100644 --- a/spark/prod.Dockerfile +++ b/spark/prod.Dockerfile @@ -18,7 +18,7 @@ USER root WORKDIR /tmp RUN apt-get update && \ - apt-get install -y curl wget gdal-bin libgdal-dev libgeos-dev g++ && \ + apt-get install -y curl wget gdal-bin libgdal-dev libgeos-dev g++ unzip default-jdk && \ apt-get clean COPY --from=deps /tmp/requirements.txt /tmp/requirements.txt @@ -27,13 +27,25 @@ RUN pip install --no-cache-dir -r requirements.txt WORKDIR /opt/bitnami/spark/jars -RUN wget https://repo1.maven.org/maven2/org/apache/hadoop/hadoop-azure/3.3.4/hadoop-azure-3.3.4.jar \ +RUN wget --tries=3 \ + https://repo1.maven.org/maven2/org/apache/hadoop/hadoop-azure/3.3.4/hadoop-azure-3.3.4.jar \ https://repo1.maven.org/maven2/com/microsoft/azure/azure-storage/8.6.6/azure-storage-8.6.6.jar \ https://repo1.maven.org/maven2/com/azure/azure-storage-blob/12.24.0/azure-storage-blob-12.24.0.jar \ https://repo1.maven.org/maven2/org/eclipse/jetty/jetty-util/9.4.51.v20230217/jetty-util-9.4.51.v20230217.jar \ https://repo1.maven.org/maven2/org/eclipse/jetty/jetty-util-ajax/11.0.14/jetty-util-ajax-11.0.14.jar \ https://repo1.maven.org/maven2/io/delta/delta-spark_2.12/3.0.0/delta-spark_2.12-3.0.0.jar \ - https://repo1.maven.org/maven2/io/delta/delta-storage/3.0.0/delta-storage-3.0.0.jar + https://repo1.maven.org/maven2/io/delta/delta-storage/3.0.0/delta-storage-3.0.0.jar \ + && unzip -t hadoop-azure-3.3.4.jar > /dev/null + +# Compile FixedSASTokenProvider (not included in the standard hadoop-azure JAR) +COPY spark/java/FixedSASTokenProvider.java /tmp/sas/FixedSASTokenProvider.java +RUN mkdir -p /tmp/sas/src/org/apache/hadoop/fs/azurebfs/sas /tmp/sas/classes && \ + cp /tmp/sas/FixedSASTokenProvider.java /tmp/sas/src/org/apache/hadoop/fs/azurebfs/sas/ && \ + javac -cp "/opt/bitnami/spark/jars/*" \ + -d /tmp/sas/classes \ + /tmp/sas/src/org/apache/hadoop/fs/azurebfs/sas/FixedSASTokenProvider.java && \ + jar cf /opt/bitnami/spark/jars/fixed-sas-token-provider.jar -C /tmp/sas/classes . && \ + rm -rf /tmp/sas USER 1001 From 7e2abbe2e2e0df237aed88521dc3403d6883d504 Mon Sep 17 00:00:00 2001 From: Brian Musisi Date: Wed, 8 Apr 2026 13:44:08 +0200 Subject: [PATCH 11/26] fix: update asset key check for multi asets (#447) --- dagster/src/utils/sentry.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dagster/src/utils/sentry.py b/dagster/src/utils/sentry.py index d3402c9f6..c88380dd7 100644 --- a/dagster/src/utils/sentry.py +++ b/dagster/src/utils/sentry.py @@ -51,7 +51,7 @@ def log_op_context(context: OpExecutionContext) -> None: "run_id": context.run_id, "run_tags": context.run_tags, "retry_number": context.retry_number, - "asset_key": context.asset_keys_for_node, + "asset_key": getattr(context, "asset_keys", None), }, ) From a9da8cbaaa584952addc8b08e56e7d218114c062 Mon Sep 17 00:00:00 2001 From: Brian Musisi Date: Wed, 8 Apr 2026 17:47:57 +0300 Subject: [PATCH 12/26] fix: make updates to geolocation_staging idempotent --- dagster/src/internal/common_assets/staging.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/dagster/src/internal/common_assets/staging.py b/dagster/src/internal/common_assets/staging.py index 130b76df3..d1cfcad57 100644 --- a/dagster/src/internal/common_assets/staging.py +++ b/dagster/src/internal/common_assets/staging.py @@ -311,6 +311,14 @@ def _write_pending_records(self, pending: sql.DataFrame) -> None: replace=True, partition_by=["upload_id"], ) + else: + upload_id = self.config.filename_components.id + DeltaTable.forName(self.spark, self.staging_table_name).delete( + f.col("upload_id") == upload_id + ) + self.context.log.info( + f"Deleted existing pending_changes rows for upload_id={upload_id}" + ) # Cast pending columns to the expected schema types before writing schema_type_map = {field.name: field.dataType for field in pending_schema} From fca9e95628255c08807439b1c0b1497df48e91de Mon Sep 17 00:00:00 2001 From: Brian Musisi Date: Wed, 8 Apr 2026 19:02:37 +0300 Subject: [PATCH 13/26] feat: add SAS token support for hive with abfss --- hive/metastore-site.template.xml | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/hive/metastore-site.template.xml b/hive/metastore-site.template.xml index 24a7fbe50..477c3b592 100644 --- a/hive/metastore-site.template.xml +++ b/hive/metastore-site.template.xml @@ -7,8 +7,16 @@ - fs.azure.account.key.{ENV:STORAGE_ACCOUNT_NAME}.dfs.core.windows.net - {ENV:AZURE_STORAGE_ACCOUNT_KEY} + fs.azure.account.auth.type.{ENV:STORAGE_ACCOUNT_NAME}.dfs.core.windows.net + SAS + + + fs.azure.sas.token.provider.type.{ENV:STORAGE_ACCOUNT_NAME}.dfs.core.windows.net + org.apache.hadoop.fs.azurebfs.sas.FixedSASTokenProvider + + + fs.azure.sas.fixed.token.{ENV:STORAGE_ACCOUNT_NAME}.dfs.core.windows.net + {ENV:AZURE_SAS_TOKEN} metastore.warehouse.dir From 6e110b732b4845b591d1ec3f4f44868a3dac2ba9 Mon Sep 17 00:00:00 2001 From: Brian Musisi Date: Thu, 9 Apr 2026 15:14:35 +0300 Subject: [PATCH 14/26] fix: add FixedSASTokenProvider JAR to hive --- hive/FixedSASTokenProvider.java | 34 +++++++++++++++++++++++++++++++++ hive/prod.Dockerfile | 12 +++++++++++- 2 files changed, 45 insertions(+), 1 deletion(-) create mode 100644 hive/FixedSASTokenProvider.java diff --git a/hive/FixedSASTokenProvider.java b/hive/FixedSASTokenProvider.java new file mode 100644 index 000000000..9c4598b11 --- /dev/null +++ b/hive/FixedSASTokenProvider.java @@ -0,0 +1,34 @@ +package org.apache.hadoop.fs.azurebfs.sas; + +import java.io.IOException; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.azurebfs.extensions.SASTokenProvider; + +/** + * A SASTokenProvider implementation that returns a fixed SAS token read from + * the Hadoop configuration key: + * fs.azure.sas.fixed.token. + * + * This class is not included in the standard Apache hadoop-azure JAR and must + * be compiled and packaged separately. See docs/abfss-migration.md. + */ +public class FixedSASTokenProvider implements SASTokenProvider { + + private String sasToken; + + @Override + public void initialize(Configuration conf, String accountName) throws IOException { + sasToken = conf.get("fs.azure.sas.fixed.token." + accountName); + if (sasToken == null || sasToken.isEmpty()) { + throw new IOException( + "No SAS token configured for account: " + accountName + + ". Set fs.azure.sas.fixed.token." + accountName + ); + } + } + + @Override + public String getSASToken(String account, String fileSystem, String path, String operation) { + return sasToken; + } +} diff --git a/hive/prod.Dockerfile b/hive/prod.Dockerfile index 9e6ffac16..c450c7d5e 100644 --- a/hive/prod.Dockerfile +++ b/hive/prod.Dockerfile @@ -3,7 +3,7 @@ FROM apache/hive:4.0.0 USER root RUN apt-get update && \ - apt-get install -y curl wget && \ + apt-get install -y curl wget default-jdk-headless && \ apt-get clean WORKDIR /opt/hive/lib @@ -13,6 +13,16 @@ RUN wget https://repo1.maven.org/maven2/org/apache/hadoop/hadoop-azure/3.3.4/had https://repo1.maven.org/maven2/com/azure/azure-storage-blob/12.24.0/azure-storage-blob-12.24.0.jar \ https://repo1.maven.org/maven2/org/postgresql/postgresql/42.7.3/postgresql-42.7.3.jar +# Compile FixedSASTokenProvider (not included in the standard hadoop-azure JAR) +COPY FixedSASTokenProvider.java /tmp/sas/FixedSASTokenProvider.java +RUN mkdir -p /tmp/sas/src/org/apache/hadoop/fs/azurebfs/sas /tmp/sas/classes && \ + cp /tmp/sas/FixedSASTokenProvider.java /tmp/sas/src/org/apache/hadoop/fs/azurebfs/sas/ && \ + javac -cp "/opt/hive/lib/*:/opt/hadoop/share/hadoop/common/*:/opt/hadoop/share/hadoop/common/lib/*" \ + -d /tmp/sas/classes \ + /tmp/sas/src/org/apache/hadoop/fs/azurebfs/sas/FixedSASTokenProvider.java && \ + jar cf /opt/hive/lib/fixed-sas-token-provider.jar -C /tmp/sas/classes . && \ + rm -rf /tmp/sas + COPY ./hms-entrypoint.sh /opt/hive/bin/hms-entrypoint.sh COPY ./metastore-site.template.xml /opt/hive/tpl/metastore-site.template.xml COPY ./hive-site.template.xml /opt/hive/tpl/hive-site.template.xml From 00ee468f4a0a976f3d9193d760ba978031e40717 Mon Sep 17 00:00:00 2001 From: Brian Musisi Date: Thu, 9 Apr 2026 16:24:34 +0300 Subject: [PATCH 15/26] fix: use correct version for jar build --- hive/prod.Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hive/prod.Dockerfile b/hive/prod.Dockerfile index c450c7d5e..63c9ae5aa 100644 --- a/hive/prod.Dockerfile +++ b/hive/prod.Dockerfile @@ -17,7 +17,7 @@ RUN wget https://repo1.maven.org/maven2/org/apache/hadoop/hadoop-azure/3.3.4/had COPY FixedSASTokenProvider.java /tmp/sas/FixedSASTokenProvider.java RUN mkdir -p /tmp/sas/src/org/apache/hadoop/fs/azurebfs/sas /tmp/sas/classes && \ cp /tmp/sas/FixedSASTokenProvider.java /tmp/sas/src/org/apache/hadoop/fs/azurebfs/sas/ && \ - javac -cp "/opt/hive/lib/*:/opt/hadoop/share/hadoop/common/*:/opt/hadoop/share/hadoop/common/lib/*" \ + javac -source 8 -target 8 -cp "/opt/hive/lib/*:/opt/hadoop/share/hadoop/common/*:/opt/hadoop/share/hadoop/common/lib/*" \ -d /tmp/sas/classes \ /tmp/sas/src/org/apache/hadoop/fs/azurebfs/sas/FixedSASTokenProvider.java && \ jar cf /opt/hive/lib/fixed-sas-token-provider.jar -C /tmp/sas/classes . && \ From 9d6def7e5b4d2e2787f058ccc540ba31daf25244 Mon Sep 17 00:00:00 2001 From: sharky93 Date: Mon, 13 Apr 2026 16:16:14 +0530 Subject: [PATCH 16/26] update timeout for creating the physical table --- dagster/src/resources/superset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dagster/src/resources/superset.py b/dagster/src/resources/superset.py index 6107c5c93..34688645b 100644 --- a/dagster/src/resources/superset.py +++ b/dagster/src/resources/superset.py @@ -75,7 +75,7 @@ def fetch_saved_query(): def run_query(query, access_token): - timeout = 600 + timeout = 1800 max_retries = 3 headers = { "Authorization": f"Bearer {access_token}", From e731264279246adb4cec78b5be9d1d676acaf344 Mon Sep 17 00:00:00 2001 From: Bidhan Mondal Date: Tue, 14 Apr 2026 11:52:12 +0530 Subject: [PATCH 17/26] fix: merge main branch --- dagster/src/assets/adhoc/master_csv_to_gold.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/dagster/src/assets/adhoc/master_csv_to_gold.py b/dagster/src/assets/adhoc/master_csv_to_gold.py index da5aeef25..5f07d03d8 100644 --- a/dagster/src/assets/adhoc/master_csv_to_gold.py +++ b/dagster/src/assets/adhoc/master_csv_to_gold.py @@ -14,9 +14,6 @@ ) from pyspark.sql.types import NullType, StructType from sqlalchemy import select, update - -from azure.core.exceptions import ResourceNotFoundError -from dagster import OpExecutionContext, Output, PythonObjectDagsterType, asset from src.constants import DataTier from src.data_quality_checks.utils import ( aggregate_report_json, @@ -59,6 +56,9 @@ from src.utils.sentry import capture_op_exceptions from src.utils.spark import compute_row_hash, transform_types +from azure.core.exceptions import ResourceNotFoundError +from dagster import OpExecutionContext, Output, PythonObjectDagsterType, asset + @asset(io_manager_key=ResourceKey.ADLS_PASSTHROUGH_IO_MANAGER.value) @capture_op_exceptions From 14331d8e58618cf2d6568314728a47f366c3ae72 Mon Sep 17 00:00:00 2001 From: Bidhan Mondal Date: Wed, 15 Apr 2026 13:24:19 +0530 Subject: [PATCH 18/26] fix: schema evolution changes --- dagster/src/assets/migrations/core.py | 24 +++++-- dagster/src/internal/common_assets/staging.py | 20 +++++- .../src/resources/io_managers/adls_delta.py | 59 +++++++--------- dagster/src/spark/config_expectations.py | 2 + dagster/src/spark/transform_functions.py | 32 ++++++--- dagster/src/utils/schema.py | 67 +++++++++++++++++-- 6 files changed, 146 insertions(+), 58 deletions(-) diff --git a/dagster/src/assets/migrations/core.py b/dagster/src/assets/migrations/core.py index 06c18f153..732426040 100644 --- a/dagster/src/assets/migrations/core.py +++ b/dagster/src/assets/migrations/core.py @@ -4,7 +4,10 @@ from models import VALID_PRIMITIVES, Schema from pyspark import sql from pyspark.sql.functions import col, when -from src.utils.delta import execute_query_with_error_handler +from src.utils.delta import ( + create_delta_table, + sync_schema, +) from dagster import OpExecutionContext @@ -42,16 +45,27 @@ def save_schema_delta_table(context: OpExecutionContext, df: sql.DataFrame): full_table_name = f"{schema_name}.{table_name}" columns = Schema.fields - query = ( - DeltaTable.createOrReplace(spark).tableName(full_table_name).addColumns(columns) + create_delta_table( + spark, + schema_name, + table_name, + columns, + context, + if_not_exists=True, + ) + sync_schema( + table_name=full_table_name, + existing_schema=spark.table(full_table_name).schema, + updated_schema=Schema.schema, + spark=spark, + context=context, ) - execute_query_with_error_handler(spark, query, schema_name, table_name, context) spark.catalog.refreshTable(full_table_name) ( DeltaTable.forName(spark, full_table_name) .alias("master") - .merge(df.alias("updates"), "master.name = updates.name") + .merge(df.alias("updates"), "master.id = updates.id") .whenMatchedUpdateAll() .whenNotMatchedInsertAll() .execute() diff --git a/dagster/src/internal/common_assets/staging.py b/dagster/src/internal/common_assets/staging.py index d1cfcad57..5cd00c443 100644 --- a/dagster/src/internal/common_assets/staging.py +++ b/dagster/src/internal/common_assets/staging.py @@ -8,7 +8,13 @@ SparkSession, functions as f, ) -from pyspark.sql.types import ArrayType, StringType, StructField, TimestampType +from pyspark.sql.types import ( + ArrayType, + StringType, + StructField, + StructType, + TimestampType, +) from sqlalchemy import select, update from dagster import OpExecutionContext @@ -22,6 +28,7 @@ check_table_exists, create_delta_table, create_schema, + sync_schema, ) from src.utils.op_config import FileConfig from src.utils.schema import ( @@ -312,6 +319,17 @@ def _write_pending_records(self, pending: sql.DataFrame) -> None: partition_by=["upload_id"], ) else: + # Synchronise staging schema (handles renames/deletions) + existing_schema = self.spark.table(self.staging_table_name).schema + sync_schema( + table_name=self.staging_table_name, + existing_schema=existing_schema, + updated_schema=StructType(pending_schema), + spark=self.spark, + context=self.context, + schema_name=self.schema_name, + ) + upload_id = self.config.filename_components.id DeltaTable.forName(self.spark, self.staging_table_name).delete( f.col("upload_id") == upload_id diff --git a/dagster/src/resources/io_managers/adls_delta.py b/dagster/src/resources/io_managers/adls_delta.py index 35134d11b..0cb5fb829 100644 --- a/dagster/src/resources/io_managers/adls_delta.py +++ b/dagster/src/resources/io_managers/adls_delta.py @@ -12,7 +12,7 @@ from src.settings import settings from src.spark.transform_functions import add_missing_columns from src.utils.adls import ADLSFileClient -from src.utils.delta import build_deduped_merge_query, execute_query_with_error_handler +from src.utils.delta import build_deduped_merge_query from src.utils.op_config import FileConfig from src.utils.schema import ( construct_full_table_name, @@ -157,10 +157,23 @@ def _create_table_if_not_exists( "delta.logRetentionDuration", constants.school_master_retention_period ) + from src.utils.delta import ( + execute_query_with_error_handler, + persist_column_id_map, + ) + execute_query_with_error_handler( spark, query, schema_name_for_tier, table_name, context ) + if columns_schema_name not in [ + "qos", + "qos_raw", + "qos_availability", + "custom_dataset", + ]: + persist_column_id_map(spark, full_table_name, columns_schema_name) + def _upsert_data( self, data: sql.DataFrame, @@ -191,47 +204,23 @@ def _upsert_data( columns = incoming_schema.fields primary_key = "gigasync_id" else: - from src.utils.delta import apply_renames_and_deletes, persist_column_id_map + from src.utils.delta import sync_schema columns = get_schema_columns(spark, schema_name) primary_key = get_primary_key(spark, schema_name) updated_schema = StructType(columns) - updated_columns = sorted(updated_schema.fieldNames()) - - existing_df = DeltaTable.forName(spark, full_table_name).toDF() - existing_columns = sorted(existing_df.schema.fieldNames()) - - context.log.info(f"incoming schema {data.schema}") - context.log.info(f"existing schema {existing_df.schema}") - - any_renames_deletes = apply_renames_and_deletes( - spark, full_table_name, schema_name, context + existing_schema = DeltaTable.forName(spark, full_table_name).toDF().schema + + sync_schema( + table_name=full_table_name, + existing_schema=existing_schema, + updated_schema=updated_schema, + spark=spark, + context=context, + schema_name=schema_name, ) - # Refresh after rename/delete - if any_renames_deletes: - existing_df = DeltaTable.forName(spark, full_table_name).toDF() - existing_columns = sorted(existing_df.schema.fieldNames()) - - if updated_columns != existing_columns: - context.log.info("Updating schema...") - - empty_data = spark.sparkContext.emptyRDD() - updated_schema_df = spark.createDataFrame( - data=empty_data, schema=updated_schema - ) - - ( - updated_schema_df.write.option("mergeSchema", "true") - .format("delta") - .mode("append") - .saveAsTable(full_table_name) - ) - - # Persist column-ID mapping - persist_column_id_map(spark, full_table_name, schema_name) - update_columns = [c.name for c in columns if c.name != primary_key] master = DeltaTable.forName(spark, full_table_name) query = build_deduped_merge_query( diff --git a/dagster/src/spark/config_expectations.py b/dagster/src/spark/config_expectations.py index 829541924..3c36a7b06 100644 --- a/dagster/src/spark/config_expectations.py +++ b/dagster/src/spark/config_expectations.py @@ -136,6 +136,7 @@ class Config(BaseSettings): ("school_id_giga", "STRING"), ("school_id_govt", "STRING"), ("school_name", "STRING"), + ("official_school_name", "STRING"), ("school_establishment_year", "INT"), ("latitude", "DOUBLE"), ("longitude", "DOUBLE"), @@ -222,6 +223,7 @@ class Config(BaseSettings): "school_id_giga", "school_id_govt", "school_name", + "official_school_name", "longitude", "latitude", "education_level", diff --git a/dagster/src/spark/transform_functions.py b/dagster/src/spark/transform_functions.py index 8d3ebe410..b1ba6836c 100644 --- a/dagster/src/spark/transform_functions.py +++ b/dagster/src/spark/transform_functions.py @@ -55,22 +55,32 @@ def create_school_id_giga(df: sql.DataFrame) -> sql.DataFrame: # "school_id_giga", f.coalesce(f.col("school_id_giga"), f.lit(None)) # ) - school_id_giga_prereqs = [ - "school_id_govt", - "school_name", - "education_level", - "latitude", - "longitude", - ] - for column in school_id_giga_prereqs: - if column not in df.columns: - return df.withColumn("school_id_giga", f.lit(None)) + available_columns = set(df.columns) + if ( + "school_id_govt" not in available_columns + or ( + "school_name" not in available_columns + and "official_school_name" not in available_columns + ) + or "education_level" not in available_columns + or "latitude" not in available_columns + or "longitude" not in available_columns + ): + return df.withColumn("school_id_giga", f.lit(None)) + + # Use official_school_name as fallback for school_name + school_name_col = f.coalesce( + f.col("school_name") if "school_name" in available_columns else f.lit(None), + f.col("official_school_name") + if "official_school_name" in available_columns + else f.lit(None), + ) df = df.withColumn( "identifier_concat", f.concat( f.col("school_id_govt").cast(StringType()), - f.col("school_name").cast(StringType()), + school_name_col.cast(StringType()), f.col("education_level").cast(StringType()), f.col("latitude").cast(StringType()), f.col("longitude").cast(StringType()), diff --git a/dagster/src/utils/schema.py b/dagster/src/utils/schema.py index aaff06033..dd617f794 100644 --- a/dagster/src/utils/schema.py +++ b/dagster/src/utils/schema.py @@ -1,6 +1,7 @@ from delta import DeltaTable from models import Schema from pyspark import sql +from pyspark.errors.exceptions.captured import AnalysisException from pyspark.sql import SparkSession from pyspark.sql.functions import col from pyspark.sql.types import StructField @@ -12,6 +13,20 @@ OutputContext, ) from src.constants import DataTier, constants +from src.spark.config_expectations import config + + +def _get_type_mapping(data_type: str): + """Map a data type string to its corresponding TypeMapping from constants. + + Handles case-insensitivity and common aliases (e.g., 'INT' -> 'integer'). + """ + normalized_type = data_type.lower() + # Handle common aliases from config_expectations + if normalized_type == "int": + normalized_type = "integer" + + return getattr(constants.TYPE_MAPPINGS, normalized_type) def get_schema_name( @@ -22,12 +37,53 @@ def get_schema_name( return context.op_config["metastore_schema"] +def _get_fallback_schema_df(spark: SparkSession, schema_name: str) -> sql.DataFrame: + """Return a fallback schema DataFrame from hardcoded configs if Delta table is missing.""" + if schema_name == "school_geolocation": + columns = config.COLUMNS_EXCEPT_SCHOOL_ID_GEOLOCATION + [ + "school_id_govt", + "school_id_giga", + ] + data_types = dict(config.DATA_TYPES) + + fallback_data = [] + for col_name in columns: + data_type = data_types.get(col_name, "string") + fallback_data.append( + { + "id": col_name, + "name": col_name, + "data_type": data_type, + "is_nullable": True, + "is_important": False, + "is_system_generated": False, + "description": "", + "primary_key": col_name in config.UNIQUE_COLUMNS_GEOLOCATION, + "partition_order": None, + "license": None, + "units": None, + "hint": None, + } + ) + + # Define the schema explicitly to match SchemaModel + schema = Schema.schema + return spark.createDataFrame(fallback_data, schema=schema) + + raise ValueError(f"No fallback schema available for `{schema_name}`") + + def get_schema_table(spark: SparkSession, schema_name: str) -> sql.DataFrame: metaschema_name = Schema.__schema_name__ full_table_name = f"{metaschema_name}.{schema_name}" - # This should be cheap if the migrations.migrate_schema asset is caching the table properly - return DeltaTable.forName(spark, full_table_name).toDF() + try: + # This should be cheap if the migrations.migrate_schema asset is caching the table properly + return DeltaTable.forName(spark, full_table_name).toDF() + except AnalysisException as e: + if "DELTA_TABLE_NOT_FOUND" in str(e): + return _get_fallback_schema_df(spark, schema_name) + raise e def get_schema_columns(spark: SparkSession, schema_name: str) -> list[StructField]: @@ -35,7 +91,7 @@ def get_schema_columns(spark: SparkSession, schema_name: str) -> list[StructFiel return [ StructField( row.name, - getattr(constants.TYPE_MAPPINGS, row.data_type).pyspark(), + _get_type_mapping(row.data_type).pyspark(), row.is_nullable, ) for row in df.collect() @@ -62,7 +118,7 @@ def get_schema_columns_with_id( row.id, StructField( row.name, - getattr(constants.TYPE_MAPPINGS, row.data_type).pyspark(), + _get_type_mapping(row.data_type).pyspark(), row.is_nullable, ), ) @@ -80,8 +136,7 @@ def get_schema_column_descriptions( def get_schema_columns_datahub(spark: SparkSession, schema_name: str) -> list[tuple]: df = get_schema_table(spark, schema_name) return [ - (row.name, getattr(constants.TYPE_MAPPINGS, row.data_type).datahub()) - for row in df.collect() + (row.name, _get_type_mapping(row.data_type).datahub()) for row in df.collect() ] From e39b68b52ce5e076e022a3e4e89f0c2b4f62f9d0 Mon Sep 17 00:00:00 2001 From: Bidhan Mondal Date: Wed, 22 Apr 2026 11:46:17 +0530 Subject: [PATCH 19/26] fix: spark error addressed --- dagster/src/resources/io_managers/adls_delta.py | 7 +++++-- dagster/src/utils/delta.py | 5 +++++ 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/dagster/src/resources/io_managers/adls_delta.py b/dagster/src/resources/io_managers/adls_delta.py index 0cb5fb829..524915aa2 100644 --- a/dagster/src/resources/io_managers/adls_delta.py +++ b/dagster/src/resources/io_managers/adls_delta.py @@ -210,7 +210,7 @@ def _upsert_data( primary_key = get_primary_key(spark, schema_name) updated_schema = StructType(columns) - existing_schema = DeltaTable.forName(spark, full_table_name).toDF().schema + existing_schema = spark.table(full_table_name).schema sync_schema( table_name=full_table_name, @@ -222,7 +222,10 @@ def _upsert_data( ) update_columns = [c.name for c in columns if c.name != primary_key] + + spark.catalog.refreshTable(full_table_name) master = DeltaTable.forName(spark, full_table_name) + query = build_deduped_merge_query( master, data, @@ -238,7 +241,7 @@ def _upsert_data( def _overwrite_data( self, - data: sql.dataframe, + data: sql.DataFrame, schema_name: str, full_table_name: str, context: OutputContext = None, diff --git a/dagster/src/utils/delta.py b/dagster/src/utils/delta.py index 50d28ec33..f7fd49df7 100644 --- a/dagster/src/utils/delta.py +++ b/dagster/src/utils/delta.py @@ -441,6 +441,9 @@ def apply_renames_and_deletes( spark.sql(stmt) _remove_column_id_props(spark, table_name, deletes) + if renames or deletes: + spark.catalog.refreshTable(table_name) + return bool(renames or deletes) @@ -481,6 +484,7 @@ def apply_datatype_changes( .mode("overwrite") .saveAsTable(table_name) ) + spark.catalog.refreshTable(table_name) def sync_schema( @@ -517,6 +521,7 @@ def sync_schema( # 2. Refresh schemas after rename/delete to get accurate comparison # ------------------------------------------------------------------ if any_renames_deletes: + spark.catalog.refreshTable(table_name) existing_schema = spark.table(table_name).schema # ------------------------------------------------------------------ From abbc539a985257115a93807e7b1e46d6e60af04d Mon Sep 17 00:00:00 2001 From: Bidhan Mondal Date: Thu, 23 Apr 2026 10:46:54 +0530 Subject: [PATCH 20/26] fix: bug fix --- dagster/src/assets/common/assets.py | 6 +- .../src/resources/io_managers/adls_delta.py | 2 + dagster/src/utils/delta.py | 22 ++++---- dagster/src/utils/schema.py | 55 ++----------------- 4 files changed, 22 insertions(+), 63 deletions(-) diff --git a/dagster/src/assets/common/assets.py b/dagster/src/assets/common/assets.py index 214245a96..9d25cf2a5 100644 --- a/dagster/src/assets/common/assets.py +++ b/dagster/src/assets/common/assets.py @@ -600,7 +600,7 @@ def reset_staging_table( raise -def _handle_null_columns(schema_columns, primary_key): +def handle_null_columns(schema_columns, primary_key): """Handle null columns by providing default values based on data type. If the column value is NULL, add a placeholder value if the following @@ -669,7 +669,7 @@ def master( else: new_master = silver - column_actions = _handle_null_columns(schema_columns, primary_key) + column_actions = handle_null_columns(schema_columns, primary_key) new_master = new_master.withColumns(column_actions) new_master = compute_row_hash(new_master) @@ -727,7 +727,7 @@ def reference( else: new_reference = silver - column_actions = _handle_null_columns(schema_columns, primary_key) + column_actions = handle_null_columns(schema_columns, primary_key) new_reference = new_reference.withColumns(column_actions) new_reference = compute_row_hash(new_reference) diff --git a/dagster/src/resources/io_managers/adls_delta.py b/dagster/src/resources/io_managers/adls_delta.py index 524915aa2..5e70bd5f5 100644 --- a/dagster/src/resources/io_managers/adls_delta.py +++ b/dagster/src/resources/io_managers/adls_delta.py @@ -209,6 +209,8 @@ def _upsert_data( columns = get_schema_columns(spark, schema_name) primary_key = get_primary_key(spark, schema_name) + data = data.localCheckpoint() + updated_schema = StructType(columns) existing_schema = spark.table(full_table_name).schema diff --git a/dagster/src/utils/delta.py b/dagster/src/utils/delta.py index f7fd49df7..57fa9a392 100644 --- a/dagster/src/utils/delta.py +++ b/dagster/src/utils/delta.py @@ -294,7 +294,7 @@ def build_nullability_queries( return alter_stmts -def _enable_column_mapping(spark: SparkSession, table_name: str) -> None: +def enable_column_mapping(spark: SparkSession, table_name: str) -> None: """Enable column mapping mode on an existing Delta table if not already enabled.""" spark.sql( f"ALTER TABLE {table_name} SET TBLPROPERTIES (" @@ -305,7 +305,7 @@ def _enable_column_mapping(spark: SparkSession, table_name: str) -> None: ) -def _get_stored_column_id_map(spark: SparkSession, table_name: str) -> dict[str, str]: +def get_stored_column_id_map(spark: SparkSession, table_name: str) -> dict[str, str]: """Retrieve the column-name → schema-CSV-ID mapping stored in table properties. Returns ``{column_name: csv_id}`` or an empty dict if no mapping has been @@ -322,7 +322,7 @@ def _get_stored_column_id_map(spark: SparkSession, table_name: str) -> dict[str, return result -def _store_column_id_map( +def store_column_id_map( spark: SparkSession, table_name: str, column_id_map: dict[str, str], @@ -337,7 +337,7 @@ def _store_column_id_map( spark.sql(f"ALTER TABLE {table_name} SET TBLPROPERTIES ({props})") -def _remove_column_id_props( +def remove_column_id_props( spark: SparkSession, table_name: str, column_names: list[str], @@ -349,7 +349,7 @@ def _remove_column_id_props( spark.sql(f"ALTER TABLE {table_name} UNSET TBLPROPERTIES IF EXISTS ({props})") -def _detect_renames_and_deletes( +def detect_renames_and_deletes( existing_id_map: dict[str, str], updated_id_map: dict[str, str], ) -> tuple[dict[str, str], list[str]]: @@ -403,13 +403,13 @@ def apply_renames_and_deletes( columns_with_id = get_schema_columns_with_id(spark, schema_name) updated_id_map = {field.name: csv_id for csv_id, field in columns_with_id} - existing_id_map = _get_stored_column_id_map(spark, table_name) + existing_id_map = get_stored_column_id_map(spark, table_name) renames: dict[str, str] = {} deletes: list[str] = [] if existing_id_map: - renames, deletes = _detect_renames_and_deletes(existing_id_map, updated_id_map) + renames, deletes = detect_renames_and_deletes(existing_id_map, updated_id_map) context.log.info(f"Detected renames: {renames}") context.log.info(f"Detected deletes: {deletes}") else: @@ -421,7 +421,7 @@ def apply_renames_and_deletes( context.log.info( "Enabling column mapping on table for rename/delete support..." ) - _enable_column_mapping(spark, table_name) + enable_column_mapping(spark, table_name) if renames: context.log.info(f"Renaming columns: {renames}") @@ -431,7 +431,7 @@ def apply_renames_and_deletes( ) context.log.info(f"Executing: {stmt}") spark.sql(stmt) - _remove_column_id_props(spark, table_name, list(renames.keys())) + remove_column_id_props(spark, table_name, list(renames.keys())) if deletes: context.log.info(f"Dropping columns: {deletes}") @@ -439,7 +439,7 @@ def apply_renames_and_deletes( stmt = f"ALTER TABLE {table_name} DROP COLUMN `{col_name}`" context.log.info(f"Executing: {stmt}") spark.sql(stmt) - _remove_column_id_props(spark, table_name, deletes) + remove_column_id_props(spark, table_name, deletes) if renames or deletes: spark.catalog.refreshTable(table_name) @@ -455,7 +455,7 @@ def persist_column_id_map( columns_with_id = get_schema_columns_with_id(spark, schema_name) new_id_map = {field.name: csv_id for csv_id, field in columns_with_id} - _store_column_id_map(spark, table_name, new_id_map) + store_column_id_map(spark, table_name, new_id_map) def apply_datatype_changes( diff --git a/dagster/src/utils/schema.py b/dagster/src/utils/schema.py index dd617f794..1633a8934 100644 --- a/dagster/src/utils/schema.py +++ b/dagster/src/utils/schema.py @@ -1,7 +1,6 @@ from delta import DeltaTable from models import Schema from pyspark import sql -from pyspark.errors.exceptions.captured import AnalysisException from pyspark.sql import SparkSession from pyspark.sql.functions import col from pyspark.sql.types import StructField @@ -13,10 +12,9 @@ OutputContext, ) from src.constants import DataTier, constants -from src.spark.config_expectations import config -def _get_type_mapping(data_type: str): +def get_type_mapping(data_type: str): """Map a data type string to its corresponding TypeMapping from constants. Handles case-insensitivity and common aliases (e.g., 'INT' -> 'integer'). @@ -37,53 +35,12 @@ def get_schema_name( return context.op_config["metastore_schema"] -def _get_fallback_schema_df(spark: SparkSession, schema_name: str) -> sql.DataFrame: - """Return a fallback schema DataFrame from hardcoded configs if Delta table is missing.""" - if schema_name == "school_geolocation": - columns = config.COLUMNS_EXCEPT_SCHOOL_ID_GEOLOCATION + [ - "school_id_govt", - "school_id_giga", - ] - data_types = dict(config.DATA_TYPES) - - fallback_data = [] - for col_name in columns: - data_type = data_types.get(col_name, "string") - fallback_data.append( - { - "id": col_name, - "name": col_name, - "data_type": data_type, - "is_nullable": True, - "is_important": False, - "is_system_generated": False, - "description": "", - "primary_key": col_name in config.UNIQUE_COLUMNS_GEOLOCATION, - "partition_order": None, - "license": None, - "units": None, - "hint": None, - } - ) - - # Define the schema explicitly to match SchemaModel - schema = Schema.schema - return spark.createDataFrame(fallback_data, schema=schema) - - raise ValueError(f"No fallback schema available for `{schema_name}`") - - def get_schema_table(spark: SparkSession, schema_name: str) -> sql.DataFrame: metaschema_name = Schema.__schema_name__ full_table_name = f"{metaschema_name}.{schema_name}" - try: - # This should be cheap if the migrations.migrate_schema asset is caching the table properly - return DeltaTable.forName(spark, full_table_name).toDF() - except AnalysisException as e: - if "DELTA_TABLE_NOT_FOUND" in str(e): - return _get_fallback_schema_df(spark, schema_name) - raise e + # This should be cheap if the migrations.migrate_schema asset is caching the table properly + return DeltaTable.forName(spark, full_table_name).toDF() def get_schema_columns(spark: SparkSession, schema_name: str) -> list[StructField]: @@ -91,7 +48,7 @@ def get_schema_columns(spark: SparkSession, schema_name: str) -> list[StructFiel return [ StructField( row.name, - _get_type_mapping(row.data_type).pyspark(), + get_type_mapping(row.data_type).pyspark(), row.is_nullable, ) for row in df.collect() @@ -118,7 +75,7 @@ def get_schema_columns_with_id( row.id, StructField( row.name, - _get_type_mapping(row.data_type).pyspark(), + get_type_mapping(row.data_type).pyspark(), row.is_nullable, ), ) @@ -136,7 +93,7 @@ def get_schema_column_descriptions( def get_schema_columns_datahub(spark: SparkSession, schema_name: str) -> list[tuple]: df = get_schema_table(spark, schema_name) return [ - (row.name, _get_type_mapping(row.data_type).datahub()) for row in df.collect() + (row.name, get_type_mapping(row.data_type).datahub()) for row in df.collect() ] From 1cefbe84d2d2cfaf11f3b37ad3d997410d3a37a7 Mon Sep 17 00:00:00 2001 From: Bidhan Mondal Date: Thu, 23 Apr 2026 15:56:10 +0530 Subject: [PATCH 21/26] fix: bugfix --- dagster/src/resources/io_managers/adls_delta.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/dagster/src/resources/io_managers/adls_delta.py b/dagster/src/resources/io_managers/adls_delta.py index 5e70bd5f5..cbce5eaba 100644 --- a/dagster/src/resources/io_managers/adls_delta.py +++ b/dagster/src/resources/io_managers/adls_delta.py @@ -209,7 +209,9 @@ def _upsert_data( columns = get_schema_columns(spark, schema_name) primary_key = get_primary_key(spark, schema_name) - data = data.localCheckpoint() + # Break the lazy lineage to avoid DELTA_SCHEMA_CHANGE_SINCE_ANALYSIS + # if sync_schema renames columns or changes types. + data = spark.createDataFrame(data.rdd, data.schema).cache() updated_schema = StructType(columns) existing_schema = spark.table(full_table_name).schema From 8d30924fb1db7f689d6c1566e6c4ac79f71f72cf Mon Sep 17 00:00:00 2001 From: Bidhan Mondal Date: Fri, 24 Apr 2026 13:25:30 +0530 Subject: [PATCH 22/26] fix: bugfix --- .../src/assets/adhoc/master_csv_to_gold.py | 67 ------ dagster/src/assets/migrations/core.py | 4 + dagster/src/internal/common_assets/staging.py | 9 +- .../src/resources/io_managers/adls_delta.py | 7 +- dagster/src/utils/delta.py | 104 ++++++-- dagster/tests/utils/test_delta_sync_schema.py | 227 +++++++++++++++++- 6 files changed, 309 insertions(+), 109 deletions(-) diff --git a/dagster/src/assets/adhoc/master_csv_to_gold.py b/dagster/src/assets/adhoc/master_csv_to_gold.py index 5f07d03d8..3d0d2f480 100644 --- a/dagster/src/assets/adhoc/master_csv_to_gold.py +++ b/dagster/src/assets/adhoc/master_csv_to_gold.py @@ -42,7 +42,6 @@ check_table_exists, create_delta_table, create_schema, - sync_schema, ) from src.utils.logger import ContextLoggerWithLoguruFallback from src.utils.metadata import get_output_metadata, get_table_preview @@ -608,39 +607,6 @@ def adhoc__publish_master_to_gold( ) gold = compute_row_hash(gold) - table_exists = check_table_exists( - spark=spark.spark_session, - schema_name="school_master", - table_name=config.country_code.lower(), - data_tier=DataTier.GOLD, - ) - - if table_exists: - table_name = f"{config.metastore_schema}.{config.country_code}" - updated_schema = StructType( - get_schema_columns( - spark=spark.spark_session, schema_name=config.metastore_schema - ) - ) - - context.log.info(f"Existing table name: {table_name}") - - spark.spark_session.catalog.refreshTable(table_name) - existing_df = DeltaTable.forName( - sparkSession=spark.spark_session, tableOrViewName=table_name - ).toDF() - - existing_schema = existing_df.schema - - sync_schema( - table_name=table_name, - existing_schema=existing_schema, - updated_schema=updated_schema, - spark=spark.spark_session, - context=context, - schema_name=config.metastore_schema, - ) - schema_reference = get_schema_columns_datahub( spark.spark_session, config.metastore_schema, @@ -691,39 +657,6 @@ def adhoc__publish_reference_to_gold( ) gold = compute_row_hash(gold) - table_exists = check_table_exists( - spark=spark.spark_session, - schema_name="school_reference", - table_name=config.country_code.lower(), - data_tier=DataTier.GOLD, - ) - - if table_exists: - table_name = f"{config.metastore_schema}.{config.country_code}" - updated_schema = StructType( - get_schema_columns( - spark=spark.spark_session, schema_name=config.metastore_schema - ) - ) - - context.log.info(f"Existing table name: {table_name}") - - spark.spark_session.catalog.refreshTable(table_name) - existing_df = DeltaTable.forName( - sparkSession=spark.spark_session, - tableOrViewName=table_name, - ).toDF() - existing_schema = existing_df.schema - - sync_schema( - table_name=table_name, - existing_schema=existing_schema, - updated_schema=updated_schema, - spark=spark.spark_session, - context=context, - schema_name=config.metastore_schema, - ) - schema_reference = get_schema_columns_datahub( spark.spark_session, config.metastore_schema, diff --git a/dagster/src/assets/migrations/core.py b/dagster/src/assets/migrations/core.py index 732426040..745421456 100644 --- a/dagster/src/assets/migrations/core.py +++ b/dagster/src/assets/migrations/core.py @@ -6,6 +6,7 @@ from pyspark.sql.functions import col, when from src.utils.delta import ( create_delta_table, + persist_column_id_map, sync_schema, ) @@ -70,3 +71,6 @@ def save_schema_delta_table(context: OpExecutionContext, df: sql.DataFrame): .whenNotMatchedInsertAll() .execute() ) + + # Persist column-ID mapping after merge succeeds + persist_column_id_map(spark, full_table_name, table_name) diff --git a/dagster/src/internal/common_assets/staging.py b/dagster/src/internal/common_assets/staging.py index e9e7ba5b6..9d32ce9b1 100644 --- a/dagster/src/internal/common_assets/staging.py +++ b/dagster/src/internal/common_assets/staging.py @@ -10,12 +10,10 @@ ) from pyspark.sql.types import ( ArrayType, - StringType, - StructField, - StructType, LongType, StringType, StructField, + StructType, TimestampType, ) from sqlalchemy import select, update @@ -31,6 +29,7 @@ check_table_exists, create_delta_table, create_schema, + persist_column_id_map, sync_schema, ) from src.utils.op_config import FileConfig @@ -359,6 +358,10 @@ def _write_pending_records(self, pending: sql.DataFrame) -> None: .saveAsTable(self.staging_table_name) ) + # Persist column-ID mapping after staging data is written + + persist_column_id_map(self.spark, self.staging_table_name, self.schema_name) + def _update_approval_request_status(self) -> None: """Enable the ApprovalRequest if any actionable (non-UNCHANGED) rows exist.""" actionable = ( diff --git a/dagster/src/resources/io_managers/adls_delta.py b/dagster/src/resources/io_managers/adls_delta.py index cbce5eaba..a5d44b0dc 100644 --- a/dagster/src/resources/io_managers/adls_delta.py +++ b/dagster/src/resources/io_managers/adls_delta.py @@ -12,7 +12,7 @@ from src.settings import settings from src.spark.transform_functions import add_missing_columns from src.utils.adls import ADLSFileClient -from src.utils.delta import build_deduped_merge_query +from src.utils.delta import build_deduped_merge_query, persist_column_id_map from src.utils.op_config import FileConfig from src.utils.schema import ( construct_full_table_name, @@ -243,6 +243,11 @@ def _upsert_data( if query is not None: query.execute() + # Persist column-ID mapping after merge succeeds + # This ensures mapping is only updated when data is successfully merged + if not is_qos: + persist_column_id_map(spark, full_table_name, schema_name) + def _overwrite_data( self, data: sql.DataFrame, diff --git a/dagster/src/utils/delta.py b/dagster/src/utils/delta.py index 57fa9a392..dce04fc10 100644 --- a/dagster/src/utils/delta.py +++ b/dagster/src/utils/delta.py @@ -389,6 +389,69 @@ def detect_renames_and_deletes( return renames, deletes +def initialize_column_id_map( + spark: SparkSession, + table_name: str, + updated_id_map: dict[str, str], + context: OpExecutionContext, +) -> tuple[dict[str, str], dict[str, str], list[str]]: + """Initialize column ID mapping from table schema when no stored mapping exists. + + Returns tuple of (initialized_id_map, renames, deletes). + """ + context.log.info( + "No stored column-ID mapping found; initialising mapping from current table schema." + ) + existing_id_map: dict[str, str] = {} + # Initialize mapping from current table columns so deletions can be detected + # This handles the case where columns were dropped from CSV but table still has them + current_columns = spark.table(table_name).schema.fieldNames() + for col_name in current_columns: + if col_name not in updated_id_map: + # Column exists in table but not in reference schema - will be detected as delete + # Use a deterministic ID based on column name for tracking + existing_id_map[col_name] = f"table_{col_name}" + + renames: dict[str, str] = {} + deletes: list[str] = [] + if existing_id_map: + renames, deletes = detect_renames_and_deletes(existing_id_map, updated_id_map) + context.log.info(f"Detected renames after init: {renames}") + context.log.info(f"Detected deletes after init: {deletes}") + + return existing_id_map, renames, deletes + + +def execute_renames( + spark: SparkSession, + table_name: str, + renames: dict[str, str], + context: OpExecutionContext, +) -> None: + """Execute column rename SQL statements.""" + context.log.info(f"Renaming columns: {renames}") + for old_name, new_name in renames.items(): + stmt = f"ALTER TABLE {table_name} RENAME COLUMN `{old_name}` TO `{new_name}`" + context.log.info(f"Executing: {stmt}") + spark.sql(stmt) + remove_column_id_props(spark, table_name, list(renames.keys())) + + +def execute_deletes( + spark: SparkSession, + table_name: str, + deletes: list[str], + context: OpExecutionContext, +) -> None: + """Execute column drop SQL statements.""" + context.log.info(f"Dropping columns: {deletes}") + for col_name in deletes: + stmt = f"ALTER TABLE {table_name} DROP COLUMN `{col_name}`" + context.log.info(f"Executing: {stmt}") + spark.sql(stmt) + remove_column_id_props(spark, table_name, deletes) + + def apply_renames_and_deletes( spark: SparkSession, table_name: str, @@ -413,8 +476,8 @@ def apply_renames_and_deletes( context.log.info(f"Detected renames: {renames}") context.log.info(f"Detected deletes: {deletes}") else: - context.log.info( - "No stored column-ID mapping found; initialising mapping from current reference schema." + existing_id_map, renames, deletes = initialize_column_id_map( + spark, table_name, updated_id_map, context ) if renames or deletes: @@ -424,22 +487,10 @@ def apply_renames_and_deletes( enable_column_mapping(spark, table_name) if renames: - context.log.info(f"Renaming columns: {renames}") - for old_name, new_name in renames.items(): - stmt = ( - f"ALTER TABLE {table_name} RENAME COLUMN `{old_name}` TO `{new_name}`" - ) - context.log.info(f"Executing: {stmt}") - spark.sql(stmt) - remove_column_id_props(spark, table_name, list(renames.keys())) + execute_renames(spark, table_name, renames, context) if deletes: - context.log.info(f"Dropping columns: {deletes}") - for col_name in deletes: - stmt = f"ALTER TABLE {table_name} DROP COLUMN `{col_name}`" - context.log.info(f"Executing: {stmt}") - spark.sql(stmt) - remove_column_id_props(spark, table_name, deletes) + execute_deletes(spark, table_name, deletes, context) if renames or deletes: spark.catalog.refreshTable(table_name) @@ -540,16 +591,16 @@ def sync_schema( updated_columns_set = {field.name for field in updated_schema} added_columns = updated_columns_set - existing_columns + # Recalculate removed_columns after renames/deletes were applied + # This ensures we correctly identify columns that should be dropped removed_columns = existing_columns - updated_columns_set - has_schema_changed = len(added_columns) + len(removed_columns) > 0 - changed_datatypes = get_changed_datatypes( context=context, existing_schema=existing_schema, updated_schema=updated_schema ) apply_datatype_changes(spark, table_name, changed_datatypes, context) - if has_schema_changed: + if added_columns: context.log.info(f"Adding schema columns {added_columns}") empty_dataframe_with_updated_schema = spark.createDataFrame( @@ -561,6 +612,15 @@ def sync_schema( .mode("append") .saveAsTable(table_name) ) + + # Drop columns that are no longer in the updated schema + if removed_columns: + context.log.info(f"Dropping columns not in updated schema: {removed_columns}") + for col_name in removed_columns: + stmt = f"ALTER TABLE {table_name} DROP COLUMN `{col_name}`" + context.log.info(f"Executing: {stmt}") + spark.sql(stmt) + context.log.info(f"has_nullability_changed {has_nullability_changed}") if has_nullability_changed: @@ -577,9 +637,3 @@ def sync_schema( continue else: raise - - # ------------------------------------------------------------------ - # 4. Persist column-ID mapping for future rename/delete detection - # ------------------------------------------------------------------ - if schema_name is not None: - persist_column_id_map(spark, table_name, schema_name) diff --git a/dagster/tests/utils/test_delta_sync_schema.py b/dagster/tests/utils/test_delta_sync_schema.py index 404e01361..f677fecff 100644 --- a/dagster/tests/utils/test_delta_sync_schema.py +++ b/dagster/tests/utils/test_delta_sync_schema.py @@ -1,29 +1,29 @@ """Tests for the Delta Lake column rename/delete detection helpers in src.utils.delta.""" -from src.utils.delta import _detect_renames_and_deletes +from src.utils.delta import detect_renames_and_deletes class TestDetectRenamesAndDeletes: - """Unit tests for _detect_renames_and_deletes.""" + """Unit tests for detect_renames_and_deletes.""" def test_no_changes(self): existing = {"col_a": "id-1", "col_b": "id-2", "col_c": "id-3"} updated = {"col_a": "id-1", "col_b": "id-2", "col_c": "id-3"} - renames, deletes = _detect_renames_and_deletes(existing, updated) + renames, deletes = detect_renames_and_deletes(existing, updated) assert renames == {} assert deletes == [] def test_column_renamed(self): existing = {"old_name": "id-1", "col_b": "id-2"} updated = {"new_name": "id-1", "col_b": "id-2"} - renames, deletes = _detect_renames_and_deletes(existing, updated) + renames, deletes = detect_renames_and_deletes(existing, updated) assert renames == {"old_name": "new_name"} assert deletes == [] def test_column_deleted(self): existing = {"col_a": "id-1", "col_b": "id-2", "col_c": "id-3"} updated = {"col_a": "id-1", "col_b": "id-2"} - renames, deletes = _detect_renames_and_deletes(existing, updated) + renames, deletes = detect_renames_and_deletes(existing, updated) assert renames == {} assert deletes == ["col_c"] @@ -31,7 +31,7 @@ def test_column_added_only(self): """Adding a column (ID in updated but not existing) should not trigger renames or deletes.""" existing = {"col_a": "id-1"} updated = {"col_a": "id-1", "col_new": "id-new"} - renames, deletes = _detect_renames_and_deletes(existing, updated) + renames, deletes = detect_renames_and_deletes(existing, updated) assert renames == {} assert deletes == [] @@ -45,7 +45,7 @@ def test_rename_and_delete_combined(self): "new_name": "id-1", "col_b": "id-2", } - renames, deletes = _detect_renames_and_deletes(existing, updated) + renames, deletes = detect_renames_and_deletes(existing, updated) assert renames == {"old_name": "new_name"} assert deletes == ["col_to_drop"] @@ -60,38 +60,239 @@ def test_rename_delete_and_add(self): "col_b": "id-2", "col_new": "id-4", } - renames, deletes = _detect_renames_and_deletes(existing, updated) + renames, deletes = detect_renames_and_deletes(existing, updated) assert renames == {"old_name": "new_name"} assert deletes == ["col_drop"] def test_multiple_renames(self): existing = {"a": "id-1", "b": "id-2", "c": "id-3"} updated = {"x": "id-1", "y": "id-2", "c": "id-3"} - renames, deletes = _detect_renames_and_deletes(existing, updated) + renames, deletes = detect_renames_and_deletes(existing, updated) assert renames == {"a": "x", "b": "y"} assert deletes == [] def test_multiple_deletes(self): existing = {"a": "id-1", "b": "id-2", "c": "id-3"} updated = {"a": "id-1"} - renames, deletes = _detect_renames_and_deletes(existing, updated) + renames, deletes = detect_renames_and_deletes(existing, updated) assert renames == {} assert sorted(deletes) == ["b", "c"] def test_empty_existing(self): """If existing is empty, there should be no changes.""" - renames, deletes = _detect_renames_and_deletes({}, {"a": "id-1"}) + renames, deletes = detect_renames_and_deletes({}, {"a": "id-1"}) assert renames == {} assert deletes == [] def test_empty_updated_deletes_all(self): """If updated is empty, all existing columns should be deleted.""" existing = {"a": "id-1", "b": "id-2"} - renames, deletes = _detect_renames_and_deletes(existing, {}) + renames, deletes = detect_renames_and_deletes(existing, {}) assert renames == {} assert sorted(deletes) == ["a", "b"] def test_both_empty(self): - renames, deletes = _detect_renames_and_deletes({}, {}) + renames, deletes = detect_renames_and_deletes({}, {}) assert renames == {} assert deletes == [] + + +class TestSyncSchemaRemovedColumns: + """Unit tests for sync_schema removed_columns calculation logic.""" + + def test_removed_columns_detection(self): + """Test that removed_columns is correctly calculated after renames.""" + existing_columns = {"col_a", "col_b", "col_c"} + updated_columns_set = {"col_a", "col_x"} + removed_columns = existing_columns - updated_columns_set + assert removed_columns == {"col_b", "col_c"} + + def test_removed_columns_after_rename_applied(self): + """Test that after rename is applied, only orphaned columns are removed.""" + existing_after_rename = {"col_a", "col_x", "col_c"} + updated_columns_set = {"col_a", "col_x"} + removed_columns = existing_after_rename - updated_columns_set + assert removed_columns == {"col_c"} + + def test_removed_columns_empty_when_no_deletions(self): + """Test that removed_columns is empty when no columns are deleted.""" + existing_columns = {"col_a", "col_b", "col_c"} + updated_columns_set = {"col_a", "col_b", "col_c"} + removed_columns = existing_columns - updated_columns_set + assert removed_columns == set() + + +class TestApplyRenamesAndDeletesInitialization: + """Unit tests for apply_renames_and_deletes mapping initialization.""" + + def test_empty_mapping_initialization_logic(self): + """Test the logic for initializing column ID mapping when empty.""" + current_table_columns = ["col_a", "col_b", "col_to_delete"] + updated_id_map = {"col_a": "id-1", "col_b": "id-2"} + existing_id_map = {} + for col_name in current_table_columns: + if col_name not in updated_id_map: + existing_id_map[col_name] = f"table_{col_name}" + assert "col_to_delete" in existing_id_map + assert existing_id_map["col_to_delete"] == "table_col_to_delete" + assert "col_a" not in existing_id_map + assert "col_b" not in existing_id_map + renames, deletes = detect_renames_and_deletes(existing_id_map, updated_id_map) + assert deletes == ["col_to_delete"] + assert renames == {} + + +class TestMultipleOperations: + """Test handling of multiple simultaneous add, rename, and delete operations.""" + + def test_multiple_renames_simultaneous(self): + """Test that multiple columns can be renamed at once.""" + existing = { + "old_col_a": "id-1", + "old_col_b": "id-2", + "old_col_c": "id-3", + "unchanged": "id-4", + } + updated = { + "new_col_a": "id-1", + "new_col_b": "id-2", + "new_col_c": "id-3", + "unchanged": "id-4", + } + renames, deletes = detect_renames_and_deletes(existing, updated) + assert renames == { + "old_col_a": "new_col_a", + "old_col_b": "new_col_b", + "old_col_c": "new_col_c", + } + assert deletes == [] + + def test_multiple_deletes_simultaneous(self): + """Test that multiple columns can be deleted at once.""" + existing = { + "col_a": "id-1", + "col_to_drop_1": "id-2", + "col_b": "id-3", + "col_to_drop_2": "id-4", + "col_to_drop_3": "id-5", + } + updated = { + "col_a": "id-1", + "col_b": "id-3", + } + renames, deletes = detect_renames_and_deletes(existing, updated) + assert renames == {} + assert sorted(deletes) == ["col_to_drop_1", "col_to_drop_2", "col_to_drop_3"] + + def test_multiple_renames_and_deletes_simultaneous(self): + """Test multiple renames and deletes happening together.""" + existing = { + "old_a": "id-1", + "to_drop_1": "id-2", + "old_b": "id-3", + "to_drop_2": "id-4", + "unchanged": "id-5", + } + updated = { + "new_a": "id-1", + "new_b": "id-3", + "unchanged": "id-5", + } + renames, deletes = detect_renames_and_deletes(existing, updated) + assert renames == {"old_a": "new_a", "old_b": "new_b"} + assert sorted(deletes) == ["to_drop_1", "to_drop_2"] + + def test_full_schema_evolution_add_rename_delete(self): + """Test complete schema evolution: adds, renames, and deletes simultaneously.""" + existing = { + "school_id": "id-1", + "old_funding_type": "id-2", + "num_teachers_female": "id-3", + "num_teachers_male": "id-4", + "old_tablet_count": "id-5", + } + updated = { + "school_id": "id-1", + "school_funding_source": "id-2", + "num_tablets_used": "id-5", + } + renames, deletes = detect_renames_and_deletes(existing, updated) + assert renames == { + "old_funding_type": "school_funding_source", + "old_tablet_count": "num_tablets_used", + } + assert sorted(deletes) == ["num_teachers_female", "num_teachers_male"] + + def test_complex_multi_country_scenario(self): + """Test scenario matching the user's Gambia case with multiple changes.""" + existing = { + "school_id_giga": "csv-id-001", + "school_name": "csv-id-002", + "school_funding_type": "csv-id-010", + "num_tablets": "csv-id-015", + "num_teachers_female": "csv-id-020", + "num_teachers_male": "csv-id-021", + "latitude": "csv-id-030", + "longitude": "csv-id-031", + } + updated = { + "school_id_giga": "csv-id-001", + "school_name": "csv-id-002", + "school_funding_source": "csv-id-010", + "num_tablets_used": "csv-id-015", + "latitude": "csv-id-030", + "longitude": "csv-id-031", + } + renames, deletes = detect_renames_and_deletes(existing, updated) + assert renames == { + "school_funding_type": "school_funding_source", + "num_tablets": "num_tablets_used", + } + assert sorted(deletes) == ["num_teachers_female", "num_teachers_male"] + + def test_multiple_adds_simultaneous(self): + """Test that multiple columns can be added at once.""" + existing_columns = {"col_a", "col_b", "col_c"} + updated_columns_set = { + "col_a", + "col_b", + "col_c", + "new_col_x", + "new_col_y", + "new_col_z", + } + added_columns = updated_columns_set - existing_columns + assert added_columns == {"new_col_x", "new_col_y", "new_col_z"} + + def test_complete_workflow_add_rename_delete_together(self): + """Test the complete workflow: adds, renames, and deletes all together.""" + existing_id_map = { + "school_id": "uuid-001", + "old_name_a": "uuid-002", + "old_name_b": "uuid-003", + "to_delete_1": "uuid-100", + "to_delete_2": "uuid-101", + } + updated_id_map = { + "school_id": "uuid-001", + "new_name_a": "uuid-002", + "new_name_b": "uuid-003", + "new_col_x": "uuid-200", + "new_col_y": "uuid-201", + "new_col_z": "uuid-202", + } + renames, deletes = detect_renames_and_deletes(existing_id_map, updated_id_map) + assert renames == {"old_name_a": "new_name_a", "old_name_b": "new_name_b"} + assert sorted(deletes) == ["to_delete_1", "to_delete_2"] + existing_after_renames = { + "school_id", + "new_name_a", + "new_name_b", + "to_delete_1", + "to_delete_2", + } + updated_columns_set = set(updated_id_map.keys()) + added_columns = updated_columns_set - existing_after_renames + removed_columns = existing_after_renames - updated_columns_set + assert added_columns == {"new_col_x", "new_col_y", "new_col_z"} + assert removed_columns == {"to_delete_1", "to_delete_2"} From df0d08bd181b27929da5cb0502fa8bf565e77df6 Mon Sep 17 00:00:00 2001 From: Bidhan Mondal Date: Mon, 27 Apr 2026 19:55:10 +0530 Subject: [PATCH 23/26] fix: rename and delete fixed --- dagster/src/assets/migrations/core.py | 1 + .../src/resources/io_managers/adls_delta.py | 22 +- dagster/src/utils/delta.py | 113 ++++++-- dagster/tests/utils/test_delta_sync_schema.py | 259 +++++++++++++++++- 4 files changed, 351 insertions(+), 44 deletions(-) diff --git a/dagster/src/assets/migrations/core.py b/dagster/src/assets/migrations/core.py index 745421456..323cd4dac 100644 --- a/dagster/src/assets/migrations/core.py +++ b/dagster/src/assets/migrations/core.py @@ -69,6 +69,7 @@ def save_schema_delta_table(context: OpExecutionContext, df: sql.DataFrame): .merge(df.alias("updates"), "master.id = updates.id") .whenMatchedUpdateAll() .whenNotMatchedInsertAll() + .whenNotMatchedBySourceDelete() .execute() ) diff --git a/dagster/src/resources/io_managers/adls_delta.py b/dagster/src/resources/io_managers/adls_delta.py index a5d44b0dc..8110b39b5 100644 --- a/dagster/src/resources/io_managers/adls_delta.py +++ b/dagster/src/resources/io_managers/adls_delta.py @@ -12,7 +12,11 @@ from src.settings import settings from src.spark.transform_functions import add_missing_columns from src.utils.adls import ADLSFileClient -from src.utils.delta import build_deduped_merge_query, persist_column_id_map +from src.utils.delta import ( + build_deduped_merge_query, + execute_query_with_error_handler, + persist_column_id_map, +) from src.utils.op_config import FileConfig from src.utils.schema import ( construct_full_table_name, @@ -157,23 +161,10 @@ def _create_table_if_not_exists( "delta.logRetentionDuration", constants.school_master_retention_period ) - from src.utils.delta import ( - execute_query_with_error_handler, - persist_column_id_map, - ) - execute_query_with_error_handler( spark, query, schema_name_for_tier, table_name, context ) - if columns_schema_name not in [ - "qos", - "qos_raw", - "qos_availability", - "custom_dataset", - ]: - persist_column_id_map(spark, full_table_name, columns_schema_name) - def _upsert_data( self, data: sql.DataFrame, @@ -214,7 +205,8 @@ def _upsert_data( data = spark.createDataFrame(data.rdd, data.schema).cache() updated_schema = StructType(columns) - existing_schema = spark.table(full_table_name).schema + spark.catalog.refreshTable(full_table_name) + existing_schema = DeltaTable.forName(spark, full_table_name).toDF().schema sync_schema( table_name=full_table_name, diff --git a/dagster/src/utils/delta.py b/dagster/src/utils/delta.py index dce04fc10..fcc0b8fad 100644 --- a/dagster/src/utils/delta.py +++ b/dagster/src/utils/delta.py @@ -311,6 +311,7 @@ def get_stored_column_id_map(spark: SparkSession, table_name: str) -> dict[str, Returns ``{column_name: csv_id}`` or an empty dict if no mapping has been stored yet (e.g. tables created before this feature was added). """ + spark.catalog.refreshTable(table_name) detail = spark.sql(f"DESCRIBE DETAIL {table_name}").collect()[0] properties: dict = detail["properties"] if detail["properties"] else {} result = {} @@ -327,9 +328,24 @@ def store_column_id_map( table_name: str, column_id_map: dict[str, str], ) -> None: - """Persist the column-name → schema-CSV-ID mapping as Delta table properties.""" + """Persist the column-name → schema-CSV-ID mapping as Delta table properties. + + Also removes any stale ``giga.columnId.*`` properties for columns that are + not present in the new mapping. This prevents accumulation of old column + name props across renames, which would otherwise cause future rename + detection to misbehave (e.g. multiple props pointing to the same UUID). + """ if not column_id_map: return + + # Remove stale props for columns that no longer exist in the new mapping + current_props = get_stored_column_id_map(spark, table_name) + stale_columns = [ + col_name for col_name in current_props if col_name not in column_id_map + ] + if stale_columns: + remove_column_id_props(spark, table_name, stale_columns) + props = ", ".join( f"'giga.columnId.{col_name}' = '{csv_id}'" for col_name, csv_id in column_id_map.items() @@ -405,7 +421,7 @@ def initialize_column_id_map( existing_id_map: dict[str, str] = {} # Initialize mapping from current table columns so deletions can be detected # This handles the case where columns were dropped from CSV but table still has them - current_columns = spark.table(table_name).schema.fieldNames() + current_columns = DeltaTable.forName(spark, table_name).toDF().schema.fieldNames() for col_name in current_columns: if col_name not in updated_id_map: # Column exists in table but not in reference schema - will be detected as delete @@ -461,6 +477,13 @@ def apply_renames_and_deletes( """Detect and apply column renames and deletes to a Delta table based on the reference schema. Returns True if any schema change occurred (rename or delete). + + IMPORTANT: This function ALWAYS supplements ``existing_id_map`` with any + table columns that lack stored UUIDs. This prevents the situation where a + partially-populated ``column_id_map`` causes some columns to be silently + ignored by rename/delete detection (which previously led to those columns + being dropped by the fallback path in :func:`sync_schema`, causing data + loss). """ from src.utils.schema import get_schema_columns_with_id @@ -468,17 +491,37 @@ def apply_renames_and_deletes( updated_id_map = {field.name: csv_id for csv_id, field in columns_with_id} existing_id_map = get_stored_column_id_map(spark, table_name) - renames: dict[str, str] = {} - deletes: list[str] = [] - - if existing_id_map: - renames, deletes = detect_renames_and_deletes(existing_id_map, updated_id_map) - context.log.info(f"Detected renames: {renames}") - context.log.info(f"Detected deletes: {deletes}") - else: - existing_id_map, renames, deletes = initialize_column_id_map( - spark, table_name, updated_id_map, context + # If the table has NO stored UUIDs at all (pre-existing table that was + # created before UUID-based tracking was introduced), bootstrapping is + # needed first. We must NOT treat all columns as deletes — that would + # cause data loss by dropping every column. Instead, persist the current + # mapping now and skip rename/delete detection for this run. + if not existing_id_map: + context.log.info( + f"No stored column-ID mapping found for {table_name}. " + "Bootstrapping UUID props from current schema. " + "Rename/delete detection will be active from the next run onwards." ) + persist_column_id_map(spark, table_name, schema_name) + return False + + # Supplement existing_id_map with table columns that lack stored UUIDs. + # This handles columns added via mergeSchema after the initial bootstrap + # (e.g. ADD operations that ran before persist_column_id_map was called). + # Tag them with a synthetic ID so they are treated as deletes if they are + # no longer in the reference schema. + current_columns = DeltaTable.forName(spark, table_name).toDF().schema.fieldNames() + for col_name in current_columns: + if col_name not in existing_id_map: + existing_id_map[col_name] = f"table_{col_name}" + context.log.info( + f"Column '{col_name}' exists in table but has no stored UUID; " + f"tagging with synthetic ID for delete detection." + ) + + renames, deletes = detect_renames_and_deletes(existing_id_map, updated_id_map) + context.log.info(f"Detected renames: {renames}") + context.log.info(f"Detected deletes: {deletes}") if renames or deletes: context.log.info( @@ -538,6 +581,44 @@ def apply_datatype_changes( spark.catalog.refreshTable(table_name) +def handle_removed_columns( + spark: SparkSession, + table_name: str, + removed_columns: set[str], + schema_name: str | None, + context: OpExecutionContext, +) -> None: + """Safely handle columns that exist in the table but not in the updated schema. + + When ``schema_name`` is provided, :func:`apply_renames_and_deletes` is the + authoritative path for both renames and deletes (UUID-based detection). + If columns end up here despite that, it likely means rename detection + failed for them. Log a clear warning instead of silently dropping to + prevent data loss. + + When ``schema_name`` is None (legacy callers, schema-tables migration), + fall back to dropping by name as before. + """ + if not removed_columns: + return + + if schema_name is not None: + context.log.warning( + f"Columns exist in table but not in updated schema: " + f"{removed_columns}. These were NOT handled by " + f"apply_renames_and_deletes - leaving them in place to avoid " + f"unintended data loss. If you intend to drop them, remove the " + f"column from the schema CSV (with its UUID) and re-run." + ) + return + + context.log.info(f"Dropping columns not in updated schema: {removed_columns}") + for col_name in removed_columns: + stmt = f"ALTER TABLE {table_name} DROP COLUMN `{col_name}`" + context.log.info(f"Executing: {stmt}") + spark.sql(stmt) + + def sync_schema( table_name: str, existing_schema: StructType, @@ -613,13 +694,7 @@ def sync_schema( .saveAsTable(table_name) ) - # Drop columns that are no longer in the updated schema - if removed_columns: - context.log.info(f"Dropping columns not in updated schema: {removed_columns}") - for col_name in removed_columns: - stmt = f"ALTER TABLE {table_name} DROP COLUMN `{col_name}`" - context.log.info(f"Executing: {stmt}") - spark.sql(stmt) + handle_removed_columns(spark, table_name, removed_columns, schema_name, context) context.log.info(f"has_nullability_changed {has_nullability_changed}") diff --git a/dagster/tests/utils/test_delta_sync_schema.py b/dagster/tests/utils/test_delta_sync_schema.py index f677fecff..ec00919e1 100644 --- a/dagster/tests/utils/test_delta_sync_schema.py +++ b/dagster/tests/utils/test_delta_sync_schema.py @@ -125,20 +125,35 @@ def test_removed_columns_empty_when_no_deletions(self): class TestApplyRenamesAndDeletesInitialization: """Unit tests for apply_renames_and_deletes mapping initialization.""" - def test_empty_mapping_initialization_logic(self): - """Test the logic for initializing column ID mapping when empty.""" - current_table_columns = ["col_a", "col_b", "col_to_delete"] - updated_id_map = {"col_a": "id-1", "col_b": "id-2"} + def test_empty_mapping_skips_detection(self): + """When existing_id_map is completely empty (pre-existing table with no stored + UUIDs), rename/delete detection must be skipped entirely. + + Previously the code tagged all columns with synthetic IDs and treated them as + deletes, which caused data loss on pre-existing tables. The fix: when + existing_id_map is empty, return early after bootstrapping — do NOT drop anything. + """ + # Simulate the guard added to apply_renames_and_deletes: + # if not existing_id_map: bootstrap and return False (no changes) existing_id_map = {} + assert not existing_id_map # guard condition: skip when completely empty + + def test_partial_mapping_still_detects_deletes(self): + """When existing_id_map is PARTIALLY populated (some columns have stored UUIDs), + columns that lack a UUID are given synthetic IDs and treated as deletes if they + are absent from the updated schema. + """ + existing_id_map = {"col_a": "id-1", "col_b": "id-2"} + updated_id_map = {"col_a": "id-1", "col_b": "id-2"} + current_table_columns = ["col_a", "col_b", "orphan_col"] + + # Supplement with synthetic ID for orphan column for col_name in current_table_columns: - if col_name not in updated_id_map: + if col_name not in existing_id_map: existing_id_map[col_name] = f"table_{col_name}" - assert "col_to_delete" in existing_id_map - assert existing_id_map["col_to_delete"] == "table_col_to_delete" - assert "col_a" not in existing_id_map - assert "col_b" not in existing_id_map + renames, deletes = detect_renames_and_deletes(existing_id_map, updated_id_map) - assert deletes == ["col_to_delete"] + assert deletes == ["orphan_col"] assert renames == {} @@ -296,3 +311,227 @@ def test_complete_workflow_add_rename_delete_together(self): removed_columns = existing_after_renames - updated_columns_set assert added_columns == {"new_col_x", "new_col_y", "new_col_z"} assert removed_columns == {"to_delete_1", "to_delete_2"} + + +class TestPartialColumnIdMap: + """Regression tests for the partial column_id_map bug. + + The bug: when ``existing_id_map`` was partially populated (some columns + had stored UUIDs, others did not), columns without UUIDs were silently + ignored by detect_renames_and_deletes. The secondary path in sync_schema + then dropped them by name, causing data loss when the user intended a + rename. + + The fix: ALWAYS supplement existing_id_map with synthetic ``table_*`` IDs + for any table column that lacks a stored UUID. This ensures the column + is at least handled by the explicit delete path (with a clear log + message) instead of being silently dropped. + """ + + def test_simulates_managers_bug_report(self): + """Reproduces the exact bug from the manager's logs. + + Setup: + - silver table has columns: school_funding_source, num_tablets_used + - column_id_map has only num_tablets_used UUID stored + (school_funding_source UUID is missing for some reason) + - User updates schema CSV: school_funding_source -> school_funding_type, + num_tablets_used -> num_tablets + + Before the fix: only num_tablets_used was detected as rename; + school_funding_source was silently dropped. + + After the fix: num_tablets_used is renamed (UUID match), + school_funding_source is detected as DELETE (synthetic ID, no match) + and dropped via the explicit path (with clear log). No silent data + loss. + """ + # Simulate stored column_id_map (partial - missing school_funding_source) + stored_id_map = { + "num_tablets_used": "uuid-tablets", + } + # Schema CSV has new names + updated_id_map = { + "school_funding_type": "uuid-funding", + "num_tablets": "uuid-tablets", + } + + # Simulate the new logic: supplement existing_id_map with table columns + current_table_columns = ["school_funding_source", "num_tablets_used"] + existing_id_map = dict(stored_id_map) + for col_name in current_table_columns: + if col_name not in existing_id_map: + existing_id_map[col_name] = f"table_{col_name}" + + # Now detect renames and deletes + renames, deletes = detect_renames_and_deletes(existing_id_map, updated_id_map) + + # num_tablets_used -> num_tablets is detected as rename via UUID + assert renames == {"num_tablets_used": "num_tablets"} + # school_funding_source is detected as DELETE (synthetic ID, no UUID match) + # This is now CONSISTENT - no silent drop in sync_schema secondary path + assert deletes == ["school_funding_source"] + + def test_partial_map_with_all_renames_intent(self): + """If user wants to rename a column without stored UUID, it becomes a delete. + + This is expected behavior - we cannot detect renames without UUID + matching. The user must ensure UUIDs are preserved across renames. + Better than silent data loss. + """ + stored_id_map = {"col_a": "uuid-a"} + updated_id_map = {"col_a_renamed": "uuid-a", "col_b_new": "uuid-b"} + current_table_columns = ["col_a", "col_b_old"] + + existing_id_map = dict(stored_id_map) + for col_name in current_table_columns: + if col_name not in existing_id_map: + existing_id_map[col_name] = f"table_{col_name}" + + renames, deletes = detect_renames_and_deletes(existing_id_map, updated_id_map) + + # col_a -> col_a_renamed: detected via UUID match + assert renames == {"col_a": "col_a_renamed"} + # col_b_old has no UUID -> detected as delete (synthetic ID never matches) + assert deletes == ["col_b_old"] + + def test_full_map_rename_works_correctly(self): + """When all columns have UUIDs stored, rename detection is perfect. + + This is the happy path - column_id_map is complete and accurate. + """ + stored_id_map = { + "school_funding_source": "uuid-funding", + "num_tablets_used": "uuid-tablets", + } + updated_id_map = { + "school_funding_type": "uuid-funding", + "num_tablets": "uuid-tablets", + } + current_table_columns = ["school_funding_source", "num_tablets_used"] + + # With full map, supplementing adds nothing extra + existing_id_map = dict(stored_id_map) + for col_name in current_table_columns: + if col_name not in existing_id_map: + existing_id_map[col_name] = f"table_{col_name}" + + renames, deletes = detect_renames_and_deletes(existing_id_map, updated_id_map) + + # Both renames detected correctly + assert renames == { + "school_funding_source": "school_funding_type", + "num_tablets_used": "num_tablets", + } + assert deletes == [] + + def test_orphan_column_not_silently_dropped(self): + """Orphan columns (in table but not in CSV) are detected as explicit deletes. + + Previously they would be silently dropped by sync_schema secondary path. + Now they show up in the deletes list with clear logging. + """ + stored_id_map = {"col_a": "uuid-a"} + updated_id_map = {"col_a": "uuid-a"} + # orphan_col is in table but neither in stored map nor in CSV + current_table_columns = ["col_a", "orphan_col"] + + existing_id_map = dict(stored_id_map) + for col_name in current_table_columns: + if col_name not in existing_id_map: + existing_id_map[col_name] = f"table_{col_name}" + + renames, deletes = detect_renames_and_deletes(existing_id_map, updated_id_map) + + assert renames == {} + assert deletes == ["orphan_col"] + + +class TestStoreColumnIdMapStaleCleanup: + """Tests for the store_column_id_map stale entry cleanup logic. + + Bug: store_column_id_map only ADDed/UPDATEd props but never REMOVEd + stale entries. After multiple renames, old column name props would + accumulate in table properties, eventually causing rename detection + to misbehave (e.g. multiple props pointing to the same UUID). + + Fix: store_column_id_map now removes any giga.columnId.* props for + columns not in the new mapping. + """ + + def test_stale_columns_identified(self): + """Test the logic for identifying stale columns to remove.""" + current_props = { + "old_name_a": "uuid-a", + "old_name_b": "uuid-b", + "unchanged": "uuid-c", + } + new_map = { + "new_name_a": "uuid-a", # renamed from old_name_a + "new_name_b": "uuid-b", # renamed from old_name_b + "unchanged": "uuid-c", + } + stale = [name for name in current_props if name not in new_map] + assert sorted(stale) == ["old_name_a", "old_name_b"] + + def test_no_stale_columns(self): + """If new map matches current props, no cleanup needed.""" + current_props = {"col_a": "uuid-a", "col_b": "uuid-b"} + new_map = {"col_a": "uuid-a", "col_b": "uuid-b"} + stale = [name for name in current_props if name not in new_map] + assert stale == [] + + def test_removed_columns_in_stale(self): + """Deleted columns appear in stale list.""" + current_props = {"col_a": "uuid-a", "col_to_delete": "uuid-b"} + new_map = {"col_a": "uuid-a"} + stale = [name for name in current_props if name not in new_map] + assert stale == ["col_to_delete"] + + +class TestSyncSchemaSafeDropPath: + """Tests for the sync_schema secondary drop path safety fix. + + Bug: When schema_name was provided and apply_renames_and_deletes + missed a column (e.g. due to incomplete column_id_map), the secondary + path in sync_schema would silently drop the column by name comparison, + causing data loss. + + Fix: When schema_name is provided, the secondary path now logs a + WARNING and leaves the column in place. apply_renames_and_deletes + is the authoritative path for both renames and deletes. + + Note: These tests verify the LOGIC of the new safe drop path; they + don't run sync_schema directly (which requires Spark). + """ + + def test_safe_drop_path_with_schema_name(self): + """When schema_name is provided, leftover columns should NOT be dropped.""" + schema_name = "school_geolocation" + removed_columns = {"school_funding_source"} + + # Simulate the new safe-drop logic + should_drop = [] + if removed_columns: + if schema_name is not None: + # New behavior: warn but don't drop + pass + else: + # Legacy behavior: drop by name + should_drop = list(removed_columns) + + assert should_drop == [] + + def test_safe_drop_path_without_schema_name(self): + """Legacy callers (no schema_name) keep old drop-by-name behavior.""" + schema_name = None + removed_columns = {"col_to_drop"} + + should_drop = [] + if removed_columns: + if schema_name is not None: + pass + else: + should_drop = list(removed_columns) + + assert should_drop == ["col_to_drop"] From a52d0ab844985a708413531fbce0c32dc746e0dc Mon Sep 17 00:00:00 2001 From: Bidhan Mondal Date: Tue, 12 May 2026 16:43:39 +0530 Subject: [PATCH 24/26] fix: partition key bug fix --- .../src/assets/school_geolocation/assets.py | 15 +- dagster/src/utils/delta.py | 374 ++++++++-- dagster/tests/utils/test_delta_sync_schema.py | 646 +++++++++++++++++- 3 files changed, 983 insertions(+), 52 deletions(-) diff --git a/dagster/src/assets/school_geolocation/assets.py b/dagster/src/assets/school_geolocation/assets.py index 655938b12..a7b2cfb08 100644 --- a/dagster/src/assets/school_geolocation/assets.py +++ b/dagster/src/assets/school_geolocation/assets.py @@ -128,10 +128,19 @@ def geolocation_metadata( context.log.info("Create spark dataframe") metadata_df = s.createDataFrame(metadata_df) - table_columns = get_schema_columns(s, "school_geolocation_metadata") table_name = "school_geolocation_metadata" table_schema_name = "pipeline_tables" + context.log.info("Get schema columns for metadata table") + try: + table_columns = get_schema_columns(s, "school_geolocation_metadata") + except Exception: + context.log.warning( + "Schema table schemas.school_geolocation_metadata not found; " + "using DataFrame schema for metadata table creation" + ) + table_columns = list(metadata_df.schema.fields) + context.log.info("Create the schema and table if they do not exist") metadata_df = add_missing_columns(metadata_df, table_columns) metadata_df = metadata_df.select(*StructType(table_columns).fieldNames()) @@ -327,7 +336,9 @@ def geolocation_data_quality_results( ) dq_results_schema_name = f"{schema_name}_dq_results" - table_name = f"{id}_{country_code}_{current_timestamp}" + # Replace hyphens with underscores so the identifier is valid in Spark SQL + safe_id = id.replace("-", "_") + table_name = f"{safe_id}_{country_code}_{current_timestamp}" schema_columns = [ StructField(field.name, field.dataType, nullable=True) diff --git a/dagster/src/utils/delta.py b/dagster/src/utils/delta.py index fcc0b8fad..419323262 100644 --- a/dagster/src/utils/delta.py +++ b/dagster/src/utils/delta.py @@ -1,3 +1,5 @@ +import uuid + from delta.tables import DeltaMergeBuilder, DeltaTable, DeltaTableBuilder from icecream import ic from pyspark import sql @@ -365,6 +367,39 @@ def remove_column_id_props( spark.sql(f"ALTER TABLE {table_name} UNSET TBLPROPERTIES IF EXISTS ({props})") +_PK_PROPERTY_KEY = "giga.pkColumnIds" + + +def get_stored_pk_uuids(spark: SparkSession, table_name: str) -> set[str]: + """Return the set of UUIDs marked as primary keys on this table. + + Returns an empty set if the property is absent (e.g. tables created + before PK persistence was added). + """ + spark.catalog.refreshTable(table_name) + detail = spark.sql(f"DESCRIBE DETAIL {table_name}").collect()[0] + properties: dict = detail["properties"] if detail["properties"] else {} + raw = properties.get(_PK_PROPERTY_KEY) + if not raw: + return set() + return {part.strip() for part in raw.split(",") if part.strip()} + + +def store_pk_uuids( + spark: SparkSession, + table_name: str, + pk_uuids: set[str], +) -> None: + """Persist (or update) the set of primary-key UUIDs for this table.""" + if not pk_uuids: + return + joined = ",".join(sorted(pk_uuids)) + spark.sql( + f"ALTER TABLE {table_name} SET TBLPROPERTIES " + f"('{_PK_PROPERTY_KEY}' = '{joined}')" + ) + + def detect_renames_and_deletes( existing_id_map: dict[str, str], updated_id_map: dict[str, str], @@ -460,66 +495,173 @@ def execute_deletes( context: OpExecutionContext, ) -> None: """Execute column drop SQL statements.""" - context.log.info(f"Dropping columns: {deletes}") - for col_name in deletes: + detail = spark.sql(f"DESCRIBE DETAIL {table_name}").collect()[0] + partition_columns: set[str] = set(detail["partitionColumns"] or []) + + skipped = [c for c in deletes if c in partition_columns] + to_drop = [c for c in deletes if c not in partition_columns] + + if skipped: + context.log.info( + f"Skipping delete for partition column(s) {skipped} on {table_name} " + f"— Delta Lake does not allow dropping partition columns." + ) + + context.log.info(f"Dropping columns: {to_drop}") + for col_name in to_drop: stmt = f"ALTER TABLE {table_name} DROP COLUMN `{col_name}`" context.log.info(f"Executing: {stmt}") spark.sql(stmt) - remove_column_id_props(spark, table_name, deletes) + remove_column_id_props(spark, table_name, to_drop) -def apply_renames_and_deletes( +def build_excluded_columns( + partition_columns: set[str], + updated_schema: StructType | None, + updated_id_map: dict[str, str], +) -> set[str]: + """Return columns that should be excluded from rename/delete detection.""" + caller_managed = ( + set(updated_schema.fieldNames()) if updated_schema is not None else set() + ) + excluded = set(partition_columns) + if caller_managed: + excluded |= {c for c in caller_managed if c not in updated_id_map} + return excluded + + +def bootstrap_orphan_columns( spark: SparkSession, table_name: str, - schema_name: str, + existing_id_map: dict[str, str], + updated_id_map: dict[str, str], + excluded: set[str], context: OpExecutionContext, -) -> bool: - """Detect and apply column renames and deletes to a Delta table based on the reference schema. +) -> None: + """Supplement existing_id_map for columns present in the table but lacking stored UUIDs.""" + current_columns = DeltaTable.forName(spark, table_name).toDF().schema.fieldNames() + for col_name in current_columns: + if col_name in excluded: + continue + if col_name not in existing_id_map: + if col_name in updated_id_map: + existing_id_map[col_name] = updated_id_map[col_name] + context.log.info( + f"Column '{col_name}' exists in table but has no stored UUID; " + f"matched to schema UUID '{updated_id_map[col_name]}'." + ) + else: + existing_id_map[col_name] = f"table_{col_name}" + context.log.info( + f"Column '{col_name}' exists in table but has no stored UUID " + f"and is not in the updated schema; " + f"tagging with synthetic ID for delete detection." + ) - Returns True if any schema change occurred (rename or delete). - IMPORTANT: This function ALWAYS supplements ``existing_id_map`` with any - table columns that lack stored UUIDs. This prevents the situation where a - partially-populated ``column_id_map`` causes some columns to be silently - ignored by rename/delete detection (which previously led to those columns - being dropped by the fallback path in :func:`sync_schema`, causing data - loss). +def filter_pk_changes( + existing_id_map: dict[str, str], + updated_id_map: dict[str, str], + renames: dict[str, str], + deletes: list[str], + primary_key_columns: set[str], + persisted_pk_uuids: set[str], + context: OpExecutionContext, +) -> tuple[dict[str, str], list[str], set[str], dict[str, str]]: + """Filter renames and deletes that touch primary-key columns. + + Returns (filtered_renames, filtered_deletes, blocked_renames, renames_original). """ - from src.utils.schema import get_schema_columns_with_id + pk_uuids = { + updated_id_map[name] for name in primary_key_columns if name in updated_id_map + } + pk_uuids |= persisted_pk_uuids + renames_original = dict(renames) + + blocked_renames: set[str] = set() + if pk_uuids: + blocked_renames = { + old_name + for old_name, new_name in renames.items() + if existing_id_map.get(old_name) in pk_uuids + } + if blocked_renames: + context.log.warning( + f"Blocked rename of primary key columns: {blocked_renames}" + ) + for old_name in blocked_renames: + del renames[old_name] - columns_with_id = get_schema_columns_with_id(spark, schema_name) - updated_id_map = {field.name: csv_id for csv_id, field in columns_with_id} - existing_id_map = get_stored_column_id_map(spark, table_name) + blocked_deletes = [ + col_name + for col_name in deletes + if existing_id_map.get(col_name) in pk_uuids + ] + if blocked_deletes: + context.log.warning( + f"Blocked delete of primary key columns: {blocked_deletes}" + ) + for col_name in blocked_deletes: + deletes.remove(col_name) - # If the table has NO stored UUIDs at all (pre-existing table that was - # created before UUID-based tracking was introduced), bootstrapping is - # needed first. We must NOT treat all columns as deletes — that would - # cause data loss by dropping every column. Instead, persist the current - # mapping now and skip rename/delete detection for this run. - if not existing_id_map: - context.log.info( - f"No stored column-ID mapping found for {table_name}. " - "Bootstrapping UUID props from current schema. " - "Rename/delete detection will be active from the next run onwards." - ) - persist_column_id_map(spark, table_name, schema_name) - return False + return renames, deletes, blocked_renames, renames_original - # Supplement existing_id_map with table columns that lack stored UUIDs. - # This handles columns added via mergeSchema after the initial bootstrap - # (e.g. ADD operations that ran before persist_column_id_map was called). - # Tag them with a synthetic ID so they are treated as deletes if they are - # no longer in the reference schema. - current_columns = DeltaTable.forName(spark, table_name).toDF().schema.fieldNames() - for col_name in current_columns: - if col_name not in existing_id_map: - existing_id_map[col_name] = f"table_{col_name}" - context.log.info( - f"Column '{col_name}' exists in table but has no stored UUID; " - f"tagging with synthetic ID for delete detection." - ) + +def detect_and_filter_changes( + spark: SparkSession, + table_name: str, + existing_id_map: dict[str, str], + updated_id_map: dict[str, str], + updated_schema: StructType | None, + partition_columns: set[str], + primary_key_columns: set[str], + persisted_pk_uuids: set[str], + context: OpExecutionContext, +) -> tuple[dict[str, str], list[str], set[str], dict[str, str]]: + """Detect renames/deletes, exclude partition/caller-managed cols, and filter PK changes.""" + excluded = build_excluded_columns(partition_columns, updated_schema, updated_id_map) + + if excluded: + for col_name in list(existing_id_map.keys()): + if col_name in excluded: + reason = ( + "partition column" + if col_name in partition_columns + else "caller-managed column" + ) + context.log.info( + f"Excluding '{col_name}' from rename/delete detection ({reason})." + ) + del existing_id_map[col_name] + for col_name in list(updated_id_map.keys()): + if col_name in excluded: + del updated_id_map[col_name] + + bootstrap_orphan_columns( + spark, table_name, existing_id_map, updated_id_map, excluded, context + ) renames, deletes = detect_renames_and_deletes(existing_id_map, updated_id_map) + + return filter_pk_changes( + existing_id_map, + updated_id_map, + renames, + deletes, + primary_key_columns, + persisted_pk_uuids, + context, + ) + + +def execute_renames_and_deletes( + spark: SparkSession, + table_name: str, + renames: dict[str, str], + deletes: list[str], + context: OpExecutionContext, +) -> None: + """Enable column mapping and execute any renames or deletes.""" context.log.info(f"Detected renames: {renames}") context.log.info(f"Detected deletes: {deletes}") @@ -538,7 +680,79 @@ def apply_renames_and_deletes( if renames or deletes: spark.catalog.refreshTable(table_name) - return bool(renames or deletes) + +def apply_renames_and_deletes( + spark: SparkSession, + table_name: str, + schema_name: str, + context: OpExecutionContext, + updated_schema: StructType | None = None, +) -> tuple[bool, set[str]]: + """Detect and apply column renames and deletes to a Delta table based on the reference schema.""" + + from src.utils.schema import get_schema_columns_with_id + + columns_with_id = get_schema_columns_with_id(spark, schema_name) + updated_id_map = {field.name: csv_id for csv_id, field in columns_with_id} + existing_id_map = get_stored_column_id_map(spark, table_name) + + if not existing_id_map: + context.log.info( + f"No stored column-ID mapping found for {table_name}. " + "Bootstrapping UUID props from current schema. " + "Rename/delete detection will be active from the next run onwards." + ) + persist_column_id_map(spark, table_name, schema_name) + return False, set() + + detail = spark.sql(f"DESCRIBE DETAIL {table_name}").collect()[0] + partition_columns: set[str] = set(detail["partitionColumns"] or []) + + # Discover primary key columns from the reference schema. + # The schemas Delta table has a ``primary_key`` boolean column. + primary_key_columns: set[str] = set() + try: + pk_rows = spark.sql( + f"SELECT name FROM schemas.{schema_name} WHERE primary_key = true" # nosec B608 + ).collect() + primary_key_columns = {row["name"] for row in pk_rows} + if primary_key_columns: + context.log.info( + f"Primary key columns detected: {sorted(primary_key_columns)}" + ) + except Exception: + # Schema table may not exist or lack the primary_key column + pass + + persisted_pk_uuids: set[str] = get_stored_pk_uuids(spark, table_name) + if persisted_pk_uuids: + context.log.info(f"Persisted PK UUIDs on table: {sorted(persisted_pk_uuids)}") + + renames, deletes, blocked_renames, renames_original = detect_and_filter_changes( + spark, + table_name, + existing_id_map, + updated_id_map, + updated_schema, + partition_columns, + primary_key_columns, + persisted_pk_uuids, + context, + ) + + execute_renames_and_deletes(spark, table_name, renames, deletes, context) + + # Collect the new names of blocked renames so that sync_schema can + # avoid adding them as new columns (since the rename was blocked). + blocked_new_names: set[str] = set() + if blocked_renames: + blocked_new_names = { + new_name + for old_name, new_name in renames_original.items() + if old_name in blocked_renames + } + + return bool(renames or deletes), blocked_new_names def persist_column_id_map( @@ -548,9 +762,58 @@ def persist_column_id_map( from src.utils.schema import get_schema_columns_with_id columns_with_id = get_schema_columns_with_id(spark, schema_name) - new_id_map = {field.name: csv_id for csv_id, field in columns_with_id} + # uuid -> schema_column_name + schema_uuid_to_name = {csv_id: field.name for csv_id, field in columns_with_id} + # schema_column_name -> uuid + schema_name_to_uuid = {field.name: csv_id for csv_id, field in columns_with_id} + + # Previous stored map (may have old names for blocked renames) + previous_map = get_stored_column_id_map(spark, table_name) + + spark.catalog.refreshTable(table_name) + table_columns = spark.table(table_name).columns + + new_id_map: dict[str, str] = {} + for col_name in table_columns: + # 1. If the column name exactly matches a schema column → use schema UUID + if col_name in schema_name_to_uuid: + new_id_map[col_name] = schema_name_to_uuid[col_name] + continue + + # 2. If the column has a stored UUID that maps to a schema column + # (rename was blocked, old table name with valid UUID) + prev_uuid = previous_map.get(col_name) + if prev_uuid and prev_uuid in schema_uuid_to_name: + new_id_map[col_name] = prev_uuid + continue + + # 3. Otherwise keep previous UUID if any, or generate new one + if prev_uuid: + new_id_map[col_name] = prev_uuid + else: + new_id_map[col_name] = str(uuid.uuid4()) + store_column_id_map(spark, table_name, new_id_map) + # Persist PK UUIDs for protection across schema removals. + # We accumulate (union) so a PK once registered stays protected. + try: + pk_rows = spark.sql( + f"SELECT name FROM schemas.{schema_name} WHERE primary_key = true" # nosec B608 + ).collect() + current_pk_uuids = { + schema_name_to_uuid[row["name"]] + for row in pk_rows + if row["name"] in schema_name_to_uuid + } + except Exception: + current_pk_uuids = set() + + previous_pk_uuids = get_stored_pk_uuids(spark, table_name) + merged_pk_uuids = previous_pk_uuids | current_pk_uuids + if merged_pk_uuids and merged_pk_uuids != previous_pk_uuids: + store_pk_uuids(spark, table_name, merged_pk_uuids) + def apply_datatype_changes( spark: SparkSession, @@ -644,9 +907,10 @@ def sync_schema( # 1. Detect and apply renames & deletes # ------------------------------------------------------------------ any_renames_deletes = False + blocked_new_names: set[str] = set() if schema_name is not None: - any_renames_deletes = apply_renames_and_deletes( - spark, table_name, schema_name, context + any_renames_deletes, blocked_new_names = apply_renames_and_deletes( + spark, table_name, schema_name, context, updated_schema=updated_schema ) # ------------------------------------------------------------------ @@ -656,6 +920,20 @@ def sync_schema( spark.catalog.refreshTable(table_name) existing_schema = spark.table(table_name).schema + # ------------------------------------------------------------------ + # 2a. Remove blocked-new names from updated_schema so they are not + # incorrectly added as new columns (e.g. a primary key rename + # that was blocked should not create a duplicate column). + # ------------------------------------------------------------------ + if blocked_new_names: + context.log.info( + f"Blocked rename targets excluded from add logic: {blocked_new_names}" + ) + filtered_fields = [ + f for f in updated_schema.fields if f.name not in blocked_new_names + ] + updated_schema = StructType(filtered_fields) + # ------------------------------------------------------------------ # 3. Detect added columns & datatype changes (existing logic) # ------------------------------------------------------------------ diff --git a/dagster/tests/utils/test_delta_sync_schema.py b/dagster/tests/utils/test_delta_sync_schema.py index ec00919e1..a314b83a7 100644 --- a/dagster/tests/utils/test_delta_sync_schema.py +++ b/dagster/tests/utils/test_delta_sync_schema.py @@ -142,21 +142,163 @@ def test_partial_mapping_still_detects_deletes(self): """When existing_id_map is PARTIALLY populated (some columns have stored UUIDs), columns that lack a UUID are given synthetic IDs and treated as deletes if they are absent from the updated schema. + + Columns whose name matches a key in updated_id_map get the REAL UUID instead, + preventing false deletes for unchanged business columns. """ existing_id_map = {"col_a": "id-1", "col_b": "id-2"} updated_id_map = {"col_a": "id-1", "col_b": "id-2"} current_table_columns = ["col_a", "col_b", "orphan_col"] - # Supplement with synthetic ID for orphan column + # Supplement: use real UUID when name matches, synthetic for orphans for col_name in current_table_columns: if col_name not in existing_id_map: - existing_id_map[col_name] = f"table_{col_name}" + if col_name in updated_id_map: + existing_id_map[col_name] = updated_id_map[col_name] + else: + existing_id_map[col_name] = f"table_{col_name}" renames, deletes = detect_renames_and_deletes(existing_id_map, updated_id_map) assert deletes == ["orphan_col"] assert renames == {} +class TestFalseDeletePrevention: + """Regression tests for the drop-and-re-add bug. + + Bug: When existing_id_map was partially populated, columns that existed + in BOTH the table AND the schema CSV (unchanged columns) were tagged with + synthetic IDs. The synthetic IDs never matched the real schema UUIDs, + causing them to be falsely detected as deletes, dropped, and then + immediately re-added by the sync_schema add-columns logic — losing all + data in those columns. + + Fix: When supplementing existing_id_map, if a column name matches a key + in updated_id_map, use the real schema UUID instead of a synthetic one. + """ + + @staticmethod + def _supplement_logic( + stored_id_map: dict, + updated_id_map: dict, + current_table_columns: list, + excluded: set | None = None, + ): + """Mirror the FIXED production supplementing logic.""" + existing = dict(stored_id_map) + excluded = excluded or set() + for col_name in current_table_columns: + if col_name in excluded: + continue + if col_name not in existing: + if col_name in updated_id_map: + existing[col_name] = updated_id_map[col_name] + else: + existing[col_name] = f"table_{col_name}" + return existing + + def test_unchanged_columns_not_falsely_deleted(self): + """Reproduces the exact user-reported bug. + + Table has: school_id_giga, longitude, school_name, num_students, latitude + Schema CSV has the same columns (with UUIDs). + Only school_id_giga has a stored UUID; the rest don't. + Expected: no renames, no deletes. + """ + stored_id_map = {"school_id_giga": "uuid-sig"} # partial: only one stored UUID + updated_id_map = { + "school_id_giga": "uuid-sig", + "longitude": "uuid-lon", + "school_name": "uuid-sn", + "num_students": "uuid-ns", + "latitude": "uuid-lat", + } + current_table_columns = [ + "school_id_giga", + "longitude", + "school_name", + "num_students", + "latitude", + ] + + existing = self._supplement_logic( + stored_id_map, updated_id_map, current_table_columns + ) + + renames, deletes = detect_renames_and_deletes(existing, updated_id_map) + # All columns should be recognized as unchanged (real UUIDs match) + assert renames == {} + assert deletes == [] + + def test_mixed_unchanged_and_orphan(self): + """Columns in both table and CSV are preserved; orphans are deleted.""" + stored_id_map = {"col_a": "uuid-a"} + updated_id_map = { + "col_a": "uuid-a", + "col_b": "uuid-b", # in table but no stored UUID + "col_c": "uuid-c", # in table but no stored UUID + } + current_table_columns = ["col_a", "col_b", "col_c", "orphan_col"] + + existing = self._supplement_logic( + stored_id_map, updated_id_map, current_table_columns + ) + + renames, deletes = detect_renames_and_deletes(existing, updated_id_map) + assert renames == {} + # Only orphan_col should be deleted (not col_b or col_c) + assert deletes == ["orphan_col"] + + def test_rename_still_works_with_fix(self): + """Renames are still detected when one column has a stored UUID + and the schema CSV has a new name for that UUID.""" + stored_id_map = {"old_name": "uuid-1", "col_b": "uuid-2"} + updated_id_map = { + "new_name": "uuid-1", # rename via UUID + "col_b": "uuid-2", + "col_c": "uuid-c", # in table but no stored UUID + } + current_table_columns = ["old_name", "col_b", "col_c"] + + existing = self._supplement_logic( + stored_id_map, updated_id_map, current_table_columns + ) + + renames, deletes = detect_renames_and_deletes(existing, updated_id_map) + assert renames == {"old_name": "new_name"} + assert deletes == [] + + def test_old_synthetic_logic_would_cause_false_deletes(self): + """Demonstrates that the OLD logic (always synthetic) causes false deletes.""" + stored_id_map = {"school_id_giga": "uuid-sig"} + updated_id_map = { + "school_id_giga": "uuid-sig", + "longitude": "uuid-lon", + "school_name": "uuid-sn", + } + current_table_columns = ["school_id_giga", "longitude", "school_name"] + + # OLD logic: always use synthetic ID + existing_buggy = dict(stored_id_map) + for col_name in current_table_columns: + if col_name not in existing_buggy: + existing_buggy[col_name] = f"table_{col_name}" # BUG + + _, deletes_buggy = detect_renames_and_deletes(existing_buggy, updated_id_map) + # OLD: longitude and school_name are falsely detected as deletes + assert "longitude" in deletes_buggy + assert "school_name" in deletes_buggy + + # NEW logic: use real UUID when name matches + existing_fixed = self._supplement_logic( + stored_id_map, updated_id_map, current_table_columns + ) + + _, deletes_fixed = detect_renames_and_deletes(existing_fixed, updated_id_map) + # NEW: no false deletes + assert deletes_fixed == [] + + class TestMultipleOperations: """Test handling of multiple simultaneous add, rename, and delete operations.""" @@ -489,6 +631,357 @@ def test_removed_columns_in_stale(self): assert stale == ["col_to_delete"] +class TestPartitionColumnExclusion: + """Regression tests for DELTA_UNSUPPORTED_DROP_PARTITION_COLUMN bug. + + Staging tables have technical partition columns (upload_id, etc.) that are + never present in the business schema CSV. Before the fix these were tagged + with synthetic IDs and detected as deletes, causing: + ALTER TABLE ... DROP COLUMN `upload_id` + which Delta Lake rejects with DELTA_UNSUPPORTED_DROP_PARTITION_COLUMN. + + The fix: skip partition columns when supplementing existing_id_map. + """ + + def test_partition_columns_excluded_from_synthetic_tagging(self): + """Partition columns must not appear in the deletes list.""" + stored_id_map = {"electricity_type": "uuid-elec", "school_name": "uuid-name"} + updated_id_map = { + "electricity_type_test": "uuid-elec", + "school_name": "uuid-name", + } + partition_columns = {"upload_id"} + + # Simulate staging table columns including partition column + current_table_columns = [ + "electricity_type", + "school_name", + "upload_id", # partition column — must be skipped + "change_type", # technical column without UUID + ] + + existing_id_map = dict(stored_id_map) + for col_name in current_table_columns: + if col_name in partition_columns: + continue # THE FIX: skip partition columns + if col_name not in existing_id_map: + existing_id_map[col_name] = f"table_{col_name}" + + renames, deletes = detect_renames_and_deletes(existing_id_map, updated_id_map) + + assert "upload_id" not in deletes + assert renames == {"electricity_type": "electricity_type_test"} + assert deletes == ["change_type"] + + def test_partition_columns_not_in_deletes_without_fix(self): + """Demonstrates the bug: without the fix, upload_id would be in deletes.""" + stored_id_map = {"electricity_type": "uuid-elec"} + updated_id_map = {"electricity_type_test": "uuid-elec"} + current_table_columns = ["electricity_type", "upload_id"] + + # WITHOUT the fix (no partition exclusion) + existing_id_map_buggy = dict(stored_id_map) + for col_name in current_table_columns: + if col_name not in existing_id_map_buggy: + existing_id_map_buggy[col_name] = f"table_{col_name}" + + _, deletes_buggy = detect_renames_and_deletes( + existing_id_map_buggy, updated_id_map + ) + assert "upload_id" in deletes_buggy # proves the bug existed + + # WITH the fix (partition columns skipped) + partition_columns = {"upload_id"} + existing_id_map_fixed = dict(stored_id_map) + for col_name in current_table_columns: + if col_name in partition_columns: + continue + if col_name not in existing_id_map_fixed: + existing_id_map_fixed[col_name] = f"table_{col_name}" + + _, deletes_fixed = detect_renames_and_deletes( + existing_id_map_fixed, updated_id_map + ) + assert "upload_id" not in deletes_fixed # fix works + + def test_rollback_rename_with_partition_columns(self): + """The exact rollback scenario: electricity_type_test -> electricity_type. + + Staging table has partition column upload_id plus technical columns. + The rollback rename must succeed without attempting to drop upload_id. + """ + stored_id_map = { + "electricity_type_test": "uuid-elec", + "school_name": "uuid-name", + "change_type": "uuid-change", + } + updated_id_map = { + "electricity_type": "uuid-elec", # rollback rename + "school_name": "uuid-name", + "change_type": "uuid-change", + } + partition_columns = {"upload_id"} + current_table_columns = [ + "electricity_type_test", + "school_name", + "change_type", + "upload_id", # partition — must be excluded + "uploaded_columns", + "status", + ] + + existing_id_map = dict(stored_id_map) + for col_name in current_table_columns: + if col_name in partition_columns: + continue + if col_name not in existing_id_map: + existing_id_map[col_name] = f"table_{col_name}" + + renames, deletes = detect_renames_and_deletes(existing_id_map, updated_id_map) + + assert renames == {"electricity_type_test": "electricity_type"} + assert "upload_id" not in deletes + # uploaded_columns and status are technical orphans — correctly detected as deletes + # but NOT partition columns so Delta won't reject them (they can be dropped or ignored + # by execute_deletes' safety guard) + assert "upload_id" not in renames.keys() + assert "upload_id" not in renames.values() + + +class TestPartitionAndCallerManagedExclusion: + """Comprehensive coverage for the exclusion logic in apply_renames_and_deletes. + + These tests model the in-memory logic without touching Spark. Each test + simulates the EXACT logic in apply_renames_and_deletes after my fixes: + excluded = partition_columns | (caller_managed - business_csv) + # Strip excluded from existing_id_map and updated_id_map + # Skip excluded when supplementing synthetic IDs + """ + + @staticmethod + def _apply_logic( + stored_id_map: dict, + business_csv_map: dict, + current_table_columns: list, + partition_columns: set, + caller_managed_columns: set | None = None, + ): + """Mirror the production logic in apply_renames_and_deletes.""" + existing = dict(stored_id_map) + updated = dict(business_csv_map) + excluded = set(partition_columns) + if caller_managed_columns: + excluded |= {c for c in caller_managed_columns if c not in updated} + + # Strip excluded from both maps + for c in list(existing.keys()): + if c in excluded: + del existing[c] + for c in list(updated.keys()): + if c in excluded: + del updated[c] + + # Supplement with synthetic IDs (skip excluded) + for c in current_table_columns: + if c in excluded: + continue + if c not in existing: + existing[c] = f"table_{c}" + + return detect_renames_and_deletes(existing, updated) + + def test_multi_partition_table_all_protected(self): + """All partition columns (not just upload_id) must be protected.""" + # caller_managed_columns reflects the FUTURE (updated) schema + renames, deletes = self._apply_logic( + stored_id_map={"electricity_type": "u-elec"}, + business_csv_map={"electricity_type_test": "u-elec"}, + current_table_columns=[ + "electricity_type", + "year", + "month", + "country_code", + ], + partition_columns={"year", "month", "country_code"}, + caller_managed_columns={ + "electricity_type_test", # new name (from updated_schema) + "year", + "month", + "country_code", + }, + ) + assert renames == {"electricity_type": "electricity_type_test"} + assert "year" not in deletes + assert "month" not in deletes + assert "country_code" not in deletes + assert deletes == [] + + def test_partition_col_with_stored_uuid_not_dropped(self): + """Even if a partition column has a stored UUID (loophole A), it is excluded.""" + renames, deletes = self._apply_logic( + stored_id_map={ + "electricity_type": "u-elec", + "upload_id": "u-bogus-stored-uuid", # should never be acted on + }, + business_csv_map={"electricity_type": "u-elec"}, + current_table_columns=["electricity_type", "upload_id"], + partition_columns={"upload_id"}, + caller_managed_columns={"electricity_type", "upload_id"}, + ) + assert renames == {} + assert "upload_id" not in deletes + assert deletes == [] + + def test_partition_col_in_business_csv_is_ignored(self): + """If a partition column is mistakenly in the business CSV, exclusion still wins.""" + renames, deletes = self._apply_logic( + stored_id_map={"electricity_type": "u-elec"}, + business_csv_map={ + "electricity_type": "u-elec", + "upload_id": "u-bogus-csv", # mistakenly added to business CSV + }, + current_table_columns=["electricity_type", "upload_id"], + partition_columns={"upload_id"}, + caller_managed_columns={"electricity_type", "upload_id"}, + ) + assert renames == {} + assert deletes == [] + + def test_caller_managed_with_stored_uuid_not_dropped(self): + """Loophole A for technical columns: stored UUID for tech col must be ignored.""" + renames, deletes = self._apply_logic( + stored_id_map={ + "school_name": "u-sn", + "change_type": "u-bogus-tech", # tech col should never have UUID + }, + business_csv_map={"school_name": "u-sn"}, + current_table_columns=["school_name", "change_type", "status"], + partition_columns=set(), + caller_managed_columns={"school_name", "change_type", "status"}, + ) + assert renames == {} + assert "change_type" not in deletes + assert "status" not in deletes + assert deletes == [] + + def test_master_table_no_partitions_no_change(self): + """Master tables have no partitions and no tech cols — must work as before.""" + renames, deletes = self._apply_logic( + stored_id_map={ + "school_name": "u-sn", + "old_funding": "u-fund", + "to_drop": "u-drop", + }, + business_csv_map={ + "school_name": "u-sn", + "new_funding": "u-fund", + }, + current_table_columns=["school_name", "old_funding", "to_drop"], + partition_columns=set(), # no partitions + caller_managed_columns=None, # no updated_schema passed + ) + assert renames == {"old_funding": "new_funding"} + assert deletes == ["to_drop"] + + def test_orphan_business_col_still_detected(self): + """Business columns NOT in updated_schema NOR partitions are still subject to delete.""" + renames, deletes = self._apply_logic( + stored_id_map={"keep_me": "u-keep"}, + business_csv_map={"keep_me": "u-keep"}, + # orphan_biz is in the table but NOT in updated_schema (was supposed to be deleted) + current_table_columns=["keep_me", "orphan_biz", "upload_id"], + partition_columns={"upload_id"}, + caller_managed_columns={"keep_me", "upload_id"}, # orphan_biz NOT here + ) + # orphan_biz is not partition, not caller-managed → tagged synthetic → delete + assert renames == {} + assert deletes == ["orphan_biz"] + + def test_rollback_rename_full_scenario(self): + """End-to-end model of the manager's rollback scenario.""" + partition_cols = {"upload_id"} + tech_cols = { + "change_type", + "uploaded_columns", + "status", + "change_id", + "created_at", + "processed_at", + "approval_request_log_id", + "master_version", + } + # All 28 staging columns + current = ["school_id_giga", "school_name", "electricity_type_test"] + current += list(tech_cols) + list(partition_cols) + + # Stored UUIDs (only business) + stored = { + "school_id_giga": "u1", + "school_name": "u2", + "electricity_type_test": "u-elec", # current name in table + } + # New CSV: rollback rename + csv = { + "school_id_giga": "u1", + "school_name": "u2", + "electricity_type": "u-elec", # rollback to old name + } + # Updated schema (full pending) = business + tech + partition + updated_schema_names = set(csv.keys()) | tech_cols | partition_cols + + renames, deletes = self._apply_logic( + stored_id_map=stored, + business_csv_map=csv, + current_table_columns=current, + partition_columns=partition_cols, + caller_managed_columns=updated_schema_names, + ) + + assert renames == {"electricity_type_test": "electricity_type"} + assert deletes == [] + + def test_combined_multi_operations_with_partitions(self): + """Multi rename + multi delete + multi add, all with partition cols.""" + partition_cols = {"upload_id"} + tech_cols = {"change_type", "status"} + current = [ + "keep_a", + "old_b", + "old_c", # business + "to_drop_x", + "to_drop_y", # business deletes + "upload_id", # partition + "change_type", + "status", # tech + ] + stored = { + "keep_a": "u1", + "old_b": "u2", + "old_c": "u3", + "to_drop_x": "u4", + "to_drop_y": "u5", + } + csv = { + "keep_a": "u1", + "new_b": "u2", + "new_c": "u3", # 2 renames + "added_p": "u6", + "added_q": "u7", # 2 adds (handled elsewhere) + } + updated_schema_names = set(csv.keys()) | tech_cols | partition_cols + + renames, deletes = self._apply_logic( + stored_id_map=stored, + business_csv_map=csv, + current_table_columns=current, + partition_columns=partition_cols, + caller_managed_columns=updated_schema_names, + ) + + assert renames == {"old_b": "new_b", "old_c": "new_c"} + assert sorted(deletes) == ["to_drop_x", "to_drop_y"] + + class TestSyncSchemaSafeDropPath: """Tests for the sync_schema secondary drop path safety fix. @@ -535,3 +1028,152 @@ def test_safe_drop_path_without_schema_name(self): should_drop = list(removed_columns) assert should_drop == ["col_to_drop"] + + +class TestPrimaryKeyProtection: + """Unit tests for primary-key column protection in rename/delete detection. + + A primary key column must NEVER be renamed or deleted, even if the + schema CSV is updated with a different name for the same UUID. + """ + + def _apply_with_pk( + self, + stored_id_map: dict[str, str], + business_csv_map: dict[str, str], + pk_names: set[str], + persisted_pk_uuids: set[str] | None = None, + ) -> tuple[dict[str, str], list[str], set[str]]: + """Mimic the detection + PK-filtering logic from apply_renames_and_deletes. + + ``persisted_pk_uuids`` mimics the ``giga.pkColumnIds`` table property + that persists PK UUIDs even when the CSV removes the column. + """ + existing_id_map = dict(stored_id_map) + updated_id_map = dict(business_csv_map) + persisted_pk_uuids = set(persisted_pk_uuids or ()) + + # detect + id_to_name = {v: k for k, v in updated_id_map.items()} + renames: dict[str, str] = {} + deletes: list[str] = [] + for name, uid in existing_id_map.items(): + if uid in id_to_name and id_to_name[uid] != name: + renames[name] = id_to_name[uid] + elif uid not in id_to_name: + deletes.append(name) + + # PK filtering (same logic as production code). + # PK UUIDs come from BOTH the CSV-declared PK names and any persisted + # PK UUIDs (the latter mimics ``giga.pkColumnIds`` on the data table). + pk_uuids = {updated_id_map[n] for n in pk_names if n in updated_id_map} + pk_uuids |= persisted_pk_uuids + renames_original = dict(renames) + if pk_uuids: + blocked_renames = { + old_name + for old_name, new_name in renames.items() + if existing_id_map.get(old_name) in pk_uuids + } + for old_name in blocked_renames: + del renames[old_name] + blocked_deletes = [ + col_name + for col_name in deletes + if existing_id_map.get(col_name) in pk_uuids + ] + for col_name in blocked_deletes: + deletes.remove(col_name) + + # blocked new names to exclude from add logic + blocked_new_names = set() + if pk_uuids: + blocked_new_names = { + new_name + for old_name, new_name in renames_original.items() + if old_name + in { + o + for o, _ in renames_original.items() + if existing_id_map.get(o) in pk_uuids + } + } + + return renames, deletes, blocked_new_names + + def test_pk_rename_is_blocked(self): + """A primary key column rename must be blocked.""" + stored = {"school_id_giga": "u1", "latitude": "u2"} + csv = {"school_id_giga_renamed": "u1", "latitude": "u2"} + pk_names = {"school_id_giga_renamed"} + + renames, deletes, blocked = self._apply_with_pk(stored, csv, pk_names) + assert renames == {} + assert deletes == [] + assert blocked == {"school_id_giga_renamed"} + + def test_pk_delete_is_blocked(self): + """A primary key column delete must be blocked. + + Scenario: the schema CSV no longer contains the PK column at all, + so ``primary_key_columns`` (derived from the CSV) is empty. The + PK UUID is still recorded on the data table via + ``giga.pkColumnIds``; that persisted set must keep the column + protected from being dropped. + """ + stored = {"school_id_giga": "u1", "latitude": "u2"} + csv = {"latitude": "u2"} # school_id_giga entirely removed from CSV + pk_names: set[str] = set() # CSV has no PK markers + persisted_pks = {"u1"} # but UUID u1 was historically the PK + + renames, deletes, blocked = self._apply_with_pk( + stored, csv, pk_names, persisted_pks + ) + assert renames == {} + # PK delete must be blocked — school_id_giga stays + assert deletes == [] + assert blocked == set() + + def test_pk_delete_without_persisted_uuid_is_allowed(self): + """Without persisted PK UUIDs, the legacy behaviour is preserved. + + This documents the prior (pre-fix) behaviour for tables that have + not yet had ``giga.pkColumnIds`` written — the delete proceeds. + """ + stored = {"school_id_giga": "u1", "latitude": "u2"} + csv = {"latitude": "u2"} + pk_names: set[str] = set() + + renames, deletes, blocked = self._apply_with_pk( + stored, csv, pk_names, persisted_pk_uuids=set() + ) + assert renames == {} + # No PK protection available → delete proceeds + assert deletes == ["school_id_giga"] + assert blocked == set() + + def test_non_pk_rename_is_allowed(self): + """Non-PK column renames should still proceed normally.""" + stored = {"school_id_giga": "u1", "latitude": "u2"} + csv = {"school_id_giga": "u1", "lat_renamed": "u2"} + pk_names = {"school_id_giga"} + + renames, deletes, blocked = self._apply_with_pk(stored, csv, pk_names) + assert renames == {"latitude": "lat_renamed"} + assert deletes == [] + assert blocked == set() + + def test_pk_and_regular_rename_together(self): + """Mixed scenario: PK rename blocked, regular rename allowed.""" + stored = {"school_id_giga": "u1", "latitude": "u2", "old_col": "u3"} + csv = { + "school_id_giga_renamed": "u1", + "lat_renamed": "u2", + "new_col": "u3", + } + pk_names = {"school_id_giga_renamed"} + + renames, deletes, blocked = self._apply_with_pk(stored, csv, pk_names) + assert renames == {"latitude": "lat_renamed", "old_col": "new_col"} + assert deletes == [] + assert blocked == {"school_id_giga_renamed"} From 9d7986d5ecfde169e2393c0447661b0ff869c548 Mon Sep 17 00:00:00 2001 From: Bidhan Mondal Date: Wed, 13 May 2026 10:24:18 +0530 Subject: [PATCH 25/26] fix: catch exception --- dagster/src/utils/delta.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/dagster/src/utils/delta.py b/dagster/src/utils/delta.py index 419323262..05d6d52d3 100644 --- a/dagster/src/utils/delta.py +++ b/dagster/src/utils/delta.py @@ -988,5 +988,10 @@ def sync_schema( except AnalysisException as exc: if "DELTA_CONSTRAINT_ALREADY_EXISTS" in str(exc): continue + elif "DELTA_NEW_CHECK_CONSTRAINT_VIOLATION" in str(exc): + context.log.warning( + f"Skipping NOT NULL constraint because existing data has nulls: {exc}" + ) + continue else: raise From becc0b7d50af9ffc7b6aaccfc72003bff79657a2 Mon Sep 17 00:00:00 2001 From: Bidhan Mondal Date: Thu, 14 May 2026 13:24:35 +0530 Subject: [PATCH 26/26] fix: added one time script for column id mapping --- .../migrations/bootstrap_column_id_maps.py | 191 ++++++++++++++++++ dagster/src/utils/delta.py | 11 +- 2 files changed, 199 insertions(+), 3 deletions(-) create mode 100644 dagster/src/assets/migrations/bootstrap_column_id_maps.py diff --git a/dagster/src/assets/migrations/bootstrap_column_id_maps.py b/dagster/src/assets/migrations/bootstrap_column_id_maps.py new file mode 100644 index 000000000..8a3e7f663 --- /dev/null +++ b/dagster/src/assets/migrations/bootstrap_column_id_maps.py @@ -0,0 +1,191 @@ +"""One-time migration: bootstrap giga.columnId.* table properties for all existing +Delta tables that predate the UUID-based rename/delete detection feature. + +For tables where the schema CSV has already been renamed before this feature was +deployed, the script also detects the mismatch and renames the physical column to +match the current schema before persisting the mapping — eliminating the need for +manual per-table ALTER TABLE statements. + +Run once per environment after deploying the rename/delete detection feature. +Safe to re-run: tables that already have the mapping stored are skipped. +""" + +from dagster_pyspark import PySparkResource +from pyspark.sql import SparkSession +from src.constants import DataTier +from src.utils.delta import ( + enable_column_mapping, + get_stored_column_id_map, + persist_column_id_map, +) +from src.utils.schema import ( + construct_schema_name_for_tier, + get_schema_columns_with_id, +) + +from dagster import OpExecutionContext, asset + + +def _get_all_country_tables( + spark: SparkSession, + tier_schema: str, +) -> list[str]: + """Return all fully-qualified table names in a given schema.""" + if not spark.catalog.databaseExists(tier_schema): + return [] + rows = spark.sql(f"SHOW TABLES IN `{tier_schema}`").collect() + return [f"{tier_schema}.{row['tableName']}" for row in rows] + + +def _resolve_schema_name(dataset_type: str) -> str: + """Return the schemas-namespace table name for a dataset type, e.g. 'school_geolocation'.""" + return dataset_type + + +def bootstrap_table( + spark: SparkSession, + context: OpExecutionContext, + full_table_name: str, + schema_name: str, + dry_run: bool = False, +) -> None: + """Bootstrap or repair the column-ID mapping for a single Delta table. + + Steps: + 1. If the table already has a complete mapping, skip it. + 2. Build updated_id_map from the schema CSV (name -> uuid). + 3. For physical columns that have no stored UUID and no direct name match in + the schema, try to match them to unmatched schema columns by comparing + the set of columns that differ between table and schema. + If the mismatch is unambiguous (1 old name : 1 new name), rename the + physical column and proceed. If ambiguous, log a warning and skip. + 4. Call persist_column_id_map to store the final mapping. + """ + if not spark.catalog.tableExists(full_table_name): + context.log.info(f"Table {full_table_name} does not exist — skipping.") + return + + existing_map = get_stored_column_id_map(spark, full_table_name) + + try: + columns_with_id = get_schema_columns_with_id(spark, schema_name) + except Exception as exc: + context.log.warning( + f"Could not load schema '{schema_name}' for {full_table_name}: {exc} — skipping." + ) + return + + updated_id_map: dict[str, str] = { + field.name: csv_id for csv_id, field in columns_with_id + } + + spark.catalog.refreshTable(full_table_name) + physical_columns: list[str] = spark.table(full_table_name).columns + + # Columns in the table that have no stored UUID and no direct name match in schema + orphan_table_cols = [ + c for c in physical_columns if c not in existing_map and c not in updated_id_map + ] + # Schema columns that are not physically present in the table + missing_schema_cols = [c for c in updated_id_map if c not in physical_columns] + + if orphan_table_cols: + context.log.info( + f"{full_table_name}: orphan table columns (no UUID, no schema match): " + f"{orphan_table_cols}" + ) + context.log.info( + f"{full_table_name}: missing schema columns (in schema but not in table): " + f"{missing_schema_cols}" + ) + + if orphan_table_cols and missing_schema_cols: + if len(orphan_table_cols) == len(missing_schema_cols): + # Unambiguous 1-to-1 mismatch — safe to rename + renames = dict(zip(orphan_table_cols, missing_schema_cols, strict=False)) + context.log.info(f"{full_table_name}: detected probable renames: {renames}") + if not dry_run: + enable_column_mapping(spark, full_table_name) + for old_name, new_name in renames.items(): + stmt = ( + f"ALTER TABLE {full_table_name} " + f"RENAME COLUMN `{old_name}` TO `{new_name}`" + ) + context.log.info(f"Executing: {stmt}") + spark.sql(stmt) + spark.catalog.refreshTable(full_table_name) + else: + context.log.info( + f"[DRY RUN] Would rename columns in {full_table_name}: {renames}" + ) + else: + # Ambiguous — multiple columns differ; cannot safely auto-rename + context.log.warning( + f"{full_table_name}: ambiguous column mismatch " + f"(orphan table cols: {orphan_table_cols}, " + f"missing schema cols: {missing_schema_cols}). " + f"Cannot auto-rename. Manual intervention required." + ) + # Still persist what we can so the rest of the mapping is stored + elif not orphan_table_cols and not missing_schema_cols: + context.log.info( + f"{full_table_name}: all physical columns match the schema by name." + ) + + if not dry_run: + persist_column_id_map(spark, full_table_name, schema_name) + context.log.info(f"{full_table_name}: column-ID mapping persisted.") + else: + context.log.info( + f"[DRY RUN] Would persist column-ID mapping for {full_table_name}." + ) + + +@asset +def bootstrap_column_id_maps( + context: OpExecutionContext, + spark: PySparkResource, +) -> None: + """One-time asset: bootstrap giga.columnId.* props for all existing Delta tables. + + Set DRY_RUN = True to log what would happen without making any changes. + """ + DRY_RUN = False + + s: SparkSession = spark.spark_session + + # Map of (dataset_type, tier) -> schema_name used for rename/delete detection + # Add more entries here if other dataset types are introduced. + dataset_configs: list[tuple[str, DataTier | None]] = [ + ("school_geolocation", DataTier.SILVER), + ("school_geolocation", DataTier.STAGING), + ("school_geolocation", None), # master: schema = "school_master" + ] + + for dataset_type, tier in dataset_configs: + if tier is None: + tier_schema = "school_master" + schema_name = dataset_type + else: + tier_schema = construct_schema_name_for_tier( + f"school_{dataset_type.replace('school_', '')}", tier + ) + schema_name = dataset_type + + context.log.info( + f"Processing schema '{tier_schema}' with schema_name='{schema_name}' ..." + ) + table_names = _get_all_country_tables(s, tier_schema) + + if not table_names: + context.log.info(f"No tables found in '{tier_schema}' — skipping.") + continue + + for full_table_name in table_names: + bootstrap_table( + spark=s, + context=context, + full_table_name=full_table_name, + schema_name=schema_name, + dry_run=DRY_RUN, + ) diff --git a/dagster/src/utils/delta.py b/dagster/src/utils/delta.py index 05d6d52d3..09d52e7a9 100644 --- a/dagster/src/utils/delta.py +++ b/dagster/src/utils/delta.py @@ -552,10 +552,15 @@ def bootstrap_orphan_columns( ) else: existing_id_map[col_name] = f"table_{col_name}" - context.log.info( + context.log.warning( f"Column '{col_name}' exists in table but has no stored UUID " - f"and is not in the updated schema; " - f"tagging with synthetic ID for delete detection." + f"and is not in the updated schema. " + f"This can happen when a schema rename occurred before the column-ID " + f"mapping was ever persisted for this table. " + f"The column will NOT be auto-renamed. " + f"To resolve: manually run " + f"`ALTER TABLE {table_name} RENAME COLUMN `{col_name}` TO ` " + f"and then re-run persist_column_id_map for this table." )