Skip to content

Commit d4bead1

Browse files
fix
1 parent ec4e852 commit d4bead1

3 files changed

Lines changed: 376 additions & 17 deletions

File tree

pyrefly/lib/lsp/wasm/completion.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1105,6 +1105,12 @@ impl Transaction<'_> {
11051105
&mut result,
11061106
in_string_literal,
11071107
);
1108+
self.add_dict_value_literal_completions(
1109+
handle,
1110+
mod_module.as_ref(),
1111+
position,
1112+
&mut result,
1113+
);
11081114
let dict_key_claimed = self.add_dict_key_completions(
11091115
handle,
11101116
mod_module.as_ref(),

pyrefly/lib/state/lsp/dict_completions.rs

Lines changed: 308 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,61 @@ impl DictKeyLiteralContext {
7373
}
7474

7575
impl<'a> Transaction<'a> {
76+
fn named_target_type(&self, handle: &Handle, expr: &Expr) -> Option<Type> {
77+
let Expr::Name(name) = expr else {
78+
return None;
79+
};
80+
let short_id = ShortIdentifier::expr_name(name);
81+
let bindings = self.get_bindings(handle)?;
82+
let bound_key = Key::BoundName(short_id);
83+
if bindings.is_valid_key(&bound_key) {
84+
return self.get_type(handle, &bound_key);
85+
}
86+
let def_key = Key::Definition(short_id);
87+
if bindings.is_valid_key(&def_key) {
88+
self.get_type(handle, &def_key)
89+
} else {
90+
None
91+
}
92+
}
93+
94+
fn dict_literal_expected_type(
95+
&self,
96+
handle: &Handle,
97+
module: &ModModule,
98+
dict: &ExprDict,
99+
) -> Option<Type> {
100+
for node in Ast::locate_node(module, dict.range().start()) {
101+
match node {
102+
AnyNodeRef::StmtAnnAssign(assign)
103+
if assign
104+
.value
105+
.as_ref()
106+
.is_some_and(|value| value.range() == dict.range()) =>
107+
{
108+
return self.named_target_type(handle, assign.target.as_ref());
109+
}
110+
AnyNodeRef::StmtAssign(assign)
111+
if assign.value.range() == dict.range() && assign.targets.len() == 1 =>
112+
{
113+
return self.named_target_type(handle, &assign.targets[0]);
114+
}
115+
_ => {}
116+
}
117+
}
118+
None
119+
}
120+
121+
fn dict_literal_contextual_type(
122+
&self,
123+
handle: &Handle,
124+
module: &ModModule,
125+
dict: &ExprDict,
126+
) -> Option<Type> {
127+
self.dict_literal_expected_type(handle, module, dict)
128+
.or_else(|| self.get_type_trace(handle, dict.range()))
129+
}
130+
76131
fn type_contains_typed_dict(ty: &Type) -> bool {
77132
match ty {
78133
Type::TypedDict(_) | Type::PartialTypedDict(_) => true,
@@ -83,6 +138,130 @@ impl<'a> Transaction<'a> {
83138
}
84139
}
85140

141+
fn typed_dict_members(base_type: Type) -> Vec<Type> {
142+
let mut members = Vec::new();
143+
let mut stack = vec![base_type];
144+
while let Some(ty) = stack.pop() {
145+
match ty {
146+
Type::TypedDict(_) | Type::PartialTypedDict(_) => members.push(ty),
147+
Type::Union(box Union {
148+
members: union_members,
149+
..
150+
}) => {
151+
stack.extend(union_members.into_iter());
152+
}
153+
_ => {}
154+
}
155+
}
156+
members
157+
}
158+
159+
fn typed_dict_field_type_in_member<'b>(
160+
solver: &crate::alt::answers_solver::AnswersSolver<
161+
crate::state::lsp::TransactionHandle<'b>,
162+
>,
163+
member: &Type,
164+
key: &str,
165+
) -> Option<Type> {
166+
let typed_dict = match member {
167+
Type::TypedDict(td) | Type::PartialTypedDict(td) => td,
168+
_ => return None,
169+
};
170+
solver
171+
.type_order()
172+
.typed_dict_fields(typed_dict)
173+
.iter()
174+
.find_map(|(name, field)| (name == key).then(|| field.ty.clone()))
175+
}
176+
177+
fn narrowed_typed_dict_members_for_dict_literal(
178+
&self,
179+
handle: &Handle,
180+
module: &ModModule,
181+
dict: &ExprDict,
182+
skip_key_range: Option<TextRange>,
183+
skip_value_range: Option<TextRange>,
184+
) -> Option<Vec<Type>> {
185+
let base_type = self.dict_literal_contextual_type(handle, module, dict)?;
186+
self.ad_hoc_solve(handle, "dict_literal_typed_dict_members", |solver| {
187+
let members = Self::typed_dict_members(base_type);
188+
if members.is_empty() {
189+
return Vec::new();
190+
}
191+
let narrowed = members
192+
.iter()
193+
.filter(|member| {
194+
dict.items.iter().all(|item| {
195+
let Some(key_expr) = item.key.as_ref() else {
196+
return true;
197+
};
198+
let value_expr = &item.value;
199+
let Expr::StringLiteral(key_lit) = key_expr else {
200+
return true;
201+
};
202+
if skip_key_range == Some(key_lit.range())
203+
|| skip_value_range == Some(value_expr.range())
204+
{
205+
return true;
206+
}
207+
let Some(field_ty) = Self::typed_dict_field_type_in_member(
208+
&solver,
209+
member,
210+
key_lit.value.to_str(),
211+
) else {
212+
return false;
213+
};
214+
let Some(value_ty) = self.get_type_trace(handle, value_expr.range()) else {
215+
return true;
216+
};
217+
solver.is_subset_eq(&value_ty, &field_ty)
218+
})
219+
})
220+
.cloned()
221+
.collect::<Vec<_>>();
222+
if narrowed.is_empty() {
223+
members
224+
} else {
225+
narrowed
226+
}
227+
})
228+
}
229+
230+
fn typed_dict_field_type_from_members(
231+
&self,
232+
handle: &Handle,
233+
members: Vec<Type>,
234+
key: &str,
235+
) -> Option<Type> {
236+
self.ad_hoc_solve(handle, "typed_dict_field_type", |solver| {
237+
let field_types = members
238+
.iter()
239+
.filter_map(|member| Self::typed_dict_field_type_in_member(&solver, member, key))
240+
.collect::<Vec<_>>();
241+
match field_types.len() {
242+
0 => None,
243+
1 => field_types.into_iter().next(),
244+
_ => Some(solver.unions(field_types)),
245+
}
246+
})
247+
.flatten()
248+
}
249+
250+
fn dict_literal_present_keys(
251+
dict: &ExprDict,
252+
skip_key_range: Option<TextRange>,
253+
) -> BTreeMap<String, ()> {
254+
dict.items
255+
.iter()
256+
.filter_map(|item| {
257+
let Expr::StringLiteral(lit) = item.key.as_ref()? else {
258+
return None;
259+
};
260+
(skip_key_range != Some(lit.range())).then(|| (lit.value.to_string(), ()))
261+
})
262+
.collect()
263+
}
264+
86265
fn expr_has_typed_dict_type(&self, handle: &Handle, expr: &Expr) -> bool {
87266
self.get_type_trace(handle, expr.range())
88267
.map(|ty| Self::type_contains_typed_dict(&ty))
@@ -244,6 +423,58 @@ impl<'a> Transaction<'a> {
244423
best.map(|(_, _, dict, literal)| (dict, literal))
245424
}
246425

426+
fn dict_literal_value_string_literal_at(
427+
module: &ModModule,
428+
position: TextSize,
429+
) -> Option<(ExprDict, ExprStringLiteral, ExprStringLiteral)> {
430+
let nodes = Ast::locate_node(module, position);
431+
let mut best: Option<(u8, TextSize, ExprDict, ExprStringLiteral, ExprStringLiteral)> = None;
432+
for node in nodes {
433+
let AnyNodeRef::ExprDict(dict) = node else {
434+
continue;
435+
};
436+
let mut best_in_dict: Option<(u8, TextSize, ExprStringLiteral, ExprStringLiteral)> =
437+
None;
438+
for item in &dict.items {
439+
let Some(Expr::StringLiteral(key_lit)) = item.key.as_ref() else {
440+
continue;
441+
};
442+
let Expr::StringLiteral(value_lit) = &item.value else {
443+
continue;
444+
};
445+
let (priority, dist) = Self::string_literal_priority(position, value_lit.range());
446+
let should_update = match &best_in_dict {
447+
Some((best_prio, best_dist, _, _)) => {
448+
priority < *best_prio || (priority == *best_prio && dist < *best_dist)
449+
}
450+
None => true,
451+
};
452+
if should_update {
453+
best_in_dict = Some((priority, dist, key_lit.clone(), value_lit.clone()));
454+
if priority == 0 && dist == TextSize::from(0) {
455+
break;
456+
}
457+
}
458+
}
459+
let Some((priority, dist, key_lit, value_lit)) = best_in_dict else {
460+
continue;
461+
};
462+
let should_update = match &best {
463+
Some((best_prio, best_dist, _, _, _)) => {
464+
priority < *best_prio || (priority == *best_prio && dist < *best_dist)
465+
}
466+
None => true,
467+
};
468+
if should_update {
469+
best = Some((priority, dist, dict.clone(), key_lit, value_lit));
470+
if priority == 0 && dist == TextSize::from(0) {
471+
break;
472+
}
473+
}
474+
}
475+
best.map(|(_, _, dict, key_lit, value_lit)| (dict, key_lit, value_lit))
476+
}
477+
247478
fn expression_facets(expr: &Expr) -> Option<(Identifier, Vec<FacetKind>)> {
248479
let mut facets = Vec::new();
249480
let mut current = expr;
@@ -279,25 +510,52 @@ impl<'a> Transaction<'a> {
279510
) -> Option<BTreeMap<String, Type>> {
280511
self.ad_hoc_solve(handle, "typed_dict_keys", |solver| {
281512
let mut map = BTreeMap::new();
282-
let mut stack = vec![base_type];
283-
while let Some(ty) = stack.pop() {
284-
match ty {
285-
Type::TypedDict(td) | Type::PartialTypedDict(td) => {
286-
for (name, field) in solver.type_order().typed_dict_fields(&td) {
287-
map.entry(name.to_string())
288-
.or_insert_with(|| field.ty.clone());
289-
}
290-
}
291-
Type::Union(box Union { members, .. }) => {
292-
stack.extend(members.into_iter());
293-
}
294-
_ => {}
513+
for member in Self::typed_dict_members(base_type) {
514+
let typed_dict = match member {
515+
Type::TypedDict(td) | Type::PartialTypedDict(td) => td,
516+
_ => continue,
517+
};
518+
for (name, field) in solver.type_order().typed_dict_fields(&typed_dict) {
519+
map.entry(name.to_string())
520+
.or_insert_with(|| field.ty.clone());
295521
}
296522
}
297523
map
298524
})
299525
}
300526

527+
pub(crate) fn add_dict_value_literal_completions(
528+
&self,
529+
handle: &Handle,
530+
module: &ModModule,
531+
position: TextSize,
532+
completions: &mut Vec<RankedCompletion>,
533+
) {
534+
let Some((dict, key_lit, value_lit)) =
535+
Self::dict_literal_value_string_literal_at(module, position)
536+
else {
537+
return;
538+
};
539+
if position < value_lit.range().start() || position > value_lit.range().end() {
540+
return;
541+
}
542+
let Some(members) = self.narrowed_typed_dict_members_for_dict_literal(
543+
handle,
544+
module,
545+
&dict,
546+
Some(key_lit.range()),
547+
Some(value_lit.range()),
548+
) else {
549+
return;
550+
};
551+
let Some(field_ty) =
552+
self.typed_dict_field_type_from_members(handle, members, key_lit.value.to_str())
553+
else {
554+
return;
555+
};
556+
Self::add_literal_completions_from_type(&field_ty, completions, true);
557+
}
558+
301559
/// Adds dict key completions for the given position. Returns `true` if this function
302560
/// claimed the position (i.e., we are inside a dict/TypedDict key string literal), in
303561
/// which case the caller should skip overload-based literal completions to avoid showing
@@ -366,12 +624,45 @@ impl<'a> Transaction<'a> {
366624
}
367625
}
368626

369-
// For key access we query the container expression; for literals we query the
370-
// literal itself to pick up contextual TypedDict typing from assignments.
371-
if let Some(base_type) = self.get_type_trace(handle, context.base_range())
372-
&& let Some(typed_keys) = self.collect_typed_dict_keys(handle, base_type)
627+
let dict_literal_members = match &context {
628+
DictKeyLiteralContext::DictLiteral { dict, literal } => self
629+
.narrowed_typed_dict_members_for_dict_literal(
630+
handle,
631+
module,
632+
dict,
633+
Some(literal.range()),
634+
None,
635+
),
636+
DictKeyLiteralContext::KeyAccess { .. } => None,
637+
};
638+
639+
// For key access we query the container expression; for literals we recover the
640+
// contextual type because incomplete dict literals may infer as plain `dict[...]`.
641+
if let Some(base_type) = match (&context, dict_literal_members.as_ref()) {
642+
(DictKeyLiteralContext::DictLiteral { .. }, Some(members)) => self
643+
.ad_hoc_solve(
644+
handle,
645+
"dict_literal_typed_dict_union",
646+
|solver| match members.len() {
647+
0 => None,
648+
1 => members.first().cloned(),
649+
_ => Some(solver.unions(members.clone())),
650+
},
651+
)
652+
.flatten(),
653+
_ => self.get_type_trace(handle, context.base_range()),
654+
} && let Some(typed_keys) = self.collect_typed_dict_keys(handle, base_type)
373655
{
656+
let present_keys = match &context {
657+
DictKeyLiteralContext::DictLiteral { dict, literal } => {
658+
Self::dict_literal_present_keys(dict, Some(literal.range()))
659+
}
660+
DictKeyLiteralContext::KeyAccess { .. } => BTreeMap::new(),
661+
};
374662
for (key, ty) in typed_keys {
663+
if present_keys.contains_key(&key) {
664+
continue;
665+
}
375666
let entry = suggestions.entry(key).or_insert(None);
376667
if entry.is_none() {
377668
*entry = Some(ty);

0 commit comments

Comments
 (0)