Skip to content

Commit 8d18a85

Browse files
committed
feat: load_scalar as Figure 2.3
1 parent 07576e9 commit 8d18a85

3 files changed

Lines changed: 52 additions & 32 deletions

File tree

tests/test_loader.py

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,41 @@
11
import pytest
22

3-
from discopy import closed, monoidal
4-
5-
from widip.lang import Box, Ty, Id
3+
from widip.lang import Ty, Id
64
from widip.loader import repl_read
75

8-
@pytest.mark.parametrize(["yaml_text", "expected_box"], [
6+
7+
INPUT_ONLY_SCALAR = Id(Ty("some spaced scalar") >> Ty()).curry(1, left=True)
8+
TAG_AND_VALUE_SCALAR = Id(Ty("tagged") @ Ty("scalar") >> Ty())
9+
TAG_ONLY_SCALAR = Id(Ty("just_tag") >> Ty()).curry(1, left=False)
10+
EMPTY_SCALAR = Id(Ty() >> Ty())
11+
EMPTY_STREAM = Id(Ty())
12+
13+
14+
@pytest.mark.parametrize(["yaml_text", "expected"], [
915
[
1016
"some spaced scalar",
11-
Box("⌜−⌝", Ty("some spaced scalar"), Ty() >> Ty("some spaced scalar"))],
17+
INPUT_ONLY_SCALAR,
18+
],
1219
[
1320
"!tagged scalar",
14-
Box("tagged", Ty("scalar"), Ty("tagged") >> Ty("tagged"))],
21+
TAG_AND_VALUE_SCALAR,
22+
],
1523
[
1624
"!just_tag",
17-
Box("just_tag", Ty(""), Ty("just_tag") >> Ty("just_tag"))],
25+
TAG_ONLY_SCALAR,
26+
],
27+
[
28+
"''",
29+
EMPTY_SCALAR,
30+
],
1831
[
1932
"",
20-
Id(Ty())],
33+
EMPTY_STREAM,
34+
],
2135
])
22-
def test_loader_encoding(yaml_text, expected_box):
23-
assert repl_read(yaml_text) == expected_box
36+
def test_loader_encoding(yaml_text, expected):
37+
assert repl_read(yaml_text) == expected
38+
39+
40+
def test_scalar_and_tag_only_are_different_programs():
41+
assert repl_read("some spaced scalar") != repl_read("!just_tag")

widip/lang.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,33 @@
1-
"""The Run language category"""
1+
"""The Run language category."""
22

3-
from discopy import closed, markov, monoidal, symmetric
3+
from discopy import closed, markov, symmetric
44
from discopy.utils import factory
55

6+
67
class Box(
78
closed.Box,
89
markov.Box,
9-
symmetric.Box
10+
symmetric.Box,
1011
):
1112
""""""
1213

14+
1315
@factory
14-
class Ty(
15-
closed.Ty,
16-
):
16+
class Ty(closed.Ty):
1717
def __rshift__(self, other):
18-
return self.factory(closed.Under(self, other))
18+
return self.factory(closed.Under(other, self))
1919

2020
def __lshift__(self, other):
2121
return self.factory(closed.Over(self, other))
2222

23+
@property
24+
def base(self):
25+
return self.inside[0].base if self.is_exp else None
26+
27+
@property
28+
def exponent(self):
29+
return self.inside[0].exponent if self.is_exp else None
30+
2331

2432
def Id(x=None):
2533
"""Identity diagram over widip.lang.Ty (defaults to Ty())."""

widip/loader.py

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -47,21 +47,16 @@ def _incidences_to_diagram(node: HyperGraph, index):
4747

4848

4949
def load_scalar(node, index, tag):
50+
"""Figure 2.3: If g = {G}, then g ◦ (s × id) = {Gs} uses G but g ◦ (id × t) = {H} does not."""
5051
v = hif_node(node, index)["value"]
51-
if tag and v:
52-
return Box(tag, Ty(v), Ty(tag) >> Ty(tag))
53-
return Box("run", Ty(tag) @ Ty(v), Ty(tag)).curry(left=False)
54-
elif tag:
55-
return Box(tag, Ty(v), Ty(tag) >> Ty(tag))
56-
return Box("run", Ty(tag), Ty(tag)).curry(left=False)
57-
return Box(tag, Ty(), Ty() << Ty(""))
58-
elif v:
59-
return Box("⌜−⌝", Ty(v), Ty() >> Ty(v))
60-
return Box("⌜−⌝", Ty(v), Ty(tag)).curry(0, left=False)
61-
return Box("⌜−⌝", Ty(v), Ty() << Ty(""))
62-
else:
63-
return Box("⌜−⌝", Ty(), Ty() >> Ty(v))
64-
return Box("⌜−⌝", Ty(), Ty(tag)).curry(0, left=False)
52+
X = Ty(tag) if tag else Ty()
53+
A = Ty(v) if v else Ty()
54+
# Differentiate parameter-only vs input-only scalar forms.
55+
if X == Ty() and A != Ty():
56+
return Id(A >> Ty()).curry(1, left=True)
57+
if X != Ty() and A == Ty():
58+
return Id(X >> Ty()).curry(1, left=False)
59+
return Id(X @ A >> Ty())
6560

6661
def load_mapping(node, index, tag):
6762
ob = Id()
@@ -96,7 +91,6 @@ def load_mapping(node, index, tag):
9691
if tag:
9792
ob = (ob @ exps >> Eval(exps >> bases))
9893
box = Box(tag, ob.cod, Ty(tag) >> Ty(tag))
99-
# box = Box("run", Ty(tag) @ ob.cod, Ty(tag)).curry(left=False)
10094
ob = ob >> box
10195
return ob
10296

0 commit comments

Comments
 (0)