Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,8 @@
# Cofunction.ufl_domains references FormArgument but it isn't picked
# up by Sphinx (see https://github.com/sphinx-doc/sphinx/issues/11225)
('py:class', 'FormArgument'),
# Some complex type hints confuse Sphinx (https://github.com/sphinx-doc/sphinx/issues/14159)
("py:obj", r"typing\.Literal\[.*"),
]

# Dodgy links
Expand Down
18 changes: 15 additions & 3 deletions firedrake/bcs.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
# A module implementing strong (Dirichlet) boundary conditions.
import numpy as np

from functools import partial, reduce, cached_property
import itertools

import numpy as np
from mpi4py import MPI

import ufl
from ufl import as_ufl, as_tensor
from finat.ufl import VectorElement
import finat

import pyop2 as op2
from pyop2 import exceptions
from pyop2.mpi import temp_internal_comm
from pyop2.utils import as_tuple

import firedrake
Expand All @@ -19,6 +22,7 @@
from firedrake import slate
from firedrake import solving
from firedrake.formmanipulation import ExtractSubBlock
from firedrake.logging import logger
from firedrake.adjoint_utils.dirichletbc import DirichletBCMixin
from firedrake.petsc import PETSc

Expand Down Expand Up @@ -147,7 +151,7 @@ def hermite_stride(bcnodes):
bcnodes = np.setdiff1d(bcnodes, deriv_ids)
return bcnodes

sub_d = (self.sub_domain, ) if isinstance(self.sub_domain, str) else as_tuple(self.sub_domain)
sub_d = (self.sub_domain,) if isinstance(self.sub_domain, str) else as_tuple(self.sub_domain)
sub_d = [s if isinstance(s, str) else as_tuple(s) for s in sub_d]
bcnodes = []
for s in sub_d:
Expand All @@ -168,7 +172,15 @@ def hermite_stride(bcnodes):
bcnodes1.append(hermite_stride(self._function_space.boundary_nodes(ss)))
bcnodes1 = reduce(np.intersect1d, bcnodes1)
bcnodes.append(bcnodes1)
return np.concatenate(bcnodes)
bcnodes = np.concatenate(bcnodes)

with temp_internal_comm(self._function_space.mesh().comm) as icomm:
num_global_nodes = icomm.reduce(len(bcnodes), MPI.SUM, root=0)
if num_global_nodes == 0 and icomm.rank == 0:
logger.warn(f"Subdomain {self.sub_domain} is empty. This is likely an error. "
"Did you choose the right label?")

return bcnodes

@cached_property
def node_set(self):
Expand Down
198 changes: 190 additions & 8 deletions firedrake/cython/dmcommon.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -788,6 +788,7 @@ def quadrilateral_closure_ordering(PETSc.DM plex,
PetscInt nclosure, p, vi, v, fi, i
PetscInt start_v, off
PetscInt *closure = NULL
PetscInt closure_tmp[2*9]
PetscInt c_vertices[4]
PetscInt c_facets[4]
PetscInt g_vertices[4]
Expand All @@ -804,13 +805,13 @@ def quadrilateral_closure_ordering(PETSc.DM plex,
ncells = cEnd - cStart
entity_per_cell = 4 + 4 + 1

CHKERR(PetscMalloc1(2*9, &closure))

cell_closure = np.empty((ncells, entity_per_cell), dtype=IntType)
for c in range(cStart, cEnd):
CHKERR(PetscSectionGetOffset(cell_numbering.sec, c, &cell))
get_transitive_closure(plex.dm, c, PETSC_TRUE, &nclosure, &closure)

# First extract the facets (edges) and the vertices
# from the transitive closure into c_facets and c_vertices.
# Here we assume that DMPlex gives entities in the order:
#
# 8--3--7
Expand All @@ -821,7 +822,65 @@ def quadrilateral_closure_ordering(PETSc.DM plex,
#
# where the starting vertex and order of traversal is arbitrary.
# (We fix that later.)

# If we have a periodic mesh with only a single cell in the periodic
# direction then the closure will look like
#
# 4--1--5
# | |
# 3 0 2 (vertical periodicity)
# | |
# 4--1--5
#
# or
#
# 5--3--5
# | |
# 2 0 2 (horizontal periodicity)
# | |
# 4--1--4
#
# and only have 6 entries instead of 9. For the following to work we have
# to blow this out to a 9 entry array including the repeats.
if nclosure == 4:
raise NotImplementedError("Single-cell periodic quad meshes are "
"not supported")
elif nclosure == 6:
horiz_periodicity, vert_periodicity = _get_periodicity(plex)
(_, horiz_unit_periodic) = horiz_periodicity
(_, vert_unit_periodic) = vert_periodicity
if vert_unit_periodic:
assert not horiz_unit_periodic
closure_tmp[2*0] = closure[2*0]
closure_tmp[2*1] = closure[2*1]
closure_tmp[2*2] = closure[2*2]
closure_tmp[2*3] = closure[2*1]
closure_tmp[2*4] = closure[2*3]
closure_tmp[2*5] = closure[2*4]
closure_tmp[2*6] = closure[2*5]
closure_tmp[2*7] = closure[2*5]
closure_tmp[2*8] = closure[2*4]
else:
assert horiz_unit_periodic
assert not vert_unit_periodic
closure_tmp[2*0] = closure[2*0]
closure_tmp[2*1] = closure[2*1]
closure_tmp[2*2] = closure[2*2]
closure_tmp[2*3] = closure[2*3]
closure_tmp[2*4] = closure[2*2]
closure_tmp[2*5] = closure[2*4]
closure_tmp[2*6] = closure[2*4]
closure_tmp[2*7] = closure[2*5]
closure_tmp[2*8] = closure[2*5]

nclosure = 9
for i in range(9):
closure[2*i] = closure_tmp[2*i]
else:
assert nclosure == 9

# Extract the facets (edges) and the vertices
# from the transitive closure into c_facets and c_vertices.
# For the vertices, we also retrieve the global numbers into g_vertices.
vi = 0
fi = 0
Expand Down Expand Up @@ -923,8 +982,7 @@ def quadrilateral_closure_ordering(PETSc.DM plex,
cell_closure[cell, 4 + 3] = facets[1]
cell_closure[cell, 8] = c

if closure != NULL:
restore_transitive_closure(plex.dm, 0, PETSC_TRUE, &nclosure, &closure)
CHKERR(PetscFree(closure))

return cell_closure

Expand Down Expand Up @@ -1987,7 +2045,7 @@ def reordered_coords(PETSc.DM dm, PETSc.Section global_numbering, shape, referen
get_depth_stratum(dm.dm, 0, &vStart, &vEnd)
if isinstance(dm, PETSc.DMPlex):
if not dm.getCoordinatesLocalized():
# Use CG coordiantes.
# Use CG coordinates.
dm_sec = dm.getCoordinateSection()
dm_coords = dm.getCoordinatesLocal().array.reshape(shape)
coords = np.empty_like(dm_coords)
Expand All @@ -1998,12 +2056,11 @@ def reordered_coords(PETSc.DM dm, PETSc.Section global_numbering, shape, referen
for i in range(dim):
coords[offset, i] = dm_coords[dm_offset, i]
else:
# Use DG coordiantes.
# Use DG coordinates.
get_height_stratum(dm.dm, 0, &cStart, &cEnd)
dim = dm.getCoordinateDim()
ndofs, perm, perm_offsets = _get_firedrake_plex_permutation_dg_transitive_closure(dm)
dm_sec = dm.getCellCoordinateSection()
dm_coords = dm.getCellCoordinatesLocal().array.reshape(((cEnd - cStart) * ndofs[0], dim))
dm_coords, dm_sec = _get_expanded_dm_dg_coords(dm, ndofs)
coords = np.empty_like(dm_coords)
for c in range(cStart, cEnd):
CHKERR(PetscSectionGetOffset(global_numbering.sec, c, &offset)) # scalar offset
Expand Down Expand Up @@ -2031,6 +2088,131 @@ def reordered_coords(PETSc.DM dm, PETSc.Section global_numbering, shape, referen
raise ValueError("Only DMPlex and DMSwarm are supported.")
return coords


def _get_expanded_dm_dg_coords(dm: PETSc.DM, ndofs: np.ndarray):
cdef:
const PetscReal *L

PETSc.Section dm_sec_expanded

cStart, cEnd = dm.getHeightStratum(0)
dim = dm.getCoordinateDim()
coords_shape = ((cEnd-cStart) * ndofs[0], dim)

if dm.getCellCoordinateSection().getDof(cStart) < ndofs[0] * dim:
# Fewer cell coordinates available, we must be single-cell periodic
if dm.getCellType(cStart) == PETSc.DM.PolytopeType.QUADRILATERAL:
# If we have a periodic mesh with only a single cell in the periodic
# direction then the cell coordinates will be
#
# 1-----2
# | |
# | | (vertical periodicity)
# | |
# 1-----2
#
# or
#
# 2-----2
# | |
# | | (horizontal periodicity)
# | |
# 1-----1
#
# when the standard layout is
#
# 4-----3
# | |
# | |
# | |
# 1-----2
assert ndofs[0] == 4, "Not expecting high order coords here"
dm_coords_orig = dm.getCellCoordinatesLocal().array_r.reshape(((cEnd-cStart) * 2, dim))
dm_coords_expanded = np.empty(coords_shape, dtype=dm_coords_orig.dtype)

# Create a new cell coordinate section
dm_sec_orig = dm.getCellCoordinateSection()
dm_sec_expanded = PETSc.Section().create(comm=dm_sec_orig.comm)
dm_sec_expanded.setChart(*dm_sec_orig.getChart())
dm_sec_expanded.setPermutation(dm_sec_orig.getPermutation())

horiz_periodicity, vert_periodicity = _get_periodicity(dm)
(_, horiz_unit_periodic) = horiz_periodicity
(_, vert_unit_periodic) = vert_periodicity

# Find the domain sizes
CHKERR(DMGetPeriodicity(dm.dm, NULL, NULL, &L))

if horiz_unit_periodic:
if vert_unit_periodic:
raise NotImplementedError("Single-cell periodic quad meshes are "
"not supported")
else:
cell_width = L[0]

for c in range(cStart, cEnd):
CHKERR(PetscSectionSetDof(dm_sec_expanded.sec, c, 8))

dm_coords_expanded[4*c+0, 0] = dm_coords_orig[2*c+0, 0]
dm_coords_expanded[4*c+1, 0] = dm_coords_orig[2*c+0, 0] + cell_width
dm_coords_expanded[4*c+2, 0] = dm_coords_orig[2*c+1, 0] + cell_width
dm_coords_expanded[4*c+3, 0] = dm_coords_orig[2*c+1, 0]
dm_coords_expanded[4*c+0, 1] = dm_coords_orig[2*c+0, 1]
dm_coords_expanded[4*c+1, 1] = dm_coords_orig[2*c+0, 1]
dm_coords_expanded[4*c+2, 1] = dm_coords_orig[2*c+1, 1]
dm_coords_expanded[4*c+3, 1] = dm_coords_orig[2*c+1, 1]

else:
assert vert_unit_periodic
cell_height = L[1]

for c in range(cStart, cEnd):
CHKERR(PetscSectionSetDof(dm_sec_expanded.sec, c, 8))

dm_coords_expanded[4*c+0, 0] = dm_coords_orig[2*c+0, 0]
dm_coords_expanded[4*c+1, 0] = dm_coords_orig[2*c+1, 0]
dm_coords_expanded[4*c+2, 0] = dm_coords_orig[2*c+1, 0]
dm_coords_expanded[4*c+3, 0] = dm_coords_orig[2*c+0, 0]
dm_coords_expanded[4*c+0, 1] = dm_coords_orig[2*c+0, 1]
dm_coords_expanded[4*c+1, 1] = dm_coords_orig[2*c+1, 1]
dm_coords_expanded[4*c+2, 1] = dm_coords_orig[2*c+1, 1] + cell_height
dm_coords_expanded[4*c+3, 1] = dm_coords_orig[2*c+0, 1] + cell_height

dm_sec_expanded.setUp()

dm_coords = dm_coords_expanded
dm_sec = dm_sec_expanded

else:
raise NotImplementedError("Single cell periodicity for cell type "
f"{dm.getCellType(cStart)} is not supported")

else:
dm_coords = dm.getCellCoordinatesLocal().array_r.reshape(coords_shape)
dm_sec = dm.getCellCoordinateSection()

return dm_coords, dm_sec


def _get_periodicity(dm: PETSc.DM) -> tuple[tuple[bool, bool], ...]:
"""Return mesh periodicity information.

This function returns a 2-tuple of bools per dimension where the first entry indicates
whether the mesh is periodic in that dimension, and the second indicates whether the
mesh is single-cell periodic in that dimension.

"""
cdef:
const PetscReal *maxCell, *L

dim = dm.getCoordinateDim()
CHKERR(DMGetPeriodicity(dm.dm, &maxCell, NULL, &L))
return tuple(
(L[d] >= 0, maxCell[d] >= L[d])
for d in range(dim)
)


@cython.boundscheck(False)
@cython.wraparound(False)
def mark_entity_classes(PETSc.DM dm):
Expand Down
4 changes: 4 additions & 0 deletions firedrake/cython/petschdr.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,10 @@ cdef extern from "petscdm.h" nogil:
PetscErrorCode DMSetLabelValue(PETSc.PetscDM,char[],PetscInt,PetscInt)
PetscErrorCode DMGetLabelValue(PETSc.PetscDM,char[],PetscInt,PetscInt*)

PetscErrorCode DMGetPeriodicity(PETSc.PetscDM,PetscReal *[], PetscReal *[], PetscReal *[])
PetscErrorCode DMGetSparseLocalize(PETSc.PetscDM,PetscBool *)
PetscErrorCode DMSetSparseLocalize(PETSc.PetscDM,PetscBool)

cdef extern from "petscdmswarm.h" nogil:
PetscErrorCode DMSwarmGetLocalSize(PETSc.PetscDM,PetscInt*)
PetscErrorCode DMSwarmGetCellDM(PETSc.PetscDM, PETSc.PetscDM*)
Expand Down
21 changes: 10 additions & 11 deletions firedrake/mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
import firedrake.cython.spatialindex as spatialindex
import firedrake.utils as utils
from firedrake.utils import as_cstr, IntType, RealType
from firedrake.logging import info_red
from firedrake.logging import info_red, logger
from firedrake.parameters import parameters
from firedrake.petsc import PETSc, DEFAULT_PARTITIONER
from firedrake.adjoint_utils import MeshGeometryMixin
Expand Down Expand Up @@ -201,14 +201,6 @@ def __init__(self, mesh, facets, classes, set_, kind, facet_cell, local_facet_nu
self.unique_markers = [] if unique_markers is None else unique_markers
self._subsets = {}

@cached_property
def _null_subset(self):
'''Empty subset for the case in which there are no facets with
a given marker value. This is required because not all
markers need be represented on all processors.'''

return op2.Subset(self.set, [])

@PETSc.Log.EventDecorator()
def measure_set(self, integral_type, subdomain_id,
all_integer_subdomain_ids=None):
Expand Down Expand Up @@ -283,9 +275,16 @@ def subset(self, markers):
marked_points_list.append(self.mesh.topology_dm.getStratumIS(dmcommon.FACE_SETS_LABEL, i).indices)
if marked_points_list:
_, indices, _ = np.intersect1d(self.facets, np.concatenate(marked_points_list), return_indices=True)
return self._subsets.setdefault(markers, op2.Subset(self.set, indices))
else:
return self._subsets.setdefault(markers, self._null_subset)
indices = np.empty(0, dtype=IntType)

with temp_internal_comm(self.mesh.comm) as icomm:
num_global_indices = icomm.reduce(len(indices), MPI.SUM, root=0)
if num_global_indices == 0 and icomm.rank == 0:
logger.warn(f"Subdomain {markers} is empty. This is likely an error. "
"Did you choose the right label?")

return self._subsets.setdefault(markers, op2.Subset(self.set, indices))

def _collect_unmarked_points(self, markers):
"""Collect points that are not marked by markers."""
Expand Down
Loading
Loading