Skip to content
48 changes: 47 additions & 1 deletion src/country_workspace/contrib/hope/push/orchestration.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
)
from country_workspace.contrib.hope.exceptions import HopePushError
from country_workspace.exceptions import RemoteError, RemoteUnavailableError
from country_workspace.models import AsyncJob, Rdp
from country_workspace.models import AsyncJob, Household, Individual, Rdp
from country_workspace.workspaces.models import CountryIndividual

from .config import CreateRdpConfig, PushWorkflowConfig
from .policy import ActionCheck, get_rdp_policy
Expand Down Expand Up @@ -60,6 +61,50 @@ def _deduplication_snapshot(status: DedupClientStatus | None) -> dict[str, Any]:
}


def archive_removed_unique_values(rdp: Rdp, is_master_detail: bool) -> None:
"""Persist unique-field values for records about to be marked as removed after a successful push."""
program = rdp.program
owner = selection_owner_for_rdp(rdp=rdp)

if is_master_detail:
hh_field = program.get_unique_field_for(Household)
ind_field = program.get_unique_field_for(Individual)
if not (hh_field or ind_field):
return
hh_pks = list(owner.households.filter(removed=False).values_list("pk", flat=True))
if not hh_pks:
return
if hh_field:
hh_values = owner.households.filter(pk__in=hh_pks).values_list(f"flex_fields__{hh_field}", flat=True)
program.add_removed_unique_values_for(Household, hh_values.iterator())
if ind_field:
ind_values = CountryIndividual.objects.filter(household_id__in=hh_pks, removed=False).values_list(
f"flex_fields__{ind_field}", flat=True
)
program.add_removed_unique_values_for(Individual, ind_values.iterator())
return

if ind_field := program.get_unique_field_for(Individual):
ind_values = owner.individuals.filter(removed=False).values_list(f"flex_fields__{ind_field}", flat=True)
program.add_removed_unique_values_for(Individual, ind_values.iterator())


def steps(processor: PushProcessor, config: PushWorkflowConfig) -> Iterator[Callable[[], None]]:
"""Yield the ordered workflow callables; each step appends errors to processor.total."""
pks = config["pks"]

yield processor.preflight
yield processor.rdi_create
if config["master_detail"]:
yield from (
partial(processor.run_with, qs_individuals_by_household_pks(pks), processor.rdi_push_individuals),
partial(processor.run_with, qs_households(pks=pks), processor.rdi_push_households),
)
else:
yield partial(processor.run_with, qs_individuals_by_pks(pks), processor.rdi_push_people)
yield processor.rdi_complete


def _save_current_deduplication_snapshot(*, rdp: Rdp, key: str) -> None:
status = get_rdp_policy(rdp).deduplication_status(rdp)
snapshot = _deduplication_snapshot(status)
Expand Down Expand Up @@ -318,6 +363,7 @@ def push_existing_rdp_core(job: AsyncJob) -> dict[str, Any]:

with transaction.atomic():
locked = lock_rdp_for_update(pk=rdp.pk)
archive_removed_unique_values(locked, config["master_detail"])
_mark_rdp_beneficiaries_removed(locked, config["master_detail"])
set_rdp_push_status(
rdp=locked,
Expand Down
70 changes: 70 additions & 0 deletions src/country_workspace/models/program.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,3 +228,73 @@ def apply_default_fields(self, m: type[Validable] | Validable, data: dict[str, A
if field_name not in data or data[field_name] is None:
data[field_name] = default_value
return data

def has_any_data(self) -> bool:
if not self.pk:
return False
from country_workspace.models import Batch
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not in the beginning of module?


return Batch.objects.filter(program_id=self.pk).exists()

def get_unique_field_for(self, m: type[Validable] | Validable) -> str | None:
scope = self._scope_for(m).value
unique_fields = (self.system_fields or {}).get("unique_fields") or {}
value = unique_fields.get(scope)
return value if isinstance(value, str) and value.strip() else None

def save_unique_field_for(self, m: type[Validable] | Validable, field_name: str | None) -> None:
scope = self._scope_for(m).value
normalized = (field_name or "").strip() or None

system_fields = dict(self.system_fields or {})
unique_fields = dict(system_fields.get("unique_fields") or {})
removed_unique_values = dict(system_fields.get("removed_unique_values") or {})
scope_removed_values = dict(removed_unique_values.get(scope) or {})

if normalized is None:
unique_fields.pop(scope, None)
else:
unique_fields[scope] = normalized
scope_removed_values.setdefault(normalized, [])
removed_unique_values[scope] = scope_removed_values

system_fields["unique_fields"] = unique_fields
system_fields["removed_unique_values"] = removed_unique_values
self.system_fields = system_fields
self.save(update_fields=["system_fields"])

def get_removed_unique_values_for(self, m: type[Validable] | Validable) -> list[str]:
"""Return archived unique values for configured scope+field."""
if not (field_name := self.get_unique_field_for(m)):
return []

scope = self._scope_for(m).value
removed_unique_values = (self.system_fields or {}).get("removed_unique_values") or {}
scope_values = removed_unique_values.get(scope) or {}
values = scope_values.get(field_name) or []
if not isinstance(values, list):
return []
return [str(value) for value in values if value is not None and str(value).strip()]

def add_removed_unique_values_for(self, m: type[Validable] | Validable, values: Iterable[Any]) -> None:
"""Merge removed values for configured unique field in the given scope."""
if not (field_name := self.get_unique_field_for(m)):
return
scope = self._scope_for(m).value

normalized_values = {str(value).strip() for value in values if value is not None and str(value).strip()}
if not normalized_values:
return

system_fields = dict(self.system_fields or {})
removed_unique_values = dict(system_fields.get("removed_unique_values") or {})
scope_values = dict(removed_unique_values.get(scope) or {})
existing_values = scope_values.get(field_name) or []
existing_set = {str(value).strip() for value in existing_values if value is not None and str(value).strip()}

scope_values[field_name] = sorted(existing_set | normalized_values)
removed_unique_values[scope] = scope_values
system_fields["removed_unique_values"] = removed_unique_values

self.system_fields = system_fields
self.save(update_fields=["system_fields"])
124 changes: 115 additions & 9 deletions src/country_workspace/workspaces/admin/cleaners/validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,86 @@
from constance import config
from django.db.models import Model, QuerySet, Prefetch
from django.db.models.query import prefetch_related_objects
from django.utils import timezone

from country_workspace.context import batch_ctx
from country_workspace.models import AsyncJob, Household, Individual, Program
from country_workspace.state import state
from country_workspace.utils.imports import validate_alien_fields

logger = logging.getLogger(__name__)
UNIQUE_VALIDATION_ERROR = "Value must be unique within the programme."
ARCHIVED_UNIQUE_VALIDATION_ERROR = "Value must be unique and cannot match previously pushed records."


def _normalize_unique_value(value: object) -> str | None:
normalized = str(value).strip() if value is not None else ""
return normalized or None


def _append_unique_error(obj: Model, field_name: str, message: str) -> None:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Each call to _append_unique_error issues an individual UPDATE query to db. Can we use bulk_update instead?

errors = dict(getattr(obj, "errors", {}) or {})
current = errors.get(field_name) or []
if not isinstance(current, list):
current = [str(current)]
if message in current:
return
current.append(message)
errors[field_name] = current
obj.errors = errors
obj.last_checked = timezone.now()
obj.save(update_fields=["errors", "last_checked"])


def _append_household_member_invalid_error(obj: Model) -> None:
errors = dict(getattr(obj, "errors", {}) or {})
details = errors.get("dct") or []
if not isinstance(details, list):
details = [str(details)]
marker = "Some members did not validate"
if marker in details:
return
details.append(marker)
errors["dct"] = details
obj.errors = errors
obj.last_checked = timezone.now()
obj.save(update_fields=["errors", "last_checked"])


class UniqueValidationState:
def __init__(self, *, field_name: str, archived_values: set[str]) -> None:
self.field_name = field_name
self.archived_values = archived_values
self.seen_by_value: dict[str, Model] = {}

def validate(self, obj: Model) -> set[int]:
invalid_pks: set[int] = set()
flex_fields = getattr(obj, "flex_fields", {}) or {}
value = _normalize_unique_value(flex_fields.get(self.field_name))
if not value:
return invalid_pks

if value in self.archived_values:
_append_unique_error(obj, self.field_name, ARCHIVED_UNIQUE_VALIDATION_ERROR)
invalid_pks.add(obj.pk)
return invalid_pks

if previous := self.seen_by_value.get(value):
_append_unique_error(previous, self.field_name, UNIQUE_VALIDATION_ERROR)
_append_unique_error(obj, self.field_name, UNIQUE_VALIDATION_ERROR)
invalid_pks.add(previous.pk)
invalid_pks.add(obj.pk)
return invalid_pks

self.seen_by_value[value] = obj
return invalid_pks


def _build_unique_state(program: Program, model: type[Model]) -> UniqueValidationState | None:
if not (field_name := program.get_unique_field_for(model)):
return None
archived_values = {value for value in program.get_removed_unique_values_for(model) if value}
return UniqueValidationState(field_name=field_name, archived_values=archived_values)


def validate_queryset(queryset: QuerySet[Model], chunk_size: int = 2000, **kwargs: Any) -> dict[str, int]:
Expand All @@ -27,7 +100,9 @@ def validate_queryset(queryset: QuerySet[Model], chunk_size: int = 2000, **kwarg
return {"valid": valid, "invalid": invalid}

with state.set(tenant=first.country_office, program=first.program):
unique_state = _build_unique_state(first.program, queryset.model)
if issubclass(queryset.model, Household):
individual_unique_state = _build_unique_state(first.program, Individual)
# Reverse-FK prefetch for Household.members; include forward FKs for Individuals
prefetch_members = Prefetch(
"members",
Expand All @@ -39,11 +114,17 @@ def validate_queryset(queryset: QuerySet[Model], chunk_size: int = 2000, **kwarg
for chunk in batched(it, chunk_size):
# Populate members for all objects in this batch (no N+1 on members access).
prefetch_related_objects(chunk, prefetch_members)
dv, di = _validate_and_count(chunk)
dv, di = _validate_and_count(
chunk,
unique_state=unique_state,
member_unique_state=individual_unique_state,
)
valid, invalid = valid + dv, invalid + di
else: # Individual
# Just stream.
dv, di = _validate_and_count(queryset.iterator(chunk_size=chunk_size)) # stream rows from DB
dv, di = _validate_and_count(
queryset.iterator(chunk_size=chunk_size), unique_state=unique_state
) # stream rows from DB
valid, invalid = valid + dv, invalid + di

except Exception as e: # pragma: no cover
Expand All @@ -53,21 +134,46 @@ def validate_queryset(queryset: QuerySet[Model], chunk_size: int = 2000, **kwarg
return {"valid": valid, "invalid": invalid}


def _validate_and_count(objs: Iterable[Model]) -> tuple[int, int]:
valid = invalid = 0
def _validate_and_count( # noqa: C901
objs: Iterable[Model],
unique_state: UniqueValidationState | None = None,
member_unique_state: UniqueValidationState | None = None,
) -> tuple[int, int]:
total = 0
invalid_pks: set[int] = set()
member_household_by_member_pk: dict[int, int] = {}
aliens_checked = False

for obj in objs:
total += 1
if not aliens_checked:
validate_alien_fields(obj)
aliens_checked = True

with batch_ctx(obj.batch_id):
if obj.validate_with_checker():
valid += 1
else:
invalid += 1

if not obj.validate_with_checker():
invalid_pks.add(obj.pk)
if unique_state:
invalid_pks |= unique_state.validate(obj)
if member_unique_state and isinstance(obj, Household):
member_invalid = False
for member in obj.members.all():
member_household_by_member_pk[member.pk] = obj.pk
invalid_member_pks = member_unique_state.validate(member)
if not invalid_member_pks:
continue

member_invalid = True
for member_pk in invalid_member_pks:
if household_pk := member_household_by_member_pk.get(member_pk):
invalid_pks.add(household_pk)

if member_invalid:
invalid_pks.add(obj.pk)
_append_household_member_invalid_error(obj)

invalid = len(invalid_pks)
valid = total - invalid
return valid, invalid


Expand Down
Loading