-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
55 lines (44 loc) · 1.44 KB
/
main.py
File metadata and controls
55 lines (44 loc) · 1.44 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
from z3 import sat
from src.tensor import Context, Tensor
def main():
# --- Example Usage ---
ctx = Context()
# Create symbolic dimensions
N = ctx.new_symbol("N")
C = ctx.new_symbol("C")
H = ctx.new_symbol("H")
W = ctx.new_symbol("W")
bt32 = ctx.shape_from(["B", "T", 32])
# Create tensors with symbolic shapes
t1 = ctx.tensor(shape=[N, C, 24, W, 64])
t2 = ctx.tensor(shape=[N, C, H, W])
t3 = ctx.tensor(shape=bt32) # Mismatched shape
# Build the computation graph
t_out1 = t1 + t2
t_out2 = t1 + t3
ctx.trace()
solver = ctx.solver
print("simplified:", ctx.final_constraints)
print(f"\nSolver check: {solver.check()}")
if solver.check() == sat:
print(f"Model: {solver.model()}")
else:
core = solver.unsat_core()
for c in core:
con = ctx.constraint_map(c)
# inps = "\n".join(
# ["\t" + str(ctx.tensor_map(c).shape) for c in con.created_by]
# )
# print(f"constraint failed: {con}\nby inputs:\n{inps}")
inputs: list[Tensor] = list(
filter(None, [ctx.tensor_map(c) for c in con.created_by])
)
axis = 0
for i in inputs:
i.shape
# print(
# f"operation failed: {con}"
# f"reason: incompatible dimension at axis {} {}"
# )
if __name__ == "__main__":
main()