-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsemantic.py
More file actions
204 lines (181 loc) · 6.93 KB
/
semantic.py
File metadata and controls
204 lines (181 loc) · 6.93 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
# app/semantic.py
from __future__ import annotations
import re
import unicodedata
from typing import Any, Dict, List, Literal, Optional
from sqlalchemy.orm import Session
from sqlalchemy import text
from hashing import content_hash
from embedding import embed
from similarity import cosine_similarity
from db import decode_embedding
from config import EMBEDDINGS_MODEL, DUPLICATE_THRESHOLD, NEAR_DUPLICATE_THRESHOLD
Classification = Literal["duplicate", "near_duplicate", "new"]
def classify(sim):
if sim >= DUPLICATE_THRESHOLD:
return "duplicate"
if sim >= NEAR_DUPLICATE_THRESHOLD:
return "near_duplicate"
return "new"
def normalize_claim_text(t):
t = unicodedata.normalize("NFC", t)
t = t.strip().lower()
t = re.sub(r"\s+", " ", t)
return t
def ensure_claim(db, claim_text):
"""Insert claim if not present. Returns claim_id."""
h = content_hash(claim_text)
row = db.execute(
text("SELECT claim_id FROM claim WHERE content_hash=:h"), {"h": h}
).fetchone()
if row:
return int(row[0])
h_norm = content_hash(normalize_claim_text(claim_text))
if h_norm != h:
row = db.execute(
text("SELECT claim_id FROM claim WHERE content_hash=:h"),
{"h": h_norm},
).fetchone()
if row:
return int(row[0])
row = db.execute(
text(
"INSERT INTO claim (claim_text, content_hash) "
"VALUES (:t,:h) RETURNING claim_id"
),
{"t": claim_text, "h": h},
).fetchone()
cid = int(row[0])
vec = embed(claim_text)
db.execute(
text(
"INSERT INTO claim_embedding "
"(claim_id, embedding_model, embedding) VALUES (:id,:m,:v)"
),
{"id": cid, "m": EMBEDDINGS_MODEL, "v": vec},
)
db.commit()
return cid
def get_post_id(db, claim_id):
"""Get on-chain post_id for a claim, or None."""
row = db.execute(
text("SELECT post_id FROM claim WHERE claim_id = :id"),
{"id": claim_id},
).fetchone()
if row and row[0] is not None:
return int(row[0])
return None
def compute_one(db, claim_text, top_k=5):
cid = ensure_claim(db, claim_text)
post_id = get_post_id(db, cid)
similar = []
try:
rows = db.execute(
text(
"WITH q AS ("
" SELECT embedding FROM claim_embedding WHERE claim_id=:id"
") "
"SELECT c.claim_id, c.claim_text, "
" (1.0 - (e.embedding <=> q.embedding)) AS similarity "
"FROM claim c "
"JOIN claim_embedding e USING (claim_id) "
"CROSS JOIN q "
"WHERE c.claim_id != :id "
"ORDER BY (e.embedding <=> q.embedding) ASC "
"LIMIT :k"
),
{"id": cid, "k": top_k},
).fetchall()
similar = [
{"claim_id": int(r[0]), "text": str(r[1]), "similarity": float(r[2])}
for r in rows
]
except Exception:
q = db.execute(
text("SELECT embedding FROM claim_embedding WHERE claim_id=:id"),
{"id": cid},
).fetchone()
qvec = decode_embedding(db, q[0]) or []
rows = db.execute(
text(
"SELECT c.claim_id, c.claim_text, e.embedding "
"FROM claim c "
"JOIN claim_embedding e USING (claim_id) "
"WHERE c.claim_id != :id"
),
{"id": cid},
).fetchall()
for ocid, txt, emb_val in rows:
avec = decode_embedding(db, emb_val) or []
sim = cosine_similarity(qvec, avec) if qvec and avec else 0.0
similar.append(
{"claim_id": int(ocid), "text": str(txt), "similarity": float(sim)}
)
similar.sort(key=lambda x: x["similarity"], reverse=True)
similar = similar[:top_k]
max_sim = float(similar[0]["similarity"]) if similar else 0.0
return {
"hash": content_hash(claim_text),
"claim_id": cid,
"post_id": post_id,
"classification": classify(max_sim),
"max_similarity": max_sim,
"similar": similar,
}
OVERLAY_THRESHOLD = 0.82
def find_best_onchain_match(db, sentence_text, exclude_post_ids=None):
"""Find the best on-chain claim matching a sentence by embedding similarity."""
if not sentence_text or not sentence_text.strip():
return None
exclude_post_ids = exclude_post_ids or set()
# Strategy 1: exact text match
row = db.execute(
text("SELECT claim_id, claim_text, post_id FROM claim "
"WHERE post_id IS NOT NULL AND LOWER(TRIM(claim_text)) = LOWER(TRIM(:t)) LIMIT 1"),
{"t": sentence_text},
).fetchone()
if row and int(row[2]) not in exclude_post_ids:
return {"claim_id": int(row[0]), "claim_text": str(row[1]),
"post_id": int(row[2]), "similarity": 1.0}
# Strategy 2: embedding similarity
try:
cid = ensure_claim(db, sentence_text)
rows = db.execute(text(
"WITH q AS (SELECT embedding FROM claim_embedding WHERE claim_id = :id) "
"SELECT c.claim_id, c.claim_text, c.post_id, "
" (1.0 - (e.embedding <=> q.embedding)) AS similarity "
"FROM claim c JOIN claim_embedding e USING (claim_id) CROSS JOIN q "
"WHERE c.claim_id != :id AND c.post_id IS NOT NULL "
"ORDER BY (e.embedding <=> q.embedding) ASC LIMIT 5"
), {"id": cid}).fetchall()
for r in rows:
pid = int(r[2])
sim = float(r[3])
if pid in exclude_post_ids:
continue
if sim >= OVERLAY_THRESHOLD:
return {"claim_id": int(r[0]), "claim_text": str(r[1]),
"post_id": pid, "similarity": sim}
except Exception as e:
logger.debug(f"Embedding match failed: {e}")
return None
def find_all_onchain_matches(db, sentence_text, top_k=3):
"""Find all on-chain claims matching a sentence above the overlay threshold."""
if not sentence_text or not sentence_text.strip():
return []
try:
cid = ensure_claim(db, sentence_text)
rows = db.execute(text(
"WITH q AS (SELECT embedding FROM claim_embedding WHERE claim_id = :id) "
"SELECT c.claim_id, c.claim_text, c.post_id, "
" (1.0 - (e.embedding <=> q.embedding)) AS similarity "
"FROM claim c JOIN claim_embedding e USING (claim_id) CROSS JOIN q "
"WHERE c.claim_id != :id AND c.post_id IS NOT NULL "
"ORDER BY (e.embedding <=> q.embedding) ASC LIMIT :k"
), {"id": cid, "k": top_k}).fetchall()
return [{"claim_id": int(r[0]), "claim_text": str(r[1]),
"post_id": int(r[2]), "similarity": float(r[3])}
for r in rows if float(r[3]) >= OVERLAY_THRESHOLD]
except Exception as e:
logger.debug(f"find_all_onchain_matches failed: {e}")
return []