Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use hash_graph_store::{
entity::EntityQueryPath,
subgraph::edges::{EdgeDirection, KnowledgeGraphEdgeKind},
entity_type::EntityTypeQueryPath,
subgraph::edges::{EdgeDirection, KnowledgeGraphEdgeKind, SharedEdgeKind},
};
use serde::Deserialize as _;
use tokio_postgres::Row;
Expand Down Expand Up @@ -227,8 +228,11 @@ impl PostgresRecord for Entity {
),

edition_id: compiler.add_selection_path(&EntityQueryPath::EditionId),
type_versioned_urls_id: compiler
.add_selection_path(&EntityQueryPath::TypeVersionedUrls),
type_versioned_urls_id: compiler.add_selection_path(&EntityQueryPath::EntityTypeEdge {
edge_kind: SharedEdgeKind::IsOfType,
path: EntityTypeQueryPath::VersionedUrl,
inheritance_depth: None,
}),
direct_type_count_id: compiler.add_selection_path(&EntityQueryPath::DirectTypeCount),

properties: compiler.add_selection_path(&EntityQueryPath::Properties(None)),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@ use std::collections::HashMap;
use error_stack::{Report, ResultExt as _};
use hash_graph_store::{
entity::{EntityQueryPath, QueryEntitiesParams},
entity_type::EntityTypeQueryPath,
error::QueryError,
subgraph::edges::SharedEdgeKind,
};
use tokio_postgres::Row;
use type_system::{
Expand Down Expand Up @@ -112,7 +114,11 @@ impl EntitySummaryQuery {
.then(|| compiler.add_selection_path(&EntityQueryPath::EditionProvenance(None))),
type_columns: (params.include_type_ids || params.include_type_titles).then(|| {
(
compiler.add_selection_path(&EntityQueryPath::TypeVersionedUrls),
compiler.add_selection_path(&EntityQueryPath::EntityTypeEdge {
edge_kind: SharedEdgeKind::IsOfType,
path: EntityTypeQueryPath::VersionedUrl,
inheritance_depth: None,
}),
compiler.add_selection_path(&EntityQueryPath::DirectTypeCount),
)
}),
Expand Down
228 changes: 215 additions & 13 deletions libs/@local/graph/postgres-store/src/store/postgres/query/compile.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use postgres_types::ToSql;
use tracing::instrument;
use type_system::knowledge::Entity;

use super::expression::{JoinType, TableName, TableReference};
use super::expression::{ColumnReference, JoinType, TableName, TableReference};
use crate::store::postgres::query::{
Alias, Column, Distinctness, EqualityOperator, Expression, Function, Identifier,
PostgresQueryPath, PostgresRecord, SelectExpression, SelectStatement, Table, Transpile as _,
Expand Down Expand Up @@ -92,6 +92,8 @@ pub enum SelectCompilerError {
UnsupportedDistanceExpression,
#[display("Cannot add a cursor: {reason}")]
CursorDisallowed { reason: &'static str },
#[display("String operations are not supported on paths backed by materialized array columns")]
UnsupportedTextArrayOperation,
}

impl<'p, 'q: 'p, R: PostgresRecord> SelectCompiler<'p, 'q, R> {
Expand Down Expand Up @@ -429,18 +431,8 @@ impl<'p, 'q: 'p, R: PostgresRecord> SelectCompiler<'p, 'q, R> {
}

Ok(match filter {
Filter::All(filters) => Expression::all(
filters
.iter()
.map(|filter| self.compile_filter(filter))
.collect::<Result<_, _>>()?,
),
Filter::Any(filters) => Expression::any(
filters
.iter()
.map(|filter| self.compile_filter(filter))
.collect::<Result<_, _>>()?,
),
Filter::All(filters) => Expression::all(self.compile_filter_group(filters, true)?),
Filter::Any(filters) => Expression::any(self.compile_filter_group(filters, false)?),
Filter::Not(filter) => self.compile_filter(filter)?.not(),
Filter::Equal(lhs, rhs) => Expression::equal(
self.compile_filter_expression(lhs).0,
Expand Down Expand Up @@ -644,6 +636,8 @@ impl<'p, 'q: 'p, R: PostgresRecord> SelectCompiler<'p, 'q, R> {
self.compile_filter_expression_list(rhs).0,
),
Filter::StartsWith(lhs, rhs) => {
Self::ensure_scalar_text_operand(lhs)?;
Self::ensure_scalar_text_operand(rhs)?;
let (left_filter, left_parameter) = self.compile_filter_expression(lhs);
let left_filter = if left_parameter == ParameterType::Any {
Expression::Function(Function::JsonExtractText(Box::new(left_filter)))
Expand All @@ -661,6 +655,8 @@ impl<'p, 'q: 'p, R: PostgresRecord> SelectCompiler<'p, 'q, R> {
Expression::starts_with(left_filter, right_filter)
}
Filter::EndsWith(lhs, rhs) => {
Self::ensure_scalar_text_operand(lhs)?;
Self::ensure_scalar_text_operand(rhs)?;
let (left_filter, left_parameter) = self.compile_filter_expression(lhs);
let left_filter = if left_parameter == ParameterType::Any {
Expression::Function(Function::JsonExtractText(Box::new(left_filter)))
Expand All @@ -678,6 +674,8 @@ impl<'p, 'q: 'p, R: PostgresRecord> SelectCompiler<'p, 'q, R> {
Expression::ends_with(left_filter, right_filter)
}
Filter::ContainsSegment(lhs, rhs) => {
Self::ensure_scalar_text_operand(lhs)?;
Self::ensure_scalar_text_operand(rhs)?;
let (left_filter, left_parameter) = self.compile_filter_expression(lhs);
let left_filter = if left_parameter == ParameterType::Any {
Expression::Function(Function::JsonExtractText(Box::new(left_filter)))
Expand All @@ -697,6 +695,202 @@ impl<'p, 'q: 'p, R: PostgresRecord> SelectCompiler<'p, 'q, R> {
})
}

/// Rejects operands on paths terminating in materialized text-array columns.
///
/// Equality filters on such paths compile to array predicates, but string operations
/// have no scalar column to operate on.
fn ensure_scalar_text_operand<'f: 'q>(
operand: &FilterExpression<'f, R>,
) -> Result<(), Report<SelectCompilerError>>
where
R::QueryPath<'f>: PostgresQueryPath,
{
if let FilterExpression::Path { path } = operand {
let (column, json_field) = path.terminating_column();
ensure!(
json_field.is_some() || !Self::is_text_array_column(column),
SelectCompilerError::UnsupportedTextArrayOperation
);
}
Ok(())
}

/// Whether the column holds an array of textual values ([`BaseUrl`] and
/// [`VersionedUrl`] columns transpile to `text[]`).
///
/// [`BaseUrl`]: ParameterType::BaseUrl
/// [`VersionedUrl`]: ParameterType::VersionedUrl
fn is_text_array_column(column: Column) -> bool {
matches!(
column.parameter_type(),
ParameterType::Vector(inner) if matches!(
*inner,
ParameterType::Text | ParameterType::BaseUrl | ParameterType::VersionedUrl
)
)
}

/// Decomposes an equality (`Equal`/`NotEqual`) or membership (`In(parameter, path)`)
/// filter on a path terminating in a materialized text-array column.
///
/// Returns the path, the text parameter, and whether the filter tests for containment
/// (`true`) or its absence (`false`).
fn cached_array_equality<'f: 'q>(
filter: &'p Filter<'f, R>,
) -> Option<(&'p R::QueryPath<'f>, &'p Parameter<'f>, bool)>
where
R::QueryPath<'f>: PostgresQueryPath,
{
let (lhs, rhs, equals) = match filter {
Filter::Equal(lhs, rhs) => (lhs, rhs, true),
Filter::NotEqual(lhs, rhs) => (lhs, rhs, false),
Filter::In(
FilterExpression::Parameter {
parameter: parameter @ Parameter::Text(_),
convert: None,
},
FilterExpressionList::Path { path },
) => {
let (column, json_field) = path.terminating_column();
return (json_field.is_none() && Self::is_text_array_column(column))
.then_some((path, parameter, true));
}
Filter::All(_)
| Filter::Any(_)
| Filter::Not(_)
| Filter::Exists { .. }
| Filter::Greater(..)
| Filter::GreaterOrEqual(..)
| Filter::Less(..)
| Filter::LessOrEqual(..)
| Filter::CosineDistance(..)
| Filter::In(..)
| Filter::StartsWith(..)
| Filter::EndsWith(..)
| Filter::ContainsSegment(..) => return None,
};
match (lhs, rhs) {
(
FilterExpression::Path { path },
FilterExpression::Parameter {
parameter: parameter @ Parameter::Text(_),
convert: None,
},
)
| (
FilterExpression::Parameter {
parameter: parameter @ Parameter::Text(_),
convert: None,
},
FilterExpression::Path { path },
) => {
let (column, json_field) = path.terminating_column();
(json_field.is_none() && Self::is_text_array_column(column))
.then_some((path, parameter, equals))
}
_ => None,
}
}

/// Compiles equality filters on a path backed by a materialized array column into a
/// single array predicate on that column.
///
/// A single parameter compiles to a containment check (`<column> @> ARRAY[$n]::text[]`,
/// negated for inequalities). Multiple parameters gathered from one `All`/`Any` group
/// bundle into one predicate over the whole value set:
///
/// | group | equalities | inequalities |
/// |-------|---------------------|---------------------------|
/// | `All` | `@>` (contains all) | `NOT(&&)` (contains none) |
/// | `Any` | `&&` (contains any) | `NOT(@>)` (misses one) |
///
/// A single array predicate replaces per-value joins through the type tables and lets
/// a GIN index on the materialized column serve the positive forms.
fn compile_cached_array_predicate<'f: 'q>(
&mut self,
column: ColumnReference<'static>,
parameters: &[&'p Parameter<'f>],
equals: bool,
conjunction: bool,
) -> Expression
where
R::QueryPath<'f>: PostgresQueryPath,
{
let column_reference = Expression::ColumnReference(column);
let array = Expression::Function(Function::ArrayLiteral {
elements: parameters
.iter()
.map(|parameter| self.compile_parameter(parameter).0)
.collect(),
element_type: PostgresType::Text,
});
// For a single value `@>` and `&&` coincide; normalize to the containment form.
let conjunction = if parameters.len() == 1 {
equals
} else {
conjunction
};
match (equals, conjunction) {
(true, true) => Expression::array_contains(column_reference, array),
(true, false) => Expression::overlap(column_reference, array),
(false, true) => Expression::overlap(column_reference, array).not(),
(false, false) => Expression::array_contains(column_reference, array).not(),
}
}

/// Compiles the filters of an `All`/`Any` group, bundling equality filters backed by
/// the same materialized array column into a single array predicate.
///
/// Bundles are keyed on the *aliased* column: paths terminating in the same column
/// through different join chains (e.g. an entity's own types vs. a linked entity's
/// types) resolve to different aliases and stay separate predicates.
fn compile_filter_group<'f: 'q>(
&mut self,
filters: &'p [Filter<'f, R>],
conjunction: bool,

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would it make sense to move this to an enum instead to increase readability?

) -> Result<Vec<Expression>, Report<SelectCompilerError>>
where
R::QueryPath<'f>: PostgresQueryPath,
{
struct ArrayPredicateGroup<'c, 'p> {
column: ColumnReference<'static>,
equals: bool,
parameters: Vec<&'c Parameter<'p>>,
}

let mut groups: Vec<ArrayPredicateGroup<'p, 'f>> = Vec::new();
let mut expressions = Vec::new();
for filter in filters {
if let Some((array_path, parameter, equals)) = Self::cached_array_equality(filter) {
let alias = self.add_join_statements(array_path);
let column = array_path.terminating_column().0.aliased(alias);
if let Some(group) = groups
.iter_mut()
.find(|group| group.column == column && group.equals == equals)
{
group.parameters.push(parameter);
} else {
groups.push(ArrayPredicateGroup {
column,
equals,
parameters: vec![parameter],
});
}
} else {
expressions.push(self.compile_filter(filter)?);
}
}
for group in &groups {
expressions.push(self.compile_cached_array_predicate(
group.column.clone(),
&group.parameters,
group.equals,
conjunction,
));
}
Ok(expressions)
}

/// Compiles the `path` to a condition, which is searching for the latest version.
// Warning: This adds a CTE to the statement, which is overwriting the `ontology_ids` table.
// When more CTEs are needed, a test should be added to cover both CTEs in one
Expand Down Expand Up @@ -770,10 +964,18 @@ impl<'p, 'q: 'p, R: PostgresRecord> SelectCompiler<'p, 'q, R> {
///
/// The following [`Filter`]s will be special cased:
/// - Comparing the `"version"` field on [`Table::OntologyIds`] with `"latest"` for equality.
/// - Equality and membership filters on paths terminating in materialized text-array columns,
/// compiled to array predicates (see [`Self::compile_cached_array_predicate`]).
fn compile_special_filter<'f: 'q>(&mut self, filter: &'p Filter<'f, R>) -> Option<Expression>
where
R::QueryPath<'f>: PostgresQueryPath,
{
if let Some((array_path, parameter, equals)) = Self::cached_array_equality(filter) {
let alias = self.add_join_statements(array_path);
let column = array_path.terminating_column().0.aliased(alias);
return Some(self.compile_cached_array_predicate(column, &[parameter], equals, true));
}

match filter {
Filter::Equal(lhs, rhs) | Filter::NotEqual(lhs, rhs) => match (lhs, rhs) {
(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,12 @@ impl PostgresQueryPath for EntityQueryPath<'_> {
| Self::PropertyMetadata(_) => {
vec![Relation::EntityEditions]
}
Self::TypeBaseUrls | Self::TypeVersionedUrls | Self::DirectTypeCount => {
Self::DirectTypeCount
| Self::EntityTypeEdge {
edge_kind: SharedEdgeKind::IsOfType,
path: EntityTypeQueryPath::BaseUrl | EntityTypeQueryPath::VersionedUrl,
inheritance_depth: None,
} => {
vec![Relation::EntityEditionCache]
}
Self::EntityTypeEdge {
Expand Down Expand Up @@ -122,11 +127,19 @@ impl PostgresQueryPath for EntityQueryPath<'_> {
),
Self::Archived => (Column::EntityEditions(EntityEditions::Archived), None),
Self::Embedding => (Column::EntityEmbeddings(EntityEmbeddings::Embedding), None),
Self::TypeBaseUrls => (
Self::EntityTypeEdge {
edge_kind: SharedEdgeKind::IsOfType,
path: EntityTypeQueryPath::BaseUrl,
inheritance_depth: None,
} => (
Column::EntityEditionCache(EntityEditionCache::BaseUrls),
None,
),
Self::TypeVersionedUrls => (
Self::EntityTypeEdge {
edge_kind: SharedEdgeKind::IsOfType,
path: EntityTypeQueryPath::VersionedUrl,
inheritance_depth: None,
} => (
Column::EntityEditionCache(EntityEditionCache::VersionedUrls),
None,
),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ pub enum BinaryOperator {
// --- Domain-specific ---
/// `<lhs> @> <rhs>::TIMESTAMPTZ`
TimeIntervalContainsTimestamp,
/// `<lhs> @> <rhs>`
ArrayContains,
/// `<lhs> && <rhs>`
Overlap,
/// `<lhs> <=> <rhs>`
Expand All @@ -76,7 +78,7 @@ impl BinaryOperator {
Self::BitwiseOr => " | ",
Self::JsonAccess => " -> ",
Self::JsonAccessAsText => " ->> ",
Self::TimeIntervalContainsTimestamp => " @> ",
Self::TimeIntervalContainsTimestamp | Self::ArrayContains => " @> ",
Self::Overlap => " && ",
Self::CosineDistance => " <=> ",
};
Expand All @@ -102,6 +104,7 @@ impl BinaryOperator {
| Self::BitwiseOr
| Self::JsonAccess
| Self::JsonAccessAsText
| Self::ArrayContains
| Self::Overlap
| Self::CosineDistance => Ok(()),
}
Expand Down
Loading
Loading