Skip to content
27 changes: 20 additions & 7 deletions psydac/api/discretization.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,9 +460,14 @@ def discretize_space(V, domain_h, *, degree=None, multiplicity=None, knots=None,
assert len(ncells) == len(periodic) == len(degree_i) == len(multiplicity_i) == len(min_coords) == len(max_coords)

if knots is None:
# Create uniform grid
grids = [np.linspace(xmin, xmax, num=ne + 1)
for xmin, xmax, ne in zip(min_coords, max_coords, ncells)]
# Check if grid is provided in domain_h
if hasattr(domain_h, 'grid') and domain_h.grid is not None and interior.name in domain_h.grid:
# Use provided grid of breakpoints
grids = domain_h.grid[interior.name]
else:
# Create uniform grid
grids = [np.linspace(xmin, xmax, num=ne + 1)
for xmin, xmax, ne in zip(min_coords, max_coords, ncells)]

# Create 1D finite element spaces and precompute quadrature data
spaces[i] = [SplineSpace( p, multiplicity=m, grid=grid , periodic=P) for p,m,grid,P in zip(degree_i, multiplicity_i,grids, periodic)]
Expand Down Expand Up @@ -522,23 +527,31 @@ def discretize_space(V, domain_h, *, degree=None, multiplicity=None, knots=None,


#==============================================================================
def discretize_domain(domain, *, filename=None, ncells=None, periodic=None, comm=None, mpi_dims_mask=None):
def discretize_domain(domain, *, filename=None, ncells=None, periodic=None, comm=None, mpi_dims_mask=None, grid=None):

if comm is not None:
# Create a copy of the communicator
comm = comm.Dup()

if not (filename or ncells):
raise ValueError("Must provide either 'filename' or 'ncells'")
if not (filename or ncells or grid):
raise ValueError("Must provide either 'filename' or 'ncells' or 'grid'")

elif filename and ncells:
raise ValueError("Cannot provide both 'filename' and 'ncells'")

elif filename and grid:
raise ValueError("Cannot provide both 'filename' and 'grid'")

elif filename:
return Geometry(filename=filename, comm=comm)

elif ncells:
return Geometry.from_topological_domain(domain, ncells, periodic=periodic, comm=comm, mpi_dims_mask=mpi_dims_mask)
# Validate grid parameter if provided - basic validation only
if grid is not None:
if not isinstance(grid, (list, tuple, dict)):
raise TypeError("Grid must be a list, tuple, or dict")

return Geometry.from_topological_domain(domain, ncells, periodic=periodic, comm=comm, mpi_dims_mask=mpi_dims_mask, grid=grid)

#==============================================================================
def discretize(a, *args, **kwargs):
Expand Down
Loading