Skip to content
42 changes: 36 additions & 6 deletions psydac/api/discretization.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,9 +460,40 @@ def discretize_space(V, domain_h, *, degree=None, multiplicity=None, knots=None,
assert len(ncells) == len(periodic) == len(degree_i) == len(multiplicity_i) == len(min_coords) == len(max_coords)

if knots is None:
# Create uniform grid
grids = [np.linspace(xmin, xmax, num=ne + 1)
for xmin, xmax, ne in zip(min_coords, max_coords, ncells)]
# Check if grid is provided in domain_h
if hasattr(domain_h, 'grid') and domain_h.grid is not None:
# Use provided grid of breakpoints
grids = domain_h.grid

# Safety checks for grid consistency
if not isinstance(grids, (list, tuple)):
raise TypeError(f"Grid must be a list or tuple, got {type(grids)}")

if len(grids) != len(ncells):
raise ValueError(f"Grid dimensions ({len(grids)}) must match domain dimensions ({len(ncells)})")

for dim, (grid, nc, xmin, xmax) in enumerate(zip(grids, ncells, min_coords, max_coords)):
grid = np.asarray(grid)

# Check grid length vs ncells consistency
expected_grid_length = nc + 1
if len(grid) != expected_grid_length:
raise ValueError(f"Dimension {dim}: grid length ({len(grid)}) must be ncells+1 ({expected_grid_length})")

# Check if grid is sorted
if not np.all(np.diff(grid) > 0):
raise ValueError(f"Dimension {dim}: grid points must be strictly increasing")

# Check domain boundaries (with tolerance for numerical precision)
tol = 1e-12
if abs(grid[0] - xmin) > tol:
raise ValueError(f"Dimension {dim}: grid start ({grid[0]}) must match domain minimum ({xmin})")
if abs(grid[-1] - xmax) > tol:
raise ValueError(f"Dimension {dim}: grid end ({grid[-1]}) must match domain maximum ({xmax})")

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that these checks should be made directly in the constructor of the Geometry class, not here. Here instead you should check that if the knots are passed, these are consistent with domain_h.ncells and domain_h.grid.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have removed these checks from there and I have put them in geometry.py with a function in the Geometry class:

def _validate_grid_consistency(self, domain, ncells, grid):

else:
# Create uniform grid
grids = [np.linspace(xmin, xmax, num=ne + 1)
for xmin, xmax, ne in zip(min_coords, max_coords, ncells)]

# Create 1D finite element spaces and precompute quadrature data
spaces[i] = [SplineSpace( p, multiplicity=m, grid=grid , periodic=P) for p,m,grid,P in zip(degree_i, multiplicity_i,grids, periodic)]
Expand Down Expand Up @@ -522,8 +553,7 @@ def discretize_space(V, domain_h, *, degree=None, multiplicity=None, knots=None,


#==============================================================================
def discretize_domain(domain, *, filename=None, ncells=None, periodic=None, comm=None, mpi_dims_mask=None):

def discretize_domain(domain, *, filename=None, ncells=None, periodic=None, comm=None, mpi_dims_mask=None, grid=None):
if comm is not None:
# Create a copy of the communicator
comm = comm.Dup()
Expand All @@ -538,7 +568,7 @@ def discretize_domain(domain, *, filename=None, ncells=None, periodic=None, comm
return Geometry(filename=filename, comm=comm)

elif ncells:
return Geometry.from_topological_domain(domain, ncells, periodic=periodic, comm=comm, mpi_dims_mask=mpi_dims_mask)
return Geometry.from_topological_domain(domain, ncells, periodic=periodic, comm=comm, mpi_dims_mask=mpi_dims_mask, grid=grid)

#==============================================================================
def discretize(a, *args, **kwargs):
Expand Down
Loading