This repository was archived by the owner on Oct 6, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 13
Expand file tree
/
Copy pathtyped_sim.py
More file actions
324 lines (273 loc) · 8.59 KB
/
typed_sim.py
File metadata and controls
324 lines (273 loc) · 8.59 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
from fractions import Fraction
from dataclasses import dataclass
from typing import Optional, NewType
# A minimal example to illustrate typechecking.
class EndOfStream(Exception):
pass
@dataclass
class Stream:
source: str
pos: int
def from_string(s):
return Stream(s, 0)
def next_char(self):
if self.pos >= len(self.source):
raise EndOfStream()
self.pos = self.pos + 1
return self.source[self.pos - 1]
def unget(self):
assert self.pos > 0
self.pos = self.pos - 1
# Define the token types.
@dataclass
class Num:
n: int
@dataclass
class Bool:
b: bool
@dataclass
class Keyword:
word: str
@dataclass
class Identifier:
word: str
@dataclass
class Operator:
op: str
Token = Num | Bool | Keyword | Identifier | Operator
class EndOfTokens(Exception):
pass
keywords = "if then else end while do done".split()
symbolic_operators = "+ - × / < > ≤ ≥ = ≠".split()
word_operators = "and or not quot rem".split()
whitespace = " \t\n"
def word_to_token(word):
if word in keywords:
return Keyword(word)
if word in word_operators:
return Operator(word)
if word == "True":
return Bool(True)
if word == "False":
return Bool(False)
return Identifier(word)
class TokenError(Exception):
pass
@dataclass
class Lexer:
stream: Stream
save: Token = None
def from_stream(s):
return Lexer(s)
def next_token(self) -> Token:
try:
match self.stream.next_char():
case c if c in symbolic_operators: return Operator(c)
case c if c.isdigit():
n = int(c)
while True:
try:
c = self.stream.next_char()
if c.isdigit():
n = n*10 + int(c)
else:
self.stream.unget()
return Num(n)
except EndOfStream:
return Num(n)
case c if c.isalpha():
s = c
while True:
try:
c = self.stream.next_char()
if c.isalpha():
s = s + c
else:
self.stream.unget()
return word_to_token(s)
except EndOfStream:
return word_to_token(s)
case c if c in whitespace:
return self.next_token()
except EndOfStream:
raise EndOfTokens
def peek_token(self) -> Token:
if self.save is not None:
return self.save
self.save = self.next_token()
return self.save
def advance(self):
assert self.save is not None
self.save = None
def match(self, expected):
if self.peek_token() == expected:
return self.advance()
raise TokenError()
def __iter__(self):
return self
def __next__(self):
try:
return self.next_token()
except EndOfTokens:
raise StopIteration
@dataclass
class Parser:
lexer: Lexer
def from_lexer(lexer):
return Parser(lexer)
def parse_if(self):
self.lexer.match(Keyword("if"))
c = self.parse_expr()
self.lexer.match(Keyword("then"))
t = self.parse_expr()
self.lexer.match(Keyword("else"))
f = self.parse_expr()
self.lexer.match(Keyword("end"))
return IfElse(c, t, f)
def parse_while(self):
self.lexer.match(Keyword("while"))
c = self.parse_expr()
self.lexer.match(Keyword("do"))
b = self.parse_expr()
self.lexer.match(Keyword("done"))
return While(c, b)
def parse_atom(self):
match self.lexer.peek_token():
case Identifier(name):
self.lexer.advance()
return Variable(name)
case Num(value):
self.lexer.advance()
return NumLiteral(value)
case Bool(value):
self.lexer.advance()
return BoolLiteral(value)
def parse_mult(self):
left = self.parse_atom()
while True:
match self.lexer.peek_token():
case Operator(op) if op in "×/":
self.lexer.advance()
m = self.parse_atom()
left = BinOp(op, left, m)
case _:
break
return left
def parse_add(self):
left = self.parse_mult()
while True:
match self.lexer.peek_token():
case Operator(op) if op in "+-":
self.lexer.advance()
m = self.parse_mult()
left = BinOp(op, left, m)
case _:
break
return left
def parse_cmp(self):
left = self.parse_add()
match self.lexer.peek_token():
case Operator(op) if op in "<>":
self.lexer.advance()
right = self.parse_add()
return BinOp(op, left, right)
return left
def parse_simple(self):
return self.parse_cmp()
def parse_expr(self):
match self.lexer.peek_token():
case Keyword("if"):
return self.parse_if()
case Keyword("while"):
return self.parse_while()
case _:
return self.parse_simple()
@dataclass
class NumType:
pass
@dataclass
class BoolType:
pass
SimType = NumType | BoolType
@dataclass
class NumLiteral:
value: Fraction
type: SimType = NumType()
@dataclass
class BoolLiteral:
value: bool
type: SimType = BoolType()
@dataclass
class BinOp:
operator: str
left: 'AST'
right: 'AST'
type: Optional[SimType] = None
@dataclass
class IfElse:
condition: 'AST'
iftrue: 'AST'
iffalse: 'AST'
type: Optional[SimType] = None
@dataclass
class While:
condition: 'AST'
body: 'AST'
@dataclass
class Variable:
name: str
AST = NumLiteral | BoolLiteral | BinOp | IfElse | While | Variable
TypedAST = NewType('TypedAST', AST)
class TypeError(Exception):
pass
# Since we don't have variables, environment is not needed.
def typecheck(program: AST, env = None) -> TypedAST:
match program:
case NumLiteral() as t: # already typed.
return t
case BoolLiteral() as t: # already typed.
return t
case BinOp(op, left, right) if op in "+×-/":
tleft = typecheck(left)
tright = typecheck(right)
if tleft.type != NumType() or tright.type != NumType():
raise TypeError()
return BinOp(op, left, right, NumType())
case BinOp("<", left, right):
tleft = typecheck(left)
tright = typecheck(right)
if tleft.type != NumType() or tright.type != NumType():
raise TypeError()
return BinOp("<", left, right, BoolType())
case BinOp("=", left, right):
tleft = typecheck(left)
tright = typecheck(right)
if tleft.type != tright.type:
raise TypeError()
return BinOp("=", left, right, BoolType())
case IfElse(c, t, f): # We have to typecheck both branches.
tc = typecheck(c)
if tc.type != BoolType():
raise TypeError()
tt = typecheck(t)
tf = typecheck(f)
if tt.type != tf.type: # Both branches must have the same type.
raise TypeError()
return IfElse(tc, tt, tf, tt.type) # The common type becomes the type of the if-else.
raise TypeError()
def test_typecheck():
import pytest
te = typecheck(BinOp("+", NumLiteral(2), NumLiteral(3)))
assert te.type == NumType()
te = typecheck(BinOp("<", NumLiteral(2), NumLiteral(3)))
assert te.type == BoolType()
with pytest.raises(TypeError):
typecheck(BinOp("+", BinOp("×", NumLiteral(2), NumLiteral(3)), BinOp("<", NumLiteral(2), NumLiteral(3))))
def test_parse():
def parse(string):
return Parser.parse_expr (
Parser.from_lexer(Lexer.from_stream(Stream.from_string(string)))
)
# You should parse, evaluate and see whether the expression produces the expected value in your tests.
print(parse("if a+b > c×d then a×b - c + d else e×f/g end"))
# test_parse() # Uncomment to see the created ASTs.