Skip to content

Commit e21254e

Browse files
authored
RestrictedFunctionSpace: support Fieldsplit, multigrid, and python PC (#4169)
* dmhooks: support RestrictedFunctionSpace * RestrictedFunctionSpace: support geometric multigrid * RestrictedFunctionSpace: support p-multigrid * PC: replace FunctionSpace() constructor with V.reconstruct() to preserve boundary restriction ids.
1 parent 7e2ec2a commit e21254e

18 files changed

Lines changed: 275 additions & 136 deletions

File tree

firedrake/bcs.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,8 @@ def reconstruct(self, field=None, V=None, g=None, sub_domain=None, use_split=Fal
319319
V = V.sub(index)
320320
if g is None:
321321
g = self._original_arg
322+
if isinstance(g, firedrake.Function) and g.function_space() != V:
323+
g = firedrake.Function(V).interpolate(g)
322324
if sub_domain is None:
323325
sub_domain = self.sub_domain
324326
if field is not None:
@@ -744,11 +746,11 @@ def restricted_function_space(V, ids):
744746
return V
745747

746748
assert len(ids) == len(V)
747-
spaces = [Vsub if len(boundary_set) == 0 else
748-
firedrake.RestrictedFunctionSpace(Vsub, boundary_set=boundary_set)
749-
for Vsub, boundary_set in zip(V, ids)]
749+
spaces = [V_ if len(boundary_set) == 0 else
750+
firedrake.RestrictedFunctionSpace(V_, boundary_set=boundary_set, name=V_.name)
751+
for V_, boundary_set in zip(V, ids)]
750752

751753
if len(spaces) == 1:
752754
return spaces[0]
753755
else:
754-
return firedrake.MixedFunctionSpace(spaces)
756+
return firedrake.MixedFunctionSpace(spaces, name=V.name)

firedrake/dmhooks.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,11 +53,13 @@ def get_function_space(dm):
5353
:raises RuntimeError: if no function space was found.
5454
"""
5555
info = dm.getAttr("__fs_info__")
56-
meshref, element, indices, (name, names) = info
56+
meshref, element, indices, (name, names), boundary_sets = info
5757
mesh = meshref()
5858
if mesh is None:
5959
raise RuntimeError("Somehow your mesh was collected, this should never happen")
6060
V = firedrake.FunctionSpace(mesh, element, name=name)
61+
if any(boundary_sets):
62+
V = firedrake.bcs.restricted_function_space(V, boundary_sets)
6163
if len(V) > 1:
6264
for V_, name in zip(V, names):
6365
V_.topological.name = name
@@ -93,8 +95,8 @@ def set_function_space(dm, V):
9395
if len(V) > 1:
9496
names = tuple(V_.name for V_ in V)
9597
element = V.ufl_element()
96-
97-
info = (weakref.ref(mesh), element, tuple(reversed(indices)), (V.name, names))
98+
boundary_sets = tuple(V_.boundary_set for V_ in V)
99+
info = (weakref.ref(mesh), element, tuple(reversed(indices)), (V.name, names), boundary_sets)
98100
dm.setAttr("__fs_info__", info)
99101

100102

@@ -457,7 +459,7 @@ def refine(dm, comm):
457459
if hasattr(V, "_fine"):
458460
fdm = V._fine.dm
459461
else:
460-
V._fine = firedrake.FunctionSpace(hierarchy[level + 1], V.ufl_element())
462+
V._fine = V.reconstruct(mesh=hierarchy[level + 1])
461463
fdm = V._fine.dm
462464
V._fine._coarse = V
463465
return fdm

firedrake/functionspaceimpl.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -377,17 +377,20 @@ def make_function_space(cls, mesh, element, name=None):
377377
new = cls.create(new, mesh)
378378
return new
379379

380-
def reconstruct(self, mesh=None, name=None, **kwargs):
380+
def reconstruct(self, mesh=None, element=None, name=None, **kwargs):
381381
r"""Reconstruct this :class:`.WithGeometryBase` .
382382
383383
:kwarg mesh: the new :func:`~.Mesh` (defaults to same mesh)
384+
:kwarg element: the new :class:`finat.ufl.FiniteElement` (defaults to same element)
384385
:kwarg name: the new name (defaults to None)
385386
:returns: the new function space of the same class as ``self``.
386387
387388
Any extra kwargs are used to reconstruct the finite element.
388389
For details see :meth:`finat.ufl.finiteelement.FiniteElement.reconstruct`.
389390
"""
391+
from firedrake.bcs import restricted_function_space
390392
V_parent = self
393+
391394
# Deal with ProxyFunctionSpace
392395
indices = []
393396
while True:
@@ -402,13 +405,21 @@ def reconstruct(self, mesh=None, name=None, **kwargs):
402405

403406
if mesh is None:
404407
mesh = V_parent.mesh()
408+
if element is None:
409+
element = V_parent.ufl_element()
405410

406-
element = V_parent.ufl_element()
407411
cell = mesh.topology.ufl_cell()
408412
if len(kwargs) > 0 or element.cell != cell:
409413
element = element.reconstruct(cell=cell, **kwargs)
410414

415+
# Reconstruct the parent space
411416
V = type(self).make_function_space(mesh, element, name=name)
417+
418+
# Deal with RestrictedFunctionSpace
419+
boundary_sets = [V_.boundary_set for V_ in V_parent]
420+
if any(boundary_sets):
421+
V = restricted_function_space(V, boundary_sets)
422+
412423
for i in reversed(indices):
413424
V = V.sub(i)
414425
return V
@@ -899,8 +910,7 @@ def __init__(self, function_space, boundary_set=frozenset(), name=None):
899910
function_space.ufl_element(),
900911
label=self._label)
901912
self.function_space = function_space
902-
self.name = name or (function_space.name or "Restricted" + "_"
903-
+ "_".join(sorted(map(str, self.boundary_set))))
913+
self.name = name or function_space.name
904914

905915
def set_shared_data(self):
906916
sdata = get_shared_data(self._mesh, self.ufl_element(), self.boundary_set)

firedrake/mg/embedded.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -87,11 +87,14 @@ def _native_transfer(self, element, op):
8787
return None
8888

8989
def cache(self, V):
90-
key = (V.ufl_element(), V.value_shape)
90+
key = (V.ufl_element(), V.value_shape, V.boundary_set)
9191
try:
9292
return self.caches[key]
9393
except KeyError:
94-
return self.caches.setdefault(key, TransferManager.Cache(*key))
94+
return self.caches.setdefault(key, TransferManager.Cache(*key[:2]))
95+
96+
def cache_key(self, V):
97+
return (V.dim(),)
9598

9699
def V_dof_weights(self, V):
97100
"""Dof weights for averaging projection.
@@ -100,7 +103,7 @@ def V_dof_weights(self, V):
100103
:returns: A PETSc Vec.
101104
"""
102105
cache = self.cache(V)
103-
key = V.dim()
106+
key = self.cache_key(V)
104107
try:
105108
return cache._V_dof_weights[key]
106109
except KeyError:
@@ -125,7 +128,7 @@ def V_DG_mass(self, V, DG):
125128
:returns: A PETSc Mat mapping from V -> DG
126129
"""
127130
cache = self.cache(V)
128-
key = V.dim()
131+
key = self.cache_key(V)
129132
try:
130133
return cache._V_DG_mass[key]
131134
except KeyError:
@@ -156,7 +159,7 @@ def V_approx_inv_mass(self, V, DG):
156159
:returns: A PETSc Mat mapping from V -> DG.
157160
"""
158161
cache = self.cache(V)
159-
key = V.dim()
162+
key = self.cache_key(V)
160163
try:
161164
return cache._V_approx_inv_mass[key]
162165
except KeyError:
@@ -174,7 +177,7 @@ def V_inv_mass_ksp(self, V):
174177
:returns: A PETSc KSP for inverting (V, V).
175178
"""
176179
cache = self.cache(V)
177-
key = V.dim()
180+
key = self.cache_key(V)
178181
try:
179182
return cache._V_inv_mass_ksp[key]
180183
except KeyError:
@@ -196,7 +199,7 @@ def DG_work(self, V):
196199
"""
197200
needs_dual = ufl.duals.is_dual(V)
198201
cache = self.cache(V)
199-
key = (V.dim(), needs_dual)
202+
key = self.cache_key(V) + (needs_dual,)
200203
try:
201204
return cache._DG_work[key]
202205
except KeyError:
@@ -213,7 +216,7 @@ def work_vec(self, V):
213216
:returns: A PETSc Vec for V.
214217
"""
215218
cache = self.cache(V)
216-
key = V.dim()
219+
key = self.cache_key(V)
217220
try:
218221
return cache._work_vec[key]
219222
except KeyError:

firedrake/mg/interface.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -58,18 +58,17 @@ def prolong(coarse, fine):
5858
repeat = (fine_level - coarse_level)*refinements_per_level
5959
next_level = coarse_level * refinements_per_level
6060

61-
element = Vc.ufl_element()
6261
meshes = hierarchy._meshes
6362
for j in range(repeat):
6463
next_level += 1
6564
if j == repeat - 1:
6665
next = fine
6766
Vf = fine.function_space()
6867
else:
69-
Vf = firedrake.FunctionSpace(meshes[next_level], element)
68+
Vf = Vc.reconstruct(mesh=meshes[next_level])
7069
next = firedrake.Function(Vf)
7170

72-
coarse_coords = Vc.mesh().coordinates
71+
coarse_coords = get_coordinates(Vc)
7372
fine_to_coarse = utils.fine_node_to_coarse_node_map(Vf, Vc)
7473
fine_to_coarse_coords = utils.fine_node_to_coarse_node_map(Vf, coarse_coords.function_space())
7574
kernel = kernels.prolong_kernel(coarse)
@@ -119,7 +118,6 @@ def restrict(fine_dual, coarse_dual):
119118
repeat = (fine_level - coarse_level)*refinements_per_level
120119
next_level = fine_level * refinements_per_level
121120

122-
element = Vc.ufl_element()
123121
meshes = hierarchy._meshes
124122

125123
for j in range(repeat):
@@ -128,15 +126,15 @@ def restrict(fine_dual, coarse_dual):
128126
coarse_dual.dat.zero()
129127
next = coarse_dual
130128
else:
131-
Vc = firedrake.FunctionSpace(meshes[next_level], element)
132-
next = firedrake.Cofunction(Vc.dual())
129+
Vc = Vf.reconstruct(mesh=meshes[next_level])
130+
next = firedrake.Cofunction(Vc)
133131
Vc = next.function_space()
134132
# XXX: Should be able to figure out locations by pushing forward
135133
# reference cell node locations to physical space.
136134
# x = \sum_i c_i \phi_i(x_hat)
137-
node_locations = utils.physical_node_locations(Vf)
135+
node_locations = utils.physical_node_locations(Vf.dual())
138136

139-
coarse_coords = Vc.mesh().coordinates
137+
coarse_coords = get_coordinates(Vc.dual())
140138
fine_to_coarse = utils.fine_node_to_coarse_node_map(Vf, Vc)
141139
fine_to_coarse_coords = utils.fine_node_to_coarse_node_map(Vf, coarse_coords.function_space())
142140
# Have to do this, because the node set core size is not right for
@@ -195,7 +193,6 @@ def inject(fine, coarse):
195193
repeat = (fine_level - coarse_level)*refinements_per_level
196194
next_level = fine_level * refinements_per_level
197195

198-
element = Vc.ufl_element()
199196
meshes = hierarchy._meshes
200197

201198
for j in range(repeat):
@@ -205,12 +202,12 @@ def inject(fine, coarse):
205202
next = coarse
206203
Vc = next.function_space()
207204
else:
208-
Vc = firedrake.FunctionSpace(meshes[next_level], element)
205+
Vc = Vf.reconstruct(mesh=meshes[next_level])
209206
next = firedrake.Function(Vc)
210207
if not dg:
211208
node_locations = utils.physical_node_locations(Vc)
212209

213-
fine_coords = Vf.mesh().coordinates
210+
fine_coords = get_coordinates(Vf)
214211
coarse_node_to_fine_nodes = utils.coarse_node_to_fine_node_map(Vc, Vf)
215212
coarse_node_to_fine_coords = utils.coarse_node_to_fine_node_map(Vc, fine_coords.function_space())
216213

@@ -242,3 +239,11 @@ def inject(fine, coarse):
242239
fine = next
243240
Vf = Vc
244241
return coarse
242+
243+
244+
def get_coordinates(V):
245+
coords = V.mesh().coordinates
246+
if V.boundary_set:
247+
W = V.reconstruct(element=coords.function_space().ufl_element())
248+
coords = firedrake.Function(W).interpolate(coords)
249+
return coords

firedrake/mg/ufl_utils.py

Lines changed: 17 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ def inject_on_restrict(fine, restriction, rscale, injection, coarse):
195195
for bc in cctx._problem.dirichlet_bcs():
196196
bc.apply(cctx._x)
197197

198-
dm = problem.u.function_space().dm
198+
dm = problem.u_restrict.function_space().dm
199199
if not dm.getAttr("_coarsen_hook"):
200200
# The hook is persistent and cumulative, but also problem-independent.
201201
# Therefore, we are only adding it once.
@@ -209,7 +209,7 @@ def inject_on_restrict(fine, restriction, rscale, injection, coarse):
209209
F = self(problem.F, self, coefficient_mapping=coefficient_mapping)
210210
J = self(problem.J, self, coefficient_mapping=coefficient_mapping)
211211
Jp = self(problem.Jp, self, coefficient_mapping=coefficient_mapping)
212-
u = coefficient_mapping[problem.u]
212+
u = coefficient_mapping[problem.u_restrict]
213213

214214
fine = problem
215215
problem = firedrake.NonlinearVariationalProblem(F, u, bcs=bcs, J=J, Jp=Jp, is_linear=problem.is_linear,
@@ -283,7 +283,7 @@ def coarsen_snescontext(context, self, coefficient_mapping=None):
283283
if isinstance(val, (firedrake.Function, firedrake.Cofunction)):
284284
V = val.function_space()
285285
coarseneddm = V.dm
286-
parentdm = get_parent(context._problem.u.function_space().dm)
286+
parentdm = get_parent(context._problem.u_restrict.function_space().dm)
287287

288288
# Now attach the hook to the parent DM
289289
if get_appctx(coarseneddm) is None:
@@ -303,11 +303,11 @@ def coarsen_snescontext(context, self, coefficient_mapping=None):
303303

304304

305305
class Interpolation(object):
306-
def __init__(self, coarse, fine, manager, cbcs=None, fbcs=None):
307-
self.cprimal = coarse
308-
self.fprimal = fine
309-
self.cdual = coarse.riesz_representation(riesz_map="l2")
310-
self.fdual = fine.riesz_representation(riesz_map="l2")
306+
def __init__(self, Vcoarse, Vfine, manager, cbcs=None, fbcs=None):
307+
self.cprimal = firedrake.Function(Vcoarse)
308+
self.fprimal = firedrake.Function(Vfine)
309+
self.cdual = firedrake.Cofunction(Vcoarse.dual())
310+
self.fdual = firedrake.Cofunction(Vfine.dual())
311311
self.cbcs = cbcs or []
312312
self.fbcs = fbcs or []
313313
self.manager = manager
@@ -352,9 +352,9 @@ def multTransposeAdd(self, mat, x, y, w):
352352

353353

354354
class Injection(object):
355-
def __init__(self, cfn, ffn, manager, cbcs=None):
356-
self.cfn = cfn
357-
self.ffn = ffn
355+
def __init__(self, Vcoarse, Vfine, manager, cbcs=None):
356+
self.cfn = firedrake.Function(Vcoarse)
357+
self.ffn = firedrake.Function(Vfine)
358358
self.cbcs = cbcs or []
359359
self.manager = manager
360360

@@ -374,18 +374,15 @@ def create_interpolation(dmc, dmf):
374374

375375
manager = get_transfer_manager(dmf)
376376

377-
V_c = cctx._problem.u.function_space()
378-
V_f = fctx._problem.u.function_space()
377+
V_c = cctx._problem.u_restrict.function_space()
378+
V_f = fctx._problem.u_restrict.function_space()
379379

380380
row_size = V_f.dof_dset.layout_vec.getSizes()
381381
col_size = V_c.dof_dset.layout_vec.getSizes()
382-
383-
cfn = firedrake.Function(V_c)
384-
ffn = firedrake.Function(V_f)
385382
cbcs = tuple(cctx._problem.dirichlet_bcs())
386383
fbcs = tuple(fctx._problem.dirichlet_bcs())
387384

388-
ctx = Interpolation(cfn, ffn, manager, cbcs, fbcs)
385+
ctx = Interpolation(V_c, V_f, manager, cbcs, fbcs)
389386
mat = PETSc.Mat().create(comm=dmc.comm)
390387
mat.setSizes((row_size, col_size))
391388
mat.setType(mat.Type.PYTHON)
@@ -400,16 +397,13 @@ def create_injection(dmc, dmf):
400397

401398
manager = get_transfer_manager(dmf)
402399

403-
V_c = cctx._problem.u.function_space()
404-
V_f = fctx._problem.u.function_space()
400+
V_c = cctx._problem.u_restrict.function_space()
401+
V_f = fctx._problem.u_restrict.function_space()
405402

406403
row_size = V_c.dof_dset.layout_vec.getSizes()
407404
col_size = V_f.dof_dset.layout_vec.getSizes()
408405

409-
cfn = firedrake.Function(V_c)
410-
ffn = firedrake.Function(V_f)
411-
412-
ctx = Injection(cfn, ffn, manager)
406+
ctx = Injection(V_c, V_f, manager)
413407
mat = PETSc.Mat().create(comm=dmc.comm)
414408
mat.setSizes((row_size, col_size))
415409
mat.setType(mat.Type.PYTHON)

0 commit comments

Comments
 (0)