diff --git a/crates/emmylua_code_analysis/resources/std/builtin.lua b/crates/emmylua_code_analysis/resources/std/builtin.lua index da7557002..be4809620 100644 --- a/crates/emmylua_code_analysis/resources/std/builtin.lua +++ b/crates/emmylua_code_analysis/resources/std/builtin.lua @@ -128,6 +128,7 @@ --- built-in type for Rawget --- @alias std.RawGet unknown +--- @deprecated use `const T` as a replacement, for example `---@generic const T`. --- --- built-in type for generic template, for match integer const and true/false --- @alias std.ConstTpl unknown @@ -170,9 +171,12 @@ --- attribute +--- @class Attribute + --- --- Deprecated. Receives an optional message parameter. ---- @attribute deprecated(message: string?) +--- @class deprecated: Attribute +--- @overload fun(message?: string) --- --- Language Server Optimization Items. @@ -181,13 +185,15 @@ --- - `check_table_field`: Skip the assign check for table fields. It is recommended to use this option for all large configuration tables. --- - `delayed_definition`: Indicates that the type of the variable is determined by the first assignment. --- Only valid for `local` declarations with no initial value. ---- @attribute lsp_optimization(code: "check_table_field"|"delayed_definition") +--- @class lsp_optimization: Attribute +--- @overload fun(code: "check_table_field"|"delayed_definition") --- --- Index field alias, will be displayed in `hint` and `completion`. --- --- Receives a string parameter for the alias name. ---- @attribute index_alias(name: string) +--- @class index_alias: Attribute +--- @overload fun(name: string) --- --- This attribute must be applied to function parameters, and the function parameter's type must be a string template generic, @@ -200,7 +206,8 @@ --- - `return_mode`: Constructor return strategy. `"self"` forces `self`, `"doc"` uses the documented return type, --- and `"default"` prefers the documented return type and falls back to `self`. --- Defaults to `"default"` ---- @attribute constructor(name: string, root_class: string?, strip_self: boolean?, return_mode: "self"|"doc"|"default"?) +--- @class constructor: Attribute +--- @overload fun(name: string, root_class?: string, strip_self?: boolean, return_mode?: "self"|"doc"|"default") --- --- Associates `getter` and `setter` methods with a field. Currently provides only definition navigation functionality, @@ -210,4 +217,5 @@ --- - `convention`: Naming convention, defaults to `camelCase`. Implicitly adds `get` and `set` prefixes. eg: `_age` -> `getAge`, `setAge`. --- - `getter`: Getter method name. Takes precedence over `convention`. --- - `setter`: Setter method name. Takes precedence over `convention`. ---- @attribute field_accessor(convention: "camelCase"|"PascalCase"|"snake_case"|nil, getter: string?, setter: string?) +--- @class field_accessor: Attribute +--- @overload fun(convention?: "camelCase"|"PascalCase"|"snake_case", getter?: string, setter?: string) diff --git a/crates/emmylua_code_analysis/resources/std/global.lua b/crates/emmylua_code_analysis/resources/std/global.lua index 220ca2de9..b68dd7a52 100644 --- a/crates/emmylua_code_analysis/resources/std/global.lua +++ b/crates/emmylua_code_analysis/resources/std/global.lua @@ -277,9 +277,9 @@ function rawequal(v1, v2) end --- --- Gets the real value of `table[index]`, the `__index` metamethod. `table` --- must be a table; `index` may be any value. ---- @generic T, K +--- @generic const T, const K --- @param table T ---- @param index std.ConstTpl +--- @param index K --- @return std.RawGet function rawget(table, index) end @@ -340,8 +340,8 @@ function require(modname) end --- `index`. a negative number indexes from the end (-1 is the last argument). --- Otherwise, `index` must be the string "#", and `select` returns --- the total number of extra arguments it received. ---- @generic T, Num: integer | '#' ---- @param index std.ConstTpl +--- @generic T, const Num: integer | '#' +--- @param index Num --- @param ... T... --- @return std.Select function select(index, ...) end @@ -460,9 +460,9 @@ function xpcall(f, msgh, ...) end --- @version 5.1, JIT --- ---- @generic T, Start: integer, End: integer ---- @param i? std.ConstTpl ---- @param j? std.ConstTpl +--- @generic const T, const Start: integer, const End: integer +--- @param i? Start +--- @param j? End --- @param list T --- @return std.Unpack function unpack(list, i, j) end diff --git a/crates/emmylua_code_analysis/resources/std/table.lua b/crates/emmylua_code_analysis/resources/std/table.lua index ee60a7343..8ab495e5e 100644 --- a/crates/emmylua_code_analysis/resources/std/table.lua +++ b/crates/emmylua_code_analysis/resources/std/table.lua @@ -106,9 +106,9 @@ function table.sort(list, comp) end --- Returns the elements from the given list. This function is equivalent to --- return `list[i]`, `list[i+1]`, `···`, `list[j]` --- By default, i is 1 and j is #list. ---- @generic T, Start: integer, End: integer ---- @param i? std.ConstTpl ---- @param j? std.ConstTpl +--- @generic const T, const Start: integer, const End: integer +--- @param i? Start +--- @param j? End --- @param list T --- @return std.Unpack function table.unpack(list, i, j) end diff --git a/crates/emmylua_code_analysis/src/compilation/analyzer/decl/docs.rs b/crates/emmylua_code_analysis/src/compilation/analyzer/decl/docs.rs index 37682e436..60345dafa 100644 --- a/crates/emmylua_code_analysis/src/compilation/analyzer/decl/docs.rs +++ b/crates/emmylua_code_analysis/src/compilation/analyzer/decl/docs.rs @@ -1,7 +1,6 @@ use emmylua_parser::{ - LuaAstNode, LuaAstToken, LuaComment, LuaDocTag, LuaDocTagAlias, LuaDocTagAttribute, - LuaDocTagClass, LuaDocTagEnum, LuaDocTagMeta, LuaDocTagNamespace, LuaDocTagUsing, - LuaDocTypeFlag, + LuaAstNode, LuaAstToken, LuaComment, LuaDocTag, LuaDocTagAlias, LuaDocTagClass, LuaDocTagEnum, + LuaDocTagMeta, LuaDocTagNamespace, LuaDocTagUsing, LuaDocTypeFlag, }; use flagset::FlagSet; use rowan::TextRange; @@ -100,24 +99,6 @@ pub fn analyze_doc_tag_alias(analyzer: &mut DeclAnalyzer, alias: LuaDocTagAlias) Some(()) } -pub fn analyze_doc_tag_attribute( - analyzer: &mut DeclAnalyzer, - attribute: LuaDocTagAttribute, -) -> Option<()> { - let name_token = attribute.get_name_token()?; - let name = name_token.get_name_text().to_string(); - let range = name_token.syntax().text_range(); - - add_type_decl( - analyzer, - &name, - range, - LuaDeclTypeKind::Attribute, - FlagSet::default(), - ); - Some(()) -} - pub fn analyze_doc_tag_namespace( analyzer: &mut DeclAnalyzer, namespace: LuaDocTagNamespace, diff --git a/crates/emmylua_code_analysis/src/compilation/analyzer/decl/mod.rs b/crates/emmylua_code_analysis/src/compilation/analyzer/decl/mod.rs index 859111678..841770e63 100644 --- a/crates/emmylua_code_analysis/src/compilation/analyzer/decl/mod.rs +++ b/crates/emmylua_code_analysis/src/compilation/analyzer/decl/mod.rs @@ -107,9 +107,6 @@ fn walk_node_enter(analyzer: &mut DeclAnalyzer, node: LuaAst) { LuaAst::LuaDocTagAlias(doc_tag) => { docs::analyze_doc_tag_alias(analyzer, doc_tag); } - LuaAst::LuaDocTagAttribute(doc_tag) => { - docs::analyze_doc_tag_attribute(analyzer, doc_tag); - } LuaAst::LuaDocTagNamespace(doc_tag) => { docs::analyze_doc_tag_namespace(analyzer, doc_tag); } diff --git a/crates/emmylua_code_analysis/src/compilation/analyzer/doc/attribute_tags.rs b/crates/emmylua_code_analysis/src/compilation/analyzer/doc/attribute_tags.rs index 08f271bac..900260401 100644 --- a/crates/emmylua_code_analysis/src/compilation/analyzer/doc/attribute_tags.rs +++ b/crates/emmylua_code_analysis/src/compilation/analyzer/doc/attribute_tags.rs @@ -11,6 +11,7 @@ use crate::{ infer_type::infer_type, tags::{get_owner_id, report_orphan_tag}, }, + get_attribute_constructor_params, is_attribute_class, }; pub fn analyze_tag_attribute_use( @@ -64,27 +65,19 @@ pub fn infer_attribute_uses( LuaDocType::Name(attribute_use.get_type()?), ); if let LuaType::Ref(type_id) = attribute_type { + if !is_attribute_class(analyzer.type_context.db, &type_id) { + continue; + } + let arg_types: Vec = attribute_use .get_arg_list() .map(|arg_list| arg_list.get_args().map(infer_attribute_arg_type).collect()) .unwrap_or_default(); - let param_names = analyzer - .type_context - .db - .get_type_index() - .get_type_decl(&type_id) - .and_then(|decl| decl.get_attribute_type()) - .and_then(|typ| match typ { - LuaType::DocAttribute(attr_type) => Some( - attr_type - .get_params() - .iter() - .map(|(name, _)| name.clone()) - .collect::>(), - ), - _ => None, - }) - .unwrap_or_default(); + let param_names: Vec = + get_attribute_constructor_params(analyzer.type_context.db, &type_id, &arg_types) + .into_iter() + .map(|(name, _)| name) + .collect(); let mut params = Vec::new(); for (idx, arg_type) in arg_types.into_iter().enumerate() { diff --git a/crates/emmylua_code_analysis/src/compilation/analyzer/doc/file_generic_index.rs b/crates/emmylua_code_analysis/src/compilation/analyzer/doc/file_generic_index.rs index e3ad351d2..de39d8a0e 100644 --- a/crates/emmylua_code_analysis/src/compilation/analyzer/doc/file_generic_index.rs +++ b/crates/emmylua_code_analysis/src/compilation/analyzer/doc/file_generic_index.rs @@ -4,24 +4,32 @@ use rowan::{TextRange, TextSize}; use smol_str::SmolStr; use std::sync::Arc; -use crate::{GenericParam, GenericTpl, GenericTplId, LuaType}; +use crate::{GenericParam, GenericTpl, GenericTplId}; pub trait GenericIndex: std::fmt::Debug { fn add_generic_scope(&mut self, ranges: Vec, is_func: bool) -> GenericScopeId; - fn append_generic_param(&mut self, scope_id: GenericScopeId, param: GenericParam); + fn append_generic_param( + &mut self, + scope_id: GenericScopeId, + param: GenericParam, + ) -> Option; fn append_generic_params(&mut self, scope_id: GenericScopeId, params: Vec) { for param in params { - self.append_generic_param(scope_id, param); + let _ = self.append_generic_param(scope_id, param); } } - fn find_generic( - &self, - position: TextSize, - name: &str, - ) -> Option<(GenericTplId, Option, Option)>; + fn find_generic(&self, position: TextSize, name: &str) -> Option<(GenericTplId, GenericParam)>; + + fn generic_param_mut(&mut self, tpl_id: GenericTplId) -> Option<&mut GenericParam>; + + fn mark_generic_const(&mut self, tpl_id: GenericTplId) -> Option { + let param = self.generic_param_mut(tpl_id)?; + param.is_const = true; + Some(param.clone()) + } } #[derive(Debug, Clone)] @@ -63,36 +71,38 @@ impl GenericIndex for FileGenericIndex { scope_id } - fn append_generic_param(&mut self, scope_id: GenericScopeId, param: GenericParam) { + fn append_generic_param( + &mut self, + scope_id: GenericScopeId, + param: GenericParam, + ) -> Option { if let Some(scope) = self.scopes.get_mut(scope_id.id) { - scope.insert_param(param); - } - } - - fn append_generic_params(&mut self, scope_id: GenericScopeId, params: Vec) { - for param in params { - self.append_generic_param(scope_id, param); + return Some(scope.insert_param(param)); } + None } /// Find generic parameter by position and name. - /// return (GenericTplId, constraint, default) - fn find_generic( - &self, - position: TextSize, - name: &str, - ) -> Option<(GenericTplId, Option, Option)> { + fn find_generic(&self, position: TextSize, name: &str) -> Option<(GenericTplId, GenericParam)> { for scope in self.scopes.iter().rev() { if !scope.contains(position) { continue; } if let Some((id, param)) = scope.params.get(name) { - return Some(( - *id, - param.type_constraint.clone(), - param.default_type.clone(), - )); + return Some((*id, param.clone())); + } + } + + None + } + + fn generic_param_mut(&mut self, tpl_id: GenericTplId) -> Option<&mut GenericParam> { + for scope in self.scopes.iter_mut().rev() { + for (id, param) in scope.params.values_mut() { + if *id == tpl_id { + return Some(param); + } } } @@ -131,10 +141,11 @@ impl FileGenericScope { self.next_tpl_id.is_func() } - fn insert_param(&mut self, param: GenericParam) { + fn insert_param(&mut self, param: GenericParam) -> GenericTplId { let tpl_id = self.next_tpl_id; self.next_tpl_id = self.next_tpl_id.with_idx((tpl_id.get_idx() + 1) as u32); self.params.insert(param.name.to_string(), (tpl_id, param)); + tpl_id } fn contains(&self, position: TextSize) -> bool { @@ -175,18 +186,19 @@ impl ConditionalInferIndex { let tpl_id = GenericTplId::ConditionalInfer(self.next_infer_id); self.next_infer_id += 1; + let param = GenericParam::new(SmolStr::new(name), None, None, false, None); let tpl = Arc::new(GenericTpl::new( tpl_id, - SmolStr::new(name).into(), - None, - None, + param.name.clone(), + param.constraint.clone(), + param.default.clone(), + param.is_const, + param.attributes.clone(), )); let scope = &mut self.scopes[scope_idx]; scope.bindings.insert(name.to_string(), tpl.clone()); - scope - .params - .push(GenericParam::new(SmolStr::new(name), None, None, None)); + scope.params.push(param); Some(tpl) } diff --git a/crates/emmylua_code_analysis/src/compilation/analyzer/doc/infer_type.rs b/crates/emmylua_code_analysis/src/compilation/analyzer/doc/infer_type.rs index fe2ff27d1..51a30524f 100644 --- a/crates/emmylua_code_analysis/src/compilation/analyzer/doc/infer_type.rs +++ b/crates/emmylua_code_analysis/src/compilation/analyzer/doc/infer_type.rs @@ -1,7 +1,7 @@ use std::sync::Arc; use emmylua_parser::{ - LuaAst, LuaAstNode, LuaComment, LuaDocAttributeType, LuaDocBinaryType, LuaDocConditionalType, + LuaAst, LuaAstNode, LuaClosureExpr, LuaComment, LuaDocBinaryType, LuaDocConditionalType, LuaDocDescriptionOwner, LuaDocFuncType, LuaDocGenericDecl, LuaDocGenericDeclList, LuaDocGenericType, LuaDocIndexAccessType, LuaDocMappedType, LuaDocMultiLineUnionType, LuaDocObjectFieldKey, LuaDocObjectType, LuaDocStrTplType, LuaDocType, LuaDocUnaryType, @@ -13,7 +13,7 @@ use smol_str::SmolStr; use crate::{ AsyncState, DiagnosticCode, FileId, GenericParam, GenericTpl, InFiled, LuaAliasCallKind, - LuaArrayLen, LuaArrayType, LuaAttributeType, LuaMultiLineUnion, LuaTupleStatus, LuaTypeDeclId, + LuaArrayLen, LuaArrayType, LuaMultiLineUnion, LuaSignatureId, LuaTupleStatus, LuaTypeDeclId, TypeOps, VariadicType, complete_type_generic_args, db_index::{ AnalyzeError, DbIndex, LuaAliasCallType, LuaConditionalType, LuaFunctionType, @@ -109,6 +109,70 @@ impl<'a> DocTypeAnalyzeContext<'a> { .add_type_reference(self.file_id, type_id, range); } } + + // TODO: 为`std.ConstTpl`实现的兼容性代码, 应在下一版本中移除 + fn mark_generic_const(&mut self, tpl: &GenericTpl) -> GenericTpl { + let tpl_id = tpl.get_tpl_id(); + let param = self + .generic_index + .mark_generic_const(tpl_id) + .unwrap_or_else(|| { + let mut param = tpl.get_param().clone(); + param.is_const = true; + param + }); + + if tpl_id.is_func() + && let Some(signature_id) = self.current_signature_id() + && let Some(signature) = self.db.get_signature_index_mut().get_mut(&signature_id) + { + if let Some(signature_param) = signature.generic_params.get_mut(tpl_id.get_idx()) { + signature_param.is_const = true; + } + + for overload in &mut signature.overloads { + let mut generic_params = overload.get_generic_params().to_vec(); + let mut changed = false; + for generic_param in &mut generic_params { + if generic_param.get_tpl_id() == tpl_id && !generic_param.is_const() { + *generic_param = generic_param.with_const(true); + changed = true; + } + } + + if changed { + *overload = Arc::new(LuaFunctionType::new( + overload.get_async_state(), + overload.is_colon_define(), + overload.is_variadic(), + overload.get_params().to_vec(), + overload.get_ret().clone(), + Some(generic_params), + )); + } + } + } + + GenericTpl::new( + tpl_id, + param.name, + param.constraint, + param.default, + true, + param.attributes, + ) + } + + fn current_signature_id(&self) -> Option { + let owner = self.comment.as_ref()?.get_owner()?; + let closure = match owner { + LuaAst::LuaFuncStat(func) => func.get_closure(), + LuaAst::LuaLocalFuncStat(local_func) => local_func.get_closure(), + owner => owner.descendants::().next(), + }?; + + Some(LuaSignatureId::from_closure(self.file_id, &closure)) + } } pub fn infer_type(analyzer: &mut DocTypeAnalyzeContext<'_>, node: LuaDocType) -> LuaType { @@ -199,9 +263,6 @@ pub fn infer_type(analyzer: &mut DocTypeAnalyzeContext<'_>, node: LuaDocType) -> LuaDocType::MultiLineUnion(multi_union) => { return infer_multi_line_union_type(analyzer, multi_union); } - LuaDocType::Attribute(attribute_type) => { - return infer_attribute_type(analyzer, attribute_type); - } LuaDocType::Conditional(cond_type) => { return infer_conditional_type(analyzer, cond_type); } @@ -256,14 +317,14 @@ fn infer_buildin_or_ref_type( return LuaType::TplRef(tpl); } - if let Some((tpl_id, constraint, default_type)) = - analyzer.generic_index.find_generic(position, name) - { + if let Some((tpl_id, param)) = analyzer.generic_index.find_generic(position, name) { return LuaType::TplRef(Arc::new(GenericTpl::new( tpl_id, - SmolStr::new(name).into(), - constraint, - default_type, + param.name, + param.constraint, + param.default, + param.is_const, + param.attributes, ))); } @@ -484,7 +545,8 @@ fn infer_special_generic_type( let first_doc_param_type = generic_type.get_generic_types()?.get_types().next()?; let first_param = infer_type(analyzer, first_doc_param_type); if let LuaType::TplRef(tpl) = first_param { - return Some(LuaType::ConstTplRef(tpl)); + let const_tpl = analyzer.mark_generic_const(&tpl); + return Some(LuaType::TplRef(Arc::new(const_tpl))); } } "Language" => { @@ -628,9 +690,11 @@ fn infer_unary_type( } fn infer_func_type(analyzer: &mut DocTypeAnalyzeContext<'_>, func: &LuaDocFuncType) -> LuaType { - if let Some(generic_list) = func.get_generic_decl_list() { - register_inline_func_generics(analyzer, func, generic_list); - } + let generic_params = if let Some(generic_list) = func.get_generic_decl_list() { + register_inline_func_generics(analyzer, func, generic_list) + } else { + Vec::new() + }; let mut params_result = Vec::new(); let mut is_variadic = false; @@ -711,6 +775,7 @@ fn infer_func_type(analyzer: &mut DocTypeAnalyzeContext<'_>, func: &LuaDocFuncTy is_variadic, params_result, return_type, + Some(generic_params), ) .into(), ) @@ -720,10 +785,11 @@ fn register_inline_func_generics( analyzer: &mut DocTypeAnalyzeContext<'_>, func: &LuaDocFuncType, generic_list: LuaDocGenericDeclList, -) { +) -> Vec { let scope_id = analyzer .generic_index .add_generic_scope(vec![func.get_range()], true); + let mut generic_params = Vec::new(); for param in generic_list.get_generic_decl() { let Some(name_token) = param.get_name_token() else { continue; @@ -733,16 +799,28 @@ fn register_inline_func_generics( .get_constraint_type() .map(|ty| infer_type(analyzer, ty)); let default_type = param.get_default_type().map(|ty| infer_type(analyzer, ty)); - analyzer.generic_index.append_generic_param( - scope_id, - GenericParam::new( - SmolStr::new(name_token.get_name_text()), - constraint, - default_type, - None, - ), + let generic_param = GenericParam::new( + SmolStr::new(name_token.get_name_text()), + constraint, + default_type, + param.has_const_modifier(), + None, ); + if let Some(tpl_id) = analyzer + .generic_index + .append_generic_param(scope_id, generic_param.clone()) + { + generic_params.push(GenericTpl::new( + tpl_id, + generic_param.name, + generic_param.constraint, + generic_param.default, + generic_param.is_const, + generic_param.attributes, + )); + } } + generic_params } fn get_colon_define(analyzer: &mut DocTypeAnalyzeContext<'_>) -> Option { @@ -872,38 +950,6 @@ fn infer_multi_line_union_type( LuaType::MultiLineUnion(LuaMultiLineUnion::new(union_members).into()) } -fn infer_attribute_type( - analyzer: &mut DocTypeAnalyzeContext<'_>, - attribute_type: &LuaDocAttributeType, -) -> LuaType { - let mut params_result = Vec::new(); - for param in attribute_type.get_params() { - let name = if let Some(param) = param.get_name_token() { - param.get_name_text().to_string() - } else if param.is_dots() { - "...".to_string() - } else { - continue; - }; - - let nullable = param.is_nullable(); - - let type_ref = if let Some(type_ref) = param.get_type() { - let mut typ = infer_type(analyzer, type_ref); - if nullable && !typ.is_nullable() { - typ = TypeOps::Union.apply(analyzer.db, &typ, &LuaType::Nil); - } - Some(typ) - } else { - None - }; - - params_result.push((name, type_ref)); - } - - LuaType::DocAttribute(LuaAttributeType::new(params_result).into()) -} - fn infer_conditional_type( analyzer: &mut DocTypeAnalyzeContext<'_>, cond_type: &LuaDocConditionalType, @@ -963,7 +1009,13 @@ fn infer_mapped_type( let constraint = generic_decl .get_constraint_type() .map(|constraint| infer_type(analyzer, constraint)); - let param = GenericParam::new(SmolStr::new(name), constraint, None, None); + let param = GenericParam::new( + SmolStr::new(name), + constraint, + None, + generic_decl.has_const_modifier(), + None, + ); let scope_id = analyzer .generic_index @@ -972,7 +1024,7 @@ fn infer_mapped_type( .generic_index .append_generic_param(scope_id, param.clone()); let position = mapped_type.get_range().start(); - let (id, _, _) = analyzer.generic_index.find_generic(position, name)?; + let (id, _) = analyzer.generic_index.find_generic(position, name)?; let doc_type = mapped_type.get_value_type()?; let value_type = infer_type(analyzer, doc_type); diff --git a/crates/emmylua_code_analysis/src/compilation/analyzer/doc/property_tags.rs b/crates/emmylua_code_analysis/src/compilation/analyzer/doc/property_tags.rs index 5e4e53a7a..8dfea9613 100644 --- a/crates/emmylua_code_analysis/src/compilation/analyzer/doc/property_tags.rs +++ b/crates/emmylua_code_analysis/src/compilation/analyzer/doc/property_tags.rs @@ -5,10 +5,10 @@ use crate::{ use super::{ DocAnalyzer, - tags::{find_owner_closure_or_report, get_owner_id_or_report}, + tags::{find_owner_closure_or_report, get_owner_id, get_owner_id_or_report}, }; use emmylua_parser::{ - LuaAst, LuaAstNode, LuaDocDescriptionOwner, LuaDocTagAsync, LuaDocTagDeprecated, + LuaAst, LuaAstNode, LuaDocDescriptionOwner, LuaDocTag, LuaDocTagAsync, LuaDocTagDeprecated, LuaDocTagNodiscard, LuaDocTagReadonly, LuaDocTagSource, LuaDocTagVersion, LuaDocTagVisibility, LuaExpr, }; @@ -105,15 +105,81 @@ pub fn analyze_deprecated(analyzer: &mut DocAnalyzer, tag: LuaDocTagDeprecated) let message = tag .get_description() .map(|desc| desc.get_description_text().to_string()); + + let mut type_owner_id = None; + if let Some(current_type_id) = &analyzer.current_type_id { + type_owner_id = Some(LuaSemanticDeclId::TypeDecl(current_type_id.clone())); + } else { + let file_id = analyzer.file_id; + let workspace_id = analyzer.workspace_id; + let tags = analyzer.comment.get_doc_tags(); + for tag in tags { + match tag { + LuaDocTag::Class(class) => { + if let Some(name_token) = class.get_name_token() { + let name = name_token.get_name_text().to_string(); + if let Some(decl) = analyzer.get_db().get_type_index().find_type_decl( + file_id, + &name, + Some(workspace_id), + ) { + if decl.is_class() { + type_owner_id = Some(LuaSemanticDeclId::TypeDecl(decl.get_id())); + break; + } + } + } + } + LuaDocTag::Alias(alias) => { + if let Some(name_token) = alias.get_name_token() { + let name = name_token.get_name_text().to_string(); + if let Some(decl) = analyzer.get_db().get_type_index().find_type_decl( + file_id, + &name, + Some(workspace_id), + ) { + if decl.is_alias() { + type_owner_id = Some(LuaSemanticDeclId::TypeDecl(decl.get_id())); + break; + } + } + } + } + _ => {} + } + } + } + + if let Some(type_owner_id) = type_owner_id { + add_deprecated(analyzer, type_owner_id, message.clone())?; + let mut compat_owner_id = None; + if let Some(owner) = get_owner_id(analyzer, None, true) { + if let owner @ (LuaSemanticDeclId::LuaDecl(_) | LuaSemanticDeclId::Member(_)) = owner { + compat_owner_id = Some(owner); + } + } + if let Some(compat_owner_id) = compat_owner_id { + add_deprecated(analyzer, compat_owner_id, message)?; + } + return Some(()); + } + let owner_id = get_owner_id_or_report(analyzer, &tag)?; + add_deprecated(analyzer, owner_id, message)?; + + Some(()) +} +fn add_deprecated( + analyzer: &mut DocAnalyzer, + owner_id: LuaSemanticDeclId, + message: Option, +) -> Option<()> { analyzer .type_context .db .get_property_index_mut() - .add_deprecated(analyzer.file_id, owner_id, message); - - Some(()) + .add_deprecated(analyzer.file_id, owner_id, message) } pub fn analyze_version(analyzer: &mut DocAnalyzer, version: LuaDocTagVersion) -> Option<()> { diff --git a/crates/emmylua_code_analysis/src/compilation/analyzer/doc/tags.rs b/crates/emmylua_code_analysis/src/compilation/analyzer/doc/tags.rs index 683dc1427..be8fad835 100644 --- a/crates/emmylua_code_analysis/src/compilation/analyzer/doc/tags.rs +++ b/crates/emmylua_code_analysis/src/compilation/analyzer/doc/tags.rs @@ -6,7 +6,7 @@ use crate::{ AnalyzeError, DiagnosticCode, LuaDeclId, compilation::analyzer::doc::{ attribute_tags::analyze_tag_attribute_use, property_tags::analyze_readonly, - type_def_tags::analyze_attribute, type_ref_tags::analyze_doc_tag_schema, + type_ref_tags::analyze_doc_tag_schema, }, db_index::{LuaMemberId, LuaSemanticDeclId, LuaSignatureId}, }; @@ -41,9 +41,6 @@ pub fn analyze_tag(analyzer: &mut DocAnalyzer, tag: LuaDocTag) -> Option<()> { LuaDocTag::Alias(alias) => { analyze_alias(analyzer, alias)?; } - LuaDocTag::Attribute(attribute) => { - analyze_attribute(analyzer, attribute)?; - } // ref LuaDocTag::Type(type_tag) => { diff --git a/crates/emmylua_code_analysis/src/compilation/analyzer/doc/type_def_tags.rs b/crates/emmylua_code_analysis/src/compilation/analyzer/doc/type_def_tags.rs index d63ed8290..ed2fe123a 100644 --- a/crates/emmylua_code_analysis/src/compilation/analyzer/doc/type_def_tags.rs +++ b/crates/emmylua_code_analysis/src/compilation/analyzer/doc/type_def_tags.rs @@ -1,8 +1,8 @@ use emmylua_parser::{ LuaAssignStat, LuaAst, LuaAstNode, LuaAstToken, LuaCommentOwner, LuaDocDescription, - LuaDocDescriptionOwner, LuaDocTag, LuaDocTagAlias, LuaDocTagAttribute, LuaDocTagClass, - LuaDocTagEnum, LuaDocTagGeneric, LuaFuncStat, LuaLocalName, LuaLocalStat, LuaNameExpr, - LuaSyntaxId, LuaSyntaxKind, LuaTokenKind, LuaVarExpr, + LuaDocDescriptionOwner, LuaDocTag, LuaDocTagAlias, LuaDocTagClass, LuaDocTagEnum, + LuaDocTagGeneric, LuaFuncStat, LuaLocalName, LuaLocalStat, LuaNameExpr, LuaSyntaxId, + LuaSyntaxKind, LuaTokenKind, LuaVarExpr, }; use rowan::TextRange; use smol_str::SmolStr; @@ -10,15 +10,13 @@ use smol_str::SmolStr; use super::{ DocAnalyzer, infer_type::infer_type, preprocess_description, tags::find_owner_closure, }; -use crate::GenericParam; use crate::compilation::analyzer::doc::tags::report_orphan_tag; use crate::{ DbIndex, LuaTypeCache, LuaTypeDeclId, compilation::analyzer::common::bind_type, - db_index::{ - LuaDeclId, LuaGenericParamInfo, LuaMemberId, LuaSemanticDeclId, LuaSignatureId, LuaType, - }, + db_index::{LuaDeclId, LuaMemberId, LuaSemanticDeclId, LuaSignatureId, LuaType}, }; +use crate::{GenericParam, LuaFunctionType}; use std::{collections::HashSet, sync::Arc, vec}; pub fn analyze_class(analyzer: &mut DocAnalyzer, tag: LuaDocTagClass) -> Option<()> { @@ -206,34 +204,6 @@ fn alias_chain_ref(typ: &LuaType) -> Option { } } -/// 分析属性定义 -pub fn analyze_attribute(analyzer: &mut DocAnalyzer, tag: LuaDocTagAttribute) -> Option<()> { - let file_id = analyzer.file_id; - let workspace_id = analyzer.workspace_id; - let name = tag.get_name_token()?.get_name_text().to_string(); - - let decl_id = { - let decl = analyzer.get_db().get_type_index().find_type_decl( - file_id, - &name, - Some(workspace_id), - )?; - if !decl.is_attribute() { - return None; - } - decl.get_id() - }; - let attribute_type = infer_type(&mut analyzer.type_context, tag.get_type()?); - let attribute_decl = analyzer - .get_db() - .get_type_index_mut() - .get_type_decl_mut(&decl_id)?; - attribute_decl.add_attribute_type(attribute_type); - - add_description_for_type_decl(analyzer, &decl_id, tag.get_descriptions()); - Some(()) -} - fn get_type_generic_params( analyzer: &mut DocAnalyzer, type_decl_id: &LuaTypeDeclId, @@ -384,8 +354,7 @@ pub fn analyze_func_generic(analyzer: &mut DocAnalyzer, tag: LuaDocTagGeneric) - let Some(name_token) = param.get_name_token() else { continue; }; - let name_text = name_token.get_name_text().to_string(); - let smol_name = SmolStr::new(name_text.as_str()); + let smol_name = SmolStr::new(name_token.get_name_text()); let type_ref = param .get_constraint_type() @@ -394,22 +363,18 @@ pub fn analyze_func_generic(analyzer: &mut DocAnalyzer, tag: LuaDocTagGeneric) - .get_default_type() .map(|type_ref| infer_type(&mut analyzer.type_context, type_ref)); - analyzer.type_context.generic_index.append_generic_param( - scope_id, - GenericParam::new( - smol_name.clone(), - type_ref.clone(), - default_type.clone(), - None, - ), - ); - - param_info.push(Arc::new(LuaGenericParamInfo::new( - name_text, + let generic_param = GenericParam::new( + smol_name, type_ref, default_type, + param.has_const_modifier(), None, - ))); + ); + analyzer + .type_context + .generic_index + .append_generic_param(scope_id, generic_param.clone()); + param_info.push(generic_param); } } @@ -420,6 +385,26 @@ pub fn analyze_func_generic(analyzer: &mut DocAnalyzer, tag: LuaDocTagGeneric) - .get_signature_index_mut() .get_or_create(signature_id); signature.generic_params = param_info; + let signature_generic_params = signature.get_function_generic_params(); + for overload in &mut signature.overloads { + let mut generic_params = signature_generic_params.clone(); + for generic_param in overload.get_generic_params() { + if !generic_params + .iter() + .any(|tpl| tpl.get_tpl_id() == generic_param.get_tpl_id()) + { + generic_params.push(generic_param.clone()); + } + } + *overload = Arc::new(LuaFunctionType::new( + overload.get_async_state(), + overload.is_colon_define(), + overload.is_variadic(), + overload.get_params().to_vec(), + overload.get_ret().clone(), + Some(generic_params), + )); + } Some(()) } diff --git a/crates/emmylua_code_analysis/src/compilation/analyzer/doc/type_generic_header.rs b/crates/emmylua_code_analysis/src/compilation/analyzer/doc/type_generic_header.rs index ed8601db4..8d9048e5c 100644 --- a/crates/emmylua_code_analysis/src/compilation/analyzer/doc/type_generic_header.rs +++ b/crates/emmylua_code_analysis/src/compilation/analyzer/doc/type_generic_header.rs @@ -59,13 +59,14 @@ fn normalize_generic_params(db: &DbIndex, params: &[GenericParam]) -> Vec Option { + let scope = self.scopes.get_mut(scope_id.id)?; let tpl_id = scope.next_tpl_id; scope.next_tpl_id = scope.next_tpl_id.with_idx((tpl_id.get_idx() + 1) as u32); scope.params.push((tpl_id, param)); + Some(tpl_id) } - fn find_generic( - &self, - position: TextSize, - name: &str, - ) -> Option<(GenericTplId, Option, Option)> { + fn find_generic(&self, position: TextSize, name: &str) -> Option<(GenericTplId, GenericParam)> { for scope in self.scopes.iter().rev() { if !scope.contains(position) { continue; @@ -181,11 +187,17 @@ impl GenericIndex for HeaderGenericIndex { .rev() .find(|(_, param)| param.name == name) { - return Some(( - *tpl_id, - param.type_constraint.clone(), - param.default_type.clone(), - )); + return Some((*tpl_id, param.clone())); + } + } + + None + } + + fn generic_param_mut(&mut self, tpl_id: GenericTplId) -> Option<&mut GenericParam> { + for scope in self.scopes.iter_mut().rev() { + if let Some((_, param)) = scope.params.iter_mut().find(|(id, _)| *id == tpl_id) { + return Some(param); } } diff --git a/crates/emmylua_code_analysis/src/compilation/analyzer/doc/type_ref_tags.rs b/crates/emmylua_code_analysis/src/compilation/analyzer/doc/type_ref_tags.rs index de0cc3320..3a1c8a821 100644 --- a/crates/emmylua_code_analysis/src/compilation/analyzer/doc/type_ref_tags.rs +++ b/crates/emmylua_code_analysis/src/compilation/analyzer/doc/type_ref_tags.rs @@ -4,6 +4,7 @@ use emmylua_parser::{ LuaDocTagReturnCast, LuaDocTagReturnOverload, LuaDocTagSchema, LuaDocTagSee, LuaDocTagType, LuaExpr, LuaLocalName, LuaTokenKind, LuaVarExpr, }; +use std::sync::Arc; use super::{ DocAnalyzer, @@ -12,8 +13,8 @@ use super::{ tags::{find_owner_closure, get_owner_id_or_report}, }; use crate::{ - InFiled, JsonSchemaFile, LuaOperatorMetaMethod, LuaTypeCache, LuaTypeOwner, OperatorFunction, - SignatureReturnStatus, TypeOps, + InFiled, JsonSchemaFile, LuaFunctionType, LuaOperatorMetaMethod, LuaTypeCache, LuaTypeOwner, + OperatorFunction, SignatureReturnStatus, TypeOps, compilation::analyzer::common::bind_type, db_index::{ LuaDeclId, LuaDocParamInfo, LuaDocReturnInfo, LuaDocReturnOverloadInfo, LuaMemberId, @@ -375,6 +376,23 @@ pub fn analyze_overload(analyzer: &mut DocAnalyzer, tag: LuaDocTagOverload) -> O .db .get_signature_index_mut() .get_or_create(id); + let mut generic_params = signature.get_function_generic_params(); + for generic_param in func.get_generic_params() { + if !generic_params + .iter() + .any(|tpl| tpl.get_tpl_id() == generic_param.get_tpl_id()) + { + generic_params.push(generic_param.clone()); + } + } + let func = Arc::new(LuaFunctionType::new( + func.get_async_state(), + func.is_colon_define(), + func.is_variadic(), + func.get_params().to_vec(), + func.get_ret().clone(), + Some(generic_params), + )); signature.overloads.push(func); } } diff --git a/crates/emmylua_code_analysis/src/compilation/analyzer/lua/for_range_stat.rs b/crates/emmylua_code_analysis/src/compilation/analyzer/lua/for_range_stat.rs index d6605830b..8ece392f4 100644 --- a/crates/emmylua_code_analysis/src/compilation/analyzer/lua/for_range_stat.rs +++ b/crates/emmylua_code_analysis/src/compilation/analyzer/lua/for_range_stat.rs @@ -3,7 +3,7 @@ use emmylua_parser::{LuaAstToken, LuaExpr, LuaForRangeStat}; use crate::{ DbIndex, InferFailReason, LuaDeclId, LuaInferCache, LuaOperatorMetaMethod, LuaType, LuaTypeCache, TplContext, TypeOps, TypeSubstitutor, VariadicType, - compilation::analyzer::unresolve::UnResolveIterVar, infer_expr, instantiate_doc_function, + compilation::analyzer::unresolve::UnResolveIterVar, infer_expr, instantiate_type_generic, tpl_pattern_match_args, }; @@ -145,6 +145,12 @@ pub fn infer_for_range_iter_expr_func( return Ok(doc_function.get_variadic_ret()); }; let mut substitutor = TypeSubstitutor::new(); + let generic_tpls = doc_function + .get_generic_params() + .iter() + .map(|generic_tpl| generic_tpl.get_tpl_id()) + .collect(); + substitutor.add_need_infer_tpls(generic_tpls); let mut context = TplContext { db, cache, @@ -159,8 +165,9 @@ pub fn infer_for_range_iter_expr_func( tpl_pattern_match_args(&mut context, ¶ms, &[status_param])?; + let doc_function_type = LuaType::DocFunction(doc_function.clone()); let instantiate_func = if let LuaType::DocFunction(f) = - instantiate_doc_function(db, &doc_function, &substitutor) + instantiate_type_generic(db, &doc_function_type, &substitutor) { f } else { diff --git a/crates/emmylua_code_analysis/src/compilation/analyzer/unresolve/resolve_closure.rs b/crates/emmylua_code_analysis/src/compilation/analyzer/unresolve/resolve_closure.rs index 120588356..c37561eb1 100644 --- a/crates/emmylua_code_analysis/src/compilation/analyzer/unresolve/resolve_closure.rs +++ b/crates/emmylua_code_analysis/src/compilation/analyzer/unresolve/resolve_closure.rs @@ -409,6 +409,7 @@ fn resolve_closure_member_type( signature.is_vararg, final_params, final_ret, + Some(signature.get_function_generic_params()), ), self_type, ) diff --git a/crates/emmylua_code_analysis/src/compilation/test/attribute_test.rs b/crates/emmylua_code_analysis/src/compilation/test/attribute_test.rs index a0828719d..da8503db9 100644 --- a/crates/emmylua_code_analysis/src/compilation/test/attribute_test.rs +++ b/crates/emmylua_code_analysis/src/compilation/test/attribute_test.rs @@ -16,7 +16,9 @@ mod test { ( "meta.lua", r#" - ---@attribute constructor(name: string, root_class: string?, strip_self: boolean?, return_mode: "self"|"doc"|"default"?) + ---@class Attribute + ---@class constructor: Attribute + ---@overload fun(name: string, root_class?: string, strip_self?: boolean, return_mode?: "self"|"doc"|"default") ---@generic T ---@[constructor("__init")] @@ -45,6 +47,24 @@ mod test { ); } + #[test] + fn test_attribute_overload_uses_arg_type_for_diagnostic() { + let mut ws = VirtualWorkspace::new(); + + assert!(ws.has_no_diagnostic( + DiagnosticCode::AttributeParamTypeMismatch, + r#" + ---@class Attribute + ---@class custom_attribute: Attribute + ---@overload fun(value: string) + ---@overload fun(value: integer) + + ---@[custom_attribute(1)] + local value + "#, + )); + } + #[test] fn test_delayed_definition() { let mut ws = VirtualWorkspace::new_with_init_std_lib(); @@ -96,7 +116,9 @@ mod test { ( "3_meta.lua", r#" - ---@attribute constructor(name: string, root_class: string?, strip_self: boolean?, return_mode: "self"|"doc"|"default"?) + ---@class Attribute + ---@class constructor: Attribute + ---@overload fun(name: string, root_class?: string, strip_self?: boolean, return_mode?: "self"|"doc"|"default") ---@class class ---@field is_class true @@ -121,7 +143,9 @@ mod test { ws.def_file( "init.lua", r#" - ---@attribute constructor(name: string, root_class: string?, strip_self: boolean?, return_mode: "self"|"doc"|"default"?) + ---@class Attribute + ---@class constructor: Attribute + ---@overload fun(name: string, root_class?: string, strip_self?: boolean, return_mode?: "self"|"doc"|"default") ---@generic T ---@[constructor("init")] @@ -154,7 +178,9 @@ mod test { ws.def_file( "init.lua", r#" - ---@attribute constructor(name: string, root_class: string?, strip_self: boolean?, return_mode: "self"|"doc"|"default"?) + ---@class Attribute + ---@class constructor: Attribute + ---@overload fun(name: string, root_class?: string, strip_self?: boolean, return_mode?: "self"|"doc"|"default") ---@generic T ---@[constructor("init")] @@ -189,7 +215,9 @@ mod test { ws.def_file( "init.lua", r#" - ---@attribute constructor(name: string, root_class: string?, strip_self: boolean?, return_mode: "self"|"doc"|"default"?) + ---@class Attribute + ---@class constructor: Attribute + ---@overload fun(name: string, root_class?: string, strip_self?: boolean, return_mode?: "self"|"doc"|"default") ---@generic T ---@[constructor("__init")] diff --git a/crates/emmylua_code_analysis/src/compilation/test/generic_test.rs b/crates/emmylua_code_analysis/src/compilation/test/generic_test.rs index 3a1b462b9..b85997659 100644 --- a/crates/emmylua_code_analysis/src/compilation/test/generic_test.rs +++ b/crates/emmylua_code_analysis/src/compilation/test/generic_test.rs @@ -764,10 +764,7 @@ mod test { .expect("Box generic params"); assert_eq!(box_params.len(), 1); assert_eq!(box_params[0].name.as_str(), "T"); - let box_default = box_params[0] - .default_type - .clone() - .expect("Box default type"); + let box_default = box_params[0].default.clone().expect("Box default type"); assert_eq!(ws.humanize_type(box_default), "string"); let optional_params = ws @@ -780,7 +777,7 @@ mod test { assert_eq!(optional_params.len(), 1); assert_eq!(optional_params[0].name.as_str(), "T"); let optional_default = optional_params[0] - .default_type + .default .clone() .expect("Optional default type"); assert_eq!(ws.humanize_type(optional_default), "number"); @@ -810,12 +807,105 @@ mod test { assert_eq!(signature.generic_params.len(), 1); assert_eq!(signature.generic_params[0].name, "T"); let default_type = signature.generic_params[0] - .default_type + .default .clone() .expect("signature default type"); assert_eq!(ws.humanize_type(default_type), "string"); } + #[test] + fn test_generic_const_metadata_storage() { + let mut ws = VirtualWorkspace::new(); + let file_id = ws.def( + r#" + ---@class Box + + ---@generic const R, S + ---@return R + local function id() + end + + ---@alias Mapper fun(value: A): B + "#, + ); + + let db = ws.analysis.compilation.get_db(); + let box_params = db + .get_type_index() + .get_generic_params(&LuaTypeDeclId::global("Box")) + .expect("Box generic params"); + assert_eq!(box_params.len(), 2); + assert_eq!(box_params[0].name.as_str(), "T"); + assert!(box_params[0].is_const); + assert_eq!(box_params[1].name.as_str(), "U"); + assert!(!box_params[1].is_const); + + let closure = ws.get_node::(file_id); + let signature_id = LuaSignatureId::from_closure(file_id, &closure); + let signature = db + .get_signature_index() + .get(&signature_id) + .expect("signature"); + assert_eq!(signature.generic_params.len(), 2); + assert_eq!(signature.generic_params[0].name.as_str(), "R"); + assert!(signature.generic_params[0].is_const); + assert_eq!(signature.generic_params[1].name.as_str(), "S"); + assert!(!signature.generic_params[1].is_const); + + let function_generic_params = signature.get_function_generic_params(); + assert!(function_generic_params[0].is_const()); + assert!(!function_generic_params[1].is_const()); + + let mapper_decl = db + .get_type_index() + .get_type_decl(&LuaTypeDeclId::global("Mapper")) + .expect("Mapper alias"); + let mapper_origin = mapper_decl.get_alias_ref().expect("Mapper alias origin"); + let LuaType::DocFunction(mapper_func) = mapper_origin else { + panic!("expected Mapper alias to be a function type"); + }; + let mapper_generic_params = mapper_func.get_generic_params(); + assert_eq!(mapper_generic_params.len(), 2); + assert_eq!(mapper_generic_params[0].get_name(), "A"); + assert!(mapper_generic_params[0].is_const()); + assert_eq!(mapper_generic_params[1].get_name(), "B"); + assert!(!mapper_generic_params[1].is_const()); + } + + #[test] + fn test_legacy_const_tpl_marks_generic_param_metadata() { + let mut ws = VirtualWorkspace::new(); + let file_id = ws.def( + r#" + ---@alias std.ConstTpl unknown + + ---@generic T + ---@param value std.ConstTpl + ---@return T + function id(value) + end + + result = id(1) + "#, + ); + + let closure = ws.get_node::(file_id); + let signature_id = LuaSignatureId::from_closure(file_id, &closure); + { + let signature = ws + .analysis + .compilation + .get_db() + .get_signature_index() + .get(&signature_id) + .expect("signature"); + assert_eq!(signature.generic_params.len(), 1); + assert!(signature.generic_params[0].is_const); + } + + assert_eq!(ws.expr_ty("result"), LuaType::IntegerConst(1)); + } + #[test] fn test_bare_generic_type_uses_default() { let mut ws = VirtualWorkspace::new(); @@ -954,7 +1044,7 @@ mod test { .get_type_index() .get_generic_params(&LuaTypeDeclId::global("B")) .expect("B generic params"); - let default_type = b_params[0].default_type.clone().expect("B default type"); + let default_type = b_params[0].default.clone().expect("B default type"); assert_eq!(ws.humanize_type(default_type), "A"); } @@ -982,7 +1072,7 @@ mod test { .get_type_index() .get_generic_params(&LuaTypeDeclId::global("B")) .expect("B generic params"); - let default_type = b_params[0].default_type.clone().expect("B default type"); + let default_type = b_params[0].default.clone().expect("B default type"); assert_eq!(ws.humanize_type(default_type), "A"); } @@ -1223,8 +1313,6 @@ mod test { r#" ---@alias std.RawGet unknown - ---@alias std.ConstTpl unknown - ---@generic T, K extends keyof T ---@param object T ---@param key K diff --git a/crates/emmylua_code_analysis/src/compilation/test/member_infer_test.rs b/crates/emmylua_code_analysis/src/compilation/test/member_infer_test.rs index 00307054f..b8d53f0c5 100644 --- a/crates/emmylua_code_analysis/src/compilation/test/member_infer_test.rs +++ b/crates/emmylua_code_analysis/src/compilation/test/member_infer_test.rs @@ -1,8 +1,11 @@ #[cfg(test)] mod test { + use emmylua_parser::{LuaAstNode, LuaTableField}; use smol_str::SmolStr; - use crate::{LuaType, LuaUnionType, VirtualWorkspace}; + use crate::{ + LuaMemberId, LuaMemberKey, LuaSemanticDeclId, LuaType, LuaUnionType, VirtualWorkspace, + }; #[test] fn test_issue_318() { @@ -582,4 +585,64 @@ mod test { assert_eq!(ws.expr_ty("result"), ws.ty("integer?")); } + + #[test] + fn test_member_origin_owner_switches_cache_for_cross_file_table_field() { + let mut ws = VirtualWorkspace::new(); + let defs_file = ws.def_file( + "defs.lua", + r#" + ---@class CrossFileOwner + ---@field fn fun(): string + + local function fallback() + end + + ---@type CrossFileOwner + local value = { + fn = fallback, + } + "#, + ); + let main_file = ws.def_file("main.lua", "local main = 1"); + + let root = ws + .analysis + .compilation + .get_db() + .get_vfs() + .get_syntax_tree(&defs_file) + .expect("defs tree must exist") + .get_chunk_node(); + let table_field = root + .descendants::() + .next() + .expect("table field must exist"); + let member_id = LuaMemberId::new(table_field.get_syntax_id(), defs_file); + + let semantic_model = ws + .analysis + .compilation + .get_semantic_model(main_file) + .expect("main model must exist"); + let origin = semantic_model + .get_member_origin_owner(member_id) + .expect("origin owner must resolve"); + let LuaSemanticDeclId::Member(origin_member_id) = origin else { + panic!("expected member origin, got {origin:?}"); + }; + let origin_member = ws + .analysis + .compilation + .get_db() + .get_member_index() + .get_member(&origin_member_id) + .expect("origin member must exist"); + + assert_eq!( + origin_member.get_key(), + &LuaMemberKey::Name(SmolStr::new("fn")) + ); + assert!(origin_member.is_field()); + } } diff --git a/crates/emmylua_code_analysis/src/db_index/operators/lua_operator.rs b/crates/emmylua_code_analysis/src/db_index/operators/lua_operator.rs index de64a4739..26538fbe9 100644 --- a/crates/emmylua_code_analysis/src/db_index/operators/lua_operator.rs +++ b/crates/emmylua_code_analysis/src/db_index/operators/lua_operator.rs @@ -139,6 +139,7 @@ impl LuaOperator { ("arg0".to_string(), Some(param.clone())), ], ret.clone(), + None, ) .into(), OperatorFunction::UnOp { ret } => LuaFunctionType::new( @@ -147,6 +148,7 @@ impl LuaOperator { false, vec![("self".to_string(), Some(LuaType::SelfInfer))], ret.clone(), + None, ) .into(), OperatorFunction::Call { params, ret } => { @@ -165,8 +167,15 @@ impl LuaOperator { }) .collect(); - LuaFunctionType::new(AsyncState::None, false, is_variadic, params, ret.clone()) - .into() + LuaFunctionType::new( + AsyncState::None, + false, + is_variadic, + params, + ret.clone(), + None, + ) + .into() } OperatorFunction::Overload(func) => { LuaType::DocFunction(func.to_call_operator_func_type()) @@ -183,6 +192,7 @@ impl LuaOperator { signature.is_vararg, signature.get_type_params(), get_constructor_return_type(signature, return_mode), + Some(signature.get_function_generic_params()), ) .into(), None => LuaType::Signature(*id), diff --git a/crates/emmylua_code_analysis/src/db_index/property/builtin_attribute.rs b/crates/emmylua_code_analysis/src/db_index/property/builtin_attribute.rs index 72abee1b3..d45a26bff 100644 --- a/crates/emmylua_code_analysis/src/db_index/property/builtin_attribute.rs +++ b/crates/emmylua_code_analysis/src/db_index/property/builtin_attribute.rs @@ -1,4 +1,86 @@ -use crate::{LuaType, LuaTypeDeclId}; +use std::sync::Arc; + +use crate::{ + DbIndex, LuaFunctionType, LuaOperatorMetaMethod, LuaType, LuaTypeDeclId, callable_accepts_args, + is_sub_type_of, +}; + +const ATTRIBUTE_BASE_TYPE_NAME: &str = "Attribute"; + +pub fn is_attribute_class(db: &DbIndex, type_id: &LuaTypeDeclId) -> bool { + let Some(type_decl) = db.get_type_index().get_type_decl(type_id) else { + return false; + }; + if !type_decl.is_class() { + return false; + } + + let attribute_type_id = LuaTypeDeclId::global(ATTRIBUTE_BASE_TYPE_NAME); + is_sub_type_of(db, type_id, &attribute_type_id) +} + +pub fn get_attribute_constructor_params( + db: &DbIndex, + type_id: &LuaTypeDeclId, + arg_types: &[LuaType], +) -> Vec<(String, Option)> { + select_attribute_constructor_func(db, type_id, arg_types) + .map(|func| func.get_params().to_vec()) + .unwrap_or_default() +} + +fn select_attribute_constructor_func( + db: &DbIndex, + type_id: &LuaTypeDeclId, + arg_types: &[LuaType], +) -> Option> { + let arg_count = arg_types.len(); + let operator_ids = db + .get_operator_index() + .get_operators(&type_id.clone().into(), LuaOperatorMetaMethod::Call)?; + + let mut fallback = None; + let mut count_fallback = None; + let only_candidate = operator_ids.len() == 1; + for operator_id in operator_ids { + let Some(operator) = db.get_operator_index().get_operator(operator_id) else { + continue; + }; + let LuaType::DocFunction(func) = operator.get_operator_func(db) else { + continue; + }; + + let params = func.get_params(); + fallback.get_or_insert_with(|| Arc::clone(&func)); + if !attribute_params_accept_arg_count(¶ms, arg_count) { + continue; + } + + count_fallback.get_or_insert_with(|| Arc::clone(&func)); + if only_candidate || callable_accepts_args(db, &func, arg_types, false, Some(arg_count)) { + return Some(func); + } + } + + count_fallback.or(fallback) +} + +fn attribute_params_accept_arg_count( + def_params: &[(String, Option)], + arg_count: usize, +) -> bool { + let required_count = def_params + .iter() + .take_while(|(name, typ)| name != "..." && !typ.as_ref().is_some_and(LuaType::is_variadic)) + .filter(|(_, typ)| !typ.as_ref().is_some_and(LuaType::is_optional)) + .count(); + + let allows_more = def_params + .last() + .is_some_and(|(name, typ)| name == "..." || typ.as_ref().is_some_and(LuaType::is_variadic)); + + arg_count >= required_count && (allows_more || arg_count <= def_params.len()) +} #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub enum LuaBuiltinAttributeKind { diff --git a/crates/emmylua_code_analysis/src/db_index/property/mod.rs b/crates/emmylua_code_analysis/src/db_index/property/mod.rs index 67fc74d9a..ea61545c0 100644 --- a/crates/emmylua_code_analysis/src/db_index/property/mod.rs +++ b/crates/emmylua_code_analysis/src/db_index/property/mod.rs @@ -10,7 +10,7 @@ pub use builtin_attribute::{ LuaAttributeCollectionExt, LuaAttributeUse, LuaBuiltinAttributeKind, LuaConstructorAttribute, LuaConstructorReturnMode, LuaDeprecatedAttribute, LuaFieldAccessorAttribute, LuaFieldAccessorConvention, LuaIndexAliasAttribute, LuaLspOptimizationAttribute, - LuaLspOptimizationCode, + LuaLspOptimizationCode, get_attribute_constructor_params, is_attribute_class, }; pub use decl_feature::{DeclFeatureFlag, PropertyDeclFeature}; use emmylua_parser::{LuaAstNode, LuaDocTagField, LuaDocType, LuaVersionCondition, VisibilityKind}; diff --git a/crates/emmylua_code_analysis/src/db_index/signature/mod.rs b/crates/emmylua_code_analysis/src/db_index/signature/mod.rs index 46983b02c..25e03f603 100644 --- a/crates/emmylua_code_analysis/src/db_index/signature/mod.rs +++ b/crates/emmylua_code_analysis/src/db_index/signature/mod.rs @@ -7,8 +7,8 @@ use hashbrown::{HashMap, HashSet}; pub use async_state::AsyncState; pub use signature::{ - LuaDocParamInfo, LuaDocReturnInfo, LuaDocReturnOverloadInfo, LuaGenericParamInfo, LuaNoDiscard, - LuaSignature, LuaSignatureId, SignatureReturnStatus, + LuaDocParamInfo, LuaDocReturnInfo, LuaDocReturnOverloadInfo, LuaNoDiscard, LuaSignature, + LuaSignatureId, SignatureReturnStatus, }; use crate::FileId; diff --git a/crates/emmylua_code_analysis/src/db_index/signature/signature.rs b/crates/emmylua_code_analysis/src/db_index/signature/signature.rs index d2a6b818d..0271c3213 100644 --- a/crates/emmylua_code_analysis/src/db_index/signature/signature.rs +++ b/crates/emmylua_code_analysis/src/db_index/signature/signature.rs @@ -9,7 +9,7 @@ use rowan::TextSize; use super::return_rows; use crate::db_index::signature::async_state::AsyncState; use crate::{ - FileId, + FileId, GenericParam, GenericTpl, GenericTplId, db_index::{LuaFunctionType, LuaType}, }; use crate::{ @@ -19,7 +19,7 @@ use crate::{ #[derive(Debug)] pub struct LuaSignature { - pub generic_params: Vec>, + pub generic_params: Vec, pub overloads: Vec>, pub param_docs: HashMap, pub params: Vec, @@ -172,6 +172,7 @@ impl LuaSignature { is_vararg, params, return_type, + Some(self.get_function_generic_params()), ); Arc::new(func_type) } @@ -183,10 +184,33 @@ impl LuaSignature { } let return_type = self.get_return_type(); - let func_type = - LuaFunctionType::new(self.async_state, false, self.is_vararg, params, return_type); + let func_type = LuaFunctionType::new( + self.async_state, + false, + self.is_vararg, + params, + return_type, + Some(self.get_function_generic_params()), + ); Arc::new(func_type) } + + pub fn get_function_generic_params(&self) -> Vec { + self.generic_params + .iter() + .enumerate() + .map(|(idx, param)| { + GenericTpl::new( + GenericTplId::Func(idx as u32), + param.name.clone(), + param.constraint.clone(), + param.default.clone(), + param.is_const, + param.attributes.clone(), + ) + }) + .collect() + } } #[derive(Debug)] @@ -306,27 +330,3 @@ pub enum SignatureReturnStatus { DocResolve, InferResolve, } - -#[derive(Debug, Clone)] -pub struct LuaGenericParamInfo { - pub name: String, - pub constraint: Option, - pub default_type: Option, - pub attributes: Option>, -} - -impl LuaGenericParamInfo { - pub fn new( - name: String, - constraint: Option, - default_type: Option, - attributes: Option>, - ) -> Self { - Self { - name, - constraint, - default_type, - attributes, - } - } -} diff --git a/crates/emmylua_code_analysis/src/db_index/type/generic_param.rs b/crates/emmylua_code_analysis/src/db_index/type/generic_param.rs index 1a66b2031..1a0895af8 100644 --- a/crates/emmylua_code_analysis/src/db_index/type/generic_param.rs +++ b/crates/emmylua_code_analysis/src/db_index/type/generic_param.rs @@ -5,22 +5,25 @@ use crate::{LuaAttributeUse, LuaType}; #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct GenericParam { pub name: SmolStr, - pub type_constraint: Option, - pub default_type: Option, + pub constraint: Option, + pub default: Option, + pub is_const: bool, pub attributes: Option>, } impl GenericParam { pub fn new( name: SmolStr, - type_constraint: Option, - default_type: Option, + constraint: Option, + default: Option, + is_const: bool, attributes: Option>, ) -> Self { Self { name, - type_constraint, - default_type, + constraint, + default, + is_const, attributes, } } diff --git a/crates/emmylua_code_analysis/src/db_index/type/humanize_type.rs b/crates/emmylua_code_analysis/src/db_index/type/humanize_type.rs index f26732263..1103d689c 100644 --- a/crates/emmylua_code_analysis/src/db_index/type/humanize_type.rs +++ b/crates/emmylua_code_analysis/src/db_index/type/humanize_type.rs @@ -215,7 +215,6 @@ impl<'a> TypeHumanizer<'a> { self.level = saved; w.write_char('>') } - LuaType::ConstTplRef(const_tpl) => w.write_str(const_tpl.get_name()), LuaType::Language(s) => w.write_str(s), LuaType::Conditional(c) => self.write_conditional_type(c, w), LuaType::Never => w.write_str("never"), diff --git a/crates/emmylua_code_analysis/src/db_index/type/mod.rs b/crates/emmylua_code_analysis/src/db_index/type/mod.rs index 2d9f4c307..8e0a78245 100644 --- a/crates/emmylua_code_analysis/src/db_index/type/mod.rs +++ b/crates/emmylua_code_analysis/src/db_index/type/mod.rs @@ -618,7 +618,7 @@ pub fn first_param_may_not_self(typ: &LuaType) -> bool { if typ.is_table() || matches!( typ, - LuaType::TplRef(_) | LuaType::StrTplRef(_) | LuaType::Any + LuaType::TplRef(_) | LuaType::StrTplRef(_) | LuaType::Any | LuaType::Unknown ) { return true; diff --git a/crates/emmylua_code_analysis/src/db_index/type/type_decl.rs b/crates/emmylua_code_analysis/src/db_index/type/type_decl.rs index ba8ef4a05..a7b5ac0c5 100644 --- a/crates/emmylua_code_analysis/src/db_index/type/type_decl.rs +++ b/crates/emmylua_code_analysis/src/db_index/type/type_decl.rs @@ -18,7 +18,6 @@ pub enum LuaDeclTypeKind { Class, Enum, Alias, - Attribute, } flags! { @@ -64,7 +63,6 @@ impl LuaTypeDecl { LuaDeclTypeKind::Enum => LuaTypeExtra::Enum { base: None }, LuaDeclTypeKind::Class => LuaTypeExtra::Class, LuaDeclTypeKind::Alias => LuaTypeExtra::Alias { origin: None }, - LuaDeclTypeKind::Attribute => LuaTypeExtra::Attribute { typ: None }, }, } } @@ -93,10 +91,6 @@ impl LuaTypeDecl { matches!(self.extra, LuaTypeExtra::Alias { .. }) } - pub fn is_attribute(&self) -> bool { - matches!(self.extra, LuaTypeExtra::Attribute { .. }) - } - pub fn is_exact(&self) -> bool { self.locations .iter() @@ -178,20 +172,6 @@ impl LuaTypeDecl { } } - pub fn add_attribute_type(&mut self, attribute_type: LuaType) { - if let LuaTypeExtra::Attribute { typ } = &mut self.extra { - *typ = Some(attribute_type); - } - } - - pub fn get_attribute_type(&self) -> Option<&LuaType> { - if let LuaTypeExtra::Attribute { typ: Some(typ) } = &self.extra { - Some(typ) - } else { - None - } - } - pub fn merge_decl(&mut self, other: LuaTypeDecl) { self.locations.extend(other.locations); } @@ -398,5 +378,4 @@ pub enum LuaTypeExtra { Enum { base: Option }, Class, Alias { origin: Option }, - Attribute { typ: Option }, } diff --git a/crates/emmylua_code_analysis/src/db_index/type/types/complex.rs b/crates/emmylua_code_analysis/src/db_index/type/types/complex.rs index 1e7750c37..d9aa6a8fd 100644 --- a/crates/emmylua_code_analysis/src/db_index/type/types/complex.rs +++ b/crates/emmylua_code_analysis/src/db_index/type/types/complex.rs @@ -5,7 +5,9 @@ use smol_str::SmolStr; use std::{ops::Deref, sync::Arc}; use crate::db_index::LuaMemberKey; -use crate::{AsyncState, DbIndex, InFiled, SemanticModel, first_param_may_not_self}; +use crate::{ + AsyncState, DbIndex, InFiled, LuaAttributeUse, SemanticModel, first_param_may_not_self, +}; use super::super::basic_union::{BasicTypeKind, BasicTypeUnion}; use super::super::generic_param::GenericParam; @@ -107,6 +109,7 @@ pub struct LuaFunctionType { async_state: AsyncState, is_colon_define: bool, is_variadic: bool, + generic_params: Option>, params: Vec<(String, Option)>, ret: LuaType, } @@ -118,11 +121,14 @@ impl LuaFunctionType { is_variadic: bool, params: Vec<(String, Option)>, ret: LuaType, + generic_params: Option>, ) -> Self { + let generic_params = generic_params.filter(|params| !params.is_empty()); Self { async_state, is_colon_define, is_variadic, + generic_params, params, ret, } @@ -140,6 +146,10 @@ impl LuaFunctionType { &self.params } + pub fn get_generic_params(&self) -> &[GenericTpl] { + self.generic_params.as_deref().unwrap_or(&[]) + } + pub fn get_ret(&self) -> &LuaType { &self.ret } @@ -213,6 +223,7 @@ impl LuaFunctionType { self.is_variadic, params, self.ret.clone(), + self.generic_params.clone(), )) } } @@ -745,26 +756,24 @@ impl GenericTplId { } } -#[derive(Debug, Clone, PartialEq, Eq)] +#[derive(Debug, Clone, Hash, PartialEq, Eq)] pub struct GenericTpl { tpl_id: GenericTplId, - name: ArcIntern, - constraint: Option, - default_type: Option, + param: GenericParam, } impl GenericTpl { pub fn new( tpl_id: GenericTplId, - name: ArcIntern, + name: SmolStr, constraint: Option, default_type: Option, + is_const: bool, + attributes: Option>, ) -> Self { Self { tpl_id, - name, - constraint, - default_type, + param: GenericParam::new(name, constraint, default_type, is_const, attributes), } } @@ -772,16 +781,33 @@ impl GenericTpl { self.tpl_id } + pub fn get_param(&self) -> &GenericParam { + &self.param + } + pub fn get_name(&self) -> &str { - &self.name + self.param.name.as_str() + } + + pub fn is_const(&self) -> bool { + self.param.is_const + } + + pub fn with_const(&self, is_const: bool) -> Self { + let mut param = self.param.clone(); + param.is_const = is_const; + Self { + tpl_id: self.tpl_id, + param, + } } pub fn get_constraint(&self) -> Option<&LuaType> { - self.constraint.as_ref() + self.param.constraint.as_ref() } pub fn get_default_type(&self) -> Option<&LuaType> { - self.default_type.as_ref() + self.param.default.as_ref() } } @@ -897,21 +923,6 @@ impl LuaArrayType { } } -#[derive(Debug, Clone, Hash, PartialEq, Eq)] -pub struct LuaAttributeType { - params: Vec<(String, Option)>, -} - -impl LuaAttributeType { - pub fn new(params: Vec<(String, Option)>) -> Self { - Self { params } - } - - pub fn get_params(&self) -> &[(String, Option)] { - &self.params - } -} - #[derive(Debug, Clone, Hash, PartialEq, Eq)] pub struct LuaConditionalType { checked_type: LuaType, diff --git a/crates/emmylua_code_analysis/src/db_index/type/types/lua_type.rs b/crates/emmylua_code_analysis/src/db_index/type/types/lua_type.rs index 1ae5acb1d..54611e6ac 100644 --- a/crates/emmylua_code_analysis/src/db_index/type/types/lua_type.rs +++ b/crates/emmylua_code_analysis/src/db_index/type/types/lua_type.rs @@ -8,9 +8,9 @@ use crate::{FileId, InFiled}; use super::super::type_decl::LuaTypeDeclId; use super::complex::{ - GenericTpl, LuaAliasCallType, LuaArrayType, LuaAttributeType, LuaConditionalType, - LuaFunctionType, LuaGenericType, LuaInstanceType, LuaIntersectionType, LuaMappedType, - LuaMultiLineUnion, LuaObjectType, LuaStringTplType, LuaTupleType, LuaUnionType, VariadicType, + GenericTpl, LuaAliasCallType, LuaArrayType, LuaConditionalType, LuaFunctionType, + LuaGenericType, LuaInstanceType, LuaIntersectionType, LuaMappedType, LuaMultiLineUnion, + LuaObjectType, LuaStringTplType, LuaTupleType, LuaUnionType, VariadicType, }; #[derive(Debug, Clone)] @@ -57,10 +57,8 @@ pub enum LuaType { Call(Arc), MultiLineUnion(Arc), TypeGuard(Arc), - ConstTplRef(Arc), Language(ArcIntern), ModuleRef(FileId), - DocAttribute(Arc), Conditional(Arc), Mapped(Arc), } @@ -110,10 +108,8 @@ impl PartialEq for LuaType { (LuaType::MultiLineUnion(a), LuaType::MultiLineUnion(b)) => a == b, (LuaType::TypeGuard(a), LuaType::TypeGuard(b)) => a == b, (LuaType::Never, LuaType::Never) => true, - (LuaType::ConstTplRef(a), LuaType::ConstTplRef(b)) => a == b, (LuaType::Language(a), LuaType::Language(b)) => a == b, (LuaType::ModuleRef(a), LuaType::ModuleRef(b)) => a == b, - (LuaType::DocAttribute(a), LuaType::DocAttribute(b)) => a == b, (LuaType::Conditional(a), LuaType::Conditional(b)) => a == b, (LuaType::Mapped(a), LuaType::Mapped(b)) => a == b, _ => false, @@ -168,12 +164,10 @@ impl Hash for LuaType { LuaType::MultiLineUnion(a) => (43, Arc::as_ptr(a)).hash(state), LuaType::TypeGuard(a) => (44, Arc::as_ptr(a)).hash(state), LuaType::Never => 45.hash(state), - LuaType::ConstTplRef(a) => (46, Arc::as_ptr(a)).hash(state), - LuaType::Language(a) => (47, a).hash(state), - LuaType::ModuleRef(a) => (48, a).hash(state), - LuaType::Conditional(a) => (49, Arc::as_ptr(a)).hash(state), - LuaType::Mapped(a) => (50, Arc::as_ptr(a)).hash(state), - LuaType::DocAttribute(a) => (51, a).hash(state), + LuaType::Language(a) => (46, a).hash(state), + LuaType::ModuleRef(a) => (47, a).hash(state), + LuaType::Conditional(a) => (48, Arc::as_ptr(a)).hash(state), + LuaType::Mapped(a) => (49, Arc::as_ptr(a)).hash(state), } } } diff --git a/crates/emmylua_code_analysis/src/db_index/type/types/predicates.rs b/crates/emmylua_code_analysis/src/db_index/type/types/predicates.rs index 37cfc40c9..edcb11819 100644 --- a/crates/emmylua_code_analysis/src/db_index/type/types/predicates.rs +++ b/crates/emmylua_code_analysis/src/db_index/type/types/predicates.rs @@ -261,9 +261,10 @@ impl LuaType { match ty { LuaType::TplRef(_) | LuaType::StrTplRef(_) - | LuaType::ConstTplRef(_) | LuaType::SelfInfer - | LuaType::Mapped(_) => return true, + | LuaType::Mapped(_) => { + return true; + } _ => ty.push_direct_children(&mut stack), } } diff --git a/crates/emmylua_code_analysis/src/db_index/type/types/test.rs b/crates/emmylua_code_analysis/src/db_index/type/types/test.rs index ec63747de..f0ec046f6 100644 --- a/crates/emmylua_code_analysis/src/db_index/type/types/test.rs +++ b/crates/emmylua_code_analysis/src/db_index/type/types/test.rs @@ -1,6 +1,6 @@ #[cfg(test)] mod tests { - use internment::ArcIntern; + use smol_str::SmolStr; use std::mem::ManuallyDrop; @@ -25,8 +25,10 @@ mod tests { let mut ty = LuaType::TplRef( GenericTpl::new( GenericTplId::Type(0), - ArcIntern::new(SmolStr::new("T")), + SmolStr::new("T"), + None, None, + false, None, ) .into(), diff --git a/crates/emmylua_code_analysis/src/db_index/type/types/traverse.rs b/crates/emmylua_code_analysis/src/db_index/type/types/traverse.rs index 9cb73bdcc..1d4c8f016 100644 --- a/crates/emmylua_code_analysis/src/db_index/type/types/traverse.rs +++ b/crates/emmylua_code_analysis/src/db_index/type/types/traverse.rs @@ -1,8 +1,8 @@ use super::super::type_visit_trait::TypeVisitTrait; use super::{ - LuaAliasCallType, LuaArrayType, LuaAttributeType, LuaConditionalType, LuaFunctionType, - LuaGenericType, LuaIntersectionType, LuaMappedType, LuaMultiLineUnion, LuaObjectType, - LuaTupleType, LuaType, LuaUnionType, VariadicType, + LuaAliasCallType, LuaArrayType, LuaConditionalType, LuaFunctionType, LuaGenericType, + LuaIntersectionType, LuaMappedType, LuaMultiLineUnion, LuaObjectType, LuaTupleType, LuaType, + LuaUnionType, VariadicType, }; pub trait LuaTypeNode { @@ -49,7 +49,6 @@ pub trait LuaTypeNode { ty, LuaType::TplRef(_) | LuaType::StrTplRef(_) - | LuaType::ConstTplRef(_) | LuaType::SelfInfer | LuaType::Mapped(_) ) @@ -62,7 +61,6 @@ pub trait LuaTypeNode { ty, LuaType::TplRef(_) | LuaType::StrTplRef(_) - | LuaType::ConstTplRef(_) | LuaType::SelfInfer | LuaType::Mapped(_) ) @@ -222,16 +220,6 @@ impl LuaTypeNode for LuaArrayType { } } -impl LuaTypeNode for LuaAttributeType { - fn push_direct_children<'a>(&'a self, stack: &mut Vec<&'a LuaType>) { - for (_, ty) in self.get_params().iter().rev() { - if let Some(ty) = ty { - stack.push(ty); - } - } - } -} - impl LuaTypeNode for LuaConditionalType { fn push_direct_children<'a>(&'a self, stack: &mut Vec<&'a LuaType>) { stack.push(self.get_false_type()); @@ -244,10 +232,10 @@ impl LuaTypeNode for LuaConditionalType { impl LuaTypeNode for LuaMappedType { fn push_direct_children<'a>(&'a self, stack: &mut Vec<&'a LuaType>) { stack.push(&self.value); - if let Some(constraint) = self.param.1.type_constraint.as_ref() { + if let Some(constraint) = self.param.1.constraint.as_ref() { stack.push(constraint); } - if let Some(default_type) = self.param.1.default_type.as_ref() { + if let Some(default_type) = self.param.1.default.as_ref() { stack.push(default_type); } } @@ -280,7 +268,6 @@ impl_type_visit_trait!( super::LuaInstanceType, LuaMultiLineUnion, LuaArrayType, - LuaAttributeType, LuaConditionalType, LuaMappedType, ); diff --git a/crates/emmylua_code_analysis/src/diagnostic/checker/attribute_check.rs b/crates/emmylua_code_analysis/src/diagnostic/checker/attribute_check.rs index 03dfaea01..f3ab11a7e 100644 --- a/crates/emmylua_code_analysis/src/diagnostic/checker/attribute_check.rs +++ b/crates/emmylua_code_analysis/src/diagnostic/checker/attribute_check.rs @@ -1,8 +1,7 @@ -use std::collections::HashSet; - use crate::{ DiagnosticCode, DocTypeInferContext, LuaType, SemanticModel, TypeCheckFailReason, - TypeCheckResult, diagnostic::checker::humanize_lint_type, infer_doc_type, + TypeCheckResult, diagnostic::checker::humanize_lint_type, get_attribute_constructor_params, + infer_doc_type, is_attribute_class, }; use emmylua_parser::{ LuaAstNode, LuaDocAttributeUse, LuaDocTagAttributeUse, LuaDocType, LuaExpr, LuaLiteralExpr, @@ -42,28 +41,35 @@ fn check_attribute_use( let LuaType::Ref(type_id) = attribute_type else { return None; }; - let type_decl = semantic_model - .get_db() - .get_type_index() - .get_type_decl(&type_id)?; - if !type_decl.is_attribute() { + if !is_attribute_class(semantic_model.get_db(), &type_id) { return None; } - let LuaType::DocAttribute(attr_def) = type_decl.get_attribute_type()? else { - return None; - }; - - let def_params = attr_def.get_params(); let args = match attribute_use.get_arg_list() { Some(arg_list) => arg_list.get_args().collect::>(), None => vec![], }; + let call_arg_types = infer_attribute_arg_types(semantic_model, &args); + let def_params = + get_attribute_constructor_params(semantic_model.get_db(), &type_id, &call_arg_types); check_param_count(context, &def_params, &attribute_use, &args); - check_param(context, semantic_model, &def_params, args); + check_param(context, semantic_model, &def_params, &args, &call_arg_types); Some(()) } +fn infer_attribute_arg_types( + semantic_model: &SemanticModel, + args: &[LuaLiteralExpr], +) -> Vec { + args.iter() + .map(|arg| { + semantic_model + .infer_expr(LuaExpr::LiteralExpr(arg.clone())) + .unwrap_or(LuaType::Unknown) + }) + .collect() +} + /// 检查参数数量是否匹配 fn check_param_count( context: &mut DiagnosticContext, @@ -78,7 +84,7 @@ fn check_param_count( if def_param.0 == "..." { break; } - if def_param.1.as_ref().is_some_and(is_nullable) { + if def_param.1.as_ref().is_some_and(LuaType::is_optional) { continue; } context.add_diagnostic( @@ -128,30 +134,23 @@ fn check_param( context: &mut DiagnosticContext, semantic_model: &SemanticModel, def_params: &[(String, Option)], - args: Vec, + args: &[LuaLiteralExpr], + call_arg_types: &[LuaType], ) -> Option<()> { - let mut call_arg_types = Vec::new(); - for arg in &args { - let arg_type = semantic_model - .infer_expr(LuaExpr::LiteralExpr(arg.clone())) - .ok()?; - call_arg_types.push(arg_type); - } - for (idx, param) in def_params.iter().enumerate() { if param.0 == "..." { if call_arg_types.len() < idx { break; } - if let Some(variadic_type) = param.1.clone() { - for arg_type in call_arg_types[idx..].iter() { - let result = semantic_model.type_check_detail(&variadic_type, arg_type); + if let Some(variadic_type) = param.1.as_ref() { + for (arg_idx, arg_type) in call_arg_types[idx..].iter().enumerate() { + let result = semantic_model.type_check_detail(variadic_type, arg_type); if result.is_err() { add_type_check_diagnostic( context, semantic_model, - args.get(idx)?.get_range(), - &variadic_type, + args.get(idx + arg_idx)?.get_range(), + variadic_type, arg_type, result, ); @@ -160,15 +159,15 @@ fn check_param( } break; } - if let Some(param_type) = param.1.clone() { + if let Some(param_type) = param.1.as_ref() { let arg_type = call_arg_types.get(idx).unwrap_or(&LuaType::Any); - let result = semantic_model.type_check_detail(¶m_type, arg_type); + let result = semantic_model.type_check_detail(param_type, arg_type); if result.is_err() { add_type_check_diagnostic( context, semantic_model, args.get(idx)?.get_range(), - ¶m_type, + param_type, arg_type, result, ); @@ -212,25 +211,3 @@ fn add_type_check_diagnostic( } } } - -fn is_nullable(typ: &LuaType) -> bool { - let mut stack: Vec = Vec::new(); - stack.push(typ.clone()); - let mut visited = HashSet::new(); - while let Some(typ) = stack.pop() { - if visited.contains(&typ) { - continue; - } - visited.insert(typ.clone()); - match typ { - LuaType::Any | LuaType::Unknown | LuaType::Nil => return true, - LuaType::Union(u) => { - for t in u.into_vec() { - stack.push(t); - } - } - _ => {} - } - } - false -} diff --git a/crates/emmylua_code_analysis/src/diagnostic/checker/call_non_callable.rs b/crates/emmylua_code_analysis/src/diagnostic/checker/call_non_callable.rs index 2da6fa2de..6aa335139 100644 --- a/crates/emmylua_code_analysis/src/diagnostic/checker/call_non_callable.rs +++ b/crates/emmylua_code_analysis/src/diagnostic/checker/call_non_callable.rs @@ -150,7 +150,7 @@ fn has_non_callable_member(db: &DbIndex, typ: &LuaType) -> bool { LuaType::Any | LuaType::Unknown | LuaType::SelfInfer | LuaType::Global | LuaType::Nil => { false } - LuaType::TplRef(tpl) | LuaType::ConstTplRef(tpl) => tpl + LuaType::TplRef(tpl) => tpl .get_constraint() .is_some_and(|constraint| has_non_callable_member(db, constraint)), LuaType::StrTplRef(str_tpl) => str_tpl diff --git a/crates/emmylua_code_analysis/src/diagnostic/checker/check_param_count.rs b/crates/emmylua_code_analysis/src/diagnostic/checker/check_param_count.rs deleted file mode 100644 index 2e89c8f43..000000000 --- a/crates/emmylua_code_analysis/src/diagnostic/checker/check_param_count.rs +++ /dev/null @@ -1,281 +0,0 @@ -use std::collections::HashSet; - -use emmylua_parser::{ - LuaAst, LuaAstNode, LuaAstToken, LuaCallExpr, LuaClosureExpr, LuaExpr, LuaGeneralToken, - LuaLiteralToken, -}; - -use crate::{DbIndex, DiagnosticCode, LuaSignatureId, LuaType, SemanticModel}; - -use super::{Checker, DiagnosticContext}; - -pub struct CheckParamCountChecker; - -impl Checker for CheckParamCountChecker { - const CODES: &[DiagnosticCode] = &[ - DiagnosticCode::MissingParameter, - DiagnosticCode::RedundantParameter, - ]; - - fn check(context: &mut DiagnosticContext, semantic_model: &SemanticModel) { - for node in semantic_model.get_root().descendants::() { - match node { - LuaAst::LuaCallExpr(call_expr) => { - check_call_expr(context, semantic_model, call_expr); - } - LuaAst::LuaClosureExpr(closure_expr) => { - check_closure_expr(context, semantic_model, &closure_expr); - } - _ => {} - } - } - } -} - -/// 处理左值已绑定类型但右值为匿名函数的情况 -fn check_closure_expr( - context: &mut DiagnosticContext, - semantic_model: &SemanticModel, - closure_expr: &LuaClosureExpr, -) -> Option<()> { - let current_signature = context - .db - .get_signature_index() - .get(&LuaSignatureId::from_closure( - semantic_model.get_file_id(), - closure_expr, - ))?; - - let source_typ = semantic_model.infer_bind_value_type(closure_expr.clone().into())?; - - let source_params_len = match &source_typ { - LuaType::DocFunction(func_type) => { - let params = func_type.get_params(); - get_params_len(params) - } - LuaType::Signature(signature_id) => { - let signature = context.db.get_signature_index().get(signature_id)?; - let params = signature.get_type_params(); - get_params_len(¶ms) - } - _ => return Some(()), - }?; - - // 只检查右值参数多于左值参数的情况, 右值参数少于左值参数的情况是能够接受的 - if source_params_len > current_signature.params.len() { - return Some(()); - } - let params = closure_expr - .get_params_list()? - .get_params() - .collect::>(); - - for param in params[source_params_len..].iter() { - context.add_diagnostic( - DiagnosticCode::RedundantParameter, - param.get_range(), - t!( - "expected %{num} parameters but found %{found_num}", - num = source_params_len, - found_num = current_signature.params.len(), - ) - .to_string(), - None, - ); - } - - Some(()) -} - -fn check_call_expr( - context: &mut DiagnosticContext, - semantic_model: &SemanticModel, - call_expr: LuaCallExpr, -) -> Option<()> { - let func = semantic_model.infer_call_expr_func(call_expr.clone(), None)?; - let mut fake_params = func.get_params().to_vec(); - let call_args = call_expr.get_args_list()?.get_args().collect::>(); - let mut call_args_count = call_args.len(); - let last_arg_is_dots = call_args.last().is_some_and(is_dots_expr); - // 根据冒号定义与冒号调用的情况来调整调用参数的数量 - let colon_call = call_expr.is_colon_call(); - let colon_define = func.is_colon_define(); - match (colon_call, colon_define) { - (true, true) | (false, false) => {} - (false, true) => { - fake_params.insert(0, ("self".to_string(), Some(LuaType::SelfInfer))); - } - (true, false) => { - call_args_count += 1; - } - } - - // Check for missing parameters - if call_args_count < fake_params.len() { - // 调用参数包含 `...` - for arg in call_args.iter() { - if let LuaExpr::LiteralExpr(literal_expr) = arg - && let Some(LuaLiteralToken::Dots(_)) = literal_expr.get_literal() - { - return Some(()); - } - } - // 对调用参数的最后一个参数进行特殊处理 - if let Some(last_arg) = call_args.last() - && let Ok(LuaType::Variadic(variadic)) = semantic_model.infer_expr(last_arg.clone()) - { - let len = match variadic.get_max_len() { - Some(len) => len, - None => { - return Some(()); - } - }; - call_args_count = call_args_count + len - 1; - if call_args_count >= fake_params.len() { - return Some(()); - } - } - - let mut miss_parameter_info = Vec::new(); - - for i in call_args_count..fake_params.len() { - let param_info = fake_params.get(i)?; - if param_info.0 == "..." { - break; - } - - let typ = param_info.1.clone(); - if let Some(typ) = typ - && !is_nullable(context.db, &typ) - { - miss_parameter_info.push(t!("missing parameter: %{name}", name = param_info.0,)); - } - } - - if !miss_parameter_info.is_empty() { - let right_paren = call_expr - .get_args_list()? - .tokens::() - .last()?; - context.add_diagnostic( - DiagnosticCode::MissingParameter, - right_paren.get_range(), - t!( - "expected %{num} parameters but found %{found_num}. %{infos}", - num = fake_params.len(), - found_num = call_args_count, - infos = miss_parameter_info.join(" \n ") - ) - .to_string(), - None, - ); - } - } - // Check for redundant parameters - else { - if func.is_variadic() { - return Some(()); - } - - let mut min_call_args_count = call_args_count; - if last_arg_is_dots { - min_call_args_count = min_call_args_count.saturating_sub(1); - } - - if min_call_args_count <= fake_params.len() { - return Some(()); - } - - // 参数定义中最后一个参数是 `...` - if fake_params.last().is_some_and(|(name, typ)| { - name == "..." || typ.as_ref().is_some_and(|typ| typ.is_variadic()) - }) { - return Some(()); - } - - let mut adjusted_index = 0; - if colon_call != colon_define { - adjusted_index = if colon_define && !colon_call { -1 } else { 1 }; - } - - for (i, arg) in call_args.iter().enumerate() { - if last_arg_is_dots && i + 1 == call_args.len() { - continue; - } - - let param_index = i as isize + adjusted_index; - - if param_index < 0 || param_index < fake_params.len() as isize { - continue; - } - - context.add_diagnostic( - DiagnosticCode::RedundantParameter, - arg.get_range(), - t!( - "expected %{num} parameters but found %{found_num}", - num = fake_params.len(), - found_num = min_call_args_count, - ) - .to_string(), - None, - ); - } - } - - Some(()) -} - -fn is_dots_expr(expr: &LuaExpr) -> bool { - if let LuaExpr::LiteralExpr(literal_expr) = expr - && let Some(LuaLiteralToken::Dots(_)) = literal_expr.get_literal() - { - return true; - } - false -} - -fn get_params_len(params: &[(String, Option)]) -> Option { - if let Some((name, typ)) = params.last() { - // 如果最后一个参数是可变参数, 则直接返回, 不需要检查 - if name == "..." || typ.as_ref().is_some_and(|typ| typ.is_variadic()) { - return None; - } - } - Some(params.len()) -} - -fn is_nullable(db: &DbIndex, typ: &LuaType) -> bool { - let mut stack: Vec = Vec::new(); - stack.push(typ.clone()); - let mut visited = HashSet::new(); - while let Some(typ) = stack.pop() { - if visited.contains(&typ) { - continue; - } - visited.insert(typ.clone()); - match typ { - LuaType::Any | LuaType::Unknown | LuaType::Nil => return true, - LuaType::Ref(decl_id) => { - if let Some(decl) = db.get_type_index().get_type_decl(&decl_id) - && decl.is_alias() - && let Some(alias_origin) = decl.get_alias_ref() - { - stack.push(alias_origin.clone()); - } - } - LuaType::Union(u) => { - for t in u.into_vec() { - stack.push(t); - } - } - LuaType::MultiLineUnion(m) => { - for (t, _) in m.get_unions() { - stack.push(t.clone()); - } - } - _ => {} - } - } - false -} diff --git a/crates/emmylua_code_analysis/src/diagnostic/checker/deprecated.rs b/crates/emmylua_code_analysis/src/diagnostic/checker/deprecated.rs index 96af17edd..e4b9f9268 100644 --- a/crates/emmylua_code_analysis/src/diagnostic/checker/deprecated.rs +++ b/crates/emmylua_code_analysis/src/diagnostic/checker/deprecated.rs @@ -1,4 +1,4 @@ -use emmylua_parser::{LuaAst, LuaAstNode, LuaIndexExpr, LuaNameExpr}; +use emmylua_parser::{LuaAst, LuaAstNode, LuaDocNameType, LuaIndexExpr, LuaNameExpr}; use crate::{ DiagnosticCode, LuaDeclId, LuaDeprecated, LuaMemberId, LuaSemanticDeclId, SemanticDeclLevel, @@ -22,6 +22,9 @@ impl Checker for DeprecatedChecker { LuaAst::LuaIndexExpr(index_expr) => { check_index_expr(context, semantic_model, index_expr); } + LuaAst::LuaDocNameType(name_type) => { + check_doc_name_type(context, semantic_model, name_type); + } _ => {} } } @@ -74,6 +77,32 @@ fn check_index_expr( Some(()) } +fn check_doc_name_type( + context: &mut DiagnosticContext, + semantic_model: &SemanticModel, + name_type: LuaDocNameType, +) -> Option<()> { + let semantic_decl = semantic_model.find_decl( + rowan::NodeOrToken::Node(name_type.syntax().clone()), + SemanticDeclLevel::default(), + )?; + + let LuaSemanticDeclId::TypeDecl(_) = &semantic_decl else { + return Some(()); + }; + + if let Some(deprecated_message) = get_deprecated_message(semantic_model, &semantic_decl) { + context.add_diagnostic( + DiagnosticCode::Deprecated, + name_type.get_range(), + deprecated_message, + None, + ); + } + + Some(()) +} + fn check_deprecated( context: &mut DiagnosticContext, semantic_model: &SemanticModel, @@ -87,14 +116,11 @@ fn check_deprecated( let Some(property) = property else { return; }; - if let Some(deprecated) = property.deprecated() { - let deprecated_message = match deprecated { - LuaDeprecated::Deprecated => "deprecated".to_string(), - LuaDeprecated::DeprecatedWithMessage(message) => message.to_string(), - }; + if let Some(deprecated_message) = get_deprecated_message(semantic_model, semantic_decl) { context.add_diagnostic(DiagnosticCode::Deprecated, range, deprecated_message, None); } + // 检查特性 if let Some(attribute_uses) = property.attribute_uses() { for attribute_use in attribute_uses.iter() { @@ -105,3 +131,23 @@ fn check_deprecated( } } } + +fn get_deprecated_message( + semantic_model: &SemanticModel, + semantic_decl: &LuaSemanticDeclId, +) -> Option { + let property = semantic_model + .get_db() + .get_property_index() + .get_property(semantic_decl); + let property = property?; + if let Some(deprecated) = property.deprecated() { + let deprecated_message = match deprecated { + LuaDeprecated::Deprecated => "deprecated".to_string(), + LuaDeprecated::DeprecatedWithMessage(message) => message.to_string(), + }; + return Some(deprecated_message); + } + + None +} diff --git a/crates/emmylua_code_analysis/src/diagnostic/checker/generic/generic_constraint_mismatch.rs b/crates/emmylua_code_analysis/src/diagnostic/checker/generic/generic_constraint_mismatch.rs index cc5594eff..ce67c0cb6 100644 --- a/crates/emmylua_code_analysis/src/diagnostic/checker/generic/generic_constraint_mismatch.rs +++ b/crates/emmylua_code_analysis/src/diagnostic/checker/generic/generic_constraint_mismatch.rs @@ -103,7 +103,7 @@ fn check_doc_tag_class( .get_generic_params(&type_decl.get_id())?; let generic_param_types = generic_params .iter() - .map(|param| (param.type_constraint.clone(), param.default_type.clone())) + .map(|param| (param.constraint.clone(), param.default.clone())) .collect::>(); check_generic_decl_defaults( context, @@ -133,7 +133,7 @@ fn check_doc_tag_alias( .get_generic_params(&type_decl.get_id())?; let generic_param_types = generic_params .iter() - .map(|param| (param.type_constraint.clone(), param.default_type.clone())) + .map(|param| (param.constraint.clone(), param.default.clone())) .collect::>(); check_generic_decl_defaults( context, @@ -158,7 +158,7 @@ fn check_doc_tag_generic( let generic_param_types = signature .generic_params .iter() - .map(|param| (param.constraint.clone(), param.default_type.clone())) + .map(|param| (param.constraint.clone(), param.default.clone())) .collect::>(); check_generic_decl_defaults( context, @@ -467,7 +467,7 @@ fn check_variadic_default_satisfies_constraint( fn generic_tpl_id(ty: &LuaType) -> Option { match ty { - LuaType::TplRef(tpl) | LuaType::ConstTplRef(tpl) => Some(tpl.get_tpl_id()), + LuaType::TplRef(tpl) => Some(tpl.get_tpl_id()), LuaType::StrTplRef(str_tpl) => Some(str_tpl.get_tpl_id()), _ => None, } @@ -475,7 +475,7 @@ fn generic_tpl_id(ty: &LuaType) -> Option { fn generic_upper_bound(ty: &LuaType) -> Option<&LuaType> { match ty { - LuaType::TplRef(tpl) | LuaType::ConstTplRef(tpl) => tpl.get_constraint(), + LuaType::TplRef(tpl) => tpl.get_constraint(), LuaType::StrTplRef(str_tpl) => str_tpl.get_constraint(), _ => None, } @@ -491,7 +491,7 @@ fn instantiate_decl_default_for_check(ty: &LuaType) -> LuaType { fn instantiate_decl_type_for_check(ty: &LuaType, use_generic_upper_bound: bool) -> LuaType { match ty { - LuaType::TplRef(tpl) | LuaType::ConstTplRef(tpl) => { + LuaType::TplRef(tpl) => { if use_generic_upper_bound && let Some(constraint) = tpl.get_constraint() { return instantiate_decl_default_for_check(constraint); } @@ -638,7 +638,7 @@ fn check_doc_tag_type( .take(explicit_args.len()) .enumerate() { - let extend_type = generic_params.get(i)?.type_constraint.clone()?; + let extend_type = generic_params.get(i)?.constraint.clone()?; let result = semantic_model.type_check_detail(&extend_type, param_type); if result.is_err() { add_type_check_diagnostic( @@ -702,7 +702,7 @@ fn check_param( extend_type, ); } - LuaType::TplRef(tpl_ref) | LuaType::ConstTplRef(tpl_ref) => { + LuaType::TplRef(tpl_ref) => { let extend_type = tpl_ref.get_constraint().cloned().map(|ty| { normalize_constraint_type( semantic_model.get_db(), diff --git a/crates/emmylua_code_analysis/src/diagnostic/checker/mod.rs b/crates/emmylua_code_analysis/src/diagnostic/checker/mod.rs index 7eae02020..54bdfaafc 100644 --- a/crates/emmylua_code_analysis/src/diagnostic/checker/mod.rs +++ b/crates/emmylua_code_analysis/src/diagnostic/checker/mod.rs @@ -7,7 +7,6 @@ mod call_non_callable; mod cast_type_mismatch; mod check_export; mod check_field; -mod check_param_count; mod check_return_count; mod circle_doc_class; mod code_style; @@ -24,7 +23,7 @@ mod incomplete_signature_doc; mod local_const_reassign; mod missing_fields; mod need_check_nil; -mod param_type_check; +mod param_check; mod readonly_check; mod redefined_local; mod require_module_visibility; @@ -88,7 +87,7 @@ pub fn check_file(context: &mut DiagnosticContext, semantic_model: &SemanticMode run_check::(context, semantic_model); run_check::(context, semantic_model); run_check::(context, semantic_model); - run_check::(context, semantic_model); + run_check::(context, semantic_model); run_check::(context, semantic_model); run_check::(context, semantic_model); run_check::(context, semantic_model); @@ -102,7 +101,6 @@ pub fn check_file(context: &mut DiagnosticContext, semantic_model: &SemanticMode run_check::(context, semantic_model); run_check::(context, semantic_model); run_check::(context, semantic_model); - run_check::(context, semantic_model); run_check::(context, semantic_model); run_check::(context, semantic_model); run_check::( diff --git a/crates/emmylua_code_analysis/src/diagnostic/checker/param_check/call_facts.rs b/crates/emmylua_code_analysis/src/diagnostic/checker/param_check/call_facts.rs new file mode 100644 index 000000000..8425d5272 --- /dev/null +++ b/crates/emmylua_code_analysis/src/diagnostic/checker/param_check/call_facts.rs @@ -0,0 +1,230 @@ +use std::{ + cell::{OnceCell, RefCell}, + collections::HashSet, + sync::Arc, +}; + +use emmylua_parser::{LuaCallExpr, LuaExpr, LuaLiteralToken}; +use rowan::TextRange; + +use crate::{ + DbIndex, LuaFunctionType, LuaType, SemanticModel, infer_call_generic, + semantic::{collect_callable_overload_groups, is_func_last_param_variadic}, +}; + +pub(super) struct CallFacts { + pub(super) call_expr: LuaCallExpr, + pub(super) arg_exprs: Vec, + funcs: Vec>, + arg_types_and_ranges: RefCell>>, + base_call_arg_count_range: OnceCell>, +} + +impl CallFacts { + pub(super) fn new(semantic_model: &SemanticModel, call_expr: LuaCallExpr) -> Option { + let arg_exprs = call_expr.get_args_list()?.get_args().collect::>(); + let funcs = collect_diagnostic_callables(semantic_model, &call_expr)?; + + Some(Self { + call_expr, + arg_exprs, + funcs, + arg_types_and_ranges: RefCell::new(None), + base_call_arg_count_range: OnceCell::new(), + }) + } + + pub(super) fn funcs(&self) -> &[Arc] { + &self.funcs + } + + pub(super) fn arg_types_and_ranges( + &self, + semantic_model: &SemanticModel, + ) -> (Vec, Vec) { + let mut cached = self.arg_types_and_ranges.borrow_mut(); + if cached.is_none() { + *cached = Some(semantic_model.infer_expr_list_types(&self.arg_exprs, None)); + } + + cached + .as_ref() + .cloned() + .unwrap_or_default() + .into_iter() + .unzip() + } + + // 计算当前调用表达式的实参列表能提供多少个实参槽位. + pub(super) fn call_arg_count_range( + &self, + semantic_model: &SemanticModel, + func: &LuaFunctionType, + ) -> Option { + let mut count = self + .base_call_arg_count_range + .get_or_init(|| get_base_call_arg_count_range(semantic_model, &self.arg_exprs)) + .as_ref() + .copied()?; + if self.call_expr.is_colon_call() && !func.is_colon_define() { + // 冒号调用普通函数时, `obj:foo(x)` 等价于 `obj.foo(obj, x)`. + // 这里要把 receiver 计入调用侧的实参槽位. + count.min += 1; + count.max = count.max.map(|max| max + 1); + } + + Some(count) + } +} + +// 收集所有可调用的候选. +fn collect_diagnostic_callables( + semantic_model: &SemanticModel, + call_expr: &LuaCallExpr, +) -> Option>> { + let prefix_expr = call_expr.get_prefix_expr()?; + let prefix_type = semantic_model.infer_expr(prefix_expr).ok()?; + let mut overload_groups = Vec::new(); + collect_callable_overload_groups(semantic_model.get_db(), &prefix_type, &mut overload_groups) + .ok()?; + let mut funcs = Vec::new(); + for func in overload_groups.into_iter().flatten() { + let func = if func.contain_tpl() { + infer_call_generic( + semantic_model.get_db(), + &mut semantic_model.get_cache().borrow_mut(), + func.as_ref(), + call_expr.clone(), + ) + .map(Arc::new) + .unwrap_or(func) + } else { + func + }; + funcs.push(func); + } + + (!funcs.is_empty()).then_some(funcs) +} + +// 比较调用侧能提供的实参数量, 和函数侧能接受的形参数量是否有交集. +pub(super) fn count_ranges_overlap(call_count: CountRange, param_count: CountRange) -> bool { + let enough_args = call_count.max.is_none_or(|max| max >= param_count.min); + let not_too_many_args = param_count.max.is_none_or(|max| call_count.min <= max); + + enough_args && not_too_many_args +} + +#[derive(Clone, Copy)] +pub(super) struct CountRange { + // 数量下界: 调用侧至少会提供多少, 或函数侧至少要求多少. + pub(super) min: usize, + // 数量上界: 调用侧最多会提供多少, 或函数侧最多接受多少; None 表示无上限. + pub(super) max: Option, +} + +fn get_base_call_arg_count_range( + semantic_model: &SemanticModel, + arg_exprs: &[LuaExpr], +) -> Option { + if arg_exprs.iter().any(is_dots_expr) { + // `...` 无法精确给出数量范围, 交给后续类型检查处理. + return None; + } + + let mut count = CountRange { + min: arg_exprs.len(), + max: Some(arg_exprs.len()), + }; + + if let Some(last_arg) = arg_exprs.last() + && let Ok(LuaType::Variadic(variadic)) = semantic_model.infer_expr(last_arg.clone()) + { + let base = arg_exprs.len().saturating_sub(1); + count.min = base + variadic.get_min_len().unwrap_or(0); + count.max = variadic.get_max_len().map(|len| base + len); + } + + Some(count) +} + +// 计算当前候选函数签名能接受多少个形参槽位. +pub(super) fn get_param_count_range( + db: &DbIndex, + func: &LuaFunctionType, + call_expr: &LuaCallExpr, +) -> CountRange { + let params = func.get_params(); + // 如果以点调用但函数是冒号定义, 则表示需要传入 self 参数. + let self_offset = usize::from(!call_expr.is_colon_call() && func.is_colon_define()); + + let mut min = self_offset; + // 最小数量取最后一个非 nullable 形参, 因为前面的可选参数可以省略. + for (idx, (name, typ)) in params.iter().enumerate() { + if name == "..." || typ.as_ref().is_some_and(|typ| typ.is_variadic()) { + break; + } + + if typ.as_ref().is_some_and(|typ| !is_nullable(db, typ)) { + min = idx + self_offset + 1; + } + } + + let adjusted_len = params.len() + self_offset; + let max = if func.is_variadic() + || is_func_last_param_variadic(func) + || params + .last() + .is_some_and(|(_, typ)| typ.as_ref().is_some_and(|typ| typ.is_variadic())) + { + None + } else { + Some(adjusted_len) + }; + + CountRange { min, max } +} + +pub(super) fn is_dots_expr(expr: &LuaExpr) -> bool { + if let LuaExpr::LiteralExpr(literal_expr) = expr + && let Some(LuaLiteralToken::Dots(_)) = literal_expr.get_literal() + { + return true; + } + false +} + +pub(super) fn is_nullable(db: &DbIndex, typ: &LuaType) -> bool { + let mut stack: Vec = Vec::new(); + stack.push(typ.clone()); + let mut visited = HashSet::new(); + while let Some(typ) = stack.pop() { + if visited.contains(&typ) { + continue; + } + visited.insert(typ.clone()); + match typ { + LuaType::Any | LuaType::Unknown | LuaType::Nil => return true, + LuaType::Ref(decl_id) => { + if let Some(decl) = db.get_type_index().get_type_decl(&decl_id) + && decl.is_alias() + && let Some(alias_origin) = decl.get_alias_ref() + { + stack.push(alias_origin.clone()); + } + } + LuaType::Union(u) => { + for t in u.into_vec() { + stack.push(t); + } + } + LuaType::MultiLineUnion(m) => { + for (t, _) in m.get_unions() { + stack.push(t.clone()); + } + } + _ => {} + } + } + false +} diff --git a/crates/emmylua_code_analysis/src/diagnostic/checker/param_check/mod.rs b/crates/emmylua_code_analysis/src/diagnostic/checker/param_check/mod.rs new file mode 100644 index 000000000..429e91237 --- /dev/null +++ b/crates/emmylua_code_analysis/src/diagnostic/checker/param_check/mod.rs @@ -0,0 +1,58 @@ +mod call_facts; +mod param_count; +mod type_mismatch; + +use emmylua_parser::{LuaAst, LuaAstNode}; + +use crate::{DiagnosticCode, SemanticModel}; + +use super::{Checker, DiagnosticContext}; +use call_facts::CallFacts; + +pub struct ParamCheckChecker; + +impl Checker for ParamCheckChecker { + const CODES: &[DiagnosticCode] = &[ + DiagnosticCode::ParamTypeMismatch, + DiagnosticCode::AssignTypeMismatch, + DiagnosticCode::MissingParameter, + DiagnosticCode::RedundantParameter, + ]; + + fn check(context: &mut DiagnosticContext, semantic_model: &SemanticModel) { + let check_param_count = context + .is_checker_enable_by_code(&DiagnosticCode::MissingParameter) + || context.is_checker_enable_by_code(&DiagnosticCode::RedundantParameter); + let check_param_type = context + .is_checker_enable_by_code(&DiagnosticCode::ParamTypeMismatch) + || context.is_checker_enable_by_code(&DiagnosticCode::AssignTypeMismatch); + + let root = semantic_model.get_root().clone(); + for node in root.descendants::() { + match node { + LuaAst::LuaCallExpr(call_expr) if check_param_count || check_param_type => { + let Some(facts) = CallFacts::new(semantic_model, call_expr) else { + continue; + }; + + if check_param_count { + param_count::check_call_param_count(context, semantic_model, &facts); + } + + if check_param_type { + type_mismatch::check_param_types( + context, + semantic_model, + &facts, + check_param_count, + ); + } + } + LuaAst::LuaClosureExpr(closure_expr) if check_param_count => { + param_count::check_closure_param_count(context, semantic_model, &closure_expr); + } + _ => {} + } + } + } +} diff --git a/crates/emmylua_code_analysis/src/diagnostic/checker/param_check/param_count.rs b/crates/emmylua_code_analysis/src/diagnostic/checker/param_check/param_count.rs new file mode 100644 index 000000000..ad4736a19 --- /dev/null +++ b/crates/emmylua_code_analysis/src/diagnostic/checker/param_check/param_count.rs @@ -0,0 +1,342 @@ +use std::sync::Arc; + +use emmylua_parser::{ + LuaAstNode, LuaAstToken, LuaCallExpr, LuaClosureExpr, LuaExpr, LuaGeneralToken, +}; + +use crate::{DiagnosticCode, LuaFunctionType, LuaSignatureId, LuaType, SemanticModel}; + +use super::super::DiagnosticContext; +use super::call_facts::{CallFacts, count_ranges_overlap, get_param_count_range, is_nullable}; + +pub(super) fn check_call_param_count( + context: &mut DiagnosticContext, + semantic_model: &SemanticModel, + facts: &CallFacts, +) { + let mut best_candidate = None; + for func in facts.funcs() { + let Some(call_count) = facts.call_arg_count_range(semantic_model, func) else { + return; + }; + let param_count = get_param_count_range(context.get_db(), func, &facts.call_expr); + + if count_ranges_overlap(call_count, param_count) { + return; + } + + if let Some(max_call_count) = call_count.max + && max_call_count < param_count.min + { + update_best_candidate( + &mut best_candidate, + CountDiagnosticCandidate::Missing { + mismatch: param_count.min - max_call_count, + expected_count: param_count.min, + found_count: max_call_count, + func, + }, + ); + continue; + } + + if let Some(max_param_count) = param_count.max + && call_count.min > max_param_count + { + update_best_candidate( + &mut best_candidate, + CountDiagnosticCandidate::Redundant { + mismatch: call_count.min - max_param_count, + expected_count: max_param_count, + found_count: call_count.min, + func, + }, + ); + } + } + + let Some(candidate) = best_candidate else { + return; + }; + + match candidate { + CountDiagnosticCandidate::Missing { + expected_count, + found_count, + func, + .. + } => emit_missing_parameter(context, &facts.call_expr, expected_count, found_count, func), + CountDiagnosticCandidate::Redundant { + expected_count, + found_count, + func, + .. + } => { + emit_redundant_parameter( + context, + &facts.call_expr, + &facts.arg_exprs, + expected_count, + found_count, + func, + ); + } + } +} + +enum CountDiagnosticCandidate<'a> { + Missing { + mismatch: usize, + expected_count: usize, + found_count: usize, + func: &'a Arc, + }, + Redundant { + mismatch: usize, + expected_count: usize, + found_count: usize, + func: &'a Arc, + }, +} + +fn update_best_candidate<'a>( + best_candidate: &mut Option>, + candidate: CountDiagnosticCandidate<'a>, +) { + if best_candidate + .as_ref() + .is_none_or(|current| candidate.is_better_than(current)) + { + *best_candidate = Some(candidate); + } +} + +impl CountDiagnosticCandidate<'_> { + fn is_better_than(&self, other: &Self) -> bool { + match self.mismatch().cmp(&other.mismatch()) { + std::cmp::Ordering::Less => true, + std::cmp::Ordering::Greater => false, + std::cmp::Ordering::Equal => self.is_better_tie_than(other), + } + } + + fn mismatch(&self) -> usize { + match self { + CountDiagnosticCandidate::Missing { mismatch, .. } + | CountDiagnosticCandidate::Redundant { mismatch, .. } => *mismatch, + } + } + + fn is_better_tie_than(&self, other: &Self) -> bool { + match (self, other) { + ( + CountDiagnosticCandidate::Missing { + expected_count: left, + .. + }, + CountDiagnosticCandidate::Missing { + expected_count: right, + .. + }, + ) => left < right, + ( + CountDiagnosticCandidate::Redundant { + expected_count: left, + .. + }, + CountDiagnosticCandidate::Redundant { + expected_count: right, + .. + }, + ) => left > right, + ( + CountDiagnosticCandidate::Missing { .. }, + CountDiagnosticCandidate::Redundant { .. }, + ) => true, + ( + CountDiagnosticCandidate::Redundant { .. }, + CountDiagnosticCandidate::Missing { .. }, + ) => false, + } + } +} + +fn emit_missing_parameter( + context: &mut DiagnosticContext, + call_expr: &LuaCallExpr, + expected_count: usize, + found_count: usize, + func: &Arc, +) { + let mut miss_parameter_info = Vec::new(); + + for param_index in found_count..expected_count { + add_missing_parameter_info( + context, + call_expr, + func, + param_index, + &mut miss_parameter_info, + ); + } + + if !miss_parameter_info.is_empty() { + let Some(args_list) = call_expr.get_args_list() else { + return; + }; + let Some(right_paren) = args_list.tokens::().last() else { + return; + }; + context.add_diagnostic( + DiagnosticCode::MissingParameter, + right_paren.get_range(), + t!( + "expected %{num} parameters but found %{found_num}. %{infos}", + num = expected_count, + found_num = found_count, + infos = miss_parameter_info.join(" \n ") + ) + .to_string(), + None, + ); + } +} + +fn emit_redundant_parameter( + context: &mut DiagnosticContext, + call_expr: &LuaCallExpr, + call_args: &[LuaExpr], + expected_count: usize, + found_count: usize, + func: &Arc, +) { + let implicit_receiver_offset = + usize::from(call_expr.is_colon_call() && !func.is_colon_define()); + for (i, arg) in call_args.iter().enumerate() { + if i + implicit_receiver_offset < expected_count { + continue; + } + + context.add_diagnostic( + DiagnosticCode::RedundantParameter, + arg.get_range(), + t!( + "expected %{num} parameters but found %{found_num}", + num = expected_count, + found_num = found_count, + ) + .to_string(), + None, + ); + } +} + +fn add_missing_parameter_info( + context: &DiagnosticContext, + call_expr: &LuaCallExpr, + func: &LuaFunctionType, + adjusted_index: usize, + miss_parameter_info: &mut Vec, +) { + if needs_implicit_self_param(call_expr, func) { + if adjusted_index == 0 { + if !is_nullable(context.get_db(), &LuaType::SelfInfer) { + miss_parameter_info + .push(t!("missing parameter: %{name}", name = "self",).to_string()); + } + return; + } + let Some((name, typ)) = func.get_params().get(adjusted_index - 1) else { + return; + }; + if let Some(typ) = typ + && !is_nullable(context.get_db(), typ) + { + miss_parameter_info.push(t!("missing parameter: %{name}", name = name,).to_string()); + } + return; + } + + let Some((name, typ)) = func.get_params().get(adjusted_index) else { + return; + }; + if let Some(typ) = typ + && !is_nullable(context.get_db(), typ) + { + miss_parameter_info.push(t!("missing parameter: %{name}", name = name,).to_string()); + } +} + +fn needs_implicit_self_param(call_expr: &LuaCallExpr, func: &LuaFunctionType) -> bool { + !call_expr.is_colon_call() && func.is_colon_define() +} + +fn get_params_len(params: &[(String, Option)]) -> Option { + if let Some((name, typ)) = params.last() { + // 如果最后一个参数是可变参数, 则直接返回, 不需要检查. + if name == "..." || typ.as_ref().is_some_and(|typ| typ.is_variadic()) { + return None; + } + } + Some(params.len()) +} + +pub(super) fn check_closure_param_count( + context: &mut DiagnosticContext, + semantic_model: &SemanticModel, + closure_expr: &LuaClosureExpr, +) { + let Some(current_signature) = + context + .get_db() + .get_signature_index() + .get(&LuaSignatureId::from_closure( + semantic_model.get_file_id(), + closure_expr, + )) + else { + return; + }; + + let Some(source_typ) = semantic_model.infer_bind_value_type(closure_expr.clone().into()) else { + return; + }; + + let Some(source_params_len) = (match &source_typ { + LuaType::DocFunction(func_type) => get_params_len(func_type.get_params()), + LuaType::Signature(signature_id) => { + let Some(signature) = context.get_db().get_signature_index().get(signature_id) else { + return; + }; + let params = signature.get_type_params(); + get_params_len(¶ms) + } + _ => return, + }) else { + return; + }; + + // 只检查右值参数多于左值参数的情况, 右值参数少于左值参数的情况是能够接受的. + if source_params_len > current_signature.params.len() { + return; + } + let found_num = current_signature.params.len(); + let Some(params_list) = closure_expr.get_params_list() else { + return; + }; + let params = params_list.get_params().collect::>(); + + for param in params[source_params_len..].iter() { + context.add_diagnostic( + DiagnosticCode::RedundantParameter, + param.get_range(), + t!( + "expected %{num} parameters but found %{found_num}", + num = source_params_len, + found_num = found_num, + ) + .to_string(), + None, + ); + } +} diff --git a/crates/emmylua_code_analysis/src/diagnostic/checker/param_check/type_mismatch.rs b/crates/emmylua_code_analysis/src/diagnostic/checker/param_check/type_mismatch.rs new file mode 100644 index 000000000..6b27f35e3 --- /dev/null +++ b/crates/emmylua_code_analysis/src/diagnostic/checker/param_check/type_mismatch.rs @@ -0,0 +1,261 @@ +use emmylua_parser::{LuaAstNode, LuaAstToken, LuaCallExpr}; +use rowan::{NodeOrToken, TextRange}; + +use crate::{ + DiagnosticCode, LuaFunctionType, LuaType, RenderLevel, SemanticModel, TypeCheckFailReason, + TypeCheckResult, diagnostic::checker::assign_type_mismatch::check_table_expr, humanize_type, + semantic::get_func_param_type, +}; + +use super::super::DiagnosticContext; +use super::call_facts::{CallFacts, count_ranges_overlap, get_param_count_range}; + +pub(super) fn check_param_types( + context: &mut DiagnosticContext, + semantic_model: &SemanticModel, + facts: &CallFacts, + count_diagnostics_enabled: bool, +) -> Option<()> { + let db = semantic_model.get_db(); + let mut candidates = facts + .funcs() + .iter() + .filter(|func| { + let Some(call_count) = facts.call_arg_count_range(semantic_model, func) else { + // 如果调用数量无法确定, 保守保留候选 + return true; + }; + let param_count = get_param_count_range(db, func, &facts.call_expr); + // 计算实参范围和形参范围是否有交集, 若有则保留候选. + // 不在交集范围的的函数的诊断我们交由参数数量诊断器诊断. + count_ranges_overlap(call_count, param_count) + }) + .cloned() + .collect::>(); + if candidates.is_empty() { + if count_diagnostics_enabled { + return Some(()); + } + // 参数数量诊断器在未启用时我们需要对所有函数候选都进行参数类型检查. + candidates = facts.funcs().to_vec(); + } + + let (arg_types, arg_ranges) = facts.arg_types_and_ranges(semantic_model); + + let source_type = semantic_model.infer_call_receiver_type(&facts.call_expr); + let colon_range = facts + .call_expr + .get_colon_token() + .map(|token| token.get_range()) + .or_else(|| { + facts + .call_expr + .get_prefix_expr() + .map(|expr| expr.get_range()) + }); + let mut arg_index = 0; + loop { + let mut has_arg = false; + let mut next_candidates = Vec::with_capacity(candidates.len()); + let mut failed_param_types = Vec::with_capacity(candidates.len()); + let mut failed_arg = None; + + // 按参数位置逐步收窄候选, 第一处全体失败的位置就是本次诊断的位置. + for func in &candidates { + let Some(arg) = get_diagnostic_arg( + &facts.call_expr, + func, + &arg_types, + &arg_ranges, + source_type.as_ref(), + colon_range, + arg_index, + ) else { + next_candidates.push(func.clone()); + continue; + }; + has_arg = true; + + let Some(param_type) = + get_diagnostic_param_type(func, &facts.call_expr, source_type.as_ref(), arg_index) + else { + if failed_arg.is_none() { + failed_arg = Some(arg); + } + continue; + }; + + if param_accepts_arg(semantic_model, ¶m_type, &arg.typ) { + next_candidates.push(func.clone()); + } else { + failed_param_types.push(param_type); + if failed_arg.is_none() { + failed_arg = Some(arg); + } + } + } + + if !has_arg { + break; + } + + if next_candidates.is_empty() { + let failed_arg = failed_arg?; + if failed_param_types.is_empty() { + return Some(()); + } + let param_type = LuaType::from_vec(failed_param_types); + let result = semantic_model.type_check_detail(¶m_type, &failed_arg.typ); + if result.is_ok() { + return Some(()); + } + + // 表字段已经报错了, 则不添加参数不匹配的诊断避免干扰. + if failed_arg.typ.is_table() + && let Some(arg_expr_idx) = failed_arg.expr_index + && let Some(arg_expr) = facts.arg_exprs.get(arg_expr_idx) + && let Some(add_diagnostic) = check_table_expr( + context, + semantic_model, + NodeOrToken::Node(arg_expr.syntax().clone()), + arg_expr, + Some(¶m_type), + ) + && add_diagnostic + { + return Some(()); + } + + add_diagnostic( + context, + semantic_model, + failed_arg.range, + ¶m_type, + &failed_arg.typ, + result, + ); + break; + } + + candidates = next_candidates; + arg_index += 1; + } + + Some(()) +} + +#[derive(Clone)] +struct DiagnosticArg { + typ: LuaType, + range: TextRange, + expr_index: Option, +} + +fn get_diagnostic_arg( + call_expr: &LuaCallExpr, + func: &LuaFunctionType, + arg_types: &[LuaType], + arg_ranges: &[TextRange], + source_type: Option<&LuaType>, + colon_range: Option, + arg_index: usize, +) -> Option { + // 冒号调用到非冒号定义时, 隐式 receiver 作为第 0 个实参参与类型检查. + if call_expr.is_colon_call() && !func.is_colon_define() { + if arg_index == 0 { + return Some(DiagnosticArg { + typ: source_type.cloned()?, + range: colon_range?, + expr_index: None, + }); + } + + let index = arg_index - 1; + return Some(DiagnosticArg { + typ: arg_types.get(index)?.clone(), + range: *arg_ranges.get(index)?, + expr_index: Some(index), + }); + } + + let typ = arg_types.get(arg_index)?.clone(); + Some(DiagnosticArg { + typ, + range: *arg_ranges.get(arg_index)?, + expr_index: Some(arg_index), + }) +} + +fn get_diagnostic_param_type( + func: &LuaFunctionType, + call_expr: &LuaCallExpr, + source_type: Option<&LuaType>, + arg_index: usize, +) -> Option { + // 点调用到冒号定义时, self 是第 0 个形参, 后续形参整体右移. + if !call_expr.is_colon_call() && func.is_colon_define() { + if arg_index == 0 { + return source_type.cloned().or(Some(LuaType::SelfInfer)); + } + + return get_func_param_type(func, arg_index - 1); + } + + get_func_param_type(func, arg_index) +} + +fn param_accepts_arg( + semantic_model: &SemanticModel, + param_type: &LuaType, + arg_type: &LuaType, +) -> bool { + if param_type.is_any() + || matches!((param_type, arg_type), (LuaType::Integer, LuaType::FloatConst(f)) if f.fract() == 0.0) + { + return true; + } + + semantic_model + .type_check_detail(param_type, arg_type) + .is_ok() +} + +fn add_diagnostic( + context: &mut DiagnosticContext, + semantic_model: &SemanticModel, + range: TextRange, + param_type: &LuaType, + expr_type: &LuaType, + result: TypeCheckResult, +) { + if let (LuaType::Integer, LuaType::FloatConst(f)) = (param_type, expr_type) + && f.fract() == 0.0 + { + return; + } + let db = semantic_model.get_db(); + match result { + Ok(_) => (), + Err(reason) => { + let reason_message = match reason { + TypeCheckFailReason::TypeNotMatchWithReason(reason) => reason, + TypeCheckFailReason::TypeNotMatch | TypeCheckFailReason::DonotCheck => { + "".to_string() + } + TypeCheckFailReason::TypeRecursion => "type recursion".to_string(), + }; + context.add_diagnostic( + DiagnosticCode::ParamTypeMismatch, + range, + t!( + "expected `%{source}` but found `%{found}`. %{reason}", + source = humanize_type(db, param_type, RenderLevel::Simple), + found = humanize_type(db, expr_type, RenderLevel::Simple), + reason = reason_message + ) + .to_string(), + None, + ); + } + } +} diff --git a/crates/emmylua_code_analysis/src/diagnostic/checker/param_type_check.rs b/crates/emmylua_code_analysis/src/diagnostic/checker/param_type_check.rs deleted file mode 100644 index 5c4783a81..000000000 --- a/crates/emmylua_code_analysis/src/diagnostic/checker/param_type_check.rs +++ /dev/null @@ -1,280 +0,0 @@ -use emmylua_parser::{LuaAst, LuaAstNode, LuaAstToken, LuaCallExpr, LuaExpr, LuaIndexExpr}; -use rowan::TextRange; - -use crate::{ - DiagnosticCode, LuaSemanticDeclId, LuaType, RenderLevel, SemanticDeclLevel, SemanticModel, - TypeCheckFailReason, TypeCheckResult, - diagnostic::checker::assign_type_mismatch::check_table_expr, humanize_type, -}; - -use super::{Checker, DiagnosticContext}; - -pub struct ParamTypeCheckChecker; - -impl Checker for ParamTypeCheckChecker { - const CODES: &[DiagnosticCode] = &[ - DiagnosticCode::ParamTypeMismatch, - DiagnosticCode::AssignTypeMismatch, - ]; - - /// a simple implementation of param type check, later we will do better - fn check(context: &mut DiagnosticContext, semantic_model: &SemanticModel) { - let root = semantic_model.get_root().clone(); - for node in root.descendants::() { - if let LuaAst::LuaCallExpr(call_expr) = node { - check_call_expr(context, semantic_model, call_expr); - } - } - } -} - -fn check_call_expr( - context: &mut DiagnosticContext, - semantic_model: &SemanticModel, - call_expr: LuaCallExpr, -) -> Option<()> { - let func = semantic_model.infer_call_expr_func(call_expr.clone(), None)?; - let mut params = func.get_params().to_vec(); - let arg_exprs = call_expr.get_args_list()?.get_args().collect::>(); - let (mut arg_types, mut arg_ranges): (Vec, Vec) = semantic_model - .infer_expr_list_types(&arg_exprs, None) - .into_iter() - .unzip(); - - let colon_call = call_expr.is_colon_call(); - let colon_define = func.is_colon_define(); - match (colon_call, colon_define) { - (true, true) | (false, false) => {} - (false, true) => { - // 插入 self 参数 - params.insert(0, ("self".into(), Some(LuaType::SelfInfer))); - } - (true, false) => { - // 往调用参数插入插入调用者类型 - arg_types.insert(0, get_call_source_type(semantic_model, &call_expr)?); - arg_ranges.insert(0, call_expr.get_colon_token()?.get_range()); - } - } - - for (idx, param) in params.iter().enumerate() { - if param.0 == "..." { - if arg_types.len() < idx { - break; - } - - if let Some(variadic_type) = param.1.clone() { - check_variadic_param_match_args( - context, - semantic_model, - &variadic_type, - &arg_types[idx..], - &arg_ranges[idx..], - ); - } - - break; - } - - if let Some(param_type) = param.1.clone() { - let arg_type = arg_types.get(idx).unwrap_or(&LuaType::Any); - let mut check_type = param_type.clone(); - // 对于第一个参数, 他有可能是`:`调用, 所以需要特殊处理 - if idx == 0 - && param_type.is_self_infer() - && let Some(result) = get_call_source_type(semantic_model, &call_expr) - { - check_type = result; - } - let result = semantic_model.type_check_detail(&check_type, arg_type); - if result.is_err() { - // 这里执行了`AssignTypeMismatch`的检查 - if arg_type.is_table() { - let arg_expr_idx = match (colon_call, colon_define) { - (true, false) => { - if idx == 0 { - continue; - } else { - idx - 1 - } - } - _ => idx, - }; - - // 表字段已经报错了, 则不添加参数不匹配的诊断避免干扰 - if let Some(arg_expr) = arg_exprs.get(arg_expr_idx) - && let Some(add_diagnostic) = check_table_expr( - context, - semantic_model, - rowan::NodeOrToken::Node(arg_expr.syntax().clone()), - arg_expr, - Some(¶m_type), - ) - && add_diagnostic - { - continue; - } - } - - try_add_diagnostic( - context, - semantic_model, - *arg_ranges.get(idx)?, - ¶m_type, - arg_type, - result, - ); - } - } - } - - Some(()) -} - -fn check_variadic_param_match_args( - context: &mut DiagnosticContext, - semantic_model: &SemanticModel, - variadic_type: &LuaType, - arg_types: &[LuaType], - arg_ranges: &[TextRange], -) { - for (arg_type, arg_range) in arg_types.iter().zip(arg_ranges.iter()) { - let result = semantic_model.type_check_detail(variadic_type, arg_type); - if result.is_err() { - try_add_diagnostic( - context, - semantic_model, - *arg_range, - variadic_type, - arg_type, - result, - ); - } - } -} - -fn try_add_diagnostic( - context: &mut DiagnosticContext, - semantic_model: &SemanticModel, - range: TextRange, - param_type: &LuaType, - expr_type: &LuaType, - result: TypeCheckResult, -) { - if let (LuaType::Integer, LuaType::FloatConst(f)) = (param_type, expr_type) - && f.fract() == 0.0 - { - return; - } - - add_type_check_diagnostic( - context, - semantic_model, - range, - param_type, - expr_type, - result, - ); -} - -fn add_type_check_diagnostic( - context: &mut DiagnosticContext, - semantic_model: &SemanticModel, - range: TextRange, - param_type: &LuaType, - expr_type: &LuaType, - result: TypeCheckResult, -) { - let db = semantic_model.get_db(); - match result { - Ok(_) => (), - Err(reason) => { - let reason_message = match reason { - TypeCheckFailReason::TypeNotMatchWithReason(reason) => reason, - TypeCheckFailReason::TypeNotMatch | TypeCheckFailReason::DonotCheck => { - "".to_string() - } - TypeCheckFailReason::TypeRecursion => "type recursion".to_string(), - }; - context.add_diagnostic( - DiagnosticCode::ParamTypeMismatch, - range, - t!( - "expected `%{source}` but found `%{found}`. %{reason}", - source = humanize_type(db, param_type, RenderLevel::Simple), - found = humanize_type(db, expr_type, RenderLevel::Simple), - reason = reason_message - ) - .to_string(), - None, - ); - } - } -} - -pub fn get_call_source_type( - semantic_model: &SemanticModel, - call_expr: &LuaCallExpr, -) -> Option { - match call_expr.get_prefix_expr()? { - LuaExpr::IndexExpr(index_expr) => { - let decl = semantic_model.find_decl( - index_expr.syntax().clone().into(), - SemanticDeclLevel::default(), - )?; - - if let LuaSemanticDeclId::Member(member_id) = decl - && let Some(LuaSemanticDeclId::Member(member_id)) = - semantic_model.get_member_origin_owner(member_id) - { - let root = semantic_model - .get_db() - .get_vfs() - .get_syntax_tree(&member_id.file_id)? - .get_red_root(); - let cur_node = member_id.get_syntax_id().to_node_from_root(&root)?; - let index_expr = LuaIndexExpr::cast(cur_node)?; - - return index_expr.get_prefix_expr().map(|prefix_expr| { - semantic_model - .infer_expr(prefix_expr.clone()) - .unwrap_or(LuaType::SelfInfer) - }); - } - - return if let Some(prefix_expr) = index_expr.get_prefix_expr() { - let expr_type = semantic_model - .infer_expr(prefix_expr.clone()) - .unwrap_or(LuaType::SelfInfer); - Some(expr_type) - } else { - None - }; - } - LuaExpr::NameExpr(name_expr) => { - let decl = semantic_model.find_decl( - name_expr.syntax().clone().into(), - SemanticDeclLevel::default(), - )?; - if let LuaSemanticDeclId::Member(member_id) = decl { - let root = semantic_model - .get_db() - .get_vfs() - .get_syntax_tree(&member_id.file_id)? - .get_red_root(); - let cur_node = member_id.get_syntax_id().to_node_from_root(&root)?; - let index_expr = LuaIndexExpr::cast(cur_node)?; - - return index_expr.get_prefix_expr().map(|prefix_expr| { - semantic_model - .infer_expr(prefix_expr.clone()) - .unwrap_or(LuaType::SelfInfer) - }); - } - - return None; - } - _ => {} - } - - None -} diff --git a/crates/emmylua_code_analysis/src/diagnostic/test/deprecated_test.rs b/crates/emmylua_code_analysis/src/diagnostic/test/deprecated_test.rs new file mode 100644 index 000000000..5dfb44715 --- /dev/null +++ b/crates/emmylua_code_analysis/src/diagnostic/test/deprecated_test.rs @@ -0,0 +1,174 @@ +#[cfg(test)] +mod test { + use emmylua_parser::{LuaAstNode, LuaLocalName}; + + use crate::{DiagnosticCode, LuaDeclId, LuaSemanticDeclId, VirtualWorkspace}; + + fn assert_type_decl_deprecated(content: &str, name: &str) { + let mut ws = VirtualWorkspace::new(); + let file_id = ws.def(content); + let db = ws.analysis.compilation.get_db(); + let type_decl = db + .get_type_index() + .find_type_decl(file_id, name, db.resolve_workspace_id(file_id)) + .expect("type declaration must exist"); + let property = db + .get_property_index() + .get_property(&LuaSemanticDeclId::TypeDecl(type_decl.get_id())) + .expect("type declaration property must exist"); + + assert!(property.deprecated().is_some()); + } + + fn assert_lua_decl_deprecated(content: &str, name: &str) { + let mut ws = VirtualWorkspace::new(); + let file_id = ws.def(content); + let db = ws.analysis.compilation.get_db(); + let local_name = ws.get_node::(file_id); + assert_eq!(local_name.get_text(), name); + let decl = db + .get_decl_index() + .get_decl(&LuaDeclId::new(file_id, local_name.get_position())) + .expect("declaration must exist"); + let property = db + .get_property_index() + .get_property(&LuaSemanticDeclId::LuaDecl(decl.get_id())) + .expect("declaration property must exist"); + + assert!(property.deprecated().is_some()); + } + + #[test] + fn test_deprecated_alias_use() { + let mut ws = VirtualWorkspace::new(); + + assert!(ws.has_no_diagnostic( + DiagnosticCode::Deprecated, + r#" + ---@deprecated test + ---@alias std.ConstTpl unknown + "# + )); + } + + #[test] + fn test_deprecated_alias_no_usage_error() { + let mut ws = VirtualWorkspace::new(); + + assert!(ws.has_no_diagnostic( + DiagnosticCode::AnnotationUsageError, + r#" + ---@deprecated test + ---@alias std.ConstTpl unknown + "# + )); + } + + #[test] + fn test_deprecated_alias_attaches_to_type_decl() { + assert_type_decl_deprecated( + r#" + ---@deprecated test + ---@alias ConstTpl unknown + "#, + "ConstTpl", + ); + } + + #[test] + fn test_deprecated_alias_after_alias_attaches_to_type_decl() { + assert_type_decl_deprecated( + r#" + ---@alias ConstTpl unknown + ---@deprecated test + "#, + "ConstTpl", + ); + } + + #[test] + fn test_deprecated_class_no_usage_error() { + let mut ws = VirtualWorkspace::new(); + + assert!(ws.has_no_diagnostic( + DiagnosticCode::AnnotationUsageError, + r#" + ---@deprecated test + ---@class Foo + "# + )); + } + + #[test] + fn test_deprecated_class_attaches_to_type_decl() { + assert_type_decl_deprecated( + r#" + ---@deprecated test + ---@class Foo + local Foo = {} + "#, + "Foo", + ); + } + + #[test] + fn test_deprecated_class_usage_diagnostic() { + let mut ws = VirtualWorkspace::new(); + + assert!(!ws.has_no_diagnostic( + DiagnosticCode::Deprecated, + r#" + ---@deprecated test + ---@class Foo + local Foo = {} + + local x = Foo + "# + )); + } + + #[test] + fn test_deprecated_class_type_annotation_diagnostic() { + let mut ws = VirtualWorkspace::new(); + + assert!(!ws.has_no_diagnostic( + DiagnosticCode::Deprecated, + r#" + ---@deprecated + ---@class A + + ---@type A + local a + "# + )); + } + + #[test] + fn test_deprecated_class_param_annotation_diagnostic() { + let mut ws = VirtualWorkspace::new(); + + assert!(!ws.has_no_diagnostic( + DiagnosticCode::Deprecated, + r#" + ---@deprecated + ---@class A + + ---@param a A + local function f(a) + end + "# + )); + } + + #[test] + fn test_deprecated_class_after_class_attaches_to_decl() { + assert_lua_decl_deprecated( + r#" + ---@class Foo + ---@deprecated test + local Foo = {} + "#, + "Foo", + ); + } +} diff --git a/crates/emmylua_code_analysis/src/diagnostic/test/missing_parameter_test.rs b/crates/emmylua_code_analysis/src/diagnostic/test/missing_parameter_test.rs index 501fd975d..9229e4727 100644 --- a/crates/emmylua_code_analysis/src/diagnostic/test/missing_parameter_test.rs +++ b/crates/emmylua_code_analysis/src/diagnostic/test/missing_parameter_test.rs @@ -41,6 +41,24 @@ mod test { )); } + #[test] + fn test_overload_param_count_gap_reports_missing_parameter() { + let mut ws = VirtualWorkspace::new(); + + assert!(!ws.has_no_diagnostic( + DiagnosticCode::MissingParameter, + r#" + ---@class Callable + ---@overload fun(a: string) + ---@overload fun(a: string, b: string, c: string) + ---@type Callable + local callable + + callable("a", "b") + "# + )); + } + #[test] fn test_1() { let mut ws = VirtualWorkspace::new(); diff --git a/crates/emmylua_code_analysis/src/diagnostic/test/mod.rs b/crates/emmylua_code_analysis/src/diagnostic/test/mod.rs index 97bcd11c9..ad4e896c1 100644 --- a/crates/emmylua_code_analysis/src/diagnostic/test/mod.rs +++ b/crates/emmylua_code_analysis/src/diagnostic/test/mod.rs @@ -5,6 +5,7 @@ mod call_non_callable_test; mod cast_type_mismatch_test; mod check_return_count_test; mod code_style; +mod deprecated_test; mod disable_line_test; mod duplicate_field_test; mod duplicate_index_test; diff --git a/crates/emmylua_code_analysis/src/diagnostic/test/param_type_check_test.rs b/crates/emmylua_code_analysis/src/diagnostic/test/param_type_check_test.rs index 5ef80e9c4..bcb04c85b 100644 --- a/crates/emmylua_code_analysis/src/diagnostic/test/param_type_check_test.rs +++ b/crates/emmylua_code_analysis/src/diagnostic/test/param_type_check_test.rs @@ -2,8 +2,50 @@ mod test { use std::{ops::Deref, sync::Arc}; + use lsp_types::{Diagnostic, NumberOrString}; + use tokio_util::sync::CancellationToken; + use crate::{DiagnosticCode, VirtualWorkspace}; + fn param_type_diagnostics(ws: &mut VirtualWorkspace, block_str: &str) -> Vec { + ws.analysis + .diagnostic + .enable_only(DiagnosticCode::ParamTypeMismatch); + let file_id = ws.def(block_str); + let code = Some(NumberOrString::String( + DiagnosticCode::ParamTypeMismatch.get_name().to_string(), + )); + ws.analysis + .diagnose_file(file_id, CancellationToken::new()) + .unwrap_or_default() + .into_iter() + .filter(|diagnostic| diagnostic.code == code) + .collect() + } + + #[test] + fn test_param_type_mismatch_still_runs_when_count_diagnostics_disabled() { + let mut ws = VirtualWorkspace::new(); + let diagnostics = param_type_diagnostics( + &mut ws, + r#" + ---@param a string + ---@param b string + local function test(a, b) + end + + test(1) + "#, + ); + + assert_eq!(diagnostics.len(), 1); + assert!( + diagnostics[0].message.contains("string"), + "{}", + diagnostics[0].message + ); + } + #[test] fn test_issue_216() { let mut ws = VirtualWorkspace::new(); @@ -41,6 +83,27 @@ mod test { )); } + #[test] + fn test_overload_param_type_mismatch_unions_failed_position() { + let mut ws = VirtualWorkspace::new(); + let diagnostics = param_type_diagnostics( + &mut ws, + r#" + ---@type fun(name: "游戏-初始化") | fun(name: "游戏-开始") + local event + local bad ---@type boolean + + event(bad) + "#, + ); + + assert_eq!(diagnostics.len(), 1); + let message = &diagnostics[0].message; + assert!(message.contains("boolean"), "{message}"); + assert!(message.contains("游戏-初始化"), "{message}"); + assert!(message.contains("游戏-开始"), "{message}"); + } + #[test] fn test_issue_75() { let mut ws = VirtualWorkspace::new_with_init_std_lib(); @@ -825,8 +888,9 @@ mod test { ---@class (partial) D21.A ---@field event fun(self: self, event: "游戏-初始化") + ---@field event fun(self: self, event: "游戏-开始") - ---@param p string + ---@param p boolean local function test(p) M:event(p) end diff --git a/crates/emmylua_code_analysis/src/diagnostic/test/redundant_parameter_test.rs b/crates/emmylua_code_analysis/src/diagnostic/test/redundant_parameter_test.rs index f2b5c0027..284599892 100644 --- a/crates/emmylua_code_analysis/src/diagnostic/test/redundant_parameter_test.rs +++ b/crates/emmylua_code_analysis/src/diagnostic/test/redundant_parameter_test.rs @@ -108,20 +108,19 @@ mod test { #[test] fn test_issue_360() { let mut ws = VirtualWorkspace::new(); + let source = r#" + ---@alias buz number - assert!(!ws.has_no_diagnostic( - DiagnosticCode::RedundantParameter, - r#" - ---@alias buz number + ---@param a buz + ---@overload fun(): number + function test(a) + end - ---@param a buz - ---@overload fun(): number - function test(a) - end + local c = test({'test'}) + "#; - local c = test({'test'}) - "# - )); + assert!(ws.has_no_diagnostic(DiagnosticCode::RedundantParameter, source)); + assert!(!ws.has_no_diagnostic(DiagnosticCode::ParamTypeMismatch, source)); } #[test] @@ -130,16 +129,11 @@ mod test { assert!(!ws.has_no_diagnostic( DiagnosticCode::RedundantParameter, r#" - ---@class D30 - local M = {} - ---@param callback fun() local function with_local(callback) end - function M:add_local_event() - with_local(function(local_player) end) - end + with_local(function(local_player) end) "# )); } diff --git a/crates/emmylua_code_analysis/src/semantic/cache/cache_options.rs b/crates/emmylua_code_analysis/src/semantic/cache/cache_options.rs index 124422ffe..9aa6a1416 100644 --- a/crates/emmylua_code_analysis/src/semantic/cache/cache_options.rs +++ b/crates/emmylua_code_analysis/src/semantic/cache/cache_options.rs @@ -1,4 +1,4 @@ -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct CacheOptions { pub analysis_phase: LuaAnalysisPhase, } @@ -11,7 +11,7 @@ impl Default for CacheOptions { } } -#[derive(Debug)] +#[derive(Debug, Clone, Copy)] pub enum LuaAnalysisPhase { // Ordered phase Ordered, diff --git a/crates/emmylua_code_analysis/src/semantic/cache/mod.rs b/crates/emmylua_code_analysis/src/semantic/cache/mod.rs index c6dc256fa..8df24d384 100644 --- a/crates/emmylua_code_analysis/src/semantic/cache/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/cache/mod.rs @@ -94,6 +94,12 @@ impl LuaInferCache { self.file_id } + pub(in crate::semantic) fn fork_for_file(&self, file_id: FileId) -> Self { + let mut cache = Self::new(file_id, self.config.clone()); + cache.no_flow_mode = self.no_flow_mode; + cache + } + pub(in crate::semantic) fn is_no_flow(&self) -> bool { self.no_flow_mode } diff --git a/crates/emmylua_code_analysis/src/semantic/generic/call_constraint.rs b/crates/emmylua_code_analysis/src/semantic/generic/call_constraint.rs index 1ad1c18ef..59d1d1f01 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/call_constraint.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/call_constraint.rs @@ -1,15 +1,16 @@ use std::{ops::Deref, sync::Arc}; -use emmylua_parser::{LuaAstNode, LuaAstToken, LuaCallExpr, LuaExpr, LuaIndexExpr}; +use emmylua_parser::{LuaAstNode, LuaAstToken, LuaCallExpr, LuaExpr}; use hashbrown::HashSet; use rowan::TextRange; use crate::{ - DbIndex, DocTypeInferContext, GenericTpl, GenericTplId, LuaFunctionType, LuaSemanticDeclId, - LuaType, LuaTypeNode, SemanticDeclLevel, SemanticModel, TypeOps, TypeSubstitutor, VariadicType, - infer_doc_type, + DbIndex, DocTypeInferContext, GenericTplId, LuaFunctionType, LuaType, SemanticModel, TypeOps, + TypeSubstitutor, VariadicType, infer_doc_type, }; +use super::{TplContext, tpl_pattern_match_args}; + // 泛型约束上下文 pub struct CallConstraintContext { pub params: Vec<(String, Option)>, @@ -31,7 +32,12 @@ pub fn build_call_constraint_context( let mut params = doc_func.get_params().to_vec(); let mut args = get_arg_infos(semantic_model, call_expr)?; let mut substitutor = TypeSubstitutor::new(); - let generic_tpls = collect_func_tpl_ids(¶ms); + let generic_tpls = doc_func + .get_generic_params() + .iter() + .map(|generic_tpl| generic_tpl.get_tpl_id()) + .filter(GenericTplId::is_func) + .collect::>(); if !generic_tpls.is_empty() { substitutor.add_need_infer_tpls(generic_tpls); } @@ -53,7 +59,7 @@ pub fn build_call_constraint_context( params.insert(0, ("self".into(), Some(LuaType::SelfInfer))); } (true, false) => { - let source_type = infer_call_source_type(semantic_model, call_expr)?; + let source_type = semantic_model.infer_call_receiver_type(call_expr)?; args.insert( 0, CallConstraintArg { @@ -65,7 +71,22 @@ pub fn build_call_constraint_context( } } - collect_generic_assignments(&mut substitutor, ¶ms, &args); + // 使用模式匹配推导泛型 + let mut cache = semantic_model.get_cache().borrow_mut(); + let mut context = TplContext { + db: semantic_model.get_db(), + cache: &mut cache, + substitutor: &mut substitutor, + call_expr: Some(call_expr.clone()), + }; + + let param_types: Vec = params + .iter() + .map(|(_, ty)| ty.clone().unwrap_or(LuaType::Unknown)) + .collect(); + let arg_types: Vec = args.iter().map(|arg| arg.check_type.clone()).collect(); + + let _ = tpl_pattern_match_args(&mut context, ¶m_types, &arg_types); Some(CallConstraintContext { params, @@ -82,275 +103,7 @@ pub fn normalize_constraint_type(db: &DbIndex, ty: LuaType) -> LuaType { } } -// 收集各个参数对应的泛型推导 -fn collect_generic_assignments( - substitutor: &mut TypeSubstitutor, - params: &[(String, Option)], - args: &[CallConstraintArg], -) { - for (idx, (_, param_type)) in params.iter().enumerate() { - let Some(param_type) = param_type else { - continue; - }; - let Some(arg) = args.get(idx) else { - continue; - }; - record_generic_assignment(param_type, &arg.check_type, substitutor); - } -} - -fn collect_func_tpl_ids(params: &[(String, Option)]) -> HashSet { - let mut generic_tpls = HashSet::new(); - for (_, param_type) in params { - let Some(param_type) = param_type else { - continue; - }; - collect_func_tpls_from_param_type(param_type, &mut generic_tpls); - } - - generic_tpls -} - -fn collect_func_tpls_from_param_type(ty: &LuaType, generic_tpls: &mut HashSet) { - collect_func_tpl_from_param_node(ty, generic_tpls); - ty.visit_nested_types(&mut |ty| { - collect_func_tpl_from_param_node(ty, generic_tpls); - }); -} - -fn collect_func_tpl_from_param_node(ty: &LuaType, generic_tpls: &mut HashSet) { - match ty { - LuaType::TplRef(generic_tpl) | LuaType::ConstTplRef(generic_tpl) => { - collect_func_tpl_with_fallback_deps(generic_tpl, generic_tpls); - } - LuaType::StrTplRef(str_tpl) => { - let tpl_id = str_tpl.get_tpl_id(); - if tpl_id.is_func() { - generic_tpls.insert(tpl_id); - if let Some(constraint) = str_tpl.get_constraint() { - let mut constraint_deps = HashSet::new(); - if collect_func_tpl_deps_from_fallback_type( - constraint, - &mut constraint_deps, - &mut HashSet::new(), - ) { - generic_tpls.extend(constraint_deps); - } - } - } - } - _ => {} - } -} - -fn collect_func_tpl_with_fallback_deps( - generic_tpl: &GenericTpl, - generic_tpls: &mut HashSet, -) { - let tpl_id = generic_tpl.get_tpl_id(); - if !tpl_id.is_func() { - return; - } - - generic_tpls.insert(tpl_id); - - let Some(fallback_type) = generic_tpl - .get_default_type() - .or(generic_tpl.get_constraint()) - else { - return; - }; - - let mut fallback_deps = HashSet::new(); - let mut visiting_fallbacks = HashSet::new(); - visiting_fallbacks.insert(tpl_id); - if collect_func_tpl_deps_from_fallback_type( - fallback_type, - &mut fallback_deps, - &mut visiting_fallbacks, - ) { - generic_tpls.extend(fallback_deps); - } -} - -fn collect_func_tpl_deps_from_fallback_type( - ty: &LuaType, - generic_tpls: &mut HashSet, - visiting_fallbacks: &mut HashSet, -) -> bool { - let mut no_fallback_cycle = - collect_func_tpl_dep_from_fallback_type(ty, generic_tpls, visiting_fallbacks); - ty.visit_nested_types(&mut |ty| { - no_fallback_cycle &= - collect_func_tpl_dep_from_fallback_type(ty, generic_tpls, visiting_fallbacks); - }); - no_fallback_cycle -} - -fn collect_func_tpl_dep_from_fallback_type( - ty: &LuaType, - generic_tpls: &mut HashSet, - visiting_fallbacks: &mut HashSet, -) -> bool { - match ty { - LuaType::TplRef(generic_tpl) | LuaType::ConstTplRef(generic_tpl) => { - collect_generic_tpl_from_fallback(generic_tpl, generic_tpls, visiting_fallbacks) - } - LuaType::StrTplRef(str_tpl) => { - let tpl_id = str_tpl.get_tpl_id(); - if !tpl_id.is_func() { - return true; - } - - if !visiting_fallbacks.insert(tpl_id) { - return false; - } - - generic_tpls.insert(tpl_id); - let no_fallback_cycle = match str_tpl.get_constraint() { - Some(constraint) => collect_func_tpl_deps_from_fallback_type( - constraint, - generic_tpls, - visiting_fallbacks, - ), - None => true, - }; - visiting_fallbacks.remove(&tpl_id); - no_fallback_cycle - } - _ => true, - } -} - -fn collect_generic_tpl_from_fallback( - generic_tpl: &GenericTpl, - generic_tpls: &mut HashSet, - visiting_fallbacks: &mut HashSet, -) -> bool { - let tpl_id = generic_tpl.get_tpl_id(); - if !tpl_id.is_func() { - return true; - } - - if !visiting_fallbacks.insert(tpl_id) { - return false; - } - - generic_tpls.insert(tpl_id); - let no_fallback_cycle = match generic_tpl - .get_default_type() - .or(generic_tpl.get_constraint()) - { - Some(fallback_type) => collect_func_tpl_deps_from_fallback_type( - fallback_type, - generic_tpls, - visiting_fallbacks, - ), - None => true, - }; - visiting_fallbacks.remove(&tpl_id); - no_fallback_cycle -} - -// 实际写入泛型替换表 -fn record_generic_assignment( - param_type: &LuaType, - arg_type: &LuaType, - substitutor: &mut TypeSubstitutor, -) { - match param_type { - LuaType::TplRef(tpl_ref) => { - if !tpl_ref.get_tpl_id().is_conditional_infer() { - substitutor.insert_type(tpl_ref.get_tpl_id(), arg_type.clone(), true); - } - } - LuaType::ConstTplRef(tpl_ref) => { - if !tpl_ref.get_tpl_id().is_conditional_infer() { - substitutor.insert_type(tpl_ref.get_tpl_id(), arg_type.clone(), false); - } - } - LuaType::StrTplRef(str_tpl_ref) => { - substitutor.insert_type(str_tpl_ref.get_tpl_id(), arg_type.clone(), true); - } - LuaType::Variadic(variadic) => { - if let Some(inner) = variadic.get_type(0) { - record_generic_assignment(inner, arg_type, substitutor); - } - } - _ => {} - } -} - -// 解析冒号调用时调用者的具体类型 -fn infer_call_source_type( - semantic_model: &SemanticModel, - call_expr: &LuaCallExpr, -) -> Option { - match call_expr.get_prefix_expr()? { - LuaExpr::IndexExpr(index_expr) => { - let decl = semantic_model.find_decl( - index_expr.syntax().clone().into(), - SemanticDeclLevel::default(), - )?; - - if let LuaSemanticDeclId::Member(member_id) = decl - && let Some(LuaSemanticDeclId::Member(member_id)) = - semantic_model.get_member_origin_owner(member_id) - { - let root = semantic_model - .get_db() - .get_vfs() - .get_syntax_tree(&member_id.file_id)? - .get_red_root(); - let cur_node = member_id.get_syntax_id().to_node_from_root(&root)?; - let index_expr = LuaIndexExpr::cast(cur_node)?; - - return index_expr.get_prefix_expr().map(|prefix_expr| { - semantic_model - .infer_expr(prefix_expr.clone()) - .unwrap_or(LuaType::SelfInfer) - }); - } - - return if let Some(prefix_expr) = index_expr.get_prefix_expr() { - let expr_type = semantic_model - .infer_expr(prefix_expr.clone()) - .unwrap_or(LuaType::SelfInfer); - Some(expr_type) - } else { - None - }; - } - LuaExpr::NameExpr(name_expr) => { - let decl = semantic_model.find_decl( - name_expr.syntax().clone().into(), - SemanticDeclLevel::default(), - )?; - if let LuaSemanticDeclId::Member(member_id) = decl { - let root = semantic_model - .get_db() - .get_vfs() - .get_syntax_tree(&member_id.file_id)? - .get_red_root(); - let cur_node = member_id.get_syntax_id().to_node_from_root(&root)?; - let index_expr = LuaIndexExpr::cast(cur_node)?; - - return index_expr.get_prefix_expr().map(|prefix_expr| { - semantic_model - .infer_expr(prefix_expr.clone()) - .unwrap_or(LuaType::SelfInfer) - }); - } - - return None; - } - _ => {} - } - - None -} - -// 推导每个实参类型 +// 推推导每个实参类型 fn get_arg_infos( semantic_model: &SemanticModel, call_expr: &LuaCallExpr, @@ -407,9 +160,7 @@ fn get_constraint_type( depth: usize, ) -> Option { match arg_type { - LuaType::TplRef(tpl_ref) | LuaType::ConstTplRef(tpl_ref) => { - tpl_ref.get_constraint().cloned() - } + LuaType::TplRef(tpl_ref) => tpl_ref.get_constraint().cloned(), LuaType::StrTplRef(str_tpl_ref) => str_tpl_ref.get_constraint().cloned(), LuaType::Union(union_type) => { if depth > 1 { @@ -453,40 +204,3 @@ fn infer_expr_list_types( } value_types } - -#[cfg(test)] -mod tests { - use std::sync::Arc; - - use hashbrown::HashSet; - use smol_str::SmolStr; - - use super::*; - - fn func_tpl(idx: u32, default_type: Option) -> Arc { - Arc::new(GenericTpl::new( - GenericTplId::Func(idx), - SmolStr::new(format!("T{}", idx)).into(), - None, - default_type, - )) - } - - #[test] - fn test_collect_func_tpl_with_fallback_deps_skips_cyclic_fallback_deps() { - let t0 = func_tpl(0, None); - let t1 = func_tpl(1, Some(LuaType::TplRef(t0.clone()))); - let t0 = GenericTpl::new( - GenericTplId::Func(0), - SmolStr::new("T0").into(), - None, - Some(LuaType::TplRef(t1)), - ); - - let mut generic_tpls = HashSet::new(); - collect_func_tpl_with_fallback_deps(&t0, &mut generic_tpls); - - assert!(generic_tpls.contains(&GenericTplId::Func(0))); - assert!(!generic_tpls.contains(&GenericTplId::Func(1))); - } -} diff --git a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_func_generic.rs b/crates/emmylua_code_analysis/src/semantic/generic/infer_call_generic.rs similarity index 73% rename from crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_func_generic.rs rename to crates/emmylua_code_analysis/src/semantic/generic/infer_call_generic.rs index 4a8eaff7c..7110111ed 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_func_generic.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/infer_call_generic.rs @@ -12,7 +12,6 @@ use crate::{ semantic::{ LuaInferCache, generic::{ - instantiate_type::instantiate_doc_function, tpl_context::TplContext, tpl_pattern::{ multi_param_tpl_pattern_match_multi_return, return_type_pattern_match_target_type, @@ -25,33 +24,21 @@ use crate::{ }, }; use crate::{ - GenericTpl, LuaMemberOwner, LuaSemanticDeclId, LuaTypeOwner, SemanticDeclLevel, TypeVisitTrait, + LuaMemberOwner, LuaSemanticDeclId, LuaTypeOwner, SemanticDeclLevel, TypeVisitTrait, collect_callable_overload_groups, infer_node_semantic_decl, tpl_pattern_match_args_skip_unknown, }; -use super::{TypeSubstitutor, instantiate_type_generic}; +use crate::semantic::generic::{TypeSubstitutor, instantiate_type::instantiate_type_generic}; -pub fn instantiate_func_generic( +pub fn infer_call_generic( db: &DbIndex, cache: &mut LuaInferCache, func: &LuaFunctionType, call_expr: LuaCallExpr, ) -> Result { let file_id = cache.get_file_id().clone(); - let (generic_tpls, contain_self) = collect_func_tpl_ids(func); - let origin_params = func.get_params(); - let mut func_params: Vec<_> = origin_params - .iter() - .map(|(name, t)| (name.clone(), t.clone().unwrap_or(LuaType::Unknown))) - .collect(); - - let arg_exprs = call_expr - .get_args_list() - .ok_or(InferFailReason::None)? - .get_args() - .collect::>(); let mut substitutor = TypeSubstitutor::new(); let mut context = TplContext { db, @@ -59,7 +46,20 @@ pub fn instantiate_func_generic( substitutor: &mut substitutor, call_expr: Some(call_expr.clone()), }; - if !generic_tpls.is_empty() { + // 填充前缀类型可能存在的泛型 + fill_call_prefix_substitutor(&mut context, &call_expr); + + let has_func_generic = func + .get_generic_params() + .iter() + .any(|generic_tpl| generic_tpl.get_tpl_id().is_func()); + if has_func_generic { + let generic_tpls = func + .get_generic_params() + .iter() + .map(|generic_tpl| generic_tpl.get_tpl_id()) + .filter(GenericTplId::is_func) + .collect::>(); context.substitutor.add_need_infer_tpls(generic_tpls); if let Some(type_list) = call_expr.get_call_generic_type_list() { @@ -67,22 +67,22 @@ pub fn instantiate_func_generic( apply_call_generic_type_list(db, file_id, &mut context, &type_list); } else { // 如果没有指定泛型, 则需要从调用参数中推断 - infer_generic_types_from_call( - db, - &mut context, - func, - &call_expr, - &mut func_params, - &arg_exprs, - )?; + let origin_params = func.get_params(); + let mut func_params: Vec = origin_params + .iter() + .map(|(_, t)| t.clone().unwrap_or(LuaType::Unknown)) + .collect(); + infer_generic_types_from_call(db, &mut context, func, &call_expr, &mut func_params)?; } } + let contain_self = func.any_nested_type(|ty| matches!(ty, LuaType::SelfInfer)); if contain_self && let Some(self_type) = infer_self_type(db, cache, &call_expr) { substitutor.add_self_type(self_type); } - if let LuaType::DocFunction(f) = instantiate_doc_function(db, func, &substitutor) { + let func_type = LuaType::DocFunction(func.clone().into()); + if let LuaType::DocFunction(f) = instantiate_type_generic(db, &func_type, &substitutor) { Ok(f.deref().clone()) } else { Ok(func.clone()) @@ -249,16 +249,21 @@ fn instantiate_callable_from_arg_types( return None; } - let mut callable_tpls = HashSet::new(); - callable.visit_nested_types(&mut |ty| { - if let LuaType::TplRef(generic_tpl) | LuaType::ConstTplRef(generic_tpl) = ty { - callable_tpls.insert(generic_tpl.get_tpl_id()); - } - }); - if callable_tpls.is_empty() { + let has_callable_tpls = callable + .get_generic_params() + .iter() + .any(|generic_tpl| generic_tpl.get_tpl_id().is_func()); + if !has_callable_tpls { return Some(callable.clone()); } + let callable_tpls = callable + .get_generic_params() + .iter() + .map(|generic_tpl| generic_tpl.get_tpl_id()) + .filter(GenericTplId::is_func) + .collect::>(); + let callable_param_types = callable .get_params() .iter() @@ -282,14 +287,16 @@ fn instantiate_callable_from_arg_types( return None; } - let instantiated = match instantiate_doc_function(context.db, callable, &callable_substitutor) { - LuaType::DocFunction(func) => func, - _ => callable.clone(), - }; + let callable_type = LuaType::DocFunction(callable.clone()); + let instantiated = + match instantiate_type_generic(context.db, &callable_type, &callable_substitutor) { + LuaType::DocFunction(func) => func, + _ => callable.clone(), + }; let unresolved_return_tpls = { let mut tpl_ids = HashSet::new(); instantiated.get_ret().visit_type(&mut |ty| { - if let LuaType::TplRef(generic_tpl) | LuaType::ConstTplRef(generic_tpl) = ty + if let LuaType::TplRef(generic_tpl) = ty && callable_tpls.contains(&generic_tpl.get_tpl_id()) { tpl_ids.insert(generic_tpl.get_tpl_id()); @@ -314,7 +321,7 @@ fn instantiate_callable_from_arg_types( for tpl_id in callback_return_tpls { callable_substitutor.insert_type(tpl_id, LuaType::Unknown, true); } - match instantiate_doc_function(context.db, callable, &callable_substitutor) { + match instantiate_type_generic(context.db, &callable_type, &callable_substitutor) { LuaType::DocFunction(func) => Some(func), _ => None, } @@ -360,7 +367,7 @@ fn collect_callback_return_tpls( continue; }; param_func.get_ret().visit_type(&mut |ty| { - if let LuaType::TplRef(generic_tpl) | LuaType::ConstTplRef(generic_tpl) = ty { + if let LuaType::TplRef(generic_tpl) = ty { let tpl_id = generic_tpl.get_tpl_id(); if unresolved_return_tpls.contains(&tpl_id) { callback_return_tpls.insert(tpl_id); @@ -372,120 +379,18 @@ fn collect_callback_return_tpls( callback_return_tpls } -fn collect_func_tpl_ids(func: &LuaFunctionType) -> (HashSet, bool) { - let mut generic_tpls = HashSet::new(); - let mut contain_self = false; - - func.visit_nested_types(&mut |ty| match ty { - LuaType::TplRef(generic_tpl) | LuaType::ConstTplRef(generic_tpl) => { - collect_func_tpl_with_fallback_deps(generic_tpl, &mut generic_tpls); - } - LuaType::StrTplRef(str_tpl) => { - generic_tpls.insert(str_tpl.get_tpl_id()); - } - LuaType::SelfInfer => contain_self = true, - _ => {} - }); - - (generic_tpls, contain_self) -} - -fn collect_func_tpl_with_fallback_deps( - generic_tpl: &GenericTpl, - generic_tpls: &mut HashSet, -) { - let tpl_id = generic_tpl.get_tpl_id(); - if !tpl_id.is_func() { - return; - } - - generic_tpls.insert(tpl_id); - - let Some(fallback_type) = generic_tpl - .get_default_type() - .or(generic_tpl.get_constraint()) - else { - return; - }; - - // 只有提前加入的泛型才有 None 占位, fallback 展开时才能继续使用它自己的 default/constraint. - // 例如 `U = T[]` 或 `U: T[]` 中, 即使函数返回值只直接引用了 `U`, 也需要把 `T` 一并加入. - let mut fallback_deps = HashSet::new(); - let mut visiting_fallbacks = HashSet::new(); - visiting_fallbacks.insert(tpl_id); - if collect_func_tpl_deps_from_fallback_type( - fallback_type, - &mut fallback_deps, - &mut visiting_fallbacks, - ) { - generic_tpls.extend(fallback_deps); - } -} - -fn collect_func_tpl_deps_from_fallback_type( - ty: &LuaType, - generic_tpls: &mut HashSet, - visiting_fallbacks: &mut HashSet, -) -> bool { - // 返回 false 表示 fallback 依赖链里发现循环. - // visit_nested_types 只访问子节点, 所以这里先处理类型自身, 再处理嵌套类型. - let mut no_fallback_cycle = - collect_func_tpl_dep_from_fallback_type(ty, generic_tpls, visiting_fallbacks); - ty.visit_nested_types(&mut |ty| { - no_fallback_cycle &= - collect_func_tpl_dep_from_fallback_type(ty, generic_tpls, visiting_fallbacks); - }); - no_fallback_cycle -} - -fn collect_func_tpl_dep_from_fallback_type( - ty: &LuaType, - generic_tpls: &mut HashSet, - visiting_fallbacks: &mut HashSet, -) -> bool { - let (LuaType::TplRef(generic_tpl) | LuaType::ConstTplRef(generic_tpl)) = ty else { - return true; - }; - - if !generic_tpl.get_tpl_id().is_func() { - return true; - } - - let tpl_id = generic_tpl.get_tpl_id(); - if !visiting_fallbacks.insert(tpl_id) { - // 遇到 `T = U, U = T` 这类循环 fallback 时, 放弃合并本轮依赖避免递归展开. - return false; - } - - generic_tpls.insert(tpl_id); - let no_fallback_cycle = match generic_tpl - .get_default_type() - .or(generic_tpl.get_constraint()) - { - Some(fallback_type) => collect_func_tpl_deps_from_fallback_type( - fallback_type, - generic_tpls, - visiting_fallbacks, - ), - None => true, - }; - visiting_fallbacks.remove(&tpl_id); - no_fallback_cycle -} - fn infer_generic_types_from_call( db: &DbIndex, context: &mut TplContext, func: &LuaFunctionType, call_expr: &LuaCallExpr, - func_params: &mut Vec<(String, LuaType)>, - arg_exprs: &[LuaExpr], + func_params: &mut Vec, ) -> Result<(), InferFailReason> { let colon_call = call_expr.is_colon_call(); let colon_define = func.is_colon_define(); match (colon_define, colon_call) { (true, false) => { - func_params.insert(0, ("self".to_string(), LuaType::Any)); + func_params.insert(0, LuaType::Any); } (false, true) => { if !func_params.is_empty() { @@ -496,9 +401,14 @@ fn infer_generic_types_from_call( } let mut unresolve_tpls = vec![]; + let arg_exprs = call_expr + .get_args_list() + .ok_or(InferFailReason::None)? + .get_args() + .collect::>(); for i in 0..func_params.len() { if i >= arg_exprs.len() { - if let LuaType::Variadic(variadic) = &func_params[i].1 { + if let LuaType::Variadic(variadic) = &func_params[i] { variadic_tpl_pattern_match(context, variadic, &[])?; } break; @@ -508,14 +418,16 @@ fn infer_generic_types_from_call( break; } - let (_, func_param_type) = &func_params[i]; + let func_param_type = &func_params[i]; let call_arg_expr = &arg_exprs[i]; if !func_param_type.contains_tpl_node() { continue; } + let doc_param_func = as_doc_function_type(db, func_param_type)?; + if !func_param_type.is_variadic() - && check_expr_can_later_infer(context, func_param_type, call_arg_expr)? + && check_expr_can_later_infer_with_doc_func(doc_param_func.as_deref(), call_arg_expr) { // 如果参数不能被后续推断, 那么我们先不处理 unresolve_tpls.push((func_param_type.clone(), call_arg_expr.clone())); @@ -528,19 +440,18 @@ fn infer_generic_types_from_call( Err(e) => return Err(e), }; - if let Some(return_pattern) = - as_doc_function_type(context.db, func_param_type)?.map(|func| func.get_ret().clone()) - { + if let Some(doc_func) = &doc_param_func { + let return_pattern = doc_func.get_ret(); if let Some(inferred_return_type) = infer_callable_return_from_remaining_args(context, &arg_type, &arg_exprs[i + 1..])? { return_type_pattern_match_target_type( context, - &return_pattern, + return_pattern, &inferred_return_type, )?; } else if arg_type.is_any() || arg_type.is_unknown() { - return_type_pattern_match_target_type(context, &return_pattern, &LuaType::Unknown)?; + return_type_pattern_match_target_type(context, return_pattern, &LuaType::Unknown)?; } } @@ -555,11 +466,7 @@ fn infer_generic_types_from_call( break; } (_, LuaType::Variadic(variadic)) => { - let func_param_types = func_params[i..] - .iter() - .map(|(_, t)| t) - .cloned() - .collect::>(); + let func_param_types = func_params[i..].to_vec(); multi_param_tpl_pattern_match_multi_return(context, &func_param_types, variadic)?; break; } @@ -607,9 +514,9 @@ fn build_self_generic_arg( substitutor: &TypeSubstitutor, ) -> LuaType { let Some(arg) = generic_param - .default_type + .default .as_ref() - .or(generic_param.type_constraint.as_ref()) + .or(generic_param.constraint.as_ref()) else { return LuaType::Unknown; }; @@ -666,30 +573,40 @@ pub fn infer_self_type( None } -fn check_expr_can_later_infer( - context: &mut TplContext, - func_param_type: &LuaType, +fn check_expr_can_later_infer_with_doc_func( + doc_function: Option<&LuaFunctionType>, call_arg_expr: &LuaExpr, -) -> Result { - let Some(doc_function) = as_doc_function_type(context.db, func_param_type)? else { - return Ok(false); +) -> bool { + let Some(doc_function) = doc_function else { + return false; }; if let LuaExpr::ClosureExpr(_) = call_arg_expr { - return Ok(true); + return true; } let doc_params = doc_function.get_params(); let variadic_count = doc_params .iter() - .filter_map(|(_, t)| { - if let Some(LuaType::Variadic(_)) = t { - Some(()) - } else { - None - } - }) + .filter(|(_, t)| matches!(t, Some(LuaType::Variadic(_)))) .count(); - Ok(variadic_count > 1) + variadic_count > 1 +} + +fn fill_call_prefix_substitutor(context: &mut TplContext, call_expr: &LuaCallExpr) -> Option<()> { + let prefix_expr = call_expr.get_prefix_expr()?; + if let LuaExpr::IndexExpr(index_expr) = prefix_expr { + let self_expr = index_expr.get_prefix_expr()?; + let self_type = infer_expr(context.db, context.cache, self_expr).ok()?; + if let LuaType::Generic(generic) = self_type { + for (i, param) in generic.get_params().iter().enumerate() { + context + .substitutor + .insert_type(GenericTplId::Type(i as u32), param.clone(), true); + } + return Some(()); + } + } + None } diff --git a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/complete_generic_args.rs b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/complete_generic_args.rs index 8ac1a2644..6ce777dad 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/complete_generic_args.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/complete_generic_args.rs @@ -1,7 +1,7 @@ use hashbrown::HashSet; use crate::{ - DbIndex, GenericParam, GenericTplId, LuaAliasCallType, LuaArrayType, LuaAttributeType, + DbIndex, GenericParam, GenericTpl, GenericTplId, LuaAliasCallType, LuaArrayType, LuaConditionalType, LuaMappedType, LuaMultiLineUnion, LuaTypeDeclId, db_index::{ LuaFunctionType, LuaGenericType, LuaIntersectionType, LuaObjectType, LuaTupleType, LuaType, @@ -101,7 +101,7 @@ fn complete_type_generic_args_inner( continue; } - if let Some(default_type) = &generic_param.default_type { + if let Some(default_type) = &generic_param.default { if missing_required_count != 0 { continue; } @@ -223,7 +223,6 @@ fn complete_type_generic_args_in_type_inner( let guard = complete_type_generic_args_in_type_inner(db, guard, visiting); CompletedType::new(LuaType::TypeGuard(guard.ty.into()), guard.cycled) } - LuaType::DocAttribute(attribute) => complete_attribute_type(db, attribute, visiting), LuaType::Conditional(conditional) => complete_conditional_type(db, conditional, visiting), LuaType::Mapped(mapped) => complete_mapped_type(db, mapped, visiting), _ => CompletedType::unchanged(ty), @@ -295,6 +294,7 @@ fn complete_doc_function( visiting: &mut HashSet, ) -> CompletedType { let mut cycled = false; + let generic_params = complete_function_generic_params(db, func, visiting, &mut cycled); let params = func .get_params() .iter() @@ -315,6 +315,7 @@ fn complete_doc_function( func.is_variadic(), params, ret.ty, + Some(generic_params), ) .into(), ), @@ -322,6 +323,31 @@ fn complete_doc_function( ) } +fn complete_function_generic_params( + db: &DbIndex, + func: &LuaFunctionType, + visiting: &mut HashSet, + cycled: &mut bool, +) -> Vec { + func.get_generic_params() + .iter() + .map(|generic_tpl| { + let tpl_id = generic_tpl.get_tpl_id(); + let param = generic_tpl.get_param(); + let completed = complete_generic_param(db, param, visiting); + *cycled |= completed.cycled; + GenericTpl::new( + tpl_id, + completed.param.name, + completed.param.constraint, + completed.param.default, + completed.param.is_const, + completed.param.attributes, + ) + }) + .collect() +} + fn complete_object_type( db: &DbIndex, object: &LuaObjectType, @@ -397,29 +423,6 @@ fn complete_multi_line_union( ) } -fn complete_attribute_type( - db: &DbIndex, - attribute: &LuaAttributeType, - visiting: &mut HashSet, -) -> CompletedType { - let mut cycled = false; - let params = attribute - .get_params() - .iter() - .map(|(name, ty)| { - let completed = ty - .as_ref() - .map(|ty| complete_type_generic_args_in_type_inner(db, ty, visiting)); - cycled |= completed.as_ref().is_some_and(|completed| completed.cycled); - (name.clone(), completed.map(|completed| completed.ty)) - }) - .collect(); - CompletedType::new( - LuaType::DocAttribute(LuaAttributeType::new(params).into()), - cycled, - ) -} - fn complete_conditional_type( db: &DbIndex, conditional: &LuaConditionalType, @@ -529,11 +532,11 @@ fn complete_generic_param( visiting: &mut HashSet, ) -> CompletedGenericParam { let constraint = param - .type_constraint + .constraint .as_ref() .map(|ty| complete_type_generic_args_in_type_inner(db, ty, visiting)); let default_type = param - .default_type + .default .as_ref() .map(|ty| complete_type_generic_args_in_type_inner(db, ty, visiting)); let cycled = constraint.as_ref().is_some_and(|ty| ty.cycled) @@ -543,6 +546,7 @@ fn complete_generic_param( param.name.clone(), constraint.map(|ty| ty.ty), default_type.map(|ty| ty.ty), + param.is_const, param.attributes.clone(), ), cycled, diff --git a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_conditional_generic.rs b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_conditional_generic.rs index c0c64bbdf..5ea42ce91 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_conditional_generic.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_conditional_generic.rs @@ -2,13 +2,13 @@ use hashbrown::{HashMap, HashSet}; use std::ops::Deref; use crate::{ - DbIndex, GenericTplId, LuaConditionalType, LuaTypeDeclId, LuaTypeNode, TypeOps, + DbIndex, GenericTpl, GenericTplId, LuaConditionalType, LuaTypeDeclId, LuaTypeNode, TypeOps, check_type_compact, db_index::{LuaObjectType, LuaTupleType, LuaType}, semantic::{member::find_members_with_key, type_check::check_type_compact_with_level}, }; -use super::{get_default_constructor, instantiate_type_generic_with_context}; +use super::{get_default_constructor, instantiate_type_generic_inner}; use crate::semantic::generic::type_substitutor::GenericInstantiateContext; #[derive(Debug, Clone, Copy)] @@ -80,19 +80,18 @@ fn instantiate_conditional_once( finalize_infer_assignments(infer_assignments), ) } else { - instantiate_type_generic_with_context(context, conditional.get_false_type()) + instantiate_type_generic_inner(context, conditional.get_false_type()) }; } match check_conditional_extends(context.db, &left_type, &right_type) { ConditionalCheck::True => instantiate_true_branch(context, conditional, HashMap::new()), ConditionalCheck::False => { - instantiate_type_generic_with_context(context, conditional.get_false_type()) + instantiate_type_generic_inner(context, conditional.get_false_type()) } ConditionalCheck::Both => { let true_type = instantiate_true_branch(context, conditional, HashMap::new()); - let false_type = - instantiate_type_generic_with_context(context, conditional.get_false_type()); + let false_type = instantiate_type_generic_inner(context, conditional.get_false_type()); TypeOps::Union.apply(context.db, &true_type, &false_type) } } @@ -125,9 +124,7 @@ fn instantiate_distributed_conditional( fn naked_checked_type_tpl_id(checked_type: &LuaType) -> Option { match checked_type { - LuaType::TplRef(tpl) | LuaType::ConstTplRef(tpl) if tpl.get_tpl_id().is_type() => { - Some(tpl.get_tpl_id()) - } + LuaType::TplRef(tpl) if tpl.get_tpl_id().is_type() => Some(tpl.get_tpl_id()), _ => None, } } @@ -152,7 +149,7 @@ fn instantiate_true_branch( infer_assignments: HashMap, ) -> LuaType { if infer_assignments.is_empty() { - return instantiate_type_generic_with_context(context, conditional.get_true_type()); + return instantiate_type_generic_inner(context, conditional.get_true_type()); } let mut true_substitutor = context.substitutor.clone(); @@ -160,7 +157,7 @@ fn instantiate_true_branch( true_substitutor.insert_conditional_infer_type(tpl_id, ty); } let true_context = context.with_substitutor(&true_substitutor); - instantiate_type_generic_with_context(&true_context, conditional.get_true_type()) + instantiate_type_generic_inner(&true_context, conditional.get_true_type()) } fn contains_conditional_infer(ty: &LuaType) -> bool { @@ -170,7 +167,7 @@ fn contains_conditional_infer(ty: &LuaType) -> bool { fn conditional_infer_tpl_id(ty: &LuaType) -> bool { matches!( ty, - LuaType::TplRef(tpl) | LuaType::ConstTplRef(tpl) + LuaType::TplRef(tpl) if tpl.get_tpl_id().is_conditional_infer() ) } @@ -257,9 +254,7 @@ fn collect_infer_assignments( variance: InferVariance, ) -> bool { match pattern { - LuaType::TplRef(tpl) | LuaType::ConstTplRef(tpl) - if tpl.get_tpl_id().is_conditional_infer() => - { + LuaType::TplRef(tpl) if tpl.get_tpl_id().is_conditional_infer() => { insert_infer_assignment(db, assignments, tpl.get_tpl_id(), source, variance) } LuaType::Generic(pattern_generic) => { @@ -650,8 +645,8 @@ fn instantiate_conditional_operand( checked: bool, has_new: bool, ) -> LuaType { - let mut result = instantiate_type_generic_with_context(context, operand); - if let LuaType::TplRef(tpl_ref) | LuaType::ConstTplRef(tpl_ref) = operand { + let mut result = instantiate_type_generic_inner(context, operand); + if let LuaType::TplRef(tpl_ref) = operand { let tpl_id = tpl_ref.get_tpl_id(); if let Some(raw) = context.substitutor.get_raw_type(tpl_id) { result = raw.clone(); @@ -678,7 +673,7 @@ fn instantiate_conditional_operand( // `infer` pattern 也以模板引用表示, 必须保留下来供后续结构匹配绑定. fn actualize_unresolved_templates(ty: LuaType) -> LuaType { match ty { - LuaType::TplRef(tpl) | LuaType::ConstTplRef(tpl) => { + LuaType::TplRef(tpl) => { if tpl.get_tpl_id().is_conditional_infer() { // Conditional infer 是右侧 pattern 的占位孔, 不能像普通未解模板一样抹成 unknown. LuaType::TplRef(tpl) @@ -718,6 +713,7 @@ fn actualize_unresolved_templates(ty: LuaType) -> LuaType { }) .collect(), actualize_unresolved_templates(func.get_ret().clone()), + Some(actualize_function_generic_params(&func)), ) .into(), ), @@ -818,3 +814,21 @@ fn actualize_unresolved_templates(ty: LuaType) -> LuaType { ty => ty, } } + +fn actualize_function_generic_params(func: &crate::LuaFunctionType) -> Vec { + func.get_generic_params() + .iter() + .map(|generic_tpl| { + let tpl_id = generic_tpl.get_tpl_id(); + let param = generic_tpl.get_param(); + GenericTpl::new( + tpl_id, + param.name.clone(), + param.constraint.clone().map(actualize_unresolved_templates), + param.default.clone().map(actualize_unresolved_templates), + param.is_const, + param.attributes.clone(), + ) + }) + .collect() +} diff --git a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_special_generic.rs b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_special_generic.rs index 57fae183d..02d367f55 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_special_generic.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_special_generic.rs @@ -10,7 +10,7 @@ use crate::{ use hashbrown::HashMap; use std::{ops::Deref, vec}; -use super::{GenericInstantiateContext, TypeSubstitutor, instantiate_type_generic_with_context}; +use super::{GenericInstantiateContext, TypeSubstitutor, instantiate_type_generic_inner}; pub(super) fn instantiate_alias_call( context: &GenericInstantiateContext, @@ -19,7 +19,7 @@ pub(super) fn instantiate_alias_call( let operand_exprs = alias_call.get_operands(); let operands = operand_exprs .iter() - .map(|it| instantiate_type_generic_with_context(context, it)) + .map(|it| instantiate_type_generic_inner(context, it)) .collect::>(); match alias_call.get_call_kind() { @@ -135,9 +135,7 @@ fn resolve_literal_operand( substitutor: &TypeSubstitutor, ) -> Option { match operand { - Some(LuaType::TplRef(tpl_ref)) | Some(LuaType::ConstTplRef(tpl_ref)) => { - substitutor.get_raw_type(tpl_ref.get_tpl_id()).cloned() - } + Some(LuaType::TplRef(tpl_ref)) => substitutor.get_raw_type(tpl_ref.get_tpl_id()).cloned(), _ => None, } } diff --git a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/mod.rs b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/mod.rs index 420c0842e..e71843e81 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/mod.rs @@ -1,10 +1,11 @@ mod complete_generic_args; mod instantiate_conditional_generic; -mod instantiate_func_generic; mod instantiate_special_generic; use hashbrown::{HashMap, HashSet}; -use std::{ops::Deref, sync::Arc}; +use std::ops::Deref; + +use smol_str::SmolStr; use crate::{ DbIndex, GenericTpl, GenericTplId, LuaArrayType, LuaMappedType, LuaMemberKey, @@ -14,123 +15,42 @@ use crate::{ LuaFunctionType, LuaGenericType, LuaIntersectionType, LuaObjectType, LuaType, LuaUnionType, VariadicType, }, - semantic::infer::InferFailReason, }; use super::type_substitutor::{ - GenericInstantiateContext, SubstitutorValue, TypeSubstitutor, UninferredTplPolicy, + GenericInstantiateContext, SubstitutorTypeValue, SubstitutorValue, TypeSubstitutor, }; pub use complete_generic_args::{ GenericArgumentCompletion, complete_type_generic_args, complete_type_generic_args_in_type, }; -pub use instantiate_func_generic::{build_self_type, infer_self_type, instantiate_func_generic}; pub use instantiate_special_generic::get_keyof_members; -pub(crate) fn collect_callable_overload_groups( - db: &DbIndex, - callable_type: &LuaType, - groups: &mut Vec>>, -) -> Result<(), InferFailReason> { - let mut visiting_aliases = HashSet::new(); - collect_callable_overload_groups_inner(db, callable_type, groups, &mut visiting_aliases) -} - -fn collect_callable_overload_groups_inner( - db: &DbIndex, - callable_type: &LuaType, - groups: &mut Vec>>, - visiting_aliases: &mut HashSet, -) -> Result<(), InferFailReason> { - match callable_type { - LuaType::Ref(type_id) | LuaType::Def(type_id) => { - let Some(type_decl) = db.get_type_index().get_type_decl(type_id) else { - return Ok(()); - }; - if !visiting_aliases.insert(type_id.clone()) { - return Ok(()); - } - - let result = if let Some(origin_type) = type_decl.get_alias_origin(db, None) { - collect_callable_overload_groups_inner(db, &origin_type, groups, visiting_aliases) - } else { - Ok(()) - }; - visiting_aliases.remove(type_id); - result?; - } - LuaType::Generic(generic) => { - let type_id = generic.get_base_type_id(); - if !visiting_aliases.insert(type_id.clone()) { - return Ok(()); - } - let substitutor = TypeSubstitutor::from_type_array(generic.get_params().to_vec()); - let Some(type_decl) = db.get_type_index().get_type_decl(&type_id) else { - visiting_aliases.remove(&type_id); - return Ok(()); - }; - - let result = if let Some(origin_type) = - type_decl.get_alias_origin(db, Some(&substitutor)) - { - collect_callable_overload_groups_inner(db, &origin_type, groups, visiting_aliases) - } else { - Ok(()) - }; - visiting_aliases.remove(&type_id); - result?; - } - LuaType::Union(union) => { - for member in union.into_vec() { - collect_callable_overload_groups_inner(db, &member, groups, visiting_aliases)?; - } - } - LuaType::Intersection(intersection) => { - for member in intersection.get_types() { - collect_callable_overload_groups_inner(db, member, groups, visiting_aliases)?; - } - } - LuaType::DocFunction(doc_func) => groups.push(vec![doc_func.clone()]), - LuaType::Signature(sig_id) => { - let Some(signature) = db.get_signature_index().get(sig_id) else { - return Ok(()); - }; - let mut overloads = signature.overloads.to_vec(); - overloads.push(signature.to_doc_func_type()); - groups.push(overloads); - } - _ => {} - } - - Ok(()) -} - pub fn instantiate_type_generic( db: &DbIndex, ty: &LuaType, substitutor: &TypeSubstitutor, ) -> LuaType { let context = GenericInstantiateContext::new(db, substitutor); - instantiate_type_generic_with_context(&context, ty) + match ty { + LuaType::DocFunction(doc_func) => instantiate_doc_function_with_context(&context, doc_func), + _ => instantiate_type_generic_inner(&context, ty), + } } -pub(super) fn instantiate_type_generic_with_context( +pub(super) fn instantiate_type_generic_inner( context: &GenericInstantiateContext, ty: &LuaType, ) -> LuaType { match ty { LuaType::Array(array_type) => instantiate_array(context, array_type.get_base()), LuaType::Tuple(tuple) => instantiate_tuple(context, tuple), - LuaType::DocFunction(doc_func) => instantiate_doc_function_with_context( - &context.with_policy(UninferredTplPolicy::PreserveTplRef), - doc_func, - ), + LuaType::DocFunction(doc_func) => instantiate_nested_doc_function(context, doc_func), LuaType::Object(object) => instantiate_object(context, object), LuaType::Union(union) => instantiate_union(context, union), LuaType::Intersection(intersection) => instantiate_intersection(context, intersection), - LuaType::Generic(generic) => instantiate_generic_with_context(context, generic), + LuaType::Generic(generic) => instantiate_generic_type(context, generic), LuaType::TableGeneric(table_params) => instantiate_table_generic(context, table_params), LuaType::TplRef(tpl) => instantiate_tpl_ref(tpl, context), - LuaType::ConstTplRef(tpl) => instantiate_const_tpl_ref(tpl, context), LuaType::Signature(sig_id) => instantiate_signature(context, sig_id), LuaType::Call(alias_call) => { instantiate_special_generic::instantiate_alias_call(context, alias_call) @@ -144,7 +64,7 @@ pub(super) fn instantiate_type_generic_with_context( } } LuaType::TypeGuard(guard) => { - let inner = instantiate_type_generic_with_context(context, guard.deref()); + let inner = instantiate_type_generic_inner(context, guard.deref()); LuaType::TypeGuard(inner.into()) } LuaType::Conditional(conditional) => { @@ -161,7 +81,7 @@ where { types .into_iter() - .map(|ty| instantiate_type_generic_with_context(context, ty)) + .map(|ty| instantiate_type_generic_inner(context, ty)) .collect() } @@ -176,15 +96,15 @@ where .into_iter() .map(|(key, value)| { ( - instantiate_type_generic_with_context(context, key), - instantiate_type_generic_with_context(context, value), + instantiate_type_generic_inner(context, key), + instantiate_type_generic_inner(context, value), ) }) .collect() } fn instantiate_array(context: &GenericInstantiateContext, base: &LuaType) -> LuaType { - let base = instantiate_type_generic_with_context(context, base); + let base = instantiate_type_generic_inner(context, base); LuaType::Array(LuaArrayType::from_base_type(base).into()) } @@ -209,7 +129,9 @@ fn instantiate_tuple(context: &GenericInstantiateContext, tuple: &LuaTupleType) new_types.push(ty.clone().unwrap_or(LuaType::Unknown)); } } - SubstitutorValue::Type(ty) => new_types.push(ty.default().clone()), + SubstitutorValue::Type(ty) => { + new_types.push(substitutor_type_for_tpl(tpl, ty).clone()) + } SubstitutorValue::MultiBase(base) => new_types.push(base.clone()), } } else { @@ -223,21 +145,12 @@ fn instantiate_tuple(context: &GenericInstantiateContext, tuple: &LuaTupleType) break; } - let t = instantiate_type_generic_with_context(context, t); + let t = instantiate_type_generic_inner(context, t); new_types.push(t); } LuaType::Tuple(LuaTupleType::new(new_types, tuple.status).into()) } -pub fn instantiate_doc_function( - db: &DbIndex, - doc_func: &LuaFunctionType, - substitutor: &TypeSubstitutor, -) -> LuaType { - let context = GenericInstantiateContext::new(db, substitutor); - instantiate_doc_function_with_context(&context, doc_func) -} - fn instantiate_doc_function_with_context( context: &GenericInstantiateContext, doc_func: &LuaFunctionType, @@ -246,6 +159,7 @@ fn instantiate_doc_function_with_context( let tpl_ret = doc_func.get_ret(); let async_state = doc_func.get_async_state(); let colon_define = doc_func.is_colon_define(); + let generic_params = instantiate_function_generic_params(context, doc_func); let mut new_params = Vec::new(); for origin_param in tpl_func_params.iter() { @@ -266,7 +180,7 @@ fn instantiate_doc_function_with_context( new_params.push((origin_param.0.clone(), Some(ty))); } SubstitutorValue::Type(ty) => { - let resolved_type = ty.default(); + let resolved_type = substitutor_type_for_tpl(tpl, ty); // 如果参数是 `...: T...` if origin_param.0 == "..." { // 类型是 tuple, 那么我们将展开 tuple @@ -318,7 +232,7 @@ fn instantiate_doc_function_with_context( } } LuaType::Generic(generic) => { - let new_type = instantiate_generic_with_context(context, generic); + let new_type = instantiate_generic_type(context, generic); // 如果是 rest 参数且实例化后的类型是 tuple, 那么我们将展开 tuple if let LuaType::Tuple(tuple_type) = &new_type { let base_index = new_params.len(); @@ -336,13 +250,13 @@ fn instantiate_doc_function_with_context( VariadicType::Multi(_) => (), }, _ => { - let new_type = instantiate_type_generic_with_context(context, origin_param_type); + let new_type = instantiate_type_generic_inner(context, origin_param_type); new_params.push((origin_param.0.clone(), Some(new_type))); } } } - let mut inst_ret_type = instantiate_type_generic_with_context(context, tpl_ret); + let mut inst_ret_type = instantiate_type_generic_inner(context, tpl_ret); // 对于可变返回值, 如果实例化是 tuple, 那么我们将展开 tuple if let LuaType::Variadic(_) = &&tpl_ret && let LuaType::Tuple(tuple) = &inst_ret_type @@ -375,21 +289,140 @@ fn instantiate_doc_function_with_context( is_variadic, new_params, inst_ret_type, + Some(generic_params), ) .into(), ) } +fn instantiate_nested_doc_function( + context: &GenericInstantiateContext, + doc_func: &LuaFunctionType, +) -> LuaType { + let mut transferred_params = Vec::new(); + let mut transferred_tpls = HashSet::new(); + collect_pending_function_generic_params( + context, + doc_func, + &mut transferred_params, + &mut transferred_tpls, + ); + + if transferred_tpls.is_empty() { + return instantiate_doc_function_with_context(context, doc_func); + } + + let mut generic_params = doc_func.get_generic_params().to_vec(); + for generic_param in transferred_params { + if generic_params + .iter() + .any(|tpl| tpl.get_tpl_id() == generic_param.get_tpl_id()) + { + continue; + } + + generic_params.push(generic_param); + } + + let nested_substitutor = context.substitutor.without_pending_tpls(&transferred_tpls); + let nested_context = context.with_substitutor(&nested_substitutor); + let doc_func = LuaFunctionType::new( + doc_func.get_async_state(), + doc_func.is_colon_define(), + doc_func.is_variadic(), + doc_func.get_params().to_vec(), + doc_func.get_ret().clone(), + Some(generic_params), + ); + instantiate_doc_function_with_context(&nested_context, &doc_func) +} + +fn collect_pending_function_generic_params( + context: &GenericInstantiateContext, + doc_func: &LuaFunctionType, + generic_params: &mut Vec, + generic_tpls: &mut HashSet, +) { + for generic_tpl in doc_func.get_generic_params() { + let tpl_id = generic_tpl.get_tpl_id(); + if is_pending_tpl(context, tpl_id) && generic_tpls.insert(tpl_id) { + generic_params.push(generic_tpl.clone()); + } + } + + doc_func.visit_nested_types(&mut |ty| match ty { + LuaType::TplRef(tpl) => { + let tpl_id = tpl.get_tpl_id(); + if is_pending_tpl(context, tpl_id) && generic_tpls.insert(tpl_id) { + generic_params.push(tpl.as_ref().clone()); + } + } + LuaType::StrTplRef(str_tpl) => { + let tpl_id = str_tpl.get_tpl_id(); + if is_pending_tpl(context, tpl_id) && generic_tpls.insert(tpl_id) { + generic_params.push(GenericTpl::new( + tpl_id, + SmolStr::new(str_tpl.get_name()), + str_tpl.get_constraint().cloned(), + None, + false, + None, + )); + } + } + _ => {} + }); +} + +fn is_pending_tpl(context: &GenericInstantiateContext, tpl_id: GenericTplId) -> bool { + matches!( + context.substitutor.get(tpl_id), + Some(SubstitutorValue::None) + ) +} + +fn instantiate_function_generic_params( + context: &GenericInstantiateContext, + doc_func: &LuaFunctionType, +) -> Vec { + doc_func + .get_generic_params() + .iter() + .filter_map(|generic_tpl| { + let tpl_id = generic_tpl.get_tpl_id(); + let param = generic_tpl.get_param(); + // A pending entry means this generic belongs to the current instantiation boundary + // and has been finalized into the function params/return. Foreign nested generics + // are absent from the substitutor and remain owned by the nested function. + if context.substitutor.get(tpl_id).is_some() { + return None; + } + + let constraint = param + .constraint + .as_ref() + .map(|ty| instantiate_type_generic_inner(context, ty)); + let default_type = param + .default + .as_ref() + .map(|ty| instantiate_type_generic_inner(context, ty)); + Some(GenericTpl::new( + tpl_id, + param.name.clone(), + constraint, + default_type, + param.is_const, + param.attributes.clone(), + )) + }) + .collect() +} + fn instantiate_object(context: &GenericInstantiateContext, object: &LuaObjectType) -> LuaType { let new_fields = object .get_fields() .iter() - .map(|(key, field)| { - ( - key.clone(), - instantiate_type_generic_with_context(context, field), - ) - }) + .map(|(key, field)| (key.clone(), instantiate_type_generic_inner(context, field))) .collect::>(); let new_index_access = instantiate_type_pairs(context, object.get_index_access().iter()); @@ -411,16 +444,7 @@ fn instantiate_intersection( ) } -pub fn instantiate_generic( - db: &DbIndex, - generic: &LuaGenericType, - substitutor: &TypeSubstitutor, -) -> LuaType { - let context = GenericInstantiateContext::new(db, substitutor); - instantiate_generic_with_context(&context, generic) -} - -fn instantiate_generic_with_context( +fn instantiate_generic_type( context: &GenericInstantiateContext, generic: &LuaGenericType, ) -> LuaType { @@ -458,18 +482,13 @@ fn instantiate_uninferred_tpl_fallback( tpl: &GenericTpl, context: &GenericInstantiateContext, ) -> LuaType { - // 一些情况下需要保留 TplRef, 例如高阶函数调用 - if context.should_preserve_tpl_ref() && tpl.get_default_type().is_none() { - return LuaType::TplRef(tpl.clone().into()); - } - // 显式默认值优先, 然后是 extends 约束, 最后才是 unknown. if let Some(default_type) = tpl.get_default_type() { - return instantiate_type_generic_with_context(context, default_type); + return instantiate_type_generic_inner(context, default_type); } if let Some(constraint) = tpl.get_constraint() { - return instantiate_type_generic_with_context(context, constraint); + return instantiate_type_generic_inner(context, constraint); } LuaType::Unknown @@ -481,7 +500,7 @@ fn instantiate_tpl_ref(tpl: &GenericTpl, context: &GenericInstantiateContext) -> SubstitutorValue::None => { return instantiate_uninferred_tpl_fallback(tpl, context); } - SubstitutorValue::Type(ty) => return ty.default().clone(), + SubstitutorValue::Type(ty) => return substitutor_type_for_tpl(tpl, ty).clone(), SubstitutorValue::MultiTypes(types) => { return LuaType::Variadic(VariadicType::Multi(types.clone()).into()); } @@ -500,29 +519,12 @@ fn instantiate_tpl_ref(tpl: &GenericTpl, context: &GenericInstantiateContext) -> LuaType::TplRef(tpl.clone().into()) } -fn instantiate_const_tpl_ref(tpl: &GenericTpl, context: &GenericInstantiateContext) -> LuaType { - if let Some(value) = context.substitutor.get(tpl.get_tpl_id()) { - match value { - SubstitutorValue::None => { - return instantiate_uninferred_tpl_fallback(tpl, context); - } - SubstitutorValue::Type(ty) => return ty.raw().clone(), - SubstitutorValue::MultiTypes(types) => { - return LuaType::Variadic(VariadicType::Multi(types.clone()).into()); - } - SubstitutorValue::Params(params) => { - return params - .first() - .unwrap_or(&(String::new(), None)) - .1 - .clone() - .unwrap_or(LuaType::Unknown); - } - SubstitutorValue::MultiBase(base) => return base.clone(), - } +fn substitutor_type_for_tpl<'a>(tpl: &GenericTpl, value: &'a SubstitutorTypeValue) -> &'a LuaType { + if tpl.is_const() { + value.raw() + } else { + value.default() } - - LuaType::ConstTplRef(tpl.clone().into()) } fn instantiate_signature( @@ -577,7 +579,7 @@ fn instantiate_variadic_type( }; } SubstitutorValue::Type(ty) => { - let resolved_type = ty.default(); + let resolved_type = substitutor_type_for_tpl(tpl, ty); if matches!( resolved_type, LuaType::Nil | LuaType::Any | LuaType::Unknown | LuaType::Never @@ -607,7 +609,7 @@ fn instantiate_variadic_type( } } LuaType::Generic(generic) => { - return instantiate_generic_with_context(context, generic); + return instantiate_generic_type(context, generic); } _ => {} }, @@ -615,7 +617,7 @@ fn instantiate_variadic_type( if types.iter().any(LuaTypeNode::contains_tpl_node) { let mut new_types = Vec::new(); for t in types { - let t = instantiate_type_generic_with_context(context, t); + let t = instantiate_type_generic_inner(context, t); match t { LuaType::Never => {} LuaType::Variadic(variadic) => match variadic.deref() { @@ -641,9 +643,9 @@ fn instantiate_mapped_type(context: &GenericInstantiateContext, mapped: &LuaMapp let constraint = mapped .param .1 - .type_constraint + .constraint .as_ref() - .map(|ty| instantiate_type_generic_with_context(context, ty)); + .map(|ty| instantiate_type_generic_inner(context, ty)); if let Some(constraint) = constraint { let mut key_types = Vec::new(); @@ -701,7 +703,7 @@ fn instantiate_mapped_type(context: &GenericInstantiateContext, mapped: &LuaMapp } } - instantiate_type_generic_with_context(context, &mapped.value) + instantiate_type_generic_inner(context, &mapped.value) } fn instantiate_mapped_value( @@ -713,7 +715,7 @@ fn instantiate_mapped_value( let mut local_substitutor = context.substitutor.clone(); local_substitutor.insert_type(tpl_id, replacement.clone(), true); let local_context = context.with_substitutor(&local_substitutor); - let mut result = instantiate_type_generic_with_context(&local_context, &mapped.value); + let mut result = instantiate_type_generic_inner(&local_context, &mapped.value); // 根据 readonly 和 optional 属性进行处理 if mapped.is_optional { result = TypeOps::Union.apply(context.db, &result, &LuaType::Nil); diff --git a/crates/emmylua_code_analysis/src/semantic/generic/mod.rs b/crates/emmylua_code_analysis/src/semantic/generic/mod.rs index 90e34baa3..a322f181e 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/mod.rs @@ -1,4 +1,5 @@ mod call_constraint; +mod infer_call_generic; mod instantiate_type; mod test; mod tpl_context; @@ -9,118 +10,10 @@ pub use call_constraint::{ CallConstraintArg, CallConstraintContext, build_call_constraint_context, normalize_constraint_type, }; -use emmylua_parser::LuaAstNode; -use emmylua_parser::LuaExpr; -pub(crate) use instantiate_type::collect_callable_overload_groups; +pub use infer_call_generic::{build_self_type, infer_call_generic, infer_self_type}; +pub use instantiate_type::get_keyof_members; pub use instantiate_type::*; -use rowan::NodeOrToken; pub use tpl_context::TplContext; pub use tpl_pattern::tpl_pattern_match_args; pub use tpl_pattern::tpl_pattern_match_args_skip_unknown; pub use type_substitutor::TypeSubstitutor; - -use crate::DbIndex; -use crate::GenericTplId; -use crate::LuaDeclExtra; -use crate::LuaInferCache; -use crate::LuaMemberOwner; -use crate::LuaSemanticDeclId; -use crate::LuaType; -use crate::SemanticDeclLevel; -use crate::TypeOps; -use crate::infer_node_semantic_decl; -use crate::semantic::semantic_info::infer_token_semantic_decl; -pub use instantiate_type::get_keyof_members; - -pub fn get_tpl_ref_extend_type( - db: &DbIndex, - cache: &mut LuaInferCache, - arg_type: &LuaType, - arg_expr: LuaExpr, - depth: usize, -) -> Option { - match arg_type { - LuaType::TplRef(tpl_ref) | LuaType::ConstTplRef(tpl_ref) => { - if let Some(extend) = tpl_ref.get_constraint().cloned() { - return Some(extend); - } - let node_or_token = arg_expr.syntax().clone().into(); - let semantic_decl = match node_or_token { - NodeOrToken::Node(node) => { - infer_node_semantic_decl(db, cache, node, SemanticDeclLevel::default()) - } - NodeOrToken::Token(token) => { - infer_token_semantic_decl(db, cache, token, SemanticDeclLevel::default()) - } - }?; - - match tpl_ref.get_tpl_id() { - GenericTplId::Func(tpl_id) => { - if let LuaSemanticDeclId::LuaDecl(decl_id) = semantic_decl { - let decl = db.get_decl_index().get_decl(&decl_id)?; - match decl.extra { - LuaDeclExtra::Param { signature_id, .. } => { - let signature = db.get_signature_index().get(&signature_id)?; - if let Some(generic_param) = - signature.generic_params.get(tpl_id as usize) - { - return generic_param.constraint.clone(); - } - } - _ => return None, - } - } - None - } - GenericTplId::Type(tpl_id) => { - if let LuaSemanticDeclId::LuaDecl(decl_id) = semantic_decl { - let decl = db.get_decl_index().get_decl(&decl_id)?; - match decl.extra { - LuaDeclExtra::Param { - owner_member_id, .. - } => { - let owner_member_id = owner_member_id?; - let parent_owner = - db.get_member_index().get_current_owner(&owner_member_id)?; - match parent_owner { - LuaMemberOwner::Type(type_id) => { - let generic_params = - db.get_type_index().get_generic_params(type_id)?; - return generic_params - .get(tpl_id as usize)? - .type_constraint - .clone(); - } - _ => return None, - } - } - _ => return None, - } - } - None - } - GenericTplId::ConditionalInfer(_) => None, - } - } - LuaType::StrTplRef(str_tpl) => str_tpl.get_constraint().cloned(), - LuaType::Union(union_type) => { - if depth > 1 { - return None; - } - let mut result = LuaType::Never; - for union_member_type in union_type.into_vec().iter() { - let extend_type = get_tpl_ref_extend_type( - db, - cache, - union_member_type, - arg_expr.clone(), - depth + 1, - ) - .unwrap_or(union_member_type.clone()); - result = TypeOps::Union.apply(db, &result, &extend_type); - } - Some(result) - } - _ => None, - } -} diff --git a/crates/emmylua_code_analysis/src/semantic/generic/tpl_pattern/generic_tpl_pattern.rs b/crates/emmylua_code_analysis/src/semantic/generic/tpl_pattern/generic_tpl_pattern.rs index aa556ca88..7e6949446 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/tpl_pattern/generic_tpl_pattern.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/tpl_pattern/generic_tpl_pattern.rs @@ -1,8 +1,11 @@ use crate::{ - InferFailReason, InferGuard, InferGuardRef, LuaGenericType, LuaType, LuaTypeNode, TplContext, - TypeSubstitutor, instantiate_generic, instantiate_type_generic, - semantic::generic::tpl_pattern::{ - TplPatternMatchResult, tpl_pattern_match, variadic_tpl_pattern_match, + InferFailReason, InferGuard, InferGuardRef, LuaFunctionType, LuaGenericType, LuaType, + LuaTypeNode, SignatureReturnStatus, TplContext, TypeSubstitutor, instantiate_type_generic, + semantic::{ + generic::tpl_pattern::{ + TplPatternMatchResult, tpl_pattern_match, variadic_tpl_pattern_match, + }, + member::{find_members_with_key, get_member_map}, }, }; @@ -122,11 +125,15 @@ fn generic_tpl_pattern_match_inner( )?; } } + LuaType::TableConst(_) => { + match_generic_members_with_table_literal(context, source_generic, target)?; + } _ => { // 对于 @alias 类型, 我们能拿到的 target 实际上很有可能是实例化后的类型, 因此我们需要实例化后再进行匹配 let substitutor = TypeSubstitutor::new(); - let typ = instantiate_generic(context.db, source_generic, &substitutor); - if LuaType::from(source_generic.clone()) != typ { + let source_type = LuaType::from(source_generic.clone()); + let typ = instantiate_type_generic(context.db, &source_type, &substitutor); + if source_type != typ { tpl_pattern_match(context, &typ, target)?; } } @@ -134,3 +141,129 @@ fn generic_tpl_pattern_match_inner( Ok(()) } + +fn match_generic_members_with_table_literal( + context: &mut TplContext, + source_generic: &LuaGenericType, + table_type: &LuaType, +) -> TplPatternMatchResult { + if context.substitutor.is_infer_all_tpl() { + return Ok(()); + } + + let Some(target_member_map) = get_member_map(context.db, table_type) else { + return Ok(()); + }; + + let source_type = LuaType::Generic(source_generic.clone().into()); + for (member_key, target_members) in target_member_map { + if context.substitutor.is_infer_all_tpl() { + break; + } + + let Some(source_members) = + find_members_with_key(context.db, &source_type, member_key, true) + else { + continue; + }; + + for source_member in source_members { + if !source_member.typ.contain_tpl() { + continue; + } + + for target_member in &target_members { + let target_type = erase_implicit_signature_types(context, &target_member.typ); + tpl_pattern_match_ignoring_unknown_target( + context, + &source_member.typ, + &target_type, + )?; + if context.substitutor.is_infer_all_tpl() { + break; + } + } + + if context.substitutor.is_infer_all_tpl() { + break; + } + } + } + + Ok(()) +} + +fn erase_implicit_signature_types(context: &TplContext, target: &LuaType) -> LuaType { + let LuaType::Signature(signature_id) = target else { + return target.clone(); + }; + let Some(signature) = context.db.get_signature_index().get(signature_id) else { + return target.clone(); + }; + + let params = signature + .params + .iter() + .enumerate() + .map(|(idx, name)| { + ( + name.clone(), + Some( + signature + .param_docs + .get(&idx) + .map(|param| param.type_ref.clone()) + .unwrap_or(LuaType::Unknown), + ), + ) + }) + .collect(); + let ret = if signature.resolve_return == SignatureReturnStatus::DocResolve { + signature.get_return_type() + } else { + LuaType::Unknown + }; + + LuaType::DocFunction( + LuaFunctionType::new( + signature.async_state, + signature.is_colon_define, + signature.is_vararg, + params, + ret, + Some(signature.get_function_generic_params()), + ) + .into(), + ) +} + +fn tpl_pattern_match_ignoring_unknown_target( + context: &mut TplContext, + pattern: &LuaType, + target: &LuaType, +) -> TplPatternMatchResult { + if pattern.contain_tpl() && (target.is_any() || target.is_unknown()) { + return Ok(()); + } + + match (pattern, target) { + (LuaType::DocFunction(pattern_func), LuaType::DocFunction(target_func)) => { + for ((_, pattern_param), (_, target_param)) in pattern_func + .get_params() + .iter() + .zip(target_func.get_params().iter()) + { + let pattern_param = pattern_param.clone().unwrap_or(LuaType::Any); + let target_param = target_param.clone().unwrap_or(LuaType::Unknown); + tpl_pattern_match_ignoring_unknown_target(context, &pattern_param, &target_param)?; + } + + tpl_pattern_match_ignoring_unknown_target( + context, + pattern_func.get_ret(), + target_func.get_ret(), + ) + } + _ => tpl_pattern_match(context, pattern, target), + } +} diff --git a/crates/emmylua_code_analysis/src/semantic/generic/tpl_pattern/mod.rs b/crates/emmylua_code_analysis/src/semantic/generic/tpl_pattern/mod.rs index 467da24e6..c21bff656 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/tpl_pattern/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/tpl_pattern/mod.rs @@ -161,14 +161,7 @@ pub fn tpl_pattern_match( if tpl.get_tpl_id().is_func() { context .substitutor - .insert_type(tpl.get_tpl_id(), target.clone(), true); - } - } - LuaType::ConstTplRef(tpl) => { - if tpl.get_tpl_id().is_func() { - context - .substitutor - .insert_type(tpl.get_tpl_id(), target, false); + .infer_type(tpl.get_tpl_id(), target.clone(), !tpl.is_const()); } } LuaType::StrTplRef(str_tpl) => { @@ -176,7 +169,7 @@ pub fn tpl_pattern_match( let prefix = str_tpl.get_prefix(); let suffix = str_tpl.get_suffix(); let type_name = SmolStr::new(format!("{}{}{}", prefix, s, suffix)); - context.substitutor.insert_type( + context.substitutor.infer_type( str_tpl.get_tpl_id(), get_str_tpl_infer_type(&type_name), true, @@ -220,6 +213,14 @@ pub fn constant_decay(typ: LuaType) -> LuaType { } } +fn maybe_decay_type(typ: &LuaType, decay: bool) -> LuaType { + if decay { + constant_decay(typ.clone()) + } else { + typ.clone() + } +} + fn object_tpl_pattern_match( context: &mut TplContext, origin_obj: &LuaObjectType, @@ -716,7 +717,7 @@ pub(crate) fn return_type_pattern_match_target_type( let tpl_id = type_ref.get_tpl_id(); context .substitutor - .insert_type(tpl_id, target_base.clone(), true); + .infer_type(tpl_id, target_base.clone(), true); } } VariadicType::Multi(source_multi) => { @@ -727,7 +728,7 @@ pub(crate) fn return_type_pattern_match_target_type( && let LuaType::TplRef(type_ref) = base { let tpl_id = type_ref.get_tpl_id(); - context.substitutor.insert_type( + context.substitutor.infer_type( tpl_id, target_base.clone(), true, @@ -738,7 +739,7 @@ pub(crate) fn return_type_pattern_match_target_type( } LuaType::TplRef(tpl_ref) => { let tpl_id = tpl_ref.get_tpl_id(); - context.substitutor.insert_type( + context.substitutor.infer_type( tpl_id, target_base.clone(), true, @@ -781,7 +782,7 @@ fn func_varargs_tpl_pattern_match( VariadicType::Base(base) => { if let LuaType::TplRef(tpl_ref) = base { let tpl_id = tpl_ref.get_tpl_id(); - substitutor.insert_params( + substitutor.infer_params( tpl_id, target_rest_params .iter() @@ -802,13 +803,14 @@ pub fn variadic_tpl_pattern_match( target_rest_types: &[LuaType], ) -> TplPatternMatchResult { match tpl { - VariadicType::Base(base) => match base { - LuaType::TplRef(tpl_ref) => { + VariadicType::Base(base) => { + if let LuaType::TplRef(tpl_ref) = base { let tpl_id = tpl_ref.get_tpl_id(); + let decay = !tpl_ref.is_const(); match target_rest_types.len() { 0 => { // Zero varargs are an empty sequence, not one nil return slot. - context.substitutor.insert_multi_types(tpl_id, Vec::new()); + context.substitutor.infer_multi_types(tpl_id, Vec::new()); } 1 => { // If the single argument is itself a multi-return (e.g. a function call @@ -818,67 +820,46 @@ pub fn variadic_tpl_pattern_match( LuaType::Variadic(variadic) => match variadic.deref() { VariadicType::Multi(types) => match types.len() { 0 => { - context.substitutor.insert_multi_types(tpl_id, Vec::new()); + context.substitutor.infer_multi_types(tpl_id, Vec::new()); } 1 => { - context.substitutor.insert_type( + context.substitutor.infer_type( tpl_id, types[0].clone(), - true, + decay, ); } _ => { - context.substitutor.insert_multi_types( + context.substitutor.infer_multi_types( tpl_id, types .iter() - .map(|t| constant_decay(t.clone())) + .map(|t| maybe_decay_type(t, decay)) .collect(), ); } }, VariadicType::Base(base) => { - context.substitutor.insert_multi_base(tpl_id, base.clone()); + context.substitutor.infer_multi_base(tpl_id, base.clone()); } }, arg => { - context.substitutor.insert_type(tpl_id, arg.clone(), true); + context.substitutor.infer_type(tpl_id, arg.clone(), decay); } } } _ => { - context.substitutor.insert_multi_types( + context.substitutor.infer_multi_types( tpl_id, target_rest_types .iter() - .map(|t| constant_decay(t.clone())) + .map(|t| maybe_decay_type(t, decay)) .collect(), ); } } } - LuaType::ConstTplRef(tpl_ref) => { - let tpl_id = tpl_ref.get_tpl_id(); - match target_rest_types.len() { - 0 => { - context.substitutor.insert_multi_types(tpl_id, Vec::new()); - } - 1 => { - context.substitutor.insert_type( - tpl_id, - target_rest_types[0].clone(), - false, - ); - } - _ => { - context - .substitutor - .insert_multi_types(tpl_id, target_rest_types.to_vec()); - } - } - } - _ => {} - }, + } VariadicType::Multi(multi) => { for (i, ret_type) in multi.iter().enumerate() { match ret_type { @@ -893,7 +874,7 @@ pub fn variadic_tpl_pattern_match( let tpl_id = tpl_ref.get_tpl_id(); match target_rest_types.get(i) { Some(t) => { - context.substitutor.insert_type(tpl_id, t.clone(), true); + context.substitutor.infer_type(tpl_id, t.clone(), true); } None => { break; @@ -946,7 +927,7 @@ fn tuple_tpl_pattern_match( let tpl_id = tpl_ref.get_tpl_id(); context .substitutor - .insert_multi_base(tpl_id, target_array_base.get_base().clone()); + .infer_multi_base(tpl_id, target_array_base.get_base().clone()); } } VariadicType::Multi(_) => {} diff --git a/crates/emmylua_code_analysis/src/semantic/generic/type_substitutor.rs b/crates/emmylua_code_analysis/src/semantic/generic/type_substitutor.rs index 10a1733ba..1891820d5 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/type_substitutor.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/type_substitutor.rs @@ -4,19 +4,10 @@ use std::{cell::RefCell, rc::Rc}; use super::tpl_pattern::constant_decay; use crate::{DbIndex, GenericTplId, LuaSignatureId, LuaType, LuaTypeDeclId}; -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub(super) enum UninferredTplPolicy { - /// 未推断模板按 `default -> constraint -> unknown` 推断成实际类型. - Fallback, - /// 没有默认值的未推断模板仍保留为 `TplRef`, 让后续调用点继续参与参数推导. - PreserveTplRef, -} - #[derive(Debug)] pub struct GenericInstantiateContext<'a> { pub db: &'a DbIndex, pub substitutor: &'a TypeSubstitutor, - policy: UninferredTplPolicy, instantiating_signatures: Rc>>, } @@ -25,20 +16,10 @@ impl<'a> GenericInstantiateContext<'a> { Self { db, substitutor, - policy: UninferredTplPolicy::Fallback, instantiating_signatures: Rc::new(RefCell::new(HashSet::new())), } } - pub(super) fn with_policy(&self, policy: UninferredTplPolicy) -> GenericInstantiateContext<'a> { - GenericInstantiateContext { - db: self.db, - substitutor: self.substitutor, - policy, - instantiating_signatures: self.instantiating_signatures.clone(), - } - } - pub fn with_substitutor<'b>( &'b self, substitutor: &'b TypeSubstitutor, @@ -46,15 +27,10 @@ impl<'a> GenericInstantiateContext<'a> { GenericInstantiateContext { db: self.db, substitutor, - policy: self.policy, instantiating_signatures: self.instantiating_signatures.clone(), } } - pub fn should_preserve_tpl_ref(&self) -> bool { - self.policy == UninferredTplPolicy::PreserveTplRef - } - pub(super) fn enter_signature( &self, signature_id: LuaSignatureId, @@ -169,6 +145,17 @@ impl TypeSubstitutor { self.insert_type_value(tpl_id, SubstitutorTypeValue::new(replace_type, decay)); } + pub fn infer_type(&mut self, tpl_id: GenericTplId, replace_type: LuaType, decay: bool) { + if tpl_id.is_conditional_infer() || !self.can_infer_type(tpl_id) { + return; + } + + self.tpl_replace_map.insert( + tpl_id, + SubstitutorValue::Type(SubstitutorTypeValue::new(replace_type, decay)), + ); + } + pub(super) fn replace_type( &mut self, tpl_id: GenericTplId, @@ -214,6 +201,12 @@ impl TypeSubstitutor { true } + fn can_infer_type(&self, tpl_id: GenericTplId) -> bool { + self.tpl_replace_map + .get(&tpl_id) + .is_some_and(SubstitutorValue::is_none) + } + pub fn insert_params(&mut self, tpl_id: GenericTplId, params: Vec<(String, Option)>) { if tpl_id.is_conditional_infer() { return; @@ -232,6 +225,20 @@ impl TypeSubstitutor { .insert(tpl_id, SubstitutorValue::Params(params)); } + pub fn infer_params(&mut self, tpl_id: GenericTplId, params: Vec<(String, Option)>) { + if tpl_id.is_conditional_infer() || !self.can_infer_type(tpl_id) { + return; + } + + let params = params + .into_iter() + .map(|(name, ty)| (name, ty.map(into_ref_type))) + .collect(); + + self.tpl_replace_map + .insert(tpl_id, SubstitutorValue::Params(params)); + } + pub fn insert_multi_types(&mut self, tpl_id: GenericTplId, types: Vec) { if tpl_id.is_conditional_infer() { return; @@ -245,6 +252,15 @@ impl TypeSubstitutor { .insert(tpl_id, SubstitutorValue::MultiTypes(types)); } + pub fn infer_multi_types(&mut self, tpl_id: GenericTplId, types: Vec) { + if tpl_id.is_conditional_infer() || !self.can_infer_type(tpl_id) { + return; + } + + self.tpl_replace_map + .insert(tpl_id, SubstitutorValue::MultiTypes(types)); + } + pub fn insert_multi_base(&mut self, tpl_id: GenericTplId, type_base: LuaType) { if tpl_id.is_conditional_infer() { return; @@ -258,10 +274,34 @@ impl TypeSubstitutor { .insert(tpl_id, SubstitutorValue::MultiBase(type_base)); } + pub fn infer_multi_base(&mut self, tpl_id: GenericTplId, type_base: LuaType) { + if tpl_id.is_conditional_infer() || !self.can_infer_type(tpl_id) { + return; + } + + self.tpl_replace_map + .insert(tpl_id, SubstitutorValue::MultiBase(type_base)); + } + pub fn get(&self, tpl_id: GenericTplId) -> Option<&SubstitutorValue> { self.tpl_replace_map.get(&tpl_id) } + pub(super) fn without_pending_tpls(&self, tpl_ids: &HashSet) -> Self { + let mut substitutor = self.clone(); + for tpl_id in tpl_ids { + if substitutor + .tpl_replace_map + .get(tpl_id) + .is_some_and(SubstitutorValue::is_none) + { + substitutor.tpl_replace_map.remove(tpl_id); + } + } + + substitutor + } + pub fn get_raw_type(&self, tpl_id: GenericTplId) -> Option<&LuaType> { match self.tpl_replace_map.get(&tpl_id) { Some(SubstitutorValue::Type(ty)) => Some(ty.raw()), diff --git a/crates/emmylua_code_analysis/src/semantic/infer/infer_call/mod.rs b/crates/emmylua_code_analysis/src/semantic/infer/infer_call/mod.rs index cd1360cf2..735bf5d4b 100644 --- a/crates/emmylua_code_analysis/src/semantic/infer/infer_call/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/infer/infer_call/mod.rs @@ -1,30 +1,29 @@ use std::sync::Arc; -use emmylua_parser::{LuaAstNode, LuaCallExpr, LuaExpr, LuaSyntaxKind}; -use hashbrown::HashSet; +use emmylua_parser::{LuaAstNode, LuaCallExpr, LuaExpr, LuaIndexExpr, LuaSyntaxKind}; use rowan::TextRange; use super::{ super::{InferGuard, LuaInferCache, instantiate_type_generic, resolve_signature}, InferFailReason, InferResult, }; -use crate::semantic::overload_resolve::callable_accepts_args; use crate::{ AsyncState, CacheEntry, DbIndex, InFiled, LuaFunctionType, LuaGenericType, LuaInstanceType, - LuaIntersectionType, LuaOperatorMetaMethod, LuaOperatorOwner, LuaSignature, LuaSignatureId, - LuaType, LuaTypeDeclId, LuaUnionType, TypeVisitTrait, VariadicType, + LuaIntersectionType, LuaOperatorMetaMethod, LuaOperatorOwner, LuaSemanticDeclId, LuaSignature, + LuaSignatureId, LuaType, LuaTypeDeclId, LuaUnionType, SemanticDeclLevel, TypeVisitTrait, + VariadicType, }; use crate::{ InferGuardRef, semantic::{ - generic::{ - TypeSubstitutor, collect_callable_overload_groups, get_tpl_ref_extend_type, - instantiate_doc_function, - }, + generic::TypeSubstitutor, infer::narrow::get_type_at_call_expr_inline_cast, + infer_node_semantic_decl, + member::find_member_origin_owner, + overload_resolve::{collect_callable_overload_groups, match_callable_by_arg_types}, }, }; -use crate::{build_self_type, infer_self_type, instantiate_func_generic, semantic::infer_expr}; +use crate::{build_self_type, infer_call_generic, infer_self_type, semantic::infer_expr}; use infer_require::infer_require_call; use infer_setmetatable::infer_setmetatable_call; @@ -33,6 +32,7 @@ mod infer_setmetatable; pub type InferCallFuncResult = Result, InferFailReason>; +// TODO: 如果没有完全匹配的签名也会返回一个不精确的类型, 考虑返回`None` pub fn infer_call_expr_func( db: &DbIndex, cache: &mut LuaInferCache, @@ -99,7 +99,7 @@ pub fn infer_call_expr_func( ), LuaType::Instance(inst) => infer_instance_type_doc_function(db, inst), LuaType::TableConst(meta_table) => infer_table_type_doc_function(db, meta_table.clone()), - LuaType::TplRef(_) | LuaType::ConstTplRef(_) | LuaType::StrTplRef(_) => infer_tpl_ref_call( + LuaType::TplRef(_) | LuaType::StrTplRef(_) => infer_tpl_ref_call( db, cache, call_expr.clone(), @@ -113,6 +113,7 @@ pub fn infer_call_expr_func( true, vec![("...".to_string(), Some(LuaType::Unknown))], LuaType::Variadic(VariadicType::Base(LuaType::Unknown).into()), + None, ))), LuaType::Intersection(intersection) => infer_intersection( db, @@ -128,6 +129,7 @@ pub fn infer_call_expr_func( true, vec![], LuaType::Any, + None, ))), LuaType::Union(union) => infer_union(db, cache, union, call_expr.clone(), args_count), _ => Err(InferFailReason::None), @@ -136,7 +138,7 @@ pub fn infer_call_expr_func( let result = if let Ok(func_ty) = result { let func_ty = match func_ty.get_ret() { LuaType::Call(_) => { - match instantiate_func_generic(db, cache, func_ty.as_ref(), call_expr.clone()) { + match infer_call_generic(db, cache, func_ty.as_ref(), call_expr.clone()) { Ok(func_ty) => Arc::new(func_ty), Err(_) => func_ty, } @@ -154,6 +156,7 @@ pub fn infer_call_expr_func( func_ty.is_variadic(), func_ty.get_params().to_vec(), new_ret, + Some(func_ty.get_generic_params().to_vec()), ) .into() }), @@ -207,9 +210,12 @@ fn infer_tpl_ref_call( infer_guard: &InferGuardRef, args_count: Option, ) -> InferCallFuncResult { - let prefix_expr = call_expr.get_prefix_expr().ok_or(InferFailReason::None)?; - let extend_type = get_tpl_ref_extend_type(db, cache, call_expr_type, prefix_expr, 0) - .ok_or(InferFailReason::None)?; + let extend_type = match call_expr_type { + LuaType::TplRef(tpl) => tpl.get_constraint().cloned(), + LuaType::StrTplRef(str_tpl) => str_tpl.get_constraint().cloned(), + _ => None, + } + .ok_or(InferFailReason::None)?; if &extend_type == call_expr_type { return Err(InferFailReason::None); } @@ -223,20 +229,19 @@ fn infer_doc_function( call_expr: LuaCallExpr, ) -> InferCallFuncResult { if func.contain_tpl() { - let result = instantiate_func_generic(db, cache, func, call_expr)?; + let result = infer_call_generic(db, cache, func, call_expr)?; return Ok(Arc::new(result)); } Ok(func.clone().into()) } -fn filter_callable_overloads_by_call_args( +fn filter_callable_overloads_by_args( db: &DbIndex, cache: &mut LuaInferCache, overloads: Vec>, call_expr: &LuaCallExpr, args_count: Option, - strict_arg_filter: bool, ) -> Result>, InferFailReason> { let args = call_expr.get_args_list().ok_or(InferFailReason::None)?; let expr_types = super::infer_expr_list_types( @@ -249,39 +254,11 @@ fn filter_callable_overloads_by_call_args( .into_iter() .map(|(ty, _)| ty) .collect::>(); - let is_colon_call = call_expr.is_colon_call(); Ok(overloads .into_iter() - .filter(|func| { - let mut callable_tpls = HashSet::new(); - func.visit_type(&mut |ty| match ty { - LuaType::TplRef(generic_tpl) | LuaType::ConstTplRef(generic_tpl) => { - callable_tpls.insert(generic_tpl.get_tpl_id()); - } - LuaType::StrTplRef(str_tpl) => { - callable_tpls.insert(str_tpl.get_tpl_id()); - } - _ => {} - }); - - if callable_tpls.is_empty() && !strict_arg_filter { - return true; - } - - let has_tpls = !callable_tpls.is_empty(); - let mut substitutor = TypeSubstitutor::new(); - substitutor.add_need_infer_tpls(callable_tpls); - let match_func = if has_tpls { - match instantiate_doc_function(db, func, &substitutor) { - LuaType::DocFunction(doc_func) => doc_func, - _ => func.clone(), - } - } else { - func.clone() - }; - - callable_accepts_args(db, &match_func, &expr_types, is_colon_call, args_count) + .filter_map(|func| { + match_callable_by_arg_types(db, cache, func, &expr_types, call_expr, args_count, true) }) .collect()) } @@ -353,22 +330,21 @@ fn infer_type_doc_function( let has_generic_tpl = { let mut has_generic_tpl = false; f.visit_type(&mut |t| { - has_generic_tpl |= matches!( - t, - LuaType::TplRef(_) | LuaType::ConstTplRef(_) | LuaType::StrTplRef(_) - ); + has_generic_tpl |= matches!(t, LuaType::TplRef(_) | LuaType::StrTplRef(_)); }); has_generic_tpl }; if has_generic_tpl { - let result = instantiate_func_generic(db, cache, &f, call_expr.clone())?; + let result = infer_call_generic(db, cache, &f, call_expr.clone())?; overloads.push(Arc::new(result)); } else if f.contain_self() { let mut substitutor = TypeSubstitutor::new(); let self_type = build_self_type(db, call_expr_type); substitutor.add_self_type(self_type); - if let LuaType::DocFunction(f) = instantiate_doc_function(db, &f, &substitutor) + let func_type = LuaType::DocFunction(f.clone()); + if let LuaType::DocFunction(f) = + instantiate_type_generic(db, &func_type, &substitutor) { overloads.push(f); } @@ -541,13 +517,12 @@ fn infer_union( let mut overload_groups = Vec::new(); collect_callable_overload_groups(db, &ty, &mut overload_groups)?; for overloads in overload_groups { - let compatible_overloads = filter_callable_overloads_by_call_args( + let compatible_overloads = filter_callable_overloads_by_args( db, cache, overloads.clone(), &call_expr, args_count, - true, )?; if compatible_overloads.is_empty() { fallback_overloads.extend(overloads); @@ -585,14 +560,6 @@ fn infer_union( let Some(first_func) = first_func else { if !fallback_overloads.is_empty() { let contains_tpl = fallback_overloads.iter().any(|func| func.contain_tpl()); - let fallback_overloads = filter_callable_overloads_by_call_args( - db, - cache, - fallback_overloads, - &call_expr, - args_count, - false, - )?; return resolve_signature( db, cache, @@ -612,6 +579,7 @@ fn infer_union( first_func.is_variadic(), first_func.get_params().to_vec(), LuaType::from_vec(returns), + Some(first_func.get_generic_params().to_vec()), ))) } @@ -838,6 +806,67 @@ fn signature_is_generic( } } +/// 推断调用者的具体类型. +pub fn infer_call_receiver_type( + db: &DbIndex, + cache: &mut LuaInferCache, + call_expr: &LuaCallExpr, +) -> Option { + match call_expr.get_prefix_expr()? { + LuaExpr::IndexExpr(index_expr) => { + let decl = infer_node_semantic_decl( + db, + cache, + index_expr.syntax().clone(), + SemanticDeclLevel::default(), + )?; + + if let LuaSemanticDeclId::Member(member_id) = decl + && let Some(LuaSemanticDeclId::Member(member_id)) = + find_member_origin_owner(db, cache, member_id) + { + let root = db + .get_vfs() + .get_syntax_tree(&member_id.file_id)? + .get_red_root(); + let cur_node = member_id.get_syntax_id().to_node_from_root(&root)?; + let index_expr = LuaIndexExpr::cast(cur_node)?; + + return index_expr.get_prefix_expr().map(|prefix_expr| { + infer_expr(db, cache, prefix_expr).unwrap_or(LuaType::SelfInfer) + }); + } + + index_expr + .get_prefix_expr() + .map(|prefix_expr| infer_expr(db, cache, prefix_expr).unwrap_or(LuaType::SelfInfer)) + } + LuaExpr::NameExpr(name_expr) => { + let decl = infer_node_semantic_decl( + db, + cache, + name_expr.syntax().clone(), + SemanticDeclLevel::default(), + )?; + if let LuaSemanticDeclId::Member(member_id) = decl { + let root = db + .get_vfs() + .get_syntax_tree(&member_id.file_id)? + .get_red_root(); + let cur_node = member_id.get_syntax_id().to_node_from_root(&root)?; + let index_expr = LuaIndexExpr::cast(cur_node)?; + + return index_expr.get_prefix_expr().map(|prefix_expr| { + infer_expr(db, cache, prefix_expr).unwrap_or(LuaType::SelfInfer) + }); + } + + None + } + _ => None, + } +} + #[cfg(test)] mod tests { use crate::{ diff --git a/crates/emmylua_code_analysis/src/semantic/infer/infer_doc_type.rs b/crates/emmylua_code_analysis/src/semantic/infer/infer_doc_type.rs index 288617c82..3e1d07aae 100644 --- a/crates/emmylua_code_analysis/src/semantic/infer/infer_doc_type.rs +++ b/crates/emmylua_code_analysis/src/semantic/infer/infer_doc_type.rs @@ -1,19 +1,19 @@ use std::sync::Arc; use emmylua_parser::{ - LuaAstNode, LuaDocAttributeType, LuaDocBinaryType, LuaDocDescriptionOwner, LuaDocFuncType, - LuaDocGenericType, LuaDocMultiLineUnionType, LuaDocObjectFieldKey, LuaDocObjectType, - LuaDocStrTplType, LuaDocType, LuaDocUnaryType, LuaDocVariadicType, LuaLiteralToken, - LuaSyntaxKind, LuaTypeBinaryOperator, LuaTypeUnaryOperator, NumberResult, + LuaAstNode, LuaDocBinaryType, LuaDocDescriptionOwner, LuaDocFuncType, LuaDocGenericType, + LuaDocMultiLineUnionType, LuaDocObjectFieldKey, LuaDocObjectType, LuaDocStrTplType, LuaDocType, + LuaDocUnaryType, LuaDocVariadicType, LuaLiteralToken, LuaSyntaxKind, LuaTypeBinaryOperator, + LuaTypeUnaryOperator, NumberResult, }; use rowan::TextRange; use smol_str::SmolStr; use crate::{ AsyncState, DbIndex, FileId, InFiled, LuaAliasCallKind, LuaAliasCallType, LuaArrayLen, - LuaArrayType, LuaAttributeType, LuaFunctionType, LuaGenericType, LuaIndexAccessKey, - LuaIntersectionType, LuaMultiLineUnion, LuaObjectType, LuaStringTplType, LuaTupleStatus, - LuaTupleType, LuaType, LuaTypeDeclId, TypeOps, VariadicType, complete_type_generic_args, + LuaArrayType, LuaFunctionType, LuaGenericType, LuaIndexAccessKey, LuaIntersectionType, + LuaMultiLineUnion, LuaObjectType, LuaStringTplType, LuaTupleStatus, LuaTupleType, LuaType, + LuaTypeDeclId, TypeOps, VariadicType, complete_type_generic_args, }; #[derive(Clone, Copy)] @@ -115,9 +115,6 @@ pub fn infer_doc_type(ctx: DocTypeInferContext<'_>, node: &LuaDocType) -> LuaTyp LuaDocType::MultiLineUnion(multi_union) => { return infer_multi_line_union_type(ctx, multi_union); } - LuaDocType::Attribute(attribute_type) => { - return infer_attribute_type(ctx, attribute_type); - } _ => {} } LuaType::Unknown @@ -497,6 +494,7 @@ fn infer_func_type(ctx: DocTypeInferContext<'_>, func: &LuaDocFuncType) -> LuaTy is_variadic, params_result, return_type, + None, ) .into(), ) @@ -612,35 +610,3 @@ fn infer_multi_line_union_type( LuaType::MultiLineUnion(LuaMultiLineUnion::new(union_members).into()) } - -fn infer_attribute_type( - ctx: DocTypeInferContext<'_>, - attribute_type: &LuaDocAttributeType, -) -> LuaType { - let mut params_result = Vec::new(); - for param in attribute_type.get_params() { - let name = if let Some(param) = param.get_name_token() { - param.get_name_text().to_string() - } else if param.is_dots() { - "...".to_string() - } else { - continue; - }; - - let nullable = param.is_nullable(); - - let type_ref = if let Some(type_ref) = param.get_type() { - let mut typ = infer_doc_type(ctx, &type_ref); - if nullable && !typ.is_nullable() { - typ = TypeOps::Union.apply(ctx.db, &typ, &LuaType::Nil); - } - Some(typ) - } else { - None - }; - - params_result.push((name, type_ref)); - } - - LuaType::DocAttribute(LuaAttributeType::new(params_result).into()) -} diff --git a/crates/emmylua_code_analysis/src/semantic/infer/infer_index/infer_array.rs b/crates/emmylua_code_analysis/src/semantic/infer/infer_index/infer_array.rs index 85fd20a77..aadd064da 100644 --- a/crates/emmylua_code_analysis/src/semantic/infer/infer_index/infer_array.rs +++ b/crates/emmylua_code_analysis/src/semantic/infer/infer_index/infer_array.rs @@ -68,11 +68,7 @@ pub(super) fn infer_array_member_by_key( fn key_type_matches(db: &DbIndex, expected: &LuaType, actual: &LuaType) -> bool { !matches!( actual, - LuaType::Any - | LuaType::Unknown - | LuaType::TplRef(_) - | LuaType::StrTplRef(_) - | LuaType::ConstTplRef(_) + LuaType::Any | LuaType::Unknown | LuaType::TplRef(_) | LuaType::StrTplRef(_) ) && check_type_compact(db, expected, actual).is_ok() } diff --git a/crates/emmylua_code_analysis/src/semantic/infer/infer_index/mod.rs b/crates/emmylua_code_analysis/src/semantic/infer/infer_index/mod.rs index 1f3ad09ce..b972d2f7c 100644 --- a/crates/emmylua_code_analysis/src/semantic/infer/infer_index/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/infer/infer_index/mod.rs @@ -15,7 +15,7 @@ use crate::{ DbIndex, LuaGenericType, LuaIntersectionType, LuaMemberKey, LuaObjectType, LuaOperatorMetaMethod, LuaTupleType, LuaType, LuaTypeDeclId, LuaUnionType, }, - enum_variable_is_param, get_keyof_members, get_tpl_ref_extend_type, + enum_variable_is_param, get_keyof_members, semantic::{ InferGuard, generic::{TypeSubstitutor, instantiate_type_generic}, @@ -1183,18 +1183,9 @@ fn infer_tpl_ref_member( lookup: &MemberLookupQuery, infer_guard: &InferGuardRef, ) -> InferResult { - let extend_type = get_tpl_ref_extend_type( - db, - cache, - &LuaType::TplRef(generic.clone().into()), - lookup - .index_expr - .get_index_expr() - .ok_or(InferFailReason::None)? - .get_prefix_expr() - .ok_or(InferFailReason::None)?, - 0, - ) - .ok_or(InferFailReason::None)?; + let extend_type = generic + .get_constraint() + .cloned() + .ok_or(InferFailReason::None)?; infer_member_by_lookup(db, cache, &extend_type, lookup, infer_guard) } diff --git a/crates/emmylua_code_analysis/src/semantic/infer/mod.rs b/crates/emmylua_code_analysis/src/semantic/infer/mod.rs index eed7743ff..b20467513 100644 --- a/crates/emmylua_code_analysis/src/semantic/infer/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/infer/mod.rs @@ -17,7 +17,7 @@ use emmylua_parser::{ }; use infer_binary::infer_binary_expr; use infer_call::infer_call_expr; -pub use infer_call::infer_call_expr_func; +pub use infer_call::{infer_call_expr_func, infer_call_receiver_type}; pub use infer_doc_type::{DocTypeInferContext, infer_doc_type}; pub use infer_fail_reason::InferFailReason; pub use infer_index::infer_index_expr; diff --git a/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/call_flow.rs b/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/call_flow.rs index e915915b1..47dd87510 100644 --- a/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/call_flow.rs +++ b/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/call_flow.rs @@ -16,7 +16,7 @@ use crate::{ var_ref_id::get_var_expr_var_ref_id, }, }, - semantic::instantiate_func_generic, + semantic::infer_call_generic, }; pub fn get_type_at_call_expr( @@ -225,9 +225,9 @@ fn get_type_guard_call_info( let mut return_type = func_type.get_ret().clone(); if return_type.contain_tpl() { - let Ok(inst_func) = cache.with_no_flow(|cache| { - instantiate_func_generic(db, cache, func_type.as_ref(), call_expr) - }) else { + let Ok(inst_func) = cache + .with_no_flow(|cache| infer_call_generic(db, cache, func_type.as_ref(), call_expr)) + else { return Ok(None); }; return_type = inst_func.get_ret().clone(); diff --git a/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/correlated_flow.rs b/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/correlated_flow.rs index a2d06dd5f..324498e47 100644 --- a/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/correlated_flow.rs +++ b/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/correlated_flow.rs @@ -7,7 +7,7 @@ use crate::{ LuaSignature, LuaType, TypeOps, semantic::{ infer::{InferResult, VarRefId, narrow::narrow_down_type, try_infer_expr_no_flow}, - instantiate_func_generic, + infer_call_generic, }, }; @@ -575,10 +575,9 @@ fn instantiate_return_rows( signature.is_vararg, signature.get_type_params(), return_type.clone(), + Some(signature.get_function_generic_params()), ); - match cache - .with_no_flow(|cache| instantiate_func_generic(db, cache, &func, call_expr.clone())) - { + match cache.with_no_flow(|cache| infer_call_generic(db, cache, &func, call_expr.clone())) { Ok(instantiated) => instantiated.get_ret().clone(), Err(_) => return_type, } diff --git a/crates/emmylua_code_analysis/src/semantic/member/find_members.rs b/crates/emmylua_code_analysis/src/semantic/member/find_members.rs index 606ba27b5..e38709fab 100644 --- a/crates/emmylua_code_analysis/src/semantic/member/find_members.rs +++ b/crates/emmylua_code_analysis/src/semantic/member/find_members.rs @@ -4,8 +4,8 @@ use smol_str::SmolStr; use crate::{ DbIndex, FileId, InferGuardRef, LuaGenericType, LuaInstanceType, LuaIntersectionType, - LuaMemberKey, LuaMemberOwner, LuaObjectType, LuaSemanticDeclId, LuaTupleType, LuaType, - LuaTypeDeclId, LuaUnionType, + LuaMemberFeature, LuaMemberIndexItem, LuaMemberKey, LuaMemberOwner, LuaObjectType, + LuaSemanticDeclId, LuaTupleType, LuaType, LuaTypeDeclId, LuaTypeOwner, LuaUnionType, semantic::{ InferGuard, generic::{TypeSubstitutor, instantiate_type_generic}, @@ -233,6 +233,17 @@ fn find_normal_members( member_owner: LuaMemberOwner, filter: &FindMemberFilter, ) -> FindMembersResult { + if let FindMemberFilter::ByKey { + member_key, + find_all, + } = filter + { + let member_item = db + .get_member_index() + .get_member_item(&member_owner, member_key)?; + return collect_member_infos_from_item(db, ctx, member_item, *find_all); + } + let mut members = Vec::new(); let member_index = db.get_member_index(); let owner_members = member_index.get_members(&member_owner)?; @@ -241,18 +252,14 @@ fn find_normal_members( let member_key = member.get_key().clone(); if should_include_member(&member_key, filter) { - let raw_type = db - .get_type_index() - .get_type_cache(&member.get_id().into()) - .map(|t| t.as_type().clone()) - .unwrap_or(LuaType::Unknown); - members.push(LuaMemberInfo { - property_owner_id: Some(LuaSemanticDeclId::Member(member.get_id())), - key: member_key, - typ: ctx.instantiate_type(db, &raw_type), - feature: Some(member.get_feature()), - overload_index: None, - }); + members.push(semantic_decl_to_member_info( + db, + ctx, + LuaTypeOwner::Member(member.get_id()), + LuaSemanticDeclId::Member(member.get_id()), + member_key, + Some(member.get_feature()), + )); if should_stop_collecting(members.len(), filter) { break; @@ -263,6 +270,55 @@ fn find_normal_members( Some(members) } +fn collect_member_infos_from_item( + db: &DbIndex, + ctx: &FindMembersContext, + member_item: &LuaMemberIndexItem, + find_all: bool, +) -> FindMembersResult { + let mut members = Vec::new(); + for member_id in member_item.get_member_ids() { + let member = db.get_member_index().get_member(&member_id)?; + members.push(semantic_decl_to_member_info( + db, + ctx, + LuaTypeOwner::Member(member.get_id()), + LuaSemanticDeclId::Member(member.get_id()), + member.get_key().clone(), + Some(member.get_feature()), + )); + + if !find_all { + break; + } + } + + Some(members) +} + +fn semantic_decl_to_member_info( + db: &DbIndex, + ctx: &FindMembersContext, + type_owner: LuaTypeOwner, + property_owner_id: LuaSemanticDeclId, + key: LuaMemberKey, + feature: Option, +) -> LuaMemberInfo { + let raw_type = db + .get_type_index() + .get_type_cache(&type_owner) + .map(|t| t.as_type().clone()) + .unwrap_or(LuaType::Unknown); + + LuaMemberInfo { + property_owner_id: Some(property_owner_id), + key, + typ: ctx.instantiate_type(db, &raw_type), + feature, + overload_index: None, + } +} + fn find_custom_type_members( db: &DbIndex, ctx: &FindMembersContext, @@ -282,25 +338,37 @@ fn find_custom_type_members( let mut members = Vec::new(); let member_index = db.get_member_index(); - if let Some(type_members) = - member_index.get_members(&LuaMemberOwner::Type(type_decl_id.clone())) + let type_owner = LuaMemberOwner::Type(type_decl_id.clone()); + if let FindMemberFilter::ByKey { + member_key, + find_all, + } = filter { + if let Some(member_item) = member_index.get_member_item(&type_owner, member_key) { + members.extend(collect_member_infos_from_item( + db, + ctx, + member_item, + *find_all, + )?); + + if should_stop_collecting(members.len(), filter) { + return Some(members); + } + } + } else if let Some(type_members) = member_index.get_members(&type_owner) { for member in type_members { let member_key = member.get_key().clone(); if should_include_member(&member_key, filter) { - let raw_type = db - .get_type_index() - .get_type_cache(&member.get_id().into()) - .map(|t| t.as_type().clone()) - .unwrap_or(LuaType::Unknown); - members.push(LuaMemberInfo { - property_owner_id: Some(LuaSemanticDeclId::Member(member.get_id())), - key: member_key, - typ: ctx.instantiate_type(db, &raw_type), - feature: Some(member.get_feature()), - overload_index: None, - }); + members.push(semantic_decl_to_member_info( + db, + ctx, + LuaTypeOwner::Member(member.get_id()), + LuaSemanticDeclId::Member(member.get_id()), + member_key, + Some(member.get_feature()), + )); if should_stop_collecting(members.len(), filter) { return Some(members); @@ -522,18 +590,14 @@ fn find_global_members( let member_key = LuaMemberKey::Name(decl.get_name().to_string().into()); if should_include_member(&member_key, filter) { - let raw_type = db - .get_type_index() - .get_type_cache(&decl_id.into()) - .map(|t| t.as_type().clone()) - .unwrap_or(LuaType::Unknown); - members.push(LuaMemberInfo { - property_owner_id: Some(LuaSemanticDeclId::LuaDecl(decl_id)), - key: member_key, - typ: ctx.instantiate_type(db, &raw_type), - feature: None, - overload_index: None, - }); + members.push(semantic_decl_to_member_info( + db, + ctx, + LuaTypeOwner::Decl(decl_id), + LuaSemanticDeclId::LuaDecl(decl_id), + member_key, + None, + )); if should_stop_collecting(members.len(), filter) { break; diff --git a/crates/emmylua_code_analysis/src/semantic/member/mod.rs b/crates/emmylua_code_analysis/src/semantic/member/mod.rs index 5025ea4ed..d01e21830 100644 --- a/crates/emmylua_code_analysis/src/semantic/member/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/member/mod.rs @@ -6,7 +6,7 @@ mod infer_raw_member; use std::collections::HashSet; use crate::{ - DbIndex, LuaMemberFeature, LuaMemberId, LuaMemberKey, LuaSemanticDeclId, TypeOps, + DbIndex, LuaMemberFeature, LuaMemberId, LuaMemberKey, LuaSemanticDeclId, LuaUnionType, TypeOps, db_index::{LuaType, LuaTypeDeclId}, }; use emmylua_parser::{LuaAssignStat, LuaAstNode, LuaSyntaxKind, LuaTableExpr, LuaTableField}; @@ -61,7 +61,7 @@ pub fn find_member_origin_owner( const MAX_ITERATIONS: usize = 50; let mut visited_members = HashSet::new(); - let mut current_owner = resolve_member_owner(db, infer_config, &member_id); + let mut current_owner = resolve_member_owner_with_file_cache(db, infer_config, &member_id); let mut final_owner = current_owner.clone(); let mut iteration_count = 0; @@ -73,7 +73,7 @@ pub fn find_member_origin_owner( visited_members.insert(*current_member_id); iteration_count += 1; - match resolve_member_owner(db, infer_config, current_member_id) { + match resolve_member_owner_with_file_cache(db, infer_config, current_member_id) { Some(next_owner) => { final_owner = Some(next_owner.clone()); current_owner = Some(next_owner); @@ -85,6 +85,19 @@ pub fn find_member_origin_owner( final_owner } +fn resolve_member_owner_with_file_cache( + db: &DbIndex, + infer_config: &mut LuaInferCache, + member_id: &LuaMemberId, +) -> Option { + if infer_config.get_file_id() == member_id.file_id { + return resolve_member_owner(db, infer_config, member_id); + } + + let mut member_file_cache = infer_config.fork_for_file(member_id.file_id); + resolve_member_owner(db, &mut member_file_cache, member_id) +} + fn resolve_member_owner( db: &DbIndex, infer_config: &mut LuaInferCache, @@ -144,7 +157,7 @@ fn resolve_table_field_through_type_inference( let table_expr = LuaTableExpr::cast(parent)?; let table_type = infer_table_should_be(db, infer_config, table_expr).ok()?; - if !matches!(table_type, LuaType::Ref(_) | LuaType::Def(_)) { + if !table_is_class(&table_type, 0) { return None; } @@ -157,3 +170,19 @@ fn resolve_table_field_through_type_inference( .cloned() .and_then(|m| m.property_owner_id) } + +fn table_is_class(table_type: &LuaType, depth: usize) -> bool { + if depth > 10 { + return false; + } + + match table_type { + LuaType::Ref(_) | LuaType::Def(_) | LuaType::Generic(_) => true, + LuaType::Union(union) => match union.as_ref() { + LuaUnionType::Basic(_) => false, + LuaUnionType::Nullable(typ) => table_is_class(typ, depth + 1), + LuaUnionType::Multi(types) => types.iter().any(|typ| table_is_class(typ, depth + 1)), + }, + _ => false, + } +} diff --git a/crates/emmylua_code_analysis/src/semantic/mod.rs b/crates/emmylua_code_analysis/src/semantic/mod.rs index b9cd692f9..e15272981 100644 --- a/crates/emmylua_code_analysis/src/semantic/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/mod.rs @@ -21,7 +21,7 @@ use emmylua_parser::{ LuaSyntaxToken, LuaTableExpr, }; pub use infer::infer_index_expr; -use infer::{infer_bind_value_type, infer_expr_list_types}; +use infer::{infer_bind_value_type, infer_call_receiver_type, infer_expr_list_types}; pub use infer::{infer_table_field_value_should_be, infer_table_should_be}; use lsp_types::Uri; pub use member::LuaMemberInfo; @@ -39,7 +39,7 @@ use semantic_info::{ infer_node_semantic_info, infer_token_semantic_decl, infer_token_semantic_info, }; pub(crate) use type_check::check_type_compact; -use type_check::is_sub_type_of; +pub(crate) use type_check::is_sub_type_of; pub use visibility::check_module_visibility; use visibility::check_visibility; @@ -59,6 +59,12 @@ pub use infer::infer_param; pub(crate) use infer::try_infer_expr_for_index; pub(crate) use infer::{infer_expr, try_infer_expr_no_flow}; use overload_resolve::resolve_signature; +pub(crate) use overload_resolve::{ + callable_accepts_args, get_func_param_type, is_func_last_param_variadic, +}; +pub use overload_resolve::{ + collect_callable_overload_groups, filter_callable_overloads, find_callable_overload, +}; pub use semantic_info::SemanticDeclLevel; pub use type_check::{TypeCheckFailReason, TypeCheckResult}; @@ -182,6 +188,16 @@ impl<'a> SemanticModel<'a> { .ok() } + pub fn callable_accepts_args( + &self, + func: &LuaFunctionType, + expr_types: &[LuaType], + is_colon_call: bool, + arg_count: Option, + ) -> bool { + callable_accepts_args(self.db, func, expr_types, is_colon_call, arg_count) + } + /// 推断表达式列表类型, 位于最后的表达式会触发多值推断 pub fn infer_expr_list_types( &self, @@ -318,6 +334,10 @@ impl<'a> SemanticModel<'a> { find_member_origin_owner(self.db, &mut self.infer_cache.borrow_mut(), member_id) } + pub fn infer_call_receiver_type(&self, call_expr: &LuaCallExpr) -> Option { + infer_call_receiver_type(self.db, &mut self.infer_cache.borrow_mut(), call_expr) + } + pub fn get_index_decl_type(&self, index_expr: LuaIndexExpr) -> Option { let cache = &mut self.infer_cache.borrow_mut(); infer_index_expr(self.db, cache, index_expr, false).ok() diff --git a/crates/emmylua_code_analysis/src/semantic/overload_resolve/collect_overloads.rs b/crates/emmylua_code_analysis/src/semantic/overload_resolve/collect_overloads.rs new file mode 100644 index 000000000..7ebd97b28 --- /dev/null +++ b/crates/emmylua_code_analysis/src/semantic/overload_resolve/collect_overloads.rs @@ -0,0 +1,167 @@ +use hashbrown::HashSet; +use std::sync::Arc; + +use crate::{ + DbIndex, LuaOperatorMetaMethod, LuaOperatorOwner, LuaTypeDeclId, + db_index::{LuaFunctionType, LuaType}, + semantic::{ + generic::{TypeSubstitutor, instantiate_type_generic}, + infer::InferFailReason, + }, +}; + +pub fn collect_callable_overload_groups( + db: &DbIndex, + callable_type: &LuaType, + groups: &mut Vec>>, +) -> Result<(), InferFailReason> { + let mut visiting_aliases = HashSet::new(); + collect_callable_overload_groups_inner(db, callable_type, groups, &mut visiting_aliases) +} + +fn collect_callable_overload_groups_inner( + db: &DbIndex, + callable_type: &LuaType, + groups: &mut Vec>>, + visiting_aliases: &mut HashSet, +) -> Result<(), InferFailReason> { + match callable_type { + LuaType::Ref(type_id) | LuaType::Def(type_id) => { + let Some(type_decl) = db.get_type_index().get_type_decl(type_id) else { + return Ok(()); + }; + if !visiting_aliases.insert(type_id.clone()) { + return Ok(()); + } + + let result = if let Some(origin_type) = type_decl.get_alias_origin(db, None) { + collect_callable_overload_groups_inner(db, &origin_type, groups, visiting_aliases) + } else { + Ok(()) + }; + // alias 的可调用性来自 origin, 非 alias 类型再补充自身的 __call 候选 + if !type_decl.is_alias() && !type_decl.is_enum() { + push_call_operator_overload_group(db, &type_id.clone().into(), groups, None); + } + visiting_aliases.remove(type_id); + result?; + } + LuaType::Generic(generic) => { + let type_id = generic.get_base_type_id(); + if !visiting_aliases.insert(type_id.clone()) { + return Ok(()); + } + let substitutor = TypeSubstitutor::from_type_array(generic.get_params().to_vec()); + let Some(type_decl) = db.get_type_index().get_type_decl(&type_id) else { + visiting_aliases.remove(&type_id); + return Ok(()); + }; + + let result = if let Some(origin_type) = + type_decl.get_alias_origin(db, Some(&substitutor)) + { + collect_callable_overload_groups_inner(db, &origin_type, groups, visiting_aliases) + } else { + Ok(()) + }; + // 泛型类型的 __call 需要先替换类型模板, 否则候选会保留未实例化的 T + if !type_decl.is_alias() && !type_decl.is_enum() { + push_call_operator_overload_group( + db, + &type_id.clone().into(), + groups, + Some(&substitutor), + ); + } + visiting_aliases.remove(&type_id); + result?; + } + LuaType::Union(union) => { + for member in union.into_vec() { + collect_callable_overload_groups_inner(db, &member, groups, visiting_aliases)?; + } + } + LuaType::Intersection(intersection) => { + for member in intersection.get_types() { + collect_callable_overload_groups_inner(db, member, groups, visiting_aliases)?; + } + } + LuaType::DocFunction(doc_func) => groups.push(vec![doc_func.clone()]), + LuaType::Signature(sig_id) => { + let Some(signature) = db.get_signature_index().get(sig_id) else { + return Ok(()); + }; + let mut overloads = signature.overloads.to_vec(); + overloads.push(signature.to_doc_func_type()); + groups.push(overloads); + } + LuaType::Instance(instance) => { + // instance 的可调用性由它的 base 决定. + collect_callable_overload_groups_inner( + db, + instance.get_base(), + groups, + visiting_aliases, + )?; + } + LuaType::TableConst(table) => { + // setmetatable 产生的 __call 挂在 metatable owner 上. + if let Some(meta_table) = db.get_metatable_index().get(table) { + push_call_operator_overload_group( + db, + &LuaOperatorOwner::Table(meta_table.clone()), + groups, + None, + ); + } + } + _ => {} + } + + Ok(()) +} + +fn push_call_operator_overload_group( + db: &DbIndex, + owner: &LuaOperatorOwner, + groups: &mut Vec>>, + substitutor: Option<&TypeSubstitutor>, +) { + let Some(operator_ids) = db + .get_operator_index() + .get_operators(owner, LuaOperatorMetaMethod::Call) + else { + return; + }; + + // 同一个 owner 的 call operators 作为一个 overload group, 由调用方再做参数匹配. + let mut overloads = Vec::new(); + for operator_id in operator_ids { + let Some(operator) = db.get_operator_index().get_operator(operator_id) else { + continue; + }; + + let mut func_type = operator.get_operator_func(db); + if let Some(substitutor) = substitutor { + func_type = instantiate_type_generic(db, &func_type, substitutor); + } + + match func_type { + LuaType::DocFunction(func) => overloads.push(func), + LuaType::Signature(signature_id) => { + let Some(signature) = db.get_signature_index().get(&signature_id) else { + continue; + }; + // 未解析返回的 signature 不能安全转换成候选, 这里先跳过. + if signature.is_resolve_return() { + overloads.push(signature.to_call_operator_func_type()); + } + } + _ => {} + } + } + + if !overloads.is_empty() { + groups.push(overloads); + } +} diff --git a/crates/emmylua_code_analysis/src/semantic/overload_resolve/filter_overloads.rs b/crates/emmylua_code_analysis/src/semantic/overload_resolve/filter_overloads.rs new file mode 100644 index 000000000..7563d8106 --- /dev/null +++ b/crates/emmylua_code_analysis/src/semantic/overload_resolve/filter_overloads.rs @@ -0,0 +1,102 @@ +use std::sync::Arc; + +use emmylua_parser::LuaCallExpr; + +use crate::{ + DbIndex, LuaFunctionType, LuaType, + semantic::{LuaInferCache, generic::infer_call_generic, infer::InferFailReason}, +}; + +use super::{ + collect_overloads::collect_callable_overload_groups, + resolve_signature_by_args::callable_accepts_args, +}; + +pub fn filter_callable_overloads( + db: &DbIndex, + cache: &mut LuaInferCache, + callable_type: &LuaType, + call_arg_types: &[LuaType], + call_expr: &LuaCallExpr, + args_count: Option, + return_instantiated_generic: bool, +) -> Result>, InferFailReason> { + let mut overload_groups = Vec::new(); + collect_callable_overload_groups(db, callable_type, &mut overload_groups)?; + + Ok(overload_groups + .into_iter() + .flatten() + .filter_map(|func| { + match_callable_by_arg_types( + db, + cache, + func, + call_arg_types, + call_expr, + args_count, + return_instantiated_generic, + ) + }) + .collect()) +} + +pub fn find_callable_overload( + db: &DbIndex, + cache: &mut LuaInferCache, + callable_type: &LuaType, + call_arg_types: &[LuaType], + call_expr: &LuaCallExpr, + args_count: Option, + return_instantiated_generic: bool, +) -> Result>, InferFailReason> { + let mut overload_groups = Vec::new(); + collect_callable_overload_groups(db, callable_type, &mut overload_groups)?; + + Ok(overload_groups.into_iter().flatten().find_map(|func| { + match_callable_by_arg_types( + db, + cache, + func, + call_arg_types, + call_expr, + args_count, + return_instantiated_generic, + ) + })) +} + +pub(crate) fn match_callable_by_arg_types( + db: &DbIndex, + cache: &mut LuaInferCache, + func: Arc, + call_arg_types: &[LuaType], + call_expr: &LuaCallExpr, + args_count: Option, + return_instantiated_generic: bool, +) -> Option> { + let has_tpls = func.contain_tpl(); + let match_func = if has_tpls { + infer_call_generic(db, cache, func.as_ref(), call_expr.clone()) + .map(Arc::new) + .unwrap_or_else(|_| func.clone()) + } else { + func.clone() + }; + + if !callable_accepts_args( + db, + &match_func, + call_arg_types, + call_expr.is_colon_call(), + args_count, + ) { + return None; + } + + if has_tpls && return_instantiated_generic { + Some(match_func) + } else { + Some(func) + } +} diff --git a/crates/emmylua_code_analysis/src/semantic/overload_resolve/mod.rs b/crates/emmylua_code_analysis/src/semantic/overload_resolve/mod.rs index a6447a91c..bd7b0cf12 100644 --- a/crates/emmylua_code_analysis/src/semantic/overload_resolve/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/overload_resolve/mod.rs @@ -1,3 +1,5 @@ +mod collect_overloads; +mod filter_overloads; mod resolve_signature_by_args; use std::sync::Arc; @@ -8,11 +10,17 @@ use crate::db_index::{DbIndex, LuaFunctionType, LuaType}; use super::{ LuaInferCache, - generic::instantiate_func_generic, + generic::infer_call_generic, infer::{InferCallFuncResult, InferFailReason, infer_expr_list_types, try_infer_expr_no_flow}, }; -pub(crate) use resolve_signature_by_args::{callable_accepts_args, resolve_signature_by_args}; +pub use collect_overloads::collect_callable_overload_groups; +pub(crate) use filter_overloads::match_callable_by_arg_types; +pub use filter_overloads::{filter_callable_overloads, find_callable_overload}; +pub(crate) use resolve_signature_by_args::{ + callable_accepts_args, get_func_param_type, is_func_last_param_variadic, + resolve_signature_by_args, +}; pub fn resolve_signature( db: &DbIndex, @@ -78,7 +86,7 @@ fn resolve_signature_by_generic( ) -> InferCallFuncResult { let mut instantiate_funcs = Vec::new(); for func in overloads { - let instantiate_func = instantiate_func_generic(db, cache, &func, call_expr.clone())?; + let instantiate_func = infer_call_generic(db, cache, &func, call_expr.clone())?; instantiate_funcs.push(Arc::new(instantiate_func)); } resolve_signature_by_args( diff --git a/crates/emmylua_code_analysis/src/semantic/overload_resolve/resolve_signature_by_args.rs b/crates/emmylua_code_analysis/src/semantic/overload_resolve/resolve_signature_by_args.rs index 7e2217f27..28537e3cc 100644 --- a/crates/emmylua_code_analysis/src/semantic/overload_resolve/resolve_signature_by_args.rs +++ b/crates/emmylua_code_analysis/src/semantic/overload_resolve/resolve_signature_by_args.rs @@ -22,7 +22,7 @@ pub(crate) fn callable_accepts_args( let Some(param_index) = get_call_param_index(func, arg_index, is_colon_call) else { continue; }; - let Some(param_type) = get_call_arg_param_type(func, param_index) else { + let Some(param_type) = get_func_param_type(func, param_index) else { return false; }; @@ -86,7 +86,7 @@ pub fn resolve_signature_by_args( let Some(param_index) = get_call_param_index(func, arg_index, is_colon_call) else { continue; }; - let Some(param_type) = get_call_arg_param_type(func, param_index) else { + let Some(param_type) = get_func_param_type(func, param_index) else { *opt_func = None; continue; }; @@ -236,7 +236,7 @@ pub fn resolve_signature_by_args( } } -fn is_func_last_param_variadic(func: &LuaFunctionType) -> bool { +pub(crate) fn is_func_last_param_variadic(func: &LuaFunctionType) -> bool { if let Some(last_param) = func.get_params().last() { last_param.0 == "..." } else { @@ -244,7 +244,7 @@ fn is_func_last_param_variadic(func: &LuaFunctionType) -> bool { } } -fn get_call_param_index( +pub(crate) fn get_call_param_index( func: &LuaFunctionType, arg_index: usize, is_colon_call: bool, @@ -265,7 +265,7 @@ fn get_call_param_index( Some(param_index) } -fn get_call_arg_param_type(func: &LuaFunctionType, param_index: usize) -> Option { +pub(crate) fn get_func_param_type(func: &LuaFunctionType, param_index: usize) -> Option { if let Some(param_info) = func.get_params().get(param_index) { return Some(param_info.1.clone().unwrap_or(LuaType::Any)); } diff --git a/crates/emmylua_code_analysis/src/semantic/type_check/mod.rs b/crates/emmylua_code_analysis/src/semantic/type_check/mod.rs index 215d71842..35e737587 100644 --- a/crates/emmylua_code_analysis/src/semantic/type_check/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/type_check/mod.rs @@ -124,7 +124,6 @@ fn check_general_type_compact( | LuaType::DocBooleanConst(_) | LuaType::TplRef(_) | LuaType::StrTplRef(_) - | LuaType::ConstTplRef(_) | LuaType::Namespace(_) | LuaType::Variadic(_) | LuaType::Language(_) => { @@ -195,11 +194,7 @@ fn check_general_type_compact( fn is_like_any(ty: &LuaType) -> bool { matches!( ty, - LuaType::Any - | LuaType::Unknown - | LuaType::TplRef(_) - | LuaType::StrTplRef(_) - | LuaType::ConstTplRef(_) + LuaType::Any | LuaType::Unknown | LuaType::TplRef(_) | LuaType::StrTplRef(_) ) } diff --git a/crates/emmylua_code_analysis/src/semantic/type_check/simple_type.rs b/crates/emmylua_code_analysis/src/semantic/type_check/simple_type.rs index 6f994cb0c..123279d46 100644 --- a/crates/emmylua_code_analysis/src/semantic/type_check/simple_type.rs +++ b/crates/emmylua_code_analysis/src/semantic/type_check/simple_type.rs @@ -263,7 +263,7 @@ pub fn check_simple_type_compact( return Ok(()); } } - LuaType::TplRef(_) | LuaType::ConstTplRef(_) => return Ok(()), + LuaType::TplRef(_) => return Ok(()), LuaType::Namespace(source_namespace) => { if let LuaType::Namespace(compact_namespace) = compact_type && source_namespace == compact_namespace diff --git a/crates/emmylua_code_analysis/src/semantic/type_check/sub_type.rs b/crates/emmylua_code_analysis/src/semantic/type_check/sub_type.rs index 9fe6344d0..411dfcb6d 100644 --- a/crates/emmylua_code_analysis/src/semantic/type_check/sub_type.rs +++ b/crates/emmylua_code_analysis/src/semantic/type_check/sub_type.rs @@ -1,6 +1,6 @@ -use std::collections::HashSet; +use hashbrown::HashSet; -use crate::{DbIndex, LuaType, LuaTypeDeclId}; +use crate::{DbIndex, LuaType, LuaTypeDeclId, LuaTypeIdentifier}; /// 检查子类型关系. /// @@ -10,16 +10,16 @@ pub fn is_sub_type_of( sub_type_ref_id: &LuaTypeDeclId, super_type_ref_id: &LuaTypeDeclId, ) -> bool { - check_sub_type_of_iterative(db, sub_type_ref_id, super_type_ref_id).unwrap_or(false) + check_sub_type_of_iterative(db, sub_type_ref_id, super_type_ref_id) } fn check_sub_type_of_iterative( db: &DbIndex, sub_type_ref_id: &LuaTypeDeclId, super_type_ref_id: &LuaTypeDeclId, -) -> Option { +) -> bool { if sub_type_ref_id == super_type_ref_id { - return Some(true); + return true; } let type_index = db.get_type_index(); @@ -27,11 +27,8 @@ fn check_sub_type_of_iterative( let mut visited = HashSet::with_capacity(4); stack.push(sub_type_ref_id); + visited.insert(sub_type_ref_id); while let Some(current_id) = stack.pop() { - if !visited.insert(current_id) { - continue; - } - let supers_iter = match type_index.get_super_types_iter(current_id) { Some(iter) => iter, None => continue, @@ -42,9 +39,9 @@ fn check_sub_type_of_iterative( LuaType::Ref(super_id) => { // TODO: 不相等时可以判断必要字段是否全部匹配, 如果匹配则认为相等 if super_id == super_type_ref_id { - return Some(true); + return true; } - if !visited.contains(super_id) { + if visited.insert(super_id) { stack.push(super_id); } } @@ -52,56 +49,63 @@ fn check_sub_type_of_iterative( LuaType::Generic(generic) => { let base_type_id = generic.get_base_type_id_ref(); if base_type_id == super_type_ref_id { - return Some(true); + return true; } - if !visited.contains(&base_type_id) { + if visited.insert(base_type_id) { stack.push(base_type_id); } } _ => { - if let Some(base_id) = get_base_type_id(super_type) - && base_id == *super_type_ref_id - { - return Some(true); + if is_base_type_id(super_type, super_type_ref_id) { + return true; } } } } } - Some(false) + false } pub fn get_base_type_id(typ: &LuaType) -> Option { + base_type_name(typ).map(LuaTypeDeclId::global) +} + +fn is_base_type_id(typ: &LuaType, type_id: &LuaTypeDeclId) -> bool { + let LuaTypeIdentifier::Global(type_name) = type_id.get_id() else { + return false; + }; + let type_name: &str = type_name.as_ref(); + + base_type_name(typ).is_some_and(|base_name| base_name == type_name) +} + +fn base_type_name(typ: &LuaType) -> Option<&'static str> { match typ { LuaType::Integer | LuaType::IntegerConst(_) | LuaType::DocIntegerConst(_) => { - Some(LuaTypeDeclId::global("integer")) + Some("integer") } - LuaType::Number | LuaType::FloatConst(_) => Some(LuaTypeDeclId::global("number")), + LuaType::Number | LuaType::FloatConst(_) => Some("number"), LuaType::Boolean | LuaType::BooleanConst(_) | LuaType::DocBooleanConst(_) => { - Some(LuaTypeDeclId::global("boolean")) - } - LuaType::String | LuaType::StringConst(_) | LuaType::DocStringConst(_) => { - Some(LuaTypeDeclId::global("string")) + Some("boolean") } + LuaType::String | LuaType::StringConst(_) | LuaType::DocStringConst(_) => Some("string"), LuaType::Table | LuaType::TableGeneric(_) | LuaType::TableConst(_) | LuaType::Tuple(_) | LuaType::Array(_) - | LuaType::Object(_) => Some(LuaTypeDeclId::global("table")), + | LuaType::Object(_) => Some("table"), LuaType::Intersection(intersection) => { - intersection.get_types().iter().find_map(get_base_type_id) - } - LuaType::DocFunction(_) | LuaType::Function | LuaType::Signature(_) => { - Some(LuaTypeDeclId::global("function")) + intersection.get_types().iter().find_map(base_type_name) } - LuaType::Thread => Some(LuaTypeDeclId::global("thread")), - LuaType::Userdata => Some(LuaTypeDeclId::global("userdata")), - LuaType::Io => Some(LuaTypeDeclId::global("io")), - LuaType::Global => Some(LuaTypeDeclId::global("global")), - LuaType::SelfInfer => Some(LuaTypeDeclId::global("self")), - LuaType::Nil => Some(LuaTypeDeclId::global("nil")), + LuaType::DocFunction(_) | LuaType::Function | LuaType::Signature(_) => Some("function"), + LuaType::Thread => Some("thread"), + LuaType::Userdata => Some("userdata"), + LuaType::Io => Some("io"), + LuaType::Global => Some("global"), + LuaType::SelfInfer => Some("self"), + LuaType::Nil => Some("nil"), _ => None, } } diff --git a/crates/emmylua_doc_cli/src/json_generator/export.rs b/crates/emmylua_doc_cli/src/json_generator/export.rs index 8e4284d5e..5ac500acc 100644 --- a/crates/emmylua_doc_cli/src/json_generator/export.rs +++ b/crates/emmylua_doc_cli/src/json_generator/export.rs @@ -194,7 +194,7 @@ fn export_generics(db: &DbIndex, type_decl_id: &LuaTypeDeclId) -> Vec { .map(|it| TypeVar { name: it.name.to_string(), base: it - .type_constraint + .constraint .as_ref() .map(|typ| render_typ(db, typ, RenderLevel::Simple)), }) diff --git a/crates/emmylua_ls/locales/tags/en.yaml b/crates/emmylua_ls/locales/tags/en.yaml index 08fc4f23c..ad4ced881 100644 --- a/crates/emmylua_ls/locales/tags/en.yaml +++ b/crates/emmylua_ls/locales/tags/en.yaml @@ -231,15 +231,3 @@ tags.language: | INSERT INTO users (name, email) VALUES ('Alice', 'alice@example.com'); ]] ``` -tags.attribute: | - `attribute` tag defines an attribute. Attribute is used to attach extra information to a definition. - Example: - ```lua - ---@attribute deprecated(message: string?) - - ---@class A - ---@[deprecated("delete")] # `b` field is marked as deprecated - ---@field b string - ---@[deprecated] # If `attribute` allows no parameters, the parentheses can be omitted - ---@field c string - ``` diff --git a/crates/emmylua_ls/locales/tags/zh_CN.yaml b/crates/emmylua_ls/locales/tags/zh_CN.yaml index d786206a8..919809404 100644 --- a/crates/emmylua_ls/locales/tags/zh_CN.yaml +++ b/crates/emmylua_ls/locales/tags/zh_CN.yaml @@ -231,15 +231,3 @@ tags.language: | INSERT INTO users (name, email) VALUES ('Alice', 'alice@example.com'); ]] ``` -tags.attribute: | - `attribute` 标签定义一个特性。特性用于附加额外信息到定义。 - 示例: - ```lua - ---@attribute deprecated(message: string?) - - ---@class A - ---@[deprecated("delete")] - ---@field b string # `b` 字段被标记为已弃用 - ---@[deprecated] # 如果`attribute`允许无参数,则可以省略括号 - ---@field c string - ``` diff --git a/crates/emmylua_ls/locales/tags/zh_HK.yaml b/crates/emmylua_ls/locales/tags/zh_HK.yaml index 9e6118c45..91d61f19c 100644 --- a/crates/emmylua_ls/locales/tags/zh_HK.yaml +++ b/crates/emmylua_ls/locales/tags/zh_HK.yaml @@ -231,15 +231,3 @@ tags.language: | INSERT INTO users (name, email) VALUES ('Alice', 'alice@example.com'); ]] ``` -tags.attribute: | - `attribute` 標籤定義一個特性。特性用於附加額外信息到定義。 - 示例: - ```lua - ---@attribute deprecated(message: string?) - - ---@class A - ---@[deprecated("delete")] # `b` 字段被標記為已棄用 - ---@field b string - ---@[deprecated] # 如果`attribute`允許無參數,則可以省略括號 - ---@field c string - ``` diff --git a/crates/emmylua_ls/src/handlers/common/find_origin.rs b/crates/emmylua_ls/src/handlers/common/find_origin.rs new file mode 100644 index 000000000..009519644 --- /dev/null +++ b/crates/emmylua_ls/src/handlers/common/find_origin.rs @@ -0,0 +1,174 @@ +use emmylua_code_analysis::{ + LuaDeclExtra, LuaDeclId, LuaMemberId, LuaSemanticDeclId, LuaType, SemanticDeclLevel, + SemanticModel, +}; + +#[derive(Debug, Clone)] +pub enum DeclOriginResult { + Single(LuaSemanticDeclId), + Multiple(Vec), +} + +impl DeclOriginResult { + pub fn get_first(&self) -> Option { + match self { + DeclOriginResult::Single(decl) => Some(decl.clone()), + DeclOriginResult::Multiple(decls) => decls.first().cloned(), + } + } + + pub fn get_types(&self, semantic_model: &SemanticModel) -> Vec<(LuaSemanticDeclId, LuaType)> { + let get_type = |decl: &LuaSemanticDeclId| -> Option<(LuaSemanticDeclId, LuaType)> { + match decl { + LuaSemanticDeclId::Member(member_id) => { + let typ = semantic_model.get_type((*member_id).into()); + Some((decl.clone(), typ)) + } + LuaSemanticDeclId::LuaDecl(decl_id) => { + let db = semantic_model.get_db(); + let decl_info = db.get_decl_index().get_decl(decl_id)?; + let typ = if let LuaDeclExtra::Param { + idx, signature_id, .. + } = &decl_info.extra + { + db.get_signature_index() + .get(signature_id)? + .get_param_info_by_id(*idx)? + .type_ref + .clone() + } else { + semantic_model.get_type((*decl_id).into()) + }; + Some((decl.clone(), typ)) + } + _ => None, + } + }; + + match self { + DeclOriginResult::Single(decl) => get_type(decl).into_iter().collect(), + DeclOriginResult::Multiple(decls) => decls.iter().filter_map(get_type).collect(), + } + } +} + +pub fn find_decl_origin_owners( + semantic_model: &SemanticModel, + decl_id: LuaDeclId, +) -> DeclOriginResult { + let node = semantic_model + .get_db() + .get_vfs() + .get_syntax_tree(&decl_id.file_id) + .and_then(|tree| { + let root = tree.get_red_root(); + semantic_model + .get_db() + .get_decl_index() + .get_decl(&decl_id) + .and_then(|decl| decl.get_value_syntax_id()) + .and_then(|syntax_id| syntax_id.to_node_from_root(&root)) + }); + + if let Some(node) = node { + let semantic_decl = semantic_model.find_decl(node.into(), SemanticDeclLevel::default()); + match semantic_decl { + Some(LuaSemanticDeclId::Member(member_id)) => { + find_member_origin_owners(semantic_model, member_id, true) + } + Some(LuaSemanticDeclId::LuaDecl(decl_id)) => { + DeclOriginResult::Single(LuaSemanticDeclId::LuaDecl(decl_id)) + } + _ => DeclOriginResult::Single(LuaSemanticDeclId::LuaDecl(decl_id)), + } + } else { + DeclOriginResult::Single(LuaSemanticDeclId::LuaDecl(decl_id)) + } +} + +pub fn find_member_origin_owners( + semantic_model: &SemanticModel, + member_id: LuaMemberId, + find_all: bool, +) -> DeclOriginResult { + let final_owner = semantic_model + .get_member_origin_owner(member_id) + .and_then(|origin| reject_param_origin(semantic_model, origin)) + .unwrap_or_else(|| LuaSemanticDeclId::Member(member_id)); + + if !find_all { + return DeclOriginResult::Single(final_owner); + } + + // 如果存在多个同名成员, 则返回多个成员 + let final_owner_result = Some(final_owner.clone()); + if let Some(same_named_members) = + find_all_same_named_members(semantic_model, &final_owner_result) + && same_named_members.len() > 1 + { + return DeclOriginResult::Multiple(same_named_members); + } + // 否则返回单个成员 + DeclOriginResult::Single(final_owner) +} + +pub fn find_member_origin_owner( + semantic_model: &SemanticModel, + member_id: LuaMemberId, +) -> Option { + find_member_origin_owners(semantic_model, member_id, false).get_first() +} + +pub fn find_all_same_named_members( + semantic_model: &SemanticModel, + final_owner: &Option, +) -> Option> { + let final_owner = final_owner.as_ref()?; + let member_id = match final_owner { + LuaSemanticDeclId::Member(id) => id, + _ => return None, + }; + + let original_member = semantic_model + .get_db() + .get_member_index() + .get_member(member_id)?; + + let target_key = original_member.get_key(); + let current_owner = semantic_model + .get_db() + .get_member_index() + .get_current_owner(member_id)?; + + let all_members = semantic_model + .get_db() + .get_member_index() + .get_members(current_owner)?; + let same_named: Vec = all_members + .iter() + .filter(|member| member.get_key() == target_key) + .map(|member| LuaSemanticDeclId::Member(member.get_id())) + .collect(); + + if same_named.is_empty() { + None + } else { + Some(same_named) + } +} + +fn reject_param_origin( + semantic_model: &SemanticModel, + result: LuaSemanticDeclId, +) -> Option { + match &result { + LuaSemanticDeclId::LuaDecl(decl_id) => { + let decl = semantic_model.get_db().get_decl_index().get_decl(decl_id)?; + if decl.is_param() { + return None; + } + Some(result) + } + _ => Some(result), + } +} diff --git a/crates/emmylua_ls/src/handlers/common/mod.rs b/crates/emmylua_ls/src/handlers/common/mod.rs new file mode 100644 index 000000000..c05342c04 --- /dev/null +++ b/crates/emmylua_ls/src/handlers/common/mod.rs @@ -0,0 +1,6 @@ +mod find_origin; + +pub(crate) use find_origin::{ + find_all_same_named_members, find_decl_origin_owners, find_member_origin_owner, + find_member_origin_owners, +}; diff --git a/crates/emmylua_ls/src/handlers/completion/add_completions/add_decl_completion.rs b/crates/emmylua_ls/src/handlers/completion/add_completions/add_decl_completion.rs index b116c70e6..8d80f4b74 100644 --- a/crates/emmylua_ls/src/handlers/completion/add_completions/add_decl_completion.rs +++ b/crates/emmylua_ls/src/handlers/completion/add_completions/add_decl_completion.rs @@ -1,4 +1,4 @@ -use emmylua_code_analysis::{DbIndex, LuaDeclId, LuaSemanticDeclId, LuaType}; +use emmylua_code_analysis::{LuaDeclId, LuaSemanticDeclId, LuaType}; use lsp_types::CompletionItem; use crate::handlers::completion::{ @@ -19,12 +19,10 @@ pub fn add_decl_completion( let property_owner = LuaSemanticDeclId::LuaDecl(decl_id); check_visibility(builder, property_owner.clone())?; - let overload_count = count_function_overloads(builder.semantic_model.get_db(), typ); - let mut completion_item = CompletionItem { label: name.to_string(), kind: Some(get_completion_kind(typ)), - data: CompletionData::from_property_owner_id(builder, decl_id.into(), overload_count), + data: CompletionData::from_property_owner_id(builder, decl_id.into()), label_details: Some(lsp_types::CompletionItemLabelDetails { detail: get_detail(builder, typ, CallDisplay::None), description: get_description(builder, typ), @@ -46,23 +44,3 @@ pub fn add_decl_completion( builder.add_completion_item(completion_item)?; Some(()) } - -fn count_function_overloads(db: &DbIndex, typ: &LuaType) -> Option { - let mut count = 0; - match typ { - LuaType::DocFunction(_) => { - count += 1; - } - LuaType::Signature(id) => { - count += 1; - if let Some(signature) = db.get_signature_index().get(id) { - count += signature.overloads.len(); - } - } - _ => {} - } - if count > 1 { - count -= 1; - } - if count == 0 { None } else { Some(count) } -} diff --git a/crates/emmylua_ls/src/handlers/completion/add_completions/add_member_completion.rs b/crates/emmylua_ls/src/handlers/completion/add_completions/add_member_completion.rs index e282ee7f1..c23f1cb72 100644 --- a/crates/emmylua_ls/src/handlers/completion/add_completions/add_member_completion.rs +++ b/crates/emmylua_ls/src/handlers/completion/add_completions/add_member_completion.rs @@ -30,7 +30,6 @@ pub fn add_member_completion( builder: &mut CompletionBuilder, member_info: LuaMemberInfo, status: CompletionTriggerStatus, - overload_count: Option, ) -> Option<()> { if builder.is_cancelled() { return None; @@ -60,7 +59,7 @@ pub fn add_member_completion( for key in member_keys { let mut member_info = member_info.clone(); member_info.key = key; - add_member_completion(builder, member_info, status, None); + add_member_completion(builder, member_info, status); } } } @@ -99,9 +98,9 @@ pub fn add_member_completion( // 附加数据, 用于在`resolve`时进一步处理 let completion_data = if let Some(id) = &property_owner { if let Some(index) = member_info.overload_index { - CompletionData::from_overload(builder, id.clone(), index, overload_count) + CompletionData::from_overload(builder, id.clone(), index) } else { - CompletionData::from_property_owner_id(builder, id.clone(), overload_count) + CompletionData::from_property_owner_id(builder, id.clone()) } } else { None @@ -179,7 +178,6 @@ pub fn add_member_completion( call_display, deprecated, label, - overload_count, ); Some(()) @@ -192,7 +190,6 @@ fn add_signature_overloads( call_display: CallDisplay, deprecated: Option, label: String, - overload_count: Option, ) -> Option<()> { let signature_id = match typ { LuaType::Signature(signature_id) => signature_id, @@ -215,7 +212,7 @@ fn add_signature_overloads( let description = get_description(builder, &typ); let detail = get_detail(builder, &typ, call_display); let data = if let Some(id) = &property_owner { - CompletionData::from_overload(builder, id.clone(), index, overload_count) + CompletionData::from_overload(builder, id.clone(), index) } else { None }; diff --git a/crates/emmylua_ls/src/handlers/completion/completion_data.rs b/crates/emmylua_ls/src/handlers/completion/completion_data.rs index 0609257a4..95c91002c 100644 --- a/crates/emmylua_ls/src/handlers/completion/completion_data.rs +++ b/crates/emmylua_ls/src/handlers/completion/completion_data.rs @@ -9,8 +9,6 @@ pub struct CompletionData { pub field_id: FileId, pub trigger_offset: Option, pub typ: CompletionDataType, - /// Total count of function overloads - pub overload_count: Option, } #[allow(unused)] @@ -18,13 +16,11 @@ impl CompletionData { pub fn from_property_owner_id( builder: &CompletionBuilder, id: LuaSemanticDeclId, - overload_count: Option, ) -> Option { let data = Self { field_id: builder.semantic_model.get_file_id(), trigger_offset: Some(builder.position_offset.into()), typ: CompletionDataType::PropertyOwnerId(id), - overload_count, }; Some(serde_json::to_value(data).unwrap()) } @@ -33,13 +29,11 @@ impl CompletionData { builder: &CompletionBuilder, id: LuaSemanticDeclId, index: usize, - overload_count: Option, ) -> Option { let data = Self { field_id: builder.semantic_model.get_file_id(), trigger_offset: Some(builder.position_offset.into()), typ: CompletionDataType::Overload((id, index)), - overload_count, }; Some(serde_json::to_value(data).unwrap()) } @@ -49,7 +43,6 @@ impl CompletionData { field_id: builder.semantic_model.get_file_id(), trigger_offset: Some(builder.position_offset.into()), typ: CompletionDataType::Module(module), - overload_count: None, }; Some(serde_json::to_value(data).unwrap()) } @@ -61,226 +54,3 @@ pub enum CompletionDataType { Module(String), Overload((LuaSemanticDeclId, usize)), } - -// // Custom serialization implementation -// impl Serialize for CompletionData { -// fn serialize(&self, serializer: S) -> Result -// where -// S: Serializer, -// { -// // Compact format: "field_id|type_flag:type_data|overload_count" -// // type_flag: P=PropertyOwnerId, M=Module, O=Overload -// let type_part = match &self.typ { -// CompletionDataType::PropertyOwnerId(id) => { -// format!("P:{}", serde_json::to_string(id).map_err(serde::ser::Error::custom)?) -// }, -// CompletionDataType::Module(module) => { -// format!("M:{}", module) -// }, -// CompletionDataType::Overload((id, index)) => { -// format!("O:{}#{}", -// serde_json::to_string(id).map_err(serde::ser::Error::custom)?, -// index -// ) -// }, -// }; - -// let overload_part = match self.overload_count { -// Some(count) => format!("|{}", count), -// None => String::new(), -// }; - -// let compact = format!("{}|{}{}", self.field_id.id, type_part, overload_part); -// serializer.serialize_str(&compact) -// } -// } - -// impl<'de> Deserialize<'de> for CompletionData { -// fn deserialize(deserializer: D) -> Result -// where -// D: Deserializer<'de>, -// { -// struct CompletionDataVisitor; - -// impl<'de> Visitor<'de> for CompletionDataVisitor { -// type Value = CompletionData; - -// fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { -// formatter.write_str("a string with format 'field_id|type_flag:type_data|overload_count'") -// } - -// fn visit_str(self, value: &str) -> Result -// where -// E: de::Error, -// { -// let parts: Vec<&str> = value.split('|').collect(); -// if parts.len() < 2 || parts.len() > 3 { -// return Err(E::custom("expected format 'field_id|type_flag:type_data|overload_count'")); -// } - -// // Parse field_id -// let field_id = FileId::new( -// parts[0] -// .parse() -// .map_err(|e| E::custom(format!("invalid field_id: {}", e)))? -// ); - -// // Parse type -// let type_part = parts[1]; -// let typ = if let Some(colon_pos) = type_part.find(':') { -// let type_flag = &type_part[..colon_pos]; -// let type_data = &type_part[colon_pos + 1..]; - -// match type_flag { -// "P" => { -// let id: LuaSemanticDeclId = serde_json::from_str(type_data) -// .map_err(|e| E::custom(format!("invalid PropertyOwnerId: {}", e)))?; -// CompletionDataType::PropertyOwnerId(id) -// }, -// "M" => { -// CompletionDataType::Module(type_data.to_string()) -// }, -// "O" => { -// if let Some(hash_pos) = type_data.find('#') { -// let id_part = &type_data[..hash_pos]; -// let index_part = &type_data[hash_pos + 1..]; - -// let id: LuaSemanticDeclId = serde_json::from_str(id_part) -// .map_err(|e| E::custom(format!("invalid Overload id: {}", e)))?; -// let index: usize = index_part -// .parse() -// .map_err(|e| E::custom(format!("invalid Overload index: {}", e)))?; - -// CompletionDataType::Overload((id, index)) -// } else { -// return Err(E::custom("expected '#' separator in Overload type")); -// } -// }, -// _ => { -// return Err(E::custom(format!("unknown type flag: {}", type_flag))); -// } -// } -// } else { -// return Err(E::custom("expected ':' separator in type part")); -// }; - -// // Parse overload_count -// let overload_count = if parts.len() == 3 { -// if parts[2].is_empty() { -// None -// } else { -// Some( -// parts[2] -// .parse() -// .map_err(|e| E::custom(format!("invalid overload count: {}", e)))? -// ) -// } -// } else { -// None -// }; - -// Ok(CompletionData { -// field_id, -// typ, -// overload_count, -// }) -// } -// } - -// deserializer.deserialize_str(CompletionDataVisitor) -// } -// } - -// #[cfg(test)] -// mod tests { -// use emmylua_code_analysis::{FileId, LuaSemanticDeclId, LuaTypeDeclId}; - -// use super::{CompletionData, CompletionDataType}; - -// #[test] -// fn test_compact_serialization() { -// let type_id = LuaTypeDeclId::new("hello world"); -// let data = CompletionData { -// field_id: FileId::new(1), -// typ: CompletionDataType::PropertyOwnerId(LuaSemanticDeclId::TypeDecl(type_id)), -// overload_count: Some(3), -// }; - -// // Test serialization -// let json = serde_json::to_string(&data).unwrap(); -// println!("Compact serialized: {}", json); - -// // Test deserialization -// let deserialized: CompletionData = serde_json::from_str(&json).unwrap(); -// assert_eq!(data, deserialized); - -// // Verify the compactness of serialization format -// assert!(json.len() < 200); // Should be more compact than default JSON serialization -// } - -// #[test] -// fn test_module_serialization() { -// let data = CompletionData { -// field_id: FileId::new(42), -// typ: CompletionDataType::Module("socket.core".to_string()), -// overload_count: None, -// }; - -// let json = serde_json::to_string(&data).unwrap(); -// println!("Module serialized: {}", json); - -// let deserialized: CompletionData = serde_json::from_str(&json).unwrap(); -// assert_eq!(data, deserialized); -// } - -// #[test] -// fn test_overload_serialization() { -// let type_id = LuaTypeDeclId::new("test_function"); -// let data = CompletionData { -// field_id: FileId::new(10), -// typ: CompletionDataType::Overload((LuaSemanticDeclId::TypeDecl(type_id), 2)), -// overload_count: Some(5), -// }; - -// let json = serde_json::to_string(&data).unwrap(); -// println!("Overload serialized: {}", json); - -// let deserialized: CompletionData = serde_json::from_str(&json).unwrap(); -// assert_eq!(data, deserialized); -// } - -// #[test] -// fn test_size_comparison() { -// let type_id = LuaTypeDeclId::new("comparison_test"); -// let data = CompletionData { -// field_id: FileId::new(999), -// typ: CompletionDataType::PropertyOwnerId(LuaSemanticDeclId::TypeDecl(type_id.clone())), -// overload_count: Some(10), -// }; - -// // Our compact serialization -// let compact_json = serde_json::to_string(&data).unwrap(); - -// // Create a struct using default serialization to compare sizes -// #[derive(serde::Serialize)] -// struct DefaultSerialized { -// field_id: u32, -// typ: CompletionDataType, -// overload_count: Option, -// } - -// let default_data = DefaultSerialized { -// field_id: data.field_id.id, -// typ: data.typ.clone(), -// overload_count: data.overload_count, -// }; - -// let default_json = serde_json::to_string(&default_data).unwrap(); - -// println!("Compact size: {} bytes", compact_json.len()); -// println!("Default size: {} bytes", default_json.len()); - -// // Compact serialization should be smaller -// assert!(compact_json.len() <= default_json.len()); -// } -// } diff --git a/crates/emmylua_ls/src/handlers/completion/data/doc_tags.rs b/crates/emmylua_ls/src/handlers/completion/data/doc_tags.rs index ebdd18b9b..13cc7842f 100644 --- a/crates/emmylua_ls/src/handlers/completion/data/doc_tags.rs +++ b/crates/emmylua_ls/src/handlers/completion/data/doc_tags.rs @@ -31,5 +31,4 @@ pub const DOC_TAGS: &[&str] = &[ "readonly", "return_cast", "language", - "attribute", ]; diff --git a/crates/emmylua_ls/src/handlers/completion/mod.rs b/crates/emmylua_ls/src/handlers/completion/mod.rs index 5d9b3c3c4..dd25d8481 100644 --- a/crates/emmylua_ls/src/handlers/completion/mod.rs +++ b/crates/emmylua_ls/src/handlers/completion/mod.rs @@ -123,7 +123,6 @@ pub fn completion_resolve( .get_semantic_model(completion_data.field_id); if let Some(semantic_model) = semantic_model { resolve_completion( - &analysis.compilation, &semantic_model, db, &mut completion_item, diff --git a/crates/emmylua_ls/src/handlers/completion/providers/auto_require_provider.rs b/crates/emmylua_ls/src/handlers/completion/providers/auto_require_provider.rs index 64aa08293..affdb0392 100644 --- a/crates/emmylua_ls/src/handlers/completion/providers/auto_require_provider.rs +++ b/crates/emmylua_ls/src/handlers/completion/providers/auto_require_provider.rs @@ -121,7 +121,7 @@ fn add_module_completion_item( } let data = if let Some(property_id) = &module_info.semantic_id { - CompletionData::from_property_owner_id(builder, property_id.clone(), None) + CompletionData::from_property_owner_id(builder, property_id.clone()) } else { None }; @@ -197,7 +197,6 @@ fn add_completion_item_by_type( CompletionData::from_property_owner_id( builder, property_owner_id.clone(), - None, ) } else { None diff --git a/crates/emmylua_ls/src/handlers/completion/providers/doc_type_provider.rs b/crates/emmylua_ls/src/handlers/completion/providers/doc_type_provider.rs index 3f876c09d..2e3be1885 100644 --- a/crates/emmylua_ls/src/handlers/completion/providers/doc_type_provider.rs +++ b/crates/emmylua_ls/src/handlers/completion/providers/doc_type_provider.rs @@ -1,4 +1,4 @@ -use emmylua_code_analysis::LuaTypeDeclId; +use emmylua_code_analysis::{LuaTypeDeclId, is_attribute_class}; use emmylua_parser::{LuaAstNode, LuaDocAttributeUse, LuaDocNameType, LuaSyntaxKind, LuaTokenKind}; use lsp_types::CompletionItem; use std::collections::HashSet; @@ -76,24 +76,14 @@ pub fn complete_types_by_prefix( match completion_type { CompletionType::AttributeUse => { if let Some(decl_id) = type_decl { - let type_decl = builder - .semantic_model - .get_db() - .get_type_index() - .get_type_decl(&decl_id)?; - if type_decl.is_attribute() { + if is_attribute_class(builder.semantic_model.get_db(), &decl_id) { add_type_completion_item(builder, &name, Some(decl_id)); } } } CompletionType::Type => { if let Some(decl_id) = &type_decl { - let type_decl = builder - .semantic_model - .get_db() - .get_type_index() - .get_type_decl(decl_id)?; - if type_decl.is_attribute() { + if is_attribute_class(builder.semantic_model.get_db(), decl_id) { continue; } } @@ -172,7 +162,7 @@ fn add_type_completion_item( }; let data = if let Some(id) = type_decl { - CompletionData::from_property_owner_id(builder, id.into(), None) + CompletionData::from_property_owner_id(builder, id.into()) } else { None }; diff --git a/crates/emmylua_ls/src/handlers/completion/providers/function_provider.rs b/crates/emmylua_ls/src/handlers/completion/providers/function_provider.rs index 0e9050f4b..f0814ba95 100644 --- a/crates/emmylua_ls/src/handlers/completion/providers/function_provider.rs +++ b/crates/emmylua_ls/src/handlers/completion/providers/function_provider.rs @@ -1,9 +1,8 @@ use emmylua_code_analysis::{ - DbIndex, GenericTplId, InferGuard, InferGuardRef, LuaAliasCallKind, LuaAliasCallType, - LuaDeclLocation, LuaFunctionType, LuaMember, LuaMemberKey, LuaMemberOwner, LuaMultiLineUnion, - LuaSemanticDeclId, LuaStringTplType, LuaType, LuaTypeCache, LuaTypeDeclId, LuaUnionType, - RenderLevel, SemanticDeclLevel, TypeSubstitutor, build_call_constraint_context, get_real_type, - instantiate_type_generic, normalize_constraint_type, + DbIndex, GenericTpl, InferGuard, InferGuardRef, LuaAliasCallKind, LuaAliasCallType, + LuaDeclLocation, LuaFunctionType, LuaMemberKey, LuaMemberOwner, LuaMultiLineUnion, + LuaStringTplType, LuaType, LuaTypeCache, LuaTypeDeclId, LuaUnionType, RenderLevel, + filter_callable_overloads, get_real_type, normalize_constraint_type, }; use emmylua_parser::{ LuaAssignStat, LuaAst, LuaAstNode, LuaAstToken, LuaCallArgList, LuaCallExpr, LuaClosureExpr, @@ -12,7 +11,6 @@ use emmylua_parser::{ }; use itertools::Itertools; use lsp_types::{CompletionItem, Documentation}; -use std::sync::Arc; use crate::handlers::{ completion::{ @@ -151,11 +149,8 @@ pub fn dispatch_type( LuaType::StrTplRef(key) => { add_str_tpl_ref_completion(builder, &key); } - LuaType::ConstTplRef(tpl) => { - return add_const_tpl_ref_completion(builder, &tpl.get_tpl_id(), infer_guard); - } LuaType::TplRef(tpl) => { - return add_tpl_ref_completion(builder, &tpl.get_tpl_id(), infer_guard); + return add_tpl_ref_completion(builder, &tpl, infer_guard); } LuaType::Call(special_call) => { add_special_call_completion(builder, &special_call); @@ -328,236 +323,66 @@ fn infer_call_arg_list( token: LuaSyntaxToken, ) -> Option> { let call_expr = call_arg_list.get_parent::()?; - let mut param_idx = get_current_param_index(&call_expr, &token)?; - let call_expr_func = builder - .semantic_model - .infer_call_expr_func(call_expr.clone(), Some(param_idx + 1))?; - let colon_call = call_expr.is_colon_call(); - let colon_define = call_expr_func.is_colon_define(); - match (colon_call, colon_define) { - (true, true) | (false, false) | (false, true) => {} - (true, false) => { - param_idx += 1; - } - } - let constraint_substitutor = build_call_constraint_context(&builder.semantic_model, &call_expr) - .map(|ctx| ctx.substitutor); - let substitutor = constraint_substitutor.as_ref(); - let typ = call_expr_func - .get_params() - .get(param_idx)? - .1 - .clone() - .unwrap_or(LuaType::Unknown); - let typ = resolve_param_type(builder, typ, substitutor); - let mut types = Vec::new(); - types.push(typ); - push_function_overloads_param( - builder, - &call_expr, - call_expr_func.get_params(), - param_idx, - substitutor, - &mut types, - ); - Some(types.into_iter().unique().collect()) // 需要去重 -} - -fn resolve_param_type( - builder: &CompletionBuilder, - mut typ: LuaType, - substitutor: Option<&TypeSubstitutor>, -) -> LuaType { - let db = builder.semantic_model.get_db(); - if let Some(substitutor) = substitutor { - typ = apply_substitutor_to_type(db, typ, substitutor); - } - normalize_constraint_type(db, typ) -} - -fn apply_substitutor_to_type(db: &DbIndex, typ: LuaType, substitutor: &TypeSubstitutor) -> LuaType { - if let LuaType::Call(alias_call) = &typ { - if alias_call.get_call_kind() == LuaAliasCallKind::KeyOf { - let operands = alias_call - .get_operands() - .iter() - .map(|operand| instantiate_type_generic(db, operand, substitutor)) - .collect::>(); - return LuaType::Call(Arc::new(LuaAliasCallType::new( - alias_call.get_call_kind(), - operands, - ))); - } - } - if let Some(alias_call) = rebuild_keyof_alias_call(db, &typ, substitutor) { - return alias_call; - } - instantiate_type_generic(db, &typ, substitutor) -} - -fn rebuild_keyof_alias_call( - db: &DbIndex, - original_type: &LuaType, - substitutor: &TypeSubstitutor, -) -> Option { - let tpl = match original_type { - LuaType::TplRef(tpl) | LuaType::ConstTplRef(tpl) => tpl, - _ => return None, - }; - let constraint = tpl.get_constraint()?; - let LuaType::Call(alias_call) = constraint else { - return None; - }; - if alias_call.get_call_kind() != LuaAliasCallKind::KeyOf { - return None; - } - - let operands = alias_call - .get_operands() - .iter() - .map(|operand| instantiate_type_generic(db, operand, substitutor)) - .collect::>(); - Some(LuaType::Call(Arc::new(LuaAliasCallType::new( - alias_call.get_call_kind(), - operands, - )))) -} - -fn push_function_overloads_param( - builder: &mut CompletionBuilder, - call_expr: &LuaCallExpr, - call_params: &[(String, Option)], - param_idx: usize, - substitutor: Option<&TypeSubstitutor>, - types: &mut Vec, -) -> Option<()> { - let member_index = builder.semantic_model.get_db().get_member_index(); + let param_idx = get_current_param_index(&call_expr, &token)?; let prefix_expr = call_expr.get_prefix_expr()?; - let semantic_decl = builder.semantic_model.find_decl( - prefix_expr.syntax().clone().into(), - SemanticDeclLevel::default(), - )?; - - // 收集函数类型 - let functions = match semantic_decl { - LuaSemanticDeclId::Member(member_id) => { - let member = member_index.get_member(&member_id)?; - let key = member.get_key().to_path(); - let owner = member_index.get_current_owner(&member_id)?; - let members = member_index.get_members(owner)?; - let functions = filter_function_members(builder.semantic_model.get_db(), members, key); - Some(functions) - } - LuaSemanticDeclId::LuaDecl(decl_id) => { - let decl = builder - .semantic_model - .get_db() - .get_decl_index() - .get_decl(&decl_id)?; + let prefix_type = builder.semantic_model.infer_expr(prefix_expr).ok()?; + let call_arg_types = infer_call_arg_types(builder, &call_expr, Some(param_idx))?; + let call_expr_funcs = filter_callable_overloads( + builder.semantic_model.get_db(), + &mut builder.semantic_model.get_cache().borrow_mut(), + &prefix_type, + &call_arg_types, + &call_expr, + Some(param_idx), + true, + ) + .ok()?; - let typ = builder - .semantic_model - .get_db() - .get_type_index() - .get_type_cache(&decl_id.into()) - .map(|cache| cache.as_type().clone()) - .unwrap_or(LuaType::Unknown); - match typ { - LuaType::Signature(_) | LuaType::DocFunction(_) => Some(vec![typ.clone()]), - _ => { - let key = decl.get_name(); - let type_id = LuaTypeDeclId::global(decl.get_name()); - let members = member_index.get_members(&LuaMemberOwner::Type(type_id))?; - let functions = filter_function_members( - builder.semantic_model.get_db(), - members, - key.to_string(), - ); - Some(functions) - } - } - } - _ => None, - }?; - - // 获取重载函数列表 - let signature_index = builder.semantic_model.get_db().get_signature_index(); - let mut overloads = Vec::new(); - for function in functions { - match function { - LuaType::Signature(signature_id) => { - if let Some(signature) = signature_index.get(&signature_id) { - overloads.extend(signature.overloads.iter().cloned()); - } - } - LuaType::DocFunction(doc_function) => { - overloads.push(doc_function); + let mut types = Vec::new(); + for call_expr_func in call_expr_funcs { + let mut param_idx = param_idx; + let colon_call = call_expr.is_colon_call(); + let colon_define = call_expr_func.is_colon_define(); + match (colon_call, colon_define) { + (true, true) | (false, false) | (false, true) => {} + (true, false) => { + param_idx += 1; } - _ => {} } - } - - // 筛选匹配的参数类型并添加到结果中 - for overload in overloads.iter() { - let overload_params = overload.get_params(); - - // 检查前面的参数是否匹配 - if !params_match_prefix(call_params, overload_params, param_idx) { - continue; - } - - // 添加匹配的参数类型 - if let Some(param_type) = overload_params.get(param_idx).and_then(|p| p.1.clone()) { - let param_type = resolve_param_type(builder, param_type, substitutor); - types.push(param_type); - } - } - /// 过滤出函数类型的成员 - fn filter_function_members( - db: &DbIndex, - members: Vec<&LuaMember>, - key: String, - ) -> Vec { - let mut result_members = vec![]; - for member in members { - if member.get_key().to_path() == key { - let member_type = db - .get_type_index() - .get_type_cache(&member.get_id().into()) - .unwrap_or(&LuaTypeCache::InferType(LuaType::Unknown)); - if let LuaType::Signature(_) | LuaType::DocFunction(_) = member_type.as_type() { - result_members.push(member_type.as_type().clone()); - } - } + if let Some(typ) = call_expr_func + .get_params() + .get(param_idx) + .and_then(|param| param.1.clone()) + { + types.push(normalize_constraint_type( + builder.semantic_model.get_db(), + typ, + )); } - - result_members } - /// 判断前面的参数是否匹配 - fn params_match_prefix( - call_params: &[(String, Option)], - overload_params: &[(String, Option)], - param_idx: usize, - ) -> bool { - if param_idx == 0 { - return true; - } - - for i in 0..param_idx { - if let (Some(call_param), Some(overload_param)) = - (call_params.get(i), overload_params.get(i)) - && call_param.1 != overload_param.1 - { - return false; - } - } - - true + if types.is_empty() { + None + } else { + Some(types.into_iter().unique().collect()) } +} - Some(()) +fn infer_call_arg_types( + builder: &CompletionBuilder, + call_expr: &LuaCallExpr, + arg_count: Option, +) -> Option> { + let args = call_expr.get_args_list()?.get_args().collect::>(); + Some( + builder + .semantic_model + .infer_expr_list_types(&args, arg_count) + .into_iter() + .map(|(typ, _)| typ) + .collect(), + ) } fn add_multi_line_union_member_completion( @@ -809,9 +634,8 @@ fn add_str_tpl_ref_completion( let db = builder.semantic_model.get_db(); let module_index = db.get_module_index(); let types = db.get_type_index().get_all_types(); - let tpl_id = str_tpl.get_tpl_id(); // 泛型约束 - let extend_type = get_tpl_ref_extend_type(builder, &tpl_id).unwrap_or(LuaType::Any); + let extend_type = str_tpl.get_constraint().cloned().unwrap_or(LuaType::Any); let mut completion_items: Vec<_> = types .into_iter() @@ -863,16 +687,6 @@ fn add_str_tpl_ref_completion( Some(()) } -fn add_const_tpl_ref_completion( - builder: &mut CompletionBuilder, - tpl_id: &GenericTplId, - infer_guard: &InferGuardRef, -) -> Option { - // 泛型约束 - let extend_type = get_tpl_ref_extend_type(builder, tpl_id)?; - dispatch_type(builder, extend_type, infer_guard) -} - fn add_special_call_completion( builder: &mut CompletionBuilder, alias_call: &LuaAliasCallType, @@ -896,36 +710,6 @@ fn add_special_call_completion( Some(()) } -fn get_tpl_ref_extend_type(builder: &CompletionBuilder, tpl_id: &GenericTplId) -> Option { - let token = builder.trigger_token.clone(); - let mut parent_node = token.parent()?; - if LuaLiteralExpr::can_cast(parent_node.kind().into()) { - parent_node = parent_node.parent()?; - } - match parent_node.kind().into() { - LuaSyntaxKind::CallArgList => { - let call_expr = LuaCallArgList::cast(parent_node)?.get_parent::()?; - let function = builder - .semantic_model - .infer_expr(call_expr.get_prefix_expr()?.clone()) - .ok()?; - if let LuaType::Signature(signature_id) = function { - let signature = builder - .semantic_model - .get_db() - .get_signature_index() - .get(&signature_id)?; - let generic_param = signature.generic_params.get(tpl_id.get_idx()); - if let Some(generic_param) = generic_param { - return Some(generic_param.constraint.clone().unwrap_or(LuaType::Any)); - } - } - None - } - _ => None, - } -} - /// 确保所有成员均为 function 或者 nil, 然后返回 function 的联合类型, 如果非 function 则返回 None pub fn get_function_remove_nil(db: &DbIndex, typ: &LuaType) -> Option { match typ { @@ -964,9 +748,9 @@ pub fn get_function_remove_nil(db: &DbIndex, typ: &LuaType) -> Option { fn add_tpl_ref_completion( builder: &mut CompletionBuilder, - tpl_id: &GenericTplId, + tpl: &GenericTpl, infer_guard: &InferGuardRef, ) -> Option { - let extend_type = get_tpl_ref_extend_type(builder, tpl_id)?; + let extend_type = tpl.get_constraint().cloned()?; dispatch_type(builder, extend_type, infer_guard) } diff --git a/crates/emmylua_ls/src/handlers/completion/providers/member_provider.rs b/crates/emmylua_ls/src/handlers/completion/providers/member_provider.rs index 373134ddb..1fcbca85f 100644 --- a/crates/emmylua_ls/src/handlers/completion/providers/member_provider.rs +++ b/crates/emmylua_ls/src/handlers/completion/providers/member_provider.rs @@ -1,6 +1,6 @@ use emmylua_code_analysis::{ DbIndex, LuaMemberInfo, LuaMemberKey, LuaSemanticDeclId, LuaType, LuaTypeDeclId, SemanticModel, - enum_variable_is_param, get_tpl_ref_extend_type, + enum_variable_is_param, }; use emmylua_parser::{LuaAstNode, LuaAstToken, LuaIndexExpr, LuaStringToken, LuaSyntaxToken}; use std::collections::HashMap; @@ -59,13 +59,7 @@ fn complete_provider(builder: &mut CompletionBuilder) -> Option<()> { .infer_expr(prefix_expr.clone()) .ok()? { - LuaType::TplRef(tpl) => get_tpl_ref_extend_type( - builder.semantic_model.get_db(), - &mut builder.semantic_model.get_cache().borrow_mut(), - &LuaType::TplRef(tpl.clone()), - prefix_expr.clone(), - 0, - )?, + LuaType::TplRef(tpl) => tpl.get_constraint().cloned()?, prefix_type => prefix_type, }; // 如果是枚举类型且为函数参数, 则不进行补全 @@ -108,33 +102,11 @@ fn add_resolve_member_infos( ) -> Option<()> { if member_infos.len() == 1 { let member_info = &member_infos[0]; - let overload_count = match &member_info.typ { - LuaType::DocFunction(_) => None, - LuaType::Signature(id) => { - if let Some(signature) = builder - .semantic_model - .get_db() - .get_signature_index() - .get(id) - { - let count = signature.overloads.len(); - if count == 0 { None } else { Some(count) } - } else { - None - } - } - _ => None, - }; - add_member_completion( - builder, - member_info.clone(), - completion_status, - overload_count, - ); + add_member_completion(builder, member_info.clone(), completion_status); return Some(()); } - let (filtered_member_infos, overload_count) = filter_member_infos( + let filtered_member_infos = filter_member_infos( &builder.semantic_model, &builder.trigger_token, member_infos, @@ -145,35 +117,20 @@ fn add_resolve_member_infos( for member_info in filtered_member_infos { match resolve_state { MemberResolveState::All => { - add_member_completion( - builder, - member_info.clone(), - completion_status, - overload_count, - ); + add_member_completion(builder, member_info.clone(), completion_status); } MemberResolveState::Meta => { if let Some(feature) = member_info.feature && feature.is_meta_decl() { - add_member_completion( - builder, - member_info.clone(), - completion_status, - overload_count, - ); + add_member_completion(builder, member_info.clone(), completion_status); } } MemberResolveState::FileDecl => { if let Some(feature) = member_info.feature && feature.is_file_decl() { - add_member_completion( - builder, - member_info.clone(), - completion_status, - overload_count, - ); + add_member_completion(builder, member_info.clone(), completion_status); } } } @@ -182,12 +139,12 @@ fn add_resolve_member_infos( Some(()) } -/// 过滤成员信息,返回需要的成员列表和重载数量 +/// 过滤成员信息,返回需要的成员列表 fn filter_member_infos<'a>( semantic_model: &SemanticModel, trigger_token: &LuaSyntaxToken, member_infos: &'a [LuaMemberInfo], -) -> Option<(Vec<&'a LuaMemberInfo>, Option)> { +) -> Option> { if member_infos.is_empty() { return None; } @@ -208,7 +165,6 @@ fn filter_member_infos<'a>( let mut member_with_owners: Vec<(&LuaMemberInfo, Option)> = Vec::with_capacity(visible_member_infos.len()); let mut all_doc_function = true; - let mut overload_count = 0; // 一次遍历收集所有信息 for member_info in visible_member_infos { @@ -223,18 +179,9 @@ fn filter_member_infos<'a>( file_decl_member = Some(member_info); } - // 检查是否全为 DocFunction,同时计算重载数量 + // 检查是否全为 DocFunction match &member_info.typ { - LuaType::DocFunction(_) => { - overload_count += 1; - } - LuaType::Signature(id) => { - all_doc_function = false; - overload_count += 1; - if let Some(signature) = semantic_model.get_db().get_signature_index().get(id) { - overload_count += signature.overloads.len(); - } - } + LuaType::DocFunction(_) => {} _ => { all_doc_function = false; } @@ -274,20 +221,12 @@ fn filter_member_infos<'a>( }) .collect(); - // 处理重载计数 - let final_overload_count = if overload_count >= 1 { - let count = overload_count - 1; - if count == 0 { None } else { Some(count) } - } else { - None - }; - // 如果全为 DocFunction, 只保留第一个 if all_doc_function && !filtered_member_infos.is_empty() { filtered_member_infos.truncate(1); } - Some((filtered_member_infos, final_overload_count)) + Some(filtered_member_infos) } enum MemberResolveState { diff --git a/crates/emmylua_ls/src/handlers/completion/providers/module_path_provider.rs b/crates/emmylua_ls/src/handlers/completion/providers/module_path_provider.rs index 653e69110..c5b2f42b0 100644 --- a/crates/emmylua_ls/src/handlers/completion/providers/module_path_provider.rs +++ b/crates/emmylua_ls/src/handlers/completion/providers/module_path_provider.rs @@ -101,7 +101,7 @@ pub fn add_modules( if let Some(child_file_id) = child_module_node.file_ids.first() { let child_module_info = db.get_module_index().get_module(*child_file_id)?; let data = if let Some(property_id) = &child_module_info.semantic_id { - CompletionData::from_property_owner_id(builder, property_id.clone(), None) + CompletionData::from_property_owner_id(builder, property_id.clone()) } else { None }; diff --git a/crates/emmylua_ls/src/handlers/completion/providers/table_field_provider.rs b/crates/emmylua_ls/src/handlers/completion/providers/table_field_provider.rs index 0a6b01eb4..1fb976b88 100644 --- a/crates/emmylua_ls/src/handlers/completion/providers/table_field_provider.rs +++ b/crates/emmylua_ls/src/handlers/completion/providers/table_field_provider.rs @@ -258,7 +258,7 @@ fn add_field_key_completion( } let data = if let Some(id) = &property_owner { - CompletionData::from_property_owner_id(builder, id.clone(), None) + CompletionData::from_property_owner_id(builder, id.clone()) } else { None }; diff --git a/crates/emmylua_ls/src/handlers/completion/resolve_completion.rs b/crates/emmylua_ls/src/handlers/completion/resolve_completion.rs index d97490b00..60464dd8b 100644 --- a/crates/emmylua_ls/src/handlers/completion/resolve_completion.rs +++ b/crates/emmylua_ls/src/handlers/completion/resolve_completion.rs @@ -1,4 +1,4 @@ -use emmylua_code_analysis::{DbIndex, LuaCompilation, SemanticModel}; +use emmylua_code_analysis::{DbIndex, SemanticModel}; use emmylua_parser::{LuaAstNode, LuaSyntaxToken}; use lsp_types::{CompletionItem, Documentation, MarkedString, MarkupContent}; use rowan::{TextSize, TokenAtOffset}; @@ -11,7 +11,6 @@ use crate::{ use super::completion_data::{CompletionData, CompletionDataType}; pub fn resolve_completion( - compilation: &LuaCompilation, semantic_model: &SemanticModel, db: &DbIndex, completion_item: &mut CompletionItem, @@ -25,14 +24,12 @@ pub fn resolve_completion( match completion_data.typ { CompletionDataType::PropertyOwnerId(property_id) => { let hover_builder = build_hover_content_for_completion( - compilation, semantic_model, db, property_id, trigger_token.clone(), ); - if let Some(mut hover_builder) = hover_builder { - update_function_signature_info(&mut hover_builder, completion_data.overload_count); + if let Some(hover_builder) = hover_builder { if client_id.is_vscode() { build_vscode_completion_item(completion_item, hover_builder, None); } else { @@ -42,14 +39,12 @@ pub fn resolve_completion( } CompletionDataType::Overload((property_id, index)) => { let hover_builder = build_hover_content_for_completion( - compilation, semantic_model, db, property_id, trigger_token.clone(), ); - if let Some(mut hover_builder) = hover_builder { - update_function_signature_info(&mut hover_builder, completion_data.overload_count); + if let Some(hover_builder) = hover_builder { if client_id.is_vscode() { build_vscode_completion_item(completion_item, hover_builder, Some(index)); } else { @@ -79,38 +74,20 @@ fn get_completion_trigger_token( } } -pub fn update_function_signature_info( - hover_builder: &mut HoverBuilder, - overload_count: Option, -) { - if let Some(overload_count) = overload_count - && overload_count > 0 - { - if let Some(signature_overload) = &mut hover_builder.signature_overload { - for signature in signature_overload.iter_mut() { - if let MarkedString::LanguageString(s) = signature { - s.value = format!("{} (+{} overloads)", s.value, overload_count); - } - } - } - if let MarkedString::LanguageString(s) = &mut hover_builder.primary { - s.value = format!("{} (+{} overloads)", s.value, overload_count); - } - } -} - fn build_vscode_completion_item( completion_item: &mut CompletionItem, hover_builder: HoverBuilder, overload_index: Option, ) -> Option<()> { - let type_description = overload_index + let (type_description, overload_comment) = overload_index .and_then(|index| { hover_builder .signature_overload + .as_ref() .and_then(|overloads| overloads.get(index).cloned()) + .map(|overload| (overload.signature, overload.comment)) }) - .unwrap_or_else(|| hover_builder.primary.clone()); + .unwrap_or_else(|| (hover_builder.primary.clone(), None)); match type_description { MarkedString::String(s) => { @@ -124,6 +101,9 @@ fn build_vscode_completion_item( let documentation = { let mut result = String::new(); let mut first_line = true; + if let Some(comment) = overload_comment { + result.push_str(&format!("\n{}\n", comment)); + } for description in hover_builder.annotation_description { match description { MarkedString::String(s) => { @@ -164,13 +144,15 @@ fn build_other_completion_item( ) -> Option<()> { let mut result = String::new(); - let type_description = overload_index + let (type_description, overload_comment) = overload_index .and_then(|index| { hover_builder .signature_overload + .as_ref() .and_then(|overloads| overloads.get(index).cloned()) + .map(|overload| (overload.signature, overload.comment)) }) - .unwrap_or_else(|| hover_builder.primary.clone()); + .unwrap_or_else(|| (hover_builder.primary.clone(), None)); match type_description { MarkedString::String(s) => { @@ -180,6 +162,9 @@ fn build_other_completion_item( result.push_str(&format!("\n```{}\n{}\n```\n", s.language, s.value)); } } + if let Some(comment) = overload_comment { + result.push_str(&format!("\n{}\n", comment)); + } if let Some(MarkedString::String(s)) = hover_builder.location_path { result.push_str(&format!("\n{}\n", s)); } diff --git a/crates/emmylua_ls/src/handlers/configuration/mod.rs b/crates/emmylua_ls/src/handlers/configuration/mod.rs index d1f6d229a..5ad779cf7 100644 --- a/crates/emmylua_ls/src/handlers/configuration/mod.rs +++ b/crates/emmylua_ls/src/handlers/configuration/mod.rs @@ -50,7 +50,7 @@ impl RegisterCapabilities for ConfigurationCapabilities { fn register_capabilities(_: &mut ServerCapabilities, _: &ClientCapabilities) {} } -#[cfg(test)] +#[cfg(all(test, feature = "full-test"))] mod tests { use super::*; use std::{ diff --git a/crates/emmylua_ls/src/handlers/definition/goto_def_definition.rs b/crates/emmylua_ls/src/handlers/definition/goto_def_definition.rs index ed7eef3c8..583e21340 100644 --- a/crates/emmylua_ls/src/handlers/definition/goto_def_definition.rs +++ b/crates/emmylua_ls/src/handlers/definition/goto_def_definition.rs @@ -14,10 +14,10 @@ use lsp_types::{GotoDefinitionResponse, Location, Position, Range, Uri}; use crate::{ handlers::{ + common::{find_all_same_named_members, find_member_origin_owner}, definition::goto_function::{ find_function_call_origin, find_matching_function_definitions, }, - hover::{find_all_same_named_members, find_member_origin_owner}, }, util::{to_camel_case, to_pascal_case, to_snake_case}, }; @@ -114,7 +114,7 @@ fn handle_member_definition( trigger_token, &same_named_members, ) { - process_matched_members(semantic_model, compilation, &match_members, &mut locations); + process_matched_members(semantic_model, &match_members, &mut locations); if !locations.is_empty() { return Some(GotoDefinitionResponse::Array(locations)); } @@ -164,7 +164,6 @@ fn handle_type_decl_definition( fn process_matched_members( semantic_model: &SemanticModel, - compilation: &LuaCompilation, match_members: &[LuaSemanticDeclId], locations: &mut Vec, ) { @@ -173,7 +172,7 @@ fn process_matched_members( LuaSemanticDeclId::Member(member_id) => { if should_trace_member(semantic_model, member_id).unwrap_or(false) { // 尝试搜索这个成员最原始的定义 - match find_member_origin_owner(compilation, semantic_model, *member_id) { + match find_member_origin_owner(semantic_model, *member_id) { Some(LuaSemanticDeclId::Member(origin_member_id)) => { if let Some(location) = get_member_location(semantic_model, &origin_member_id) diff --git a/crates/emmylua_ls/src/handlers/definition/goto_function.rs b/crates/emmylua_ls/src/handlers/definition/goto_function.rs index f60d9b395..adf505c09 100644 --- a/crates/emmylua_ls/src/handlers/definition/goto_function.rs +++ b/crates/emmylua_ls/src/handlers/definition/goto_function.rs @@ -1,6 +1,6 @@ use emmylua_code_analysis::{ LuaCompilation, LuaDeclId, LuaFunctionType, LuaSemanticDeclId, LuaSignature, LuaSignatureId, - LuaType, SemanticDeclLevel, SemanticModel, instantiate_func_generic, + LuaType, SemanticDeclLevel, SemanticModel, infer_call_generic, }; use emmylua_parser::{ LuaAstNode, LuaCallExpr, LuaExpr, LuaLiteralToken, LuaSyntaxToken, LuaTokenKind, @@ -291,7 +291,7 @@ pub fn compare_function_types( call_expr: &LuaCallExpr, ) -> Option { if func.contain_tpl() { - let instantiated_func = instantiate_func_generic( + let instantiated_func = infer_call_generic( semantic_model.get_db(), &mut semantic_model.get_cache().borrow_mut(), func, diff --git a/crates/emmylua_ls/src/handlers/hover/build_hover.rs b/crates/emmylua_ls/src/handlers/hover/build_hover.rs index b098c444b..2ab62a144 100644 --- a/crates/emmylua_ls/src/handlers/hover/build_hover.rs +++ b/crates/emmylua_ls/src/handlers/hover/build_hover.rs @@ -2,28 +2,23 @@ use std::collections::HashSet; use emmylua_code_analysis::humanize_type; use emmylua_code_analysis::{ - DbIndex, LuaCompilation, LuaDeclExtra, LuaDeclId, LuaDocument, LuaMemberId, LuaMemberKey, - LuaSemanticDeclId, LuaSignatureId, LuaType, RenderLevel, SemanticInfo, SemanticModel, -}; -use emmylua_parser::{ - LuaAssignStat, LuaAstNode, LuaCallArgList, LuaExpr, LuaSyntaxKind, LuaSyntaxToken, - LuaTableExpr, LuaTableField, + DbIndex, LuaDeclExtra, LuaDeclId, LuaDocument, LuaMemberId, LuaMemberKey, LuaSemanticDeclId, + LuaSignatureId, LuaType, RenderLevel, SemanticInfo, SemanticModel, }; +use emmylua_parser::{LuaAssignStat, LuaAstNode, LuaExpr, LuaSyntaxToken}; use lsp_types::{Hover, HoverContents, MarkedString, MarkupContent}; use rowan::TextRange; -use crate::handlers::hover::function::{build_function_hover, is_function}; +use crate::handlers::common::{find_decl_origin_owners, find_member_origin_owners}; +use crate::handlers::hover::function::{build_function_hover, has_function_candidate, is_function}; use crate::handlers::hover::humanize_type_decl::build_type_decl_hover; use crate::handlers::hover::humanize_types::hover_humanize_type; use super::{ - find_origin::{find_decl_origin_owners, find_member_origin_owners}, - hover_builder::HoverBuilder, - humanize_types::hover_const_type, + HoverDeclContext, HoverDeclInfo, hover_builder::HoverBuilder, humanize_types::hover_const_type, }; pub fn build_semantic_info_hover( - compilation: &LuaCompilation, semantic_model: &SemanticModel, db: &DbIndex, document: &LuaDocument, @@ -36,7 +31,6 @@ pub fn build_semantic_info_hover( return build_hover_without_property(db, document, token, typ); } let hover_builder = build_hover_content( - compilation, semantic_model, db, Some(typ), @@ -76,7 +70,6 @@ fn build_hover_without_property( } pub fn build_hover_content_for_completion<'a>( - compilation: &'a LuaCompilation, semantic_model: &'a SemanticModel, db: &DbIndex, property_id: LuaSemanticDeclId, @@ -91,19 +84,10 @@ pub fn build_hover_content_for_completion<'a>( } _ => None, }; - build_hover_content( - compilation, - semantic_model, - db, - typ, - property_id, - true, - token, - ) + build_hover_content(semantic_model, db, typ, property_id, true, token) } fn build_hover_content<'a>( - compilation: &'a LuaCompilation, semantic_model: &'a SemanticModel, db: &DbIndex, typ: Option, @@ -111,7 +95,7 @@ fn build_hover_content<'a>( is_completion: bool, token: Option, ) -> Option> { - let mut builder = HoverBuilder::new(compilation, semantic_model, token, is_completion); + let mut builder = HoverBuilder::new(semantic_model, token, is_completion); match property_id { LuaSemanticDeclId::LuaDecl(decl_id) => { let typ = typ?; @@ -138,26 +122,25 @@ fn build_decl_hover( ) -> Option<()> { let decl = db.get_decl_index().get_decl(&decl_id)?; - let mut semantic_decls = - find_decl_origin_owners(builder.compilation, builder.semantic_model, decl_id) - .get_types(builder.semantic_model); + let semantic_decls = + find_decl_origin_owners(builder.semantic_model, decl_id).get_types(builder.semantic_model); // 处理类型签名 if is_function(&typ) { - adjust_semantic_decls( - builder, - &mut semantic_decls, - &LuaSemanticDeclId::LuaDecl(decl_id), - &typ, + let origin_decls = into_hover_decl_infos(semantic_decls); + let hover_decl_context = HoverDeclContext::new( + HoverDeclInfo::new(LuaSemanticDeclId::LuaDecl(decl_id), typ.clone()), + origin_decls, ); // 处理函数类型 - build_function_hover(builder, db, &semantic_decls); - // hover_function_type(builder, db, &semantic_decls); + build_function_hover(builder, db, &hover_decl_context); - if let Some((LuaSemanticDeclId::Member(member_id), _)) = semantic_decls + if let Some(decl_info) = hover_decl_context + .origin_decls() .iter() - .find(|(decl, _)| matches!(decl, LuaSemanticDeclId::Member(_))) + .find(|decl_info| matches!(decl_info.id(), LuaSemanticDeclId::Member(_))) + && let LuaSemanticDeclId::Member(member_id) = decl_info.id() { let member = db.get_member_index().get_member(member_id); builder.set_location_path(member); @@ -228,9 +211,9 @@ fn build_member_hover( is_completion: bool, ) -> Option<()> { let member = db.get_member_index().get_member(&member_id)?; - let mut semantic_decls = - find_member_origin_owners(builder.compilation, builder.semantic_model, member_id, true) - .get_types(builder.semantic_model); + let mut semantic_decls = find_member_origin_owners(builder.semantic_model, member_id, true) + .get_types(builder.semantic_model); + if let Some(token) = builder.get_trigger_token() { semantic_decls.retain(|(semantic_decl, _)| { builder @@ -245,15 +228,15 @@ fn build_member_hover( _ => return None, }; - if is_function(&typ) { - adjust_semantic_decls( - builder, - &mut semantic_decls, - &LuaSemanticDeclId::Member(member_id), - &typ, - ); + let origin_decls = into_hover_decl_infos(semantic_decls); + let hover_decl_context = HoverDeclContext::new( + HoverDeclInfo::new(LuaSemanticDeclId::Member(member_id), typ.clone()), + origin_decls, + ); - build_function_hover(builder, db, &semantic_decls); + // 当为表字段时, 如果能够追溯到该成员的定义为 function, 那么我们也需要显示方法的签名而不是当前字段的真实类型 + if has_function_candidate(&hover_decl_context) { + build_function_hover(builder, db, &hover_decl_context); builder.set_location_path(Some(member)); @@ -285,7 +268,12 @@ fn build_member_hover( let member_decl = LuaSemanticDeclId::Member(member.get_id()); semantic_decl_set.insert(&member_decl); if !is_completion { - semantic_decl_set.extend(semantic_decls.iter().map(|(decl, _)| decl)); + semantic_decl_set.extend( + hover_decl_context + .origin_decls() + .iter() + .map(|decl_info| decl_info.id()), + ); } for semantic_decl in semantic_decl_set { builder.add_description(semantic_decl); @@ -295,6 +283,13 @@ fn build_member_hover( Some(()) } +fn into_hover_decl_infos(semantic_decls: Vec<(LuaSemanticDeclId, LuaType)>) -> Vec { + semantic_decls + .into_iter() + .map(|(semantic_decl_id, typ)| HoverDeclInfo::new(semantic_decl_id, typ)) + .collect() +} + pub fn add_signature_param_description( db: &DbIndex, marked_strings: &mut Vec, @@ -390,79 +385,3 @@ pub fn get_hover_type(builder: &HoverBuilder, semantic_model: &SemanticModel) -> None } - -#[allow(unused)] -fn adjust_semantic_decls( - builder: &mut HoverBuilder, - semantic_decls: &mut Vec<(LuaSemanticDeclId, LuaType)>, - current_semantic_decl_id: &LuaSemanticDeclId, - current_type: &LuaType, -) -> Option<()> { - if let Some(pos) = semantic_decls - .iter() - .position(|(_, typ)| current_type == typ) - { - let item = semantic_decls.remove(pos); - semantic_decls.push(item); - return Some(()); - } - // semantic_decls 是追溯最初定义的结果, 不包含当前内容 - let current_len = semantic_decls.len(); - if current_len == 0 { - // 没有最初定义, 直接添加原始内容 - semantic_decls.push((current_semantic_decl_id.clone(), current_type.clone())); - return Some(()); - } - // 此时有最初定义, 证明当前内容的是派生的或者全部项实例化后联合的结果, 非常难以区分 - // 如果当前定义是 LuaDecl 且追溯到了最初定义, 那么我们不需要添加 - if let LuaSemanticDeclId::LuaDecl(_) = current_semantic_decl_id { - return Some(()); - } - - // 如果当前定义在最初定义组中存在, 那么我们也不需要添加. - // 具有一个难以解决的问题, 返回的`current_semantic_decl_id`为 member 时, 不一定是当前 token 指向的内容, 因此我们还需要再做一层判断, - // 如果是具有实际定义的, 我们仍然需要添加, 例如 signature. - if semantic_decls - .iter() - .any(|(decl, typ)| decl == current_semantic_decl_id && !typ.is_signature()) - { - return Some(()); - } - - if has_add_to_semantic_decls(builder, current_semantic_decl_id).unwrap_or(true) { - semantic_decls.push((current_semantic_decl_id.clone(), current_type.clone())); - }; - - Some(()) -} - -fn has_add_to_semantic_decls( - builder: &mut HoverBuilder, - semantic_decl_id: &LuaSemanticDeclId, -) -> Option { - if let LuaSemanticDeclId::Member(member_id) = semantic_decl_id { - let semantic_model = if member_id.file_id == builder.semantic_model.get_file_id() { - builder.semantic_model - } else { - &builder.compilation.get_semantic_model(member_id.file_id)? - }; - - let root = semantic_model.get_root().syntax(); - let current_node = member_id.get_syntax_id().to_node_from_root(root)?; - if member_id.get_syntax_id().get_kind() == LuaSyntaxKind::TableFieldAssign { - if LuaTableField::can_cast(current_node.kind().into()) { - let table_field = LuaTableField::cast(current_node.clone())?; - let parent = table_field.syntax().parent()?; - let table_expr = LuaTableExpr::cast(parent)?; - let table_type = semantic_model.infer_table_should_be(table_expr.clone())?; - if matches!(table_type, LuaType::Ref(_) | LuaType::Generic(_)) { - // 如果位于函数调用中, 则不添加 - let is_in_call = table_expr.ancestors::().next().is_some(); - return Some(!is_in_call); - } - } - }; - } - - Some(true) -} diff --git a/crates/emmylua_ls/src/handlers/hover/decl_context.rs b/crates/emmylua_ls/src/handlers/hover/decl_context.rs new file mode 100644 index 000000000..7190ed1b1 --- /dev/null +++ b/crates/emmylua_ls/src/handlers/hover/decl_context.rs @@ -0,0 +1,85 @@ +use emmylua_code_analysis::{LuaSemanticDeclId, LuaType}; + +#[derive(Debug, Clone, PartialEq)] +pub(crate) struct HoverDeclInfo { + id: LuaSemanticDeclId, + typ: LuaType, +} + +impl HoverDeclInfo { + pub(crate) fn new(id: LuaSemanticDeclId, typ: LuaType) -> Self { + Self { id, typ } + } + + pub(crate) fn id(&self) -> &LuaSemanticDeclId { + &self.id + } + + pub(crate) fn typ(&self) -> &LuaType { + &self.typ + } +} + +#[derive(Debug, Clone)] +pub(crate) struct HoverDeclContext { + current_decl: HoverDeclInfo, + origin_decls: Vec, +} + +impl HoverDeclContext { + pub(crate) fn new(current_decl: HoverDeclInfo, origin_decls: Vec) -> Self { + Self { + current_decl, + origin_decls, + } + } + + pub(crate) fn current_decl(&self) -> &HoverDeclInfo { + &self.current_decl + } + + pub(crate) fn origin_decls(&self) -> &[HoverDeclInfo] { + &self.origin_decls + } + + fn primary_decl(&self) -> &HoverDeclInfo { + self.origin_decls + .iter() + .find(|decl| decl.typ().is_signature()) + .or_else(|| { + self.origin_decls + .iter() + .find(|decl| decl.typ() == self.current_decl.typ()) + }) + .or_else(|| self.origin_decls.first()) + .unwrap_or(&self.current_decl) + } + + pub(crate) fn ordered_decl_refs(&self) -> Vec<&HoverDeclInfo> { + let mut decls = if self.origin_decls.is_empty() { + vec![&self.current_decl] + } else { + self.origin_decls.iter().collect::>() + }; + + if let Some(pos) = decls + .iter() + .position(|decl| decl.typ() == self.current_decl.typ()) + { + if pos != 0 { + let item = decls.remove(pos); + decls.insert(0, item); + } + } + + let primary_decl = self.primary_decl(); + if let Some(pos) = decls.iter().position(|decl| *decl == primary_decl) { + if pos != 0 { + let item = decls.remove(pos); + decls.insert(0, item); + } + } + + decls + } +} diff --git a/crates/emmylua_ls/src/handlers/hover/find_origin.rs b/crates/emmylua_ls/src/handlers/hover/find_origin.rs deleted file mode 100644 index b786737fc..000000000 --- a/crates/emmylua_ls/src/handlers/hover/find_origin.rs +++ /dev/null @@ -1,336 +0,0 @@ -use std::collections::HashSet; - -use emmylua_code_analysis::{ - LuaCompilation, LuaDeclExtra, LuaDeclId, LuaMemberId, LuaSemanticDeclId, LuaType, LuaUnionType, - SemanticDeclLevel, SemanticModel, -}; -use emmylua_parser::{LuaAssignStat, LuaAstNode, LuaSyntaxKind, LuaTableExpr, LuaTableField}; - -#[derive(Debug, Clone)] -pub enum DeclOriginResult { - Single(LuaSemanticDeclId), - Multiple(Vec), -} - -impl DeclOriginResult { - pub fn get_first(&self) -> Option { - match self { - DeclOriginResult::Single(decl) => Some(decl.clone()), - DeclOriginResult::Multiple(decls) => decls.first().cloned(), - } - } - - pub fn get_types(&self, semantic_model: &SemanticModel) -> Vec<(LuaSemanticDeclId, LuaType)> { - let get_type = |decl: &LuaSemanticDeclId| -> Option<(LuaSemanticDeclId, LuaType)> { - match decl { - LuaSemanticDeclId::Member(member_id) => { - let typ = semantic_model.get_type((*member_id).into()); - Some((decl.clone(), typ)) - } - LuaSemanticDeclId::LuaDecl(decl_id) => { - let db = semantic_model.get_db(); - let decl_info = db.get_decl_index().get_decl(decl_id)?; - let typ = if let LuaDeclExtra::Param { - idx, signature_id, .. - } = &decl_info.extra - { - db.get_signature_index() - .get(signature_id)? - .get_param_info_by_id(*idx)? - .type_ref - .clone() - } else { - semantic_model.get_type((*decl_id).into()) - }; - Some((decl.clone(), typ)) - } - _ => None, - } - }; - - match self { - DeclOriginResult::Single(decl) => get_type(decl).into_iter().collect(), - DeclOriginResult::Multiple(decls) => decls.iter().filter_map(get_type).collect(), - } - } -} - -pub fn find_decl_origin_owners( - compilation: &LuaCompilation, - semantic_model: &SemanticModel, - decl_id: LuaDeclId, -) -> DeclOriginResult { - let node = semantic_model - .get_db() - .get_vfs() - .get_syntax_tree(&decl_id.file_id) - .and_then(|tree| { - let root = tree.get_red_root(); - semantic_model - .get_db() - .get_decl_index() - .get_decl(&decl_id) - .and_then(|decl| decl.get_value_syntax_id()) - .and_then(|syntax_id| syntax_id.to_node_from_root(&root)) - }); - - if let Some(node) = node { - let semantic_decl = semantic_model.find_decl(node.into(), SemanticDeclLevel::default()); - match semantic_decl { - Some(LuaSemanticDeclId::Member(member_id)) => { - find_member_origin_owners(compilation, semantic_model, member_id, true) - } - Some(LuaSemanticDeclId::LuaDecl(decl_id)) => { - DeclOriginResult::Single(LuaSemanticDeclId::LuaDecl(decl_id)) - } - _ => DeclOriginResult::Single(LuaSemanticDeclId::LuaDecl(decl_id)), - } - } else { - DeclOriginResult::Single(LuaSemanticDeclId::LuaDecl(decl_id)) - } -} - -pub fn find_member_origin_owners( - compilation: &LuaCompilation, - semantic_model: &SemanticModel, - member_id: LuaMemberId, - find_all: bool, -) -> DeclOriginResult { - const MAX_ITERATIONS: usize = 50; - let mut visited_members = HashSet::new(); - - let mut current_owner = resolve_member_owner(compilation, semantic_model, &member_id); - let mut final_owner = current_owner.clone(); - let mut iteration_count = 0; - - while let Some(LuaSemanticDeclId::Member(current_member_id)) = ¤t_owner { - if visited_members.contains(current_member_id) || iteration_count >= MAX_ITERATIONS { - break; - } - - visited_members.insert(*current_member_id); - iteration_count += 1; - - match resolve_member_owner(compilation, semantic_model, current_member_id) { - Some(next_owner) => { - final_owner = Some(next_owner.clone()); - current_owner = Some(next_owner); - } - None => break, - } - } - - if final_owner.is_none() { - final_owner = Some(LuaSemanticDeclId::Member(member_id)); - } - - if !find_all { - return DeclOriginResult::Single( - final_owner.unwrap_or_else(|| LuaSemanticDeclId::Member(member_id)), - ); - } - - // 如果存在多个同名成员, 则返回多个成员 - if let Some(same_named_members) = find_all_same_named_members(semantic_model, &final_owner) - && same_named_members.len() > 1 - { - return DeclOriginResult::Multiple(same_named_members); - } - // 否则返回单个成员 - DeclOriginResult::Single(final_owner.unwrap_or_else(|| LuaSemanticDeclId::Member(member_id))) -} - -pub fn find_member_origin_owner( - compilation: &LuaCompilation, - semantic_model: &SemanticModel, - member_id: LuaMemberId, -) -> Option { - find_member_origin_owners(compilation, semantic_model, member_id, false).get_first() -} - -pub fn find_all_same_named_members( - semantic_model: &SemanticModel, - final_owner: &Option, -) -> Option> { - let final_owner = final_owner.as_ref()?; - let member_id = match final_owner { - LuaSemanticDeclId::Member(id) => id, - _ => return None, - }; - - let original_member = semantic_model - .get_db() - .get_member_index() - .get_member(member_id)?; - - let target_key = original_member.get_key(); - let current_owner = semantic_model - .get_db() - .get_member_index() - .get_current_owner(member_id)?; - - let all_members = semantic_model - .get_db() - .get_member_index() - .get_members(current_owner)?; - let same_named: Vec = all_members - .iter() - .filter(|member| member.get_key() == target_key) - .map(|member| LuaSemanticDeclId::Member(member.get_id())) - .collect(); - - if same_named.is_empty() { - None - } else { - Some(same_named) - } -} - -fn resolve_member_owner( - compilation: &LuaCompilation, - semantic_model: &SemanticModel, - member_id: &LuaMemberId, -) -> Option { - // 通常来说, 即使需要跨文件也一般只会跨一个文件, 所有不需要缓存 - let semantic_model = if member_id.file_id == semantic_model.get_file_id() { - semantic_model - } else { - &compilation.get_semantic_model(member_id.file_id)? - }; - - let root = semantic_model.get_root().syntax(); - let current_node = member_id.get_syntax_id().to_node_from_root(root)?; - let result = match member_id.get_syntax_id().get_kind() { - LuaSyntaxKind::TableFieldAssign => { - if LuaTableField::can_cast(current_node.kind().into()) { - let table_field = LuaTableField::cast(current_node.clone())?; - // 如果表是类, 那么通过类型推断获取 owner - if let Some(owner_id) = - resolve_table_field_through_type_inference(semantic_model, &table_field) - { - return Some(owner_id); - } - // 非类, 那么通过右值推断 - let value_expr = table_field.get_value_expr()?; - let value_node = value_expr.get_syntax_id().to_node_from_root(root)?; - semantic_model.find_decl(value_node.into(), SemanticDeclLevel::default()) - } else { - None - } - } - LuaSyntaxKind::IndexExpr => { - let assign_node = current_node.parent()?; - let assign_stat = LuaAssignStat::cast(assign_node)?; - let (vars, exprs) = assign_stat.get_var_and_expr_list(); - - let mut result = None; - for (var, expr) in vars.iter().zip(exprs.iter()) { - if var.syntax().text_range() == current_node.text_range() { - let expr_node = expr.get_syntax_id().to_node_from_root(root)?; - result = - semantic_model.find_decl(expr_node.into(), SemanticDeclLevel::default()); - break; - } - } - result - } - _ => None, - }; - - // 禁止追溯到参数 - match result { - Some(LuaSemanticDeclId::LuaDecl(decl_id)) => { - let decl = semantic_model - .get_db() - .get_decl_index() - .get_decl(&decl_id)?; - if decl.is_param() { - return None; - } - result - } - _ => result, - } -} - -// 判断`table`是否为类 -fn table_is_class(table_type: &LuaType, depth: usize) -> bool { - if depth > 10 { - return false; - } - match table_type { - LuaType::Ref(_) | LuaType::Def(_) | LuaType::Generic(_) => true, - LuaType::Union(union) => match union.as_ref() { - LuaUnionType::Basic(_) => false, - LuaUnionType::Nullable(t) => table_is_class(t, depth + 1), - LuaUnionType::Multi(ts) => ts.iter().any(|t| table_is_class(t, depth + 1)), - }, - _ => false, - } -} - -fn resolve_table_field_through_type_inference( - semantic_model: &SemanticModel, - table_field: &LuaTableField, -) -> Option { - let parent = table_field.syntax().parent()?; - let table_expr = LuaTableExpr::cast(parent)?; - let table_type = semantic_model.infer_table_should_be(table_expr)?; - - // 必须为类我们才搜索其成员 - if !table_is_class(&table_type, 0) { - return None; - } - - let field_key = table_field.get_field_key()?; - let key = semantic_model.get_member_key(&field_key)?; - let member_infos = semantic_model.get_member_info_with_key(&table_type, key, false)?; - member_infos - .first() - .cloned() - .and_then(|m| m.property_owner_id) -} - -#[allow(unused)] -pub fn replace_semantic_type( - semantic_decls: &mut [(LuaSemanticDeclId, LuaType)], - origin_type: &LuaType, -) { - // `origin_type`不一定包含所有`semantic_decls`中的类型, 实际的推断可能非常复杂, 这里仅是临时方案. - - // 解开`origin_type` - let mut type_vec = Vec::new(); - match origin_type { - LuaType::Union(union) => { - for typ in union.into_vec() { - type_vec.push(typ); - } - } - _ => { - type_vec.push(origin_type.clone()); - } - } - if type_vec.len() != semantic_decls.len() { - return; - } - - // 判断是否存在泛型, 如果有任意类型不匹配我们就认为存在泛型 - let mut has_generic = false; - let type_set: HashSet<_> = type_vec.iter().collect(); - for (_, typ) in semantic_decls.iter() { - if !type_set.contains(&typ) { - has_generic = true; - break; - } - } - if !has_generic { - return; - } - - // 替换`semantic_decls`中的类型 - for (i, (_, typ)) in semantic_decls.iter_mut().enumerate() { - if i < type_vec.len() { - *typ = type_vec[i].clone(); - } - } -} diff --git a/crates/emmylua_ls/src/handlers/hover/function/call_hover.rs b/crates/emmylua_ls/src/handlers/hover/function/call_hover.rs new file mode 100644 index 000000000..1358aba57 --- /dev/null +++ b/crates/emmylua_ls/src/handlers/hover/function/call_hover.rs @@ -0,0 +1,150 @@ +use std::sync::Arc; + +use emmylua_code_analysis::{DbIndex, LuaFunctionType, LuaType, find_callable_overload}; +use emmylua_parser::LuaCallExpr; + +use crate::handlers::hover::{HoverBuilder, HoverDeclContext, HoverDeclInfo}; + +use super::{ + define_hover::{HoverFunctionInfo, set_function_info_to_builder}, + extract_function_member, get_function_description, + render::process_function_type, +}; + +pub(super) fn build_function_call_hover( + builder: &mut HoverBuilder, + db: &DbIndex, + decl_context: &HoverDeclContext, + call_expr: &LuaCallExpr, +) -> Option<()> { + let ordered_decls = decl_context.ordered_decl_refs(); + let call_arg_types = infer_call_arg_types(builder, call_expr); + let mut function_infos = Vec::new(); + + let matched_decls = + find_decls_for_call(builder, db, &ordered_decls, &call_arg_types, call_expr); + if matched_decls.is_empty() { + for matched_decl in ordered_decls { + if let Some(info) = + build_unmatched_call_hover_function_info(builder, db, matched_decl, call_expr) + { + function_infos.push(info); + } + } + + return set_function_info_to_builder(builder, &mut function_infos); + } + + for matched_decl in matched_decls { + let info = build_call_hover_function_info(builder, db, matched_decl); + if let Some(info) = info { + function_infos.push(info); + } + } + + set_function_info_to_builder(builder, &mut function_infos) +} + +fn infer_call_arg_types(builder: &HoverBuilder, call_expr: &LuaCallExpr) -> Vec { + let Some(args) = call_expr.get_args_list() else { + return Vec::new(); + }; + let args = args.get_args().collect::>(); + builder + .semantic_model + .infer_expr_list_types(&args, None) + .into_iter() + .map(|(typ, _)| typ) + .collect() +} + +fn build_unmatched_call_hover_function_info( + builder: &mut HoverBuilder, + db: &DbIndex, + matched_decl: &HoverDeclInfo, + call_expr: &LuaCallExpr, +) -> Option { + let match_semantic_decl = matched_decl.id(); + let function_member = extract_function_member(db, match_semantic_decl); + let contents = process_function_type( + builder, + db, + matched_decl.typ(), + match_semantic_decl, + function_member, + Some(call_expr), + )?; + if contents.is_empty() { + return None; + } + + let description = get_function_description(builder, db, match_semantic_decl); + HoverFunctionInfo::from_contents(contents, description) +} + +fn build_call_hover_function_info( + builder: &mut HoverBuilder, + db: &DbIndex, + matched_decl: MatchedCallDecl<'_>, +) -> Option { + let match_semantic_decl = matched_decl.decl.id(); + let function_member = extract_function_member(db, match_semantic_decl); + let call_type = LuaType::DocFunction(matched_decl.func); + + let contents = process_function_type( + builder, + db, + &call_type, + match_semantic_decl, + function_member, + None, + )?; + + let description = get_function_description(builder, db, match_semantic_decl); + HoverFunctionInfo::from_contents(contents, description) +} + +struct MatchedCallDecl<'a> { + decl: &'a HoverDeclInfo, + func: Arc, +} + +fn find_decls_for_call<'a>( + builder: &HoverBuilder, + db: &DbIndex, + ordered_decls: &[&'a HoverDeclInfo], + call_arg_types: &[LuaType], + call_expr: &LuaCallExpr, +) -> Vec> { + let mut matched_decls = Vec::new(); + + for decl in ordered_decls.iter().copied() { + if let Some(func) = + find_callable_for_call(builder, db, decl.typ(), call_arg_types, call_expr) + { + matched_decls.push(MatchedCallDecl { decl, func }); + } + } + + matched_decls +} + +fn find_callable_for_call( + builder: &HoverBuilder, + db: &DbIndex, + decl_type: &LuaType, + call_arg_types: &[LuaType], + call_expr: &LuaCallExpr, +) -> Option> { + find_callable_overload( + db, + &mut builder.semantic_model.get_cache().borrow_mut(), + decl_type, + call_arg_types, + call_expr, + None, + true, + ) + .ok() + .flatten() +} diff --git a/crates/emmylua_ls/src/handlers/hover/function/define_hover.rs b/crates/emmylua_ls/src/handlers/hover/function/define_hover.rs new file mode 100644 index 000000000..9d4bbd4c8 --- /dev/null +++ b/crates/emmylua_ls/src/handlers/hover/function/define_hover.rs @@ -0,0 +1,144 @@ +use emmylua_code_analysis::{DbIndex, TypeSubstitutor}; +use emmylua_parser::{LuaAstNode, LuaLocalName, LuaLocalStat}; + +use crate::handlers::hover::{HoverBuilder, HoverDeclContext, humanize_types::DescriptionInfo}; + +use super::{ + extract_function_member, generic::index_prefix_substitutor, + generic::instantiate_type_if_needed, get_function_description, render::process_function_type, +}; + +/// Hover 函数信息聚合 +#[derive(Debug, Clone)] +pub(super) struct HoverFunctionInfo { + pub primary: String, + pub overloads: Option>, + pub description: Option, +} + +impl HoverFunctionInfo { + /// 从渲染结果构造 HoverFunctionInfo,消除重复的构造模式 + pub fn from_contents( + contents: Vec, + description: Option, + ) -> Option { + let mut contents = contents.into_iter(); + let primary = contents.next()?; + let overloads = { + let overloads = contents.collect::>(); + (!overloads.is_empty()).then_some(overloads) + }; + Some(Self { + primary, + overloads, + description, + }) + } +} + +pub(super) fn build_function_define_hover( + builder: &mut HoverBuilder, + db: &DbIndex, + decl_context: &HoverDeclContext, +) -> Option<()> { + let mut function_infos = Vec::new(); + let ordered_decls = decl_context.ordered_decl_refs(); + let substitutor = ordered_decls + .iter() + .any(|decl_info| decl_info.typ().contain_tpl()) + .then(|| infer_define_substitutor(builder)) + .flatten(); + + for decl_info in ordered_decls { + let semantic_decl_id = decl_info.id(); + let function_member = extract_function_member(db, semantic_decl_id); + let instantiated_type = substitutor + .as_ref() + .and_then(|substitutor| instantiate_type_if_needed(db, decl_info.typ(), substitutor)); + let typ = instantiated_type + .as_ref() + .unwrap_or_else(|| decl_info.typ()); + + let Some(contents) = + process_function_type(builder, db, typ, semantic_decl_id, function_member, None) + else { + continue; + }; + if contents.is_empty() { + continue; + } + let description = get_function_description(builder, db, semantic_decl_id); + if let Some(info) = HoverFunctionInfo::from_contents(contents, description) { + function_infos.push(info); + } + } + + set_function_info_to_builder(builder, &mut function_infos) +} + +fn infer_define_substitutor(builder: &HoverBuilder) -> Option { + let token = builder.get_trigger_token()?; + let target_local_name = LuaLocalName::cast(token.parent()?)?; + let local_stat = LuaLocalStat::cast(target_local_name.syntax().parent()?)?; + + for (index, name) in local_stat.get_local_name_list().enumerate() { + if target_local_name == name { + let value_expr = local_stat.get_value_exprs().nth(index)?; + return index_prefix_substitutor(builder, &value_expr); + } + } + + None +} + +/// 统一处理文本设置 +pub(super) fn set_function_info_to_builder( + builder: &mut HoverBuilder, + function_infos: &mut Vec, +) -> Option<()> { + // 去重 + function_infos.dedup_by(|a, b| a.primary == b.primary); + if function_infos.is_empty() { + return None; + } + + let main = function_infos.remove(0); + + // 计算 overload 的总数 + let overload_count = main.overloads.as_ref().map_or(0, |o| o.len()) + + function_infos + .iter() + .map(|info| 1 + info.overloads.as_ref().map_or(0, |o| o.len())) + .sum::(); + + let main_primary = if overload_count > 0 { + format!("{} (+{} overloads)", main.primary, overload_count) + } else { + main.primary + }; + + builder.set_type_description(main_primary); + builder.add_description_from_info(main.description); + + // 添加 main 的 overloads + if let Some(overloads) = main.overloads { + for overload in overloads { + builder.add_signature_overload(overload, None); + } + } + + // 添加其余条目 + for type_desc in function_infos.drain(..) { + let comment = type_desc + .description + .and_then(|description| description.description); + builder.add_signature_overload(type_desc.primary, comment); + if let Some(overloads) = type_desc.overloads { + for overload in overloads { + builder.add_signature_overload(overload, None); + } + } + } + + Some(()) +} diff --git a/crates/emmylua_ls/src/handlers/hover/function/generic.rs b/crates/emmylua_ls/src/handlers/hover/function/generic.rs new file mode 100644 index 000000000..048565491 --- /dev/null +++ b/crates/emmylua_ls/src/handlers/hover/function/generic.rs @@ -0,0 +1,88 @@ +use emmylua_code_analysis::{ + DbIndex, LuaType, LuaTypeDeclId, TypeSubstitutor, instantiate_type_generic, +}; +use emmylua_parser::LuaExpr; + +use crate::handlers::hover::HoverBuilder; + +pub(super) fn instantiate_type_if_needed( + db: &DbIndex, + typ: &LuaType, + substitutor: &TypeSubstitutor, +) -> Option { + typ.contain_tpl() + .then(|| instantiate_type_generic(db, typ, substitutor)) +} + +pub(super) fn index_prefix_substitutor( + builder: &HoverBuilder, + expr: &LuaExpr, +) -> Option { + let LuaExpr::IndexExpr(index_expr) = expr else { + return None; + }; + let prefix_type = builder + .semantic_model + .infer_expr(index_expr.get_prefix_expr()?) + .ok()?; + match prefix_type { + LuaType::Generic(generic) => Some(TypeSubstitutor::from_type_array( + generic.get_params().clone(), + )), + _ => None, + } +} + +pub(super) fn owner_type_substitutor( + db: &DbIndex, + typ: &LuaType, + owner_type_id: &LuaTypeDeclId, +) -> Option { + match typ { + LuaType::Generic(generic) => { + if generic.get_base_type_id_ref() == owner_type_id { + Some(TypeSubstitutor::from_type_array( + generic.get_params().clone(), + )) + } else { + None + } + } + LuaType::Ref(id) | LuaType::Def(id) => { + if id == owner_type_id { + unknown_type_substitutor(db, owner_type_id) + } else { + None + } + } + LuaType::Union(union) => { + let mut substitutor = None; + for typ in union.into_vec() { + let Some(generic_substitutor) = owner_type_substitutor(db, &typ, owner_type_id) + else { + continue; + }; + if substitutor.is_some() { + return None; + } + substitutor = Some(generic_substitutor); + } + substitutor + } + _ => None, + } +} + +pub(super) fn unknown_type_substitutor( + db: &DbIndex, + owner_type_id: &LuaTypeDeclId, +) -> Option { + let generic_params = db.get_type_index().get_generic_params(owner_type_id)?; + if generic_params.is_empty() { + return None; + } + Some(TypeSubstitutor::from_type_array(vec![ + LuaType::Unknown; + generic_params.len() + ])) +} diff --git a/crates/emmylua_ls/src/handlers/hover/function/mod.rs b/crates/emmylua_ls/src/handlers/hover/function/mod.rs index 2402e078e..cf127bb98 100644 --- a/crates/emmylua_ls/src/handlers/hover/function/mod.rs +++ b/crates/emmylua_ls/src/handlers/hover/function/mod.rs @@ -1,554 +1,94 @@ -use std::{collections::HashSet, sync::Arc, vec}; +mod call_hover; +mod define_hover; +mod generic; +mod render; +mod table_field; use emmylua_code_analysis::{ - AsyncState, DbIndex, InferGuard, LuaDocReturnInfo, LuaDocReturnOverloadInfo, LuaFunctionType, - LuaMember, LuaMemberOwner, LuaSemanticDeclId, LuaSignature, LuaType, RenderLevel, - TypeSubstitutor, VariadicType, humanize_type, infer_call_expr_func, instantiate_doc_function, - instantiate_func_generic, try_extract_signature_id_from_field, + DbIndex, LuaMember, LuaSemanticDeclId, LuaType, infer_table_should_be, + try_extract_signature_id_from_field, }; +use emmylua_parser::{LuaAstNode, LuaCallExpr, LuaSyntaxToken, LuaTableExpr, LuaTableField}; use crate::handlers::hover::{ - HoverBuilder, - humanize_types::{ - DescriptionInfo, extract_description_from_property_owner, extract_owner_name_from_element, - extract_parent_type_from_element, hover_humanize_type, - }, - infer_prefix_global_name, + HoverBuilder, HoverDeclContext, + humanize_types::{DescriptionInfo, extract_description_from_property_owner}, }; -pub fn build_function_hover( - builder: &mut HoverBuilder, - db: &DbIndex, - semantic_decls: &[(LuaSemanticDeclId, LuaType)], -) -> Option<()> { - let (function_name, is_local) = { - let (semantic_decl, _) = semantic_decls.first()?; - match semantic_decl { - LuaSemanticDeclId::LuaDecl(id) => { - let decl = db.get_decl_index().get_decl(id)?; - (decl.get_name().to_string(), decl.is_local()) - } - LuaSemanticDeclId::Member(id) => { - let member = db.get_member_index().get_member(id)?; - (member.get_key().to_path(), false) - } - _ => { - return None; - } - } - }; +use call_hover::build_function_call_hover; +use define_hover::build_function_define_hover; +use table_field::build_table_field_hover; - // 如果是函数调用, 那么我们需要根据上下文实例化出实际类型 - if let Some(call_expr) = builder.get_call_expr() { - build_function_call_hover( - builder, - db, - semantic_decls, - &call_expr, - &function_name, - is_local, - ); - } else { - build_function_define_hover(builder, db, semantic_decls, &function_name, is_local); - } - - Some(()) -} - -fn build_function_call_hover( +pub(crate) fn build_function_hover( builder: &mut HoverBuilder, db: &DbIndex, - semantic_decls: &[(LuaSemanticDeclId, LuaType)], - call_expr: &emmylua_parser::LuaCallExpr, - function_name: &str, - is_local: bool, + decl_context: &HoverDeclContext, ) -> Option<()> { - let final_type = infer_call_expr_func( - db, - &mut builder.semantic_model.get_cache().borrow_mut(), - call_expr.clone(), - semantic_decls.last()?.1.clone(), - &InferGuard::new(), - None, - ) - .ok()?; - - // 根据推断出来的类型确定哪个 semantic_decl 是匹配的 - let mut matched_decl = semantic_decls.last()?; - for semantic_decl in semantic_decls.iter() { - let (_, typ) = semantic_decl; - if let LuaType::DocFunction(f) = typ { - if f == &final_type { - matched_decl = semantic_decl; - break; - } - } - } - let (match_semantic_decl, match_typ) = matched_decl; - - let function_member = match match_semantic_decl { - LuaSemanticDeclId::Member(id) => { - let member = db.get_member_index().get_member(&id)?; - Some(member) + if let Some(token) = builder.get_trigger_token() { + if let Some(call_expr) = get_call_expr(&token) { + return build_function_call_hover(builder, db, decl_context, &call_expr); } - _ => None, - }; - let is_field = function_member_is_field(db, semantic_decls); - let contents = if let LuaType::Signature(signature_id) = match_typ { - let signature = db.get_signature_index().get(signature_id)?; - let base_function = LuaFunctionType::new( - signature.async_state, - signature.is_colon_define, - signature.is_vararg, - signature.get_type_params(), - signature.get_return_type(), - ); - let instantiated_signature = instantiate_func_generic( - db, - &mut builder.semantic_model.get_cache().borrow_mut(), - &base_function, - call_expr.clone(), - ) - .ok()?; - - if !signature.return_overloads.is_empty() - && final_type.get_async_state() == instantiated_signature.get_async_state() - && final_type.is_colon_define() == instantiated_signature.is_colon_define() - && final_type.is_variadic() == instantiated_signature.is_variadic() - && final_type.get_params() == instantiated_signature.get_params() - { - let return_overloads = - instantiate_call_return_overloads(builder, db, call_expr, signature); - let ret_detail = build_function_return_overload_rows(builder, &return_overloads); - vec![hover_doc_function_type( - builder, - db, - final_type.as_ref(), - function_member, - function_name, - is_local, - is_field, - Vec::new(), - Some(ret_detail), - )] - } else { - process_function_type( - builder, - db, - &LuaType::DocFunction(final_type), - function_member, - function_name, - is_local, - is_field, - )? + if let Some(parent_table_type) = infer_table_field_parent_type(builder, db, &token) { + return build_table_field_hover(builder, db, decl_context, &parent_table_type); } - } else { - process_function_type( - builder, - db, - &LuaType::DocFunction(final_type), - function_member, - function_name, - is_local, - is_field, - )? - }; - let description = get_function_description(builder, db, &match_semantic_decl); - builder.set_type_description(contents.first()?.clone()); - builder.add_description_from_info(description); + } - Some(()) + build_function_define_hover(builder, db, decl_context) } -#[derive(Debug, Clone)] -struct HoverFunctionInfo { - primary: String, - overloads: Option>, - description: Option, +pub(crate) fn has_function_candidate(decl_context: &HoverDeclContext) -> bool { + is_function(decl_context.current_decl().typ()) + || decl_context + .origin_decls() + .iter() + .any(|decl_info| is_function(decl_info.typ())) } -#[allow(unused)] -fn build_function_define_hover( - builder: &mut HoverBuilder, - db: &DbIndex, - semantic_decls: &[(LuaSemanticDeclId, LuaType)], - function_name: &str, - is_local: bool, -) -> Option<()> { - let is_field = function_member_is_field(db, semantic_decls); - let mut function_infos = Vec::new(); - for (semantic_decl_id, typ) in semantic_decls { - let mut typ = typ.clone(); - let function_member = match semantic_decl_id { - LuaSemanticDeclId::Member(id) => { - let member = db.get_member_index().get_member(&id)?; - Some(member) - } - _ => None, - }; - - if let Some(substitutor) = &builder.substitutor { - if let Some(lua_func) = hover_instantiate_function_type(db, &typ, substitutor) { - typ = LuaType::DocFunction(lua_func); - } - } - - let Some(contents) = process_function_type( - builder, - db, - &typ, - function_member, - function_name, - is_local, - is_field, - ) else { - continue; - }; - if contents.is_empty() { - continue; - } - let description = get_function_description(builder, db, &semantic_decl_id); - function_infos.push(HoverFunctionInfo { - primary: contents.first()?.clone(), - overloads: if contents.len() > 1 { - Some(contents[1..].to_vec()) - } else { - None - }, - description, - }); - } - - // 去重, 这是必须的 - function_infos.dedup_by_key(|info| info.primary.clone()); - - // 需要显示重载的情况 - match function_infos.len() { - 0 => { - return None; - } - 1 => { - builder.set_type_description(function_infos[0].primary.clone()); - builder.add_description_from_info(function_infos[0].description.clone()); - } - _ => { - let main_type = function_infos.pop()?; - builder.set_type_description(main_type.primary.clone()); - builder.add_description_from_info(main_type.description.clone()); - - for type_desc in function_infos { - builder.add_signature_overload(type_desc.primary.clone()); - if let Some(overloads) = &type_desc.overloads { - for overload in overloads { - builder.add_signature_overload(overload.clone()); - } - } - builder.add_description_from_info(type_desc.description.clone()); - } - } +fn get_call_expr(token: &LuaSyntaxToken) -> Option { + let token_start = token.text_range().start(); + let call_expr = token.parent()?.ancestors().find_map(LuaCallExpr::cast)?; + let prefix_expr = call_expr.get_prefix_expr()?; + if prefix_expr.syntax().text_range().contains(token_start) { + Some(call_expr) + } else { + None } - Some(()) } -fn process_function_type( - builder: &mut HoverBuilder, - db: &DbIndex, - typ: &LuaType, - function_member: Option<&LuaMember>, - function_name: &str, - is_local: bool, - is_field: bool, -) -> Option> { - match typ { - LuaType::DocFunction(lua_func) => { - let content = hover_doc_function_type( - builder, - db, - lua_func, - function_member, - &function_name, - is_local, - is_field, - convert_function_return_to_docs(lua_func), - None, - ); - Some(vec![content]) - } - LuaType::Signature(signature_id) => { - let signature = db.get_signature_index().get(&signature_id)?; - let mut new_overloads = signature.overloads.clone(); - let fake_doc_function = Arc::new(LuaFunctionType::new( - signature.async_state, - signature.is_colon_define, - signature.is_vararg, - signature.get_type_params(), - signature.get_return_type(), - )); - new_overloads.insert(0, fake_doc_function.clone()); - let mut contents = Vec::with_capacity(new_overloads.len()); - for (i, overload) in new_overloads.iter().enumerate() { - let content = if i == 0 && !signature.return_overloads.is_empty() { - let ret_detail = - build_function_return_overload_rows(builder, &signature.return_overloads); - hover_doc_function_type( - builder, - db, - overload, - function_member, - function_name, - is_local, - is_field, - Vec::new(), - Some(ret_detail), - ) - } else { - hover_doc_function_type( - builder, - db, - overload, - function_member, - function_name, - is_local, - is_field, - if i == 0 { - if signature.return_docs.is_empty() { - convert_function_return_to_docs(fake_doc_function.as_ref()) - } else { - signature.return_docs.clone() - } - } else { - convert_function_return_to_docs(overload) - }, - None, - ) - }; - contents.push(content); - } - Some(contents) - } - LuaType::Union(union) => { - let mut contents = Vec::new(); - for typ in union.into_vec() { - if let Some(content) = process_function_type( - builder, - db, - &typ, - function_member, - function_name, - is_local, - is_field, - ) { - contents.extend(content); - } - } - Some(contents) - } - _ => None, - } +fn get_table_field_expr(token: &LuaSyntaxToken) -> Option { + token + .parent() + .and_then(LuaTableField::cast)? + .get_parent::() } -fn hover_doc_function_type( +fn infer_table_field_parent_type( builder: &mut HoverBuilder, db: &DbIndex, - func: &LuaFunctionType, - owner_member: Option<&LuaMember>, - func_name: &str, - is_local: bool, - is_field: bool, /* 是否为类字段 */ - return_docs: Vec, /* 返回值以此为准 */ - ret_detail: Option, -) -> String { - let async_label = match func.get_async_state() { - AsyncState::Async => "async ", - AsyncState::Sync => "sync ", - _ => "", - }; - let mut is_method = func.is_colon_define(); - let mut type_label = if is_local && owner_member.is_none() { - "local function " - } else { - "function " - }; - - // 有可能来源于类. 例如: `local add = class.add`, `add()`应被视为类方法 - let full_name = if let Some(owner_member) = owner_member { - if is_field { - type_label = "(field) "; - } - - let member_key = owner_member.get_key().to_path(); - let mut name = String::with_capacity(member_key.len() + 16); - - let mut push_typed_owner_prefix = |prefix: &str, type_decl_id| { - name.push_str(prefix); - let owner_ty = LuaType::Ref(type_decl_id); - is_method = func.is_method(builder.semantic_model, Some(&owner_ty)); - if is_method { - type_label = "(method) "; - } - name.push(if is_method { ':' } else { '.' }); - }; - - let parent_owner = db - .get_member_index() - .get_current_owner(&owner_member.get_id()); - if let Some(parent_owner) = parent_owner { - match parent_owner { - LuaMemberOwner::Type(type_decl_id) => { - let prefix = infer_prefix_global_name(builder.semantic_model, owner_member) - .unwrap_or_else(|| type_decl_id.get_simple_name()); - push_typed_owner_prefix(prefix, type_decl_id.clone()); - } - LuaMemberOwner::Element(element_id) => { - if let Some(LuaType::Ref(type_decl_id) | LuaType::Def(type_decl_id)) = - extract_parent_type_from_element(builder.semantic_model, element_id) - { - push_typed_owner_prefix( - type_decl_id.get_simple_name(), - type_decl_id.clone(), - ); - } else if let Some(owner_name) = - extract_owner_name_from_element(builder.semantic_model, element_id) - { - name.push_str(&owner_name); - if is_method { - type_label = "(method) "; - } - name.push(if is_method { ':' } else { '.' }); - } - } - _ => {} - } - } - - name.push_str(&member_key); - name - } else { - func_name.to_string() - }; - - let is_vararg = func.is_variadic(); - let last_idx = func.get_params().len().saturating_sub(1); - - let params = func - .get_params() - .iter() - .enumerate() - .map(|(index, param)| { - let mut name = param.0.clone(); - if is_vararg && index == last_idx && name != "..." { - name = format!("...{}", name); - } - if index == 0 && is_method && !func.is_colon_define() { - "".to_string() - } else if let Some(ty) = ¶m.1 { - format!("{}: {}", name, humanize_type(db, ty, RenderLevel::Simple)) - } else { - name.to_string() - } - }) - .filter(|s| !s.is_empty()) - .collect::>(); - - let ret_detail = ret_detail.unwrap_or_else(|| build_function_returns(builder, return_docs)); - format_function_type( - type_label, - async_label, - full_name, - params.join(", "), - ret_detail, + token: &LuaSyntaxToken, +) -> Option { + let table_expr = get_table_field_expr(token)?; + infer_table_should_be( + db, + &mut builder.semantic_model.get_cache().borrow_mut(), + table_expr, ) + .ok() } -fn instantiate_call_return_overloads( - builder: &HoverBuilder, - db: &DbIndex, - call_expr: &emmylua_parser::LuaCallExpr, - signature: &LuaSignature, -) -> Vec { - let mut cache = builder.semantic_model.get_cache().borrow_mut(); - - signature - .return_overloads - .iter() - .map(|row| { - let row_return_type = match row.type_refs.len() { - 0 => LuaType::Nil, - 1 => row.type_refs[0].clone(), - _ => LuaType::Variadic(VariadicType::Multi(row.type_refs.clone()).into()), - }; - let row_function = LuaFunctionType::new( - signature.async_state, - signature.is_colon_define, - signature.is_vararg, - signature.get_type_params(), - row_return_type, - ); - let instantiated_row = - instantiate_func_generic(db, &mut cache, &row_function, call_expr.clone()) - .ok() - .map(|func| match func.get_ret() { - LuaType::Variadic(variadic) => match variadic.as_ref() { - VariadicType::Multi(types) => types.clone(), - VariadicType::Base(_) => vec![LuaType::Variadic(variadic.clone())], - }, - typ => vec![typ.clone()], - }) - .unwrap_or_else(|| row.type_refs.clone()); - - LuaDocReturnOverloadInfo { - type_refs: instantiated_row, - description: row.description.clone(), - } - }) - .collect() -} - -fn convert_function_return_to_docs(func: &LuaFunctionType) -> Vec { - match func.get_ret() { - LuaType::Variadic(variadic) => match variadic.as_ref() { - VariadicType::Base(base) => vec![LuaDocReturnInfo { - name: None, - type_ref: base.clone(), - description: None, - attributes: None, - }], - VariadicType::Multi(types) => types - .iter() - .map(|ty| LuaDocReturnInfo { - name: None, - type_ref: ty.clone(), - description: None, - attributes: None, - }) - .collect(), - }, - _ => vec![LuaDocReturnInfo { - name: None, - type_ref: func.get_ret().clone(), - description: None, - attributes: None, - }], +/// 从 semantic_decl 中提取 function_member +pub(super) fn extract_function_member<'a>( + db: &'a DbIndex, + semantic_decl: &LuaSemanticDeclId, +) -> Option<&'a LuaMember> { + match semantic_decl { + LuaSemanticDeclId::Member(id) => db.get_member_index().get_member(id), + _ => None, } } -fn format_function_type( - type_label: &str, - async_label: &str, - full_name: String, - params: String, - rets: String, -) -> String { - let prefix = if type_label.starts_with("function") { - format!("{}{}", async_label, type_label) - } else { - format!("{}{}", type_label, async_label) - }; - format!("{}{}({}){}", prefix, full_name, params, rets) -} - -fn get_function_description( +pub(super) fn get_function_description( builder: &mut HoverBuilder, db: &DbIndex, semantic_decl_id: &LuaSemanticDeclId, @@ -574,146 +114,7 @@ fn get_function_description( description } -fn build_function_returns( - builder: &mut HoverBuilder, - return_docs: Vec, -) -> String { - let mut result = String::new(); - // 如果不是补全且存在名称, 我们需要多行显示 - let has_multiline = !builder.is_completion - && return_docs - .iter() - .any(|return_info| return_info.name.is_some()); - - for (i, return_info) in return_docs.iter().enumerate() { - if i == 0 && return_info.type_ref.is_nil() { - continue; - } - let type_text = build_function_return_type(builder, return_info, i); - - if has_multiline { - let prefix = if i == 0 { - result.push('\n'); - "-> ".to_string() - } else { - format!("{}. ", i + 1) - }; - let name = return_info.name.clone().unwrap_or_default(); - - result.push_str(&format!( - " {}{}{}\n", - prefix, - if !name.is_empty() { - format!("{}: ", name) - } else { - "".to_string() - }, - type_text, - )); - } else if i == 0 { - result.push_str(&format!(" -> {}", type_text)); - } else { - result.push_str(&format!(", {}", type_text)); - } - } - - result -} - -fn build_function_return_overload_rows( - builder: &mut HoverBuilder, - return_overloads: &[LuaDocReturnOverloadInfo], -) -> String { - let mut result = String::new(); - - for (row_idx, row) in return_overloads.iter().enumerate() { - if row.type_refs.is_empty() { - continue; - } - - let row_text = row - .type_refs - .iter() - .enumerate() - .map(|(i, typ)| build_return_type_text(builder, typ, i)) - .collect::>() - .join(", "); - - if row_idx == 0 { - result.push('\n'); - } - result.push_str(&format!(" -> {}\n", row_text)); - } - - result -} - -fn build_function_return_type( - builder: &mut HoverBuilder, - ret_info: &LuaDocReturnInfo, - i: usize, -) -> String { - build_return_type_text(builder, &ret_info.type_ref, i) -} - -fn build_return_type_text(builder: &mut HoverBuilder, typ: &LuaType, i: usize) -> String { - let type_expansion_count = builder.get_type_expansion_count(); - // 在这个过程中可能会设置`type_expansion` - let type_text = hover_humanize_type(builder, typ, Some(RenderLevel::Simple)); - if builder.get_type_expansion_count() > type_expansion_count { - // 重新设置`type_expansion` - if let Some(pop_type_expansion) = - builder.pop_type_expansion(type_expansion_count, builder.get_type_expansion_count()) - { - let mut new_type_expansion = format!("return #{}", i + 1); - let mut seen = HashSet::new(); - for type_expansion in pop_type_expansion { - for line in type_expansion.lines().skip(1) { - if seen.insert(line.to_string()) { - new_type_expansion.push('\n'); - new_type_expansion.push_str(line); - } - } - } - builder.add_type_expansion(new_type_expansion); - } - }; - type_text -} - -// 函数是否为类字段, 任意一个为类字段我们都认为全部为类字段 -fn function_member_is_field(db: &DbIndex, semantic_decls: &[(LuaSemanticDeclId, LuaType)]) -> bool { - semantic_decls.iter().any(|(semantic_decl, _)| { - if let LuaSemanticDeclId::Member(id) = semantic_decl { - let member = db.get_member_index().get_member(id); - member.is_some() && member.unwrap().is_field() - } else { - false - } - }) -} - -fn hover_instantiate_function_type( - db: &DbIndex, - typ: &LuaType, - substitutor: &TypeSubstitutor, -) -> Option> { - if !typ.contain_tpl() { - return None; - } - match typ { - LuaType::DocFunction(f) => { - if let LuaType::DocFunction(f) = instantiate_doc_function(db, f, substitutor) { - Some(f) - } else { - None - } - } - _ => None, - } -} - -pub fn is_function(typ: &LuaType) -> bool { +pub(crate) fn is_function(typ: &LuaType) -> bool { typ.is_function() || match &typ { LuaType::Union(union) => union diff --git a/crates/emmylua_ls/src/handlers/hover/function/render.rs b/crates/emmylua_ls/src/handlers/hover/function/render.rs new file mode 100644 index 000000000..c57cbaf87 --- /dev/null +++ b/crates/emmylua_ls/src/handlers/hover/function/render.rs @@ -0,0 +1,404 @@ +use std::{collections::HashSet, fmt::Write, sync::Arc}; + +use emmylua_code_analysis::{ + AsyncState, DbIndex, LuaDocReturnInfo, LuaFunctionType, LuaMember, LuaMemberOwner, + LuaSemanticDeclId, LuaSignature, LuaType, RenderLevel, VariadicType, humanize_type, + infer_call_generic, +}; +use emmylua_parser::LuaCallExpr; + +use crate::handlers::hover::{ + HoverBuilder, + humanize_types::{ + extract_owner_name_from_element, extract_parent_type_from_element, hover_humanize_type, + }, + infer_prefix_global_name, +}; + +/// 函数签名渲染上下文,封装 `hover_doc_function_type` 所需的全部参数 +pub(super) struct FunctionRenderContext<'a> { + pub func: &'a LuaFunctionType, + pub semantic_decl: &'a LuaSemanticDeclId, + pub owner_member: Option<&'a LuaMember>, + pub return_docs: Vec, +} + +/// 根据函数类型分派渲染 +pub(super) fn process_function_type( + builder: &mut HoverBuilder, + db: &DbIndex, + typ: &LuaType, + semantic_decl: &LuaSemanticDeclId, + function_member: Option<&LuaMember>, + call_expr: Option<&LuaCallExpr>, +) -> Option> { + match typ { + LuaType::DocFunction(lua_func) => { + let lua_func = instantiate_function_for_call(builder, db, lua_func, call_expr); + let ctx = FunctionRenderContext { + func: lua_func.as_ref(), + semantic_decl, + owner_member: function_member, + return_docs: convert_function_return_to_docs(lua_func.as_ref()), + }; + let content = render_function(builder, db, ctx)?; + Some(vec![content]) + } + LuaType::Signature(signature_id) => { + let signature = db.get_signature_index().get(&signature_id)?; + let fake_doc_function = signature.to_doc_func_type(); + let mut contents = Vec::with_capacity(signature.overloads.len() + 1); + for (i, overload) in std::iter::once(&fake_doc_function) + .chain(signature.overloads.iter()) + .enumerate() + { + let overload = instantiate_function_for_call(builder, db, overload, call_expr); + let return_docs = signature_return_docs(signature, i, overload.as_ref()); + + let ctx = FunctionRenderContext { + func: overload.as_ref(), + semantic_decl, + owner_member: function_member, + return_docs, + }; + contents.push(render_function(builder, db, ctx)?); + } + Some(contents) + } + LuaType::Union(union) => { + let mut contents = Vec::new(); + for typ in union.into_vec() { + if let Some(content) = process_function_type( + builder, + db, + &typ, + semantic_decl, + function_member, + call_expr, + ) { + contents.extend(content); + } + } + Some(contents) + } + _ => None, + } +} + +fn instantiate_function_for_call( + builder: &HoverBuilder, + db: &DbIndex, + func: &Arc, + call_expr: Option<&LuaCallExpr>, +) -> Arc { + let Some(call_expr) = call_expr else { + return func.clone(); + }; + if !func.contain_tpl() { + return func.clone(); + } + + infer_call_generic( + db, + &mut builder.semantic_model.get_cache().borrow_mut(), + func.as_ref(), + call_expr.clone(), + ) + .map(Arc::new) + .unwrap_or_else(|_| func.clone()) +} + +fn signature_return_docs( + signature: &LuaSignature, + index: usize, + func: &LuaFunctionType, +) -> Vec { + let mut return_docs = convert_function_return_to_docs(func); + if index == 0 && !signature.return_docs.is_empty() { + for (return_doc, declared_doc) in return_docs.iter_mut().zip(&signature.return_docs) { + return_doc.name = declared_doc.name.clone(); + return_doc.description = declared_doc.description.clone(); + return_doc.attributes = declared_doc.attributes.clone(); + } + } + + return_docs +} + +/// 渲染单个函数签名的完整 hover 文本 +pub(super) fn render_function( + builder: &mut HoverBuilder, + db: &DbIndex, + ctx: FunctionRenderContext, +) -> Option { + let FunctionRenderContext { + func, + semantic_decl, + owner_member, + return_docs, + } = ctx; + + let async_label = match func.get_async_state() { + AsyncState::Async => "async ", + AsyncState::Sync => "sync ", + _ => "", + }; + let mut is_method = func.is_colon_define(); + let mut type_label = if owner_member.is_none() && semantic_decl_is_local(db, semantic_decl) { + "local function " + } else { + "function " + }; + + // 有可能来源于类. 例如: `local add = class.add`, `add()`应被视为类方法 + let full_name = if let Some(owner_member) = owner_member { + if semantic_decl_is_field(db, semantic_decl, owner_member) { + type_label = "(field) "; + } + + let member_key = owner_member.get_key().to_path(); + let mut name = String::with_capacity(member_key.len() + 16); + + let mut push_typed_owner_prefix = |prefix: &str, type_decl_id| { + name.push_str(prefix); + let owner_ty = LuaType::Ref(type_decl_id); + is_method = func.is_method(builder.semantic_model, Some(&owner_ty)); + if is_method { + type_label = "(method) "; + } + name.push(if is_method { ':' } else { '.' }); + }; + + let parent_owner = db + .get_member_index() + .get_current_owner(&owner_member.get_id()); + if let Some(parent_owner) = parent_owner { + match parent_owner { + LuaMemberOwner::Type(type_decl_id) => { + let prefix = infer_prefix_global_name(builder.semantic_model, owner_member) + .unwrap_or_else(|| type_decl_id.get_simple_name()); + push_typed_owner_prefix(prefix, type_decl_id.clone()); + } + LuaMemberOwner::Element(element_id) => { + if let Some(LuaType::Ref(type_decl_id) | LuaType::Def(type_decl_id)) = + extract_parent_type_from_element(builder.semantic_model, element_id) + { + push_typed_owner_prefix( + type_decl_id.get_simple_name(), + type_decl_id.clone(), + ); + } else if let Some(owner_name) = + extract_owner_name_from_element(builder.semantic_model, element_id) + { + name.push_str(&owner_name); + if is_method { + type_label = "(method) "; + } + name.push(if is_method { ':' } else { '.' }); + } + } + _ => {} + } + } + + name.push_str(&member_key); + name + } else { + semantic_decl_function_name(db, semantic_decl)? + }; + + let is_vararg = func.is_variadic(); + let last_idx = func.get_params().len().saturating_sub(1); + + let params = func + .get_params() + .iter() + .enumerate() + .map(|(index, param)| { + let mut name = param.0.clone(); + if is_vararg && index == last_idx && name != "..." { + name = format!("...{}", name); + } + if index == 0 && is_method && !func.is_colon_define() { + "".to_string() + } else if let Some(ty) = ¶m.1 { + format!("{}: {}", name, humanize_type(db, ty, RenderLevel::Simple)) + } else { + name.to_string() + } + }) + .filter(|s| !s.is_empty()) + .collect::>(); + + let ret_detail = build_function_returns(builder, return_docs); + Some(format_function_type( + type_label, + async_label, + full_name, + params.join(", "), + ret_detail, + )) +} + +fn semantic_decl_is_field( + db: &DbIndex, + semantic_decl: &LuaSemanticDeclId, + owner_member: &LuaMember, +) -> bool { + if let LuaSemanticDeclId::Member(member_id) = semantic_decl { + if db + .get_member_index() + .get_member(member_id) + .is_some_and(|member| member.is_field()) + { + return true; + } + } + + let member_index = db.get_member_index(); + let Some(owner) = member_index.get_current_owner(&owner_member.get_id()) else { + return false; + }; + member_index.get_members(owner).is_some_and(|members| { + members + .iter() + .any(|member| member.get_key() == owner_member.get_key() && member.is_field()) + }) +} + +fn semantic_decl_is_local(db: &DbIndex, semantic_decl: &LuaSemanticDeclId) -> bool { + match semantic_decl { + LuaSemanticDeclId::LuaDecl(decl_id) => db + .get_decl_index() + .get_decl(decl_id) + .is_some_and(|decl| decl.is_local()), + _ => false, + } +} + +fn semantic_decl_function_name(db: &DbIndex, semantic_decl: &LuaSemanticDeclId) -> Option { + match semantic_decl { + LuaSemanticDeclId::LuaDecl(decl_id) => Some( + db.get_decl_index() + .get_decl(decl_id)? + .get_name() + .to_string(), + ), + LuaSemanticDeclId::Member(member_id) => Some( + db.get_member_index() + .get_member(member_id)? + .get_key() + .to_path(), + ), + _ => None, + } +} + +fn format_function_type( + type_label: &str, + async_label: &str, + full_name: String, + params: String, + rets: String, +) -> String { + let prefix = if type_label.starts_with("function") { + format!("{}{}", async_label, type_label) + } else { + format!("{}{}", type_label, async_label) + }; + format!("{}{}({}){}", prefix, full_name, params, rets) +} + +pub(super) fn convert_function_return_to_docs(func: &LuaFunctionType) -> Vec { + match func.get_ret() { + LuaType::Variadic(variadic) => match variadic.as_ref() { + VariadicType::Base(base) => vec![LuaDocReturnInfo { + name: None, + type_ref: base.clone(), + description: None, + attributes: None, + }], + VariadicType::Multi(types) => types + .iter() + .map(|ty| LuaDocReturnInfo { + name: None, + type_ref: ty.clone(), + description: None, + attributes: None, + }) + .collect(), + }, + _ => vec![LuaDocReturnInfo { + name: None, + type_ref: func.get_ret().clone(), + description: None, + attributes: None, + }], + } +} + +fn build_function_returns( + builder: &mut HoverBuilder, + return_docs: Vec, +) -> String { + let mut result = String::new(); + // 如果不是补全且存在名称, 我们需要多行显示 + let has_multiline = !builder.is_completion + && return_docs + .iter() + .any(|return_info| return_info.name.is_some()); + + for (i, return_info) in return_docs.iter().enumerate() { + if i == 0 && return_info.type_ref.is_nil() { + continue; + } + let type_text = build_return_type_text(builder, &return_info.type_ref, i); + + if has_multiline { + if i == 0 { + result.push('\n'); + result.push_str(" -> "); + } else { + let _ = write!(result, " {}. ", i + 1); + } + if let Some(name) = return_info.name.as_deref().filter(|name| !name.is_empty()) { + let _ = write!(result, "{}: ", name); + } + result.push_str(&type_text); + result.push('\n'); + } else if i == 0 { + result.push_str(" -> "); + result.push_str(&type_text); + } else { + result.push_str(", "); + result.push_str(&type_text); + } + } + + result +} + +fn build_return_type_text(builder: &mut HoverBuilder, typ: &LuaType, i: usize) -> String { + let type_expansion_count = builder.get_type_expansion_count(); + // 在这个过程中可能会设置`type_expansion` + let type_text = hover_humanize_type(builder, typ, Some(RenderLevel::Simple)); + if builder.get_type_expansion_count() > type_expansion_count { + // 重新设置`type_expansion` + if let Some(pop_type_expansion) = + builder.pop_type_expansion(type_expansion_count, builder.get_type_expansion_count()) + { + let mut new_type_expansion = format!("return #{}", i + 1); + let mut seen = HashSet::new(); + for type_expansion in pop_type_expansion { + for line in type_expansion.lines().skip(1) { + if seen.insert(line.to_string()) { + new_type_expansion.push('\n'); + new_type_expansion.push_str(line); + } + } + } + builder.add_type_expansion(new_type_expansion); + } + }; + type_text +} diff --git a/crates/emmylua_ls/src/handlers/hover/function/table_field.rs b/crates/emmylua_ls/src/handlers/hover/function/table_field.rs new file mode 100644 index 000000000..67c543fd2 --- /dev/null +++ b/crates/emmylua_ls/src/handlers/hover/function/table_field.rs @@ -0,0 +1,104 @@ +use std::collections::HashMap; + +use emmylua_code_analysis::{DbIndex, LuaSemanticDeclId, LuaType, LuaTypeDeclId, TypeSubstitutor}; + +use crate::handlers::hover::{HoverBuilder, HoverDeclContext}; + +use super::{ + define_hover::{HoverFunctionInfo, set_function_info_to_builder}, + extract_function_member, + generic::{instantiate_type_if_needed, owner_type_substitutor, unknown_type_substitutor}, + get_function_description, + render::process_function_type, +}; + +type OwnerSubstitutorCache = HashMap>; + +pub(super) fn build_table_field_hover( + builder: &mut HoverBuilder, + db: &DbIndex, + decl_context: &HoverDeclContext, + parent_table_type: &LuaType, +) -> Option<()> { + let mut function_infos = Vec::new(); + let mut substitutor_cache = OwnerSubstitutorCache::new(); + for decl_info in decl_context.ordered_decl_refs() { + let semantic_decl_id = decl_info.id(); + let typ = resolve_semantic_decl_type( + db, + semantic_decl_id, + decl_info.typ(), + parent_table_type, + &mut substitutor_cache, + ); + let function_member = extract_function_member(db, semantic_decl_id); + + let Some(contents) = + process_function_type(builder, db, &typ, semantic_decl_id, function_member, None) + else { + continue; + }; + if contents.is_empty() { + continue; + } + + let description = get_function_description(builder, db, semantic_decl_id); + if let Some(info) = HoverFunctionInfo::from_contents(contents, description) { + function_infos.push(info); + } + } + + set_function_info_to_builder(builder, &mut function_infos) +} + +fn resolve_semantic_decl_type( + db: &DbIndex, + semantic_decl: &LuaSemanticDeclId, + typ: &LuaType, + parent_table_type: &LuaType, + substitutor_cache: &mut OwnerSubstitutorCache, +) -> LuaType { + if !typ.contain_tpl() { + return typ.clone(); + } + + let Some(owner_type_id) = semantic_decl_owner_type_id(db, semantic_decl) else { + return typ.clone(); + }; + let substitutor = + cached_substitutor_for_owner(db, parent_table_type, owner_type_id, substitutor_cache); + + substitutor + .and_then(|substitutor| instantiate_type_if_needed(db, typ, &substitutor)) + .unwrap_or_else(|| typ.clone()) +} + +fn cached_substitutor_for_owner( + db: &DbIndex, + parent_table_type: &LuaType, + owner_type_id: LuaTypeDeclId, + substitutor_cache: &mut OwnerSubstitutorCache, +) -> Option { + if let Some(substitutor) = substitutor_cache.get(&owner_type_id) { + return substitutor.clone(); + } + + let substitutor = owner_type_substitutor(db, parent_table_type, &owner_type_id) + .or_else(|| unknown_type_substitutor(db, &owner_type_id)); + substitutor_cache.insert(owner_type_id, substitutor.clone()); + substitutor +} + +fn semantic_decl_owner_type_id( + db: &DbIndex, + semantic_decl: &LuaSemanticDeclId, +) -> Option { + match semantic_decl { + LuaSemanticDeclId::Member(id) => db + .get_member_index() + .get_current_owner(id)? + .get_type_id() + .cloned(), + _ => None, + } +} diff --git a/crates/emmylua_ls/src/handlers/hover/hover_builder.rs b/crates/emmylua_ls/src/handlers/hover/hover_builder.rs index d4cd08773..7b3f639a9 100644 --- a/crates/emmylua_ls/src/handlers/hover/hover_builder.rs +++ b/crates/emmylua_ls/src/handlers/hover/hover_builder.rs @@ -1,10 +1,7 @@ use emmylua_code_analysis::{ - GenericTplId, LuaCompilation, LuaMember, LuaMemberOwner, LuaSemanticDeclId, LuaType, - RenderLevel, SemanticModel, TypeSubstitutor, -}; -use emmylua_parser::{ - LuaAstNode, LuaCallExpr, LuaExpr, LuaLocalName, LuaLocalStat, LuaSyntaxKind, LuaSyntaxToken, + LuaMember, LuaMemberOwner, LuaSemanticDeclId, LuaType, RenderLevel, SemanticModel, }; +use emmylua_parser::LuaSyntaxToken; use lsp_types::{Hover, HoverContents, MarkedString, MarkupContent}; use crate::handlers::hover::humanize_types::{ @@ -19,8 +16,8 @@ pub struct HoverBuilder<'a> { pub primary: MarkedString, /// Full path of the class pub location_path: Option, - /// Function overload signatures, with the first being the primary overload - pub signature_overload: Option>, + /// Function overload signatures + pub signature_overload: Option>, /// Annotation descriptions, including function parameters and return values pub annotation_description: Vec, /// 一些类型的完整追加显示, 通常是 @alias @@ -30,17 +27,13 @@ pub struct HoverBuilder<'a> { trigger_token: Option, pub semantic_model: &'a SemanticModel<'a>, - pub compilation: &'a LuaCompilation, pub detail_render_level: RenderLevel, pub is_completion: bool, - // 默认的泛型替换器 - pub substitutor: Option, } impl<'a> HoverBuilder<'a> { pub fn new( - compilation: &'a LuaCompilation, semantic_model: &'a SemanticModel, token: Option, is_completion: bool, @@ -52,14 +45,7 @@ impl<'a> HoverBuilder<'a> { RenderLevel::Detailed }; - let substitutor = if let Some(token) = token.clone() { - infer_substitutor_base_type(semantic_model, token) - } else { - None - }; - Self { - compilation, semantic_model, primary: MarkedString::String("".to_string()), location_path: None, @@ -70,7 +56,6 @@ impl<'a> HoverBuilder<'a> { type_expansion: None, tag_content: None, detail_render_level, - substitutor, } } @@ -98,7 +83,7 @@ impl<'a> HoverBuilder<'a> { } } - pub fn add_signature_overload(&mut self, signature_overload: String) { + pub fn add_signature_overload(&mut self, signature_overload: String, comment: Option) { if signature_overload.is_empty() { return; } @@ -108,10 +93,7 @@ impl<'a> HoverBuilder<'a> { self.signature_overload .as_mut() .unwrap() - .push(MarkedString::from_language_code( - "lua".to_string(), - signature_overload, - )); + .push(HoverSignatureOverload::new(signature_overload, comment)); } pub fn add_type_expansion(&mut self, type_expansion: String) { @@ -237,15 +219,8 @@ impl<'a> HoverBuilder<'a> { let mut expansion = String::new(); if let Some(signature_overload) = &self.signature_overload { expansion.push_str("\n---\n"); - for signature in signature_overload { - match signature { - MarkedString::String(s) => { - expansion.push_str(&format!("\n{}\n", s)); - } - MarkedString::LanguageString(s) => { - expansion.push_str(&format!("\n```{}\n{}\n```\n", s.language, s.value)); - } - } + for overload in signature_overload { + overload.append_markdown(&mut expansion); } } @@ -281,67 +256,64 @@ impl<'a> HoverBuilder<'a> { pub fn get_trigger_token(&self) -> Option { self.trigger_token.clone() } +} + +#[derive(Debug, Clone)] +pub struct HoverSignatureOverload { + pub signature: MarkedString, + pub comment: Option, +} - pub fn get_call_expr(&self) -> Option { - if let Some(token) = self.trigger_token.clone() - && let Some(call_expr) = token.parent()?.parent() - && LuaCallExpr::can_cast(call_expr.kind().into()) - { - return LuaCallExpr::cast(call_expr); +impl HoverSignatureOverload { + fn new(signature: String, comment: Option) -> Self { + Self { + signature: MarkedString::from_language_code("lua".to_string(), signature), + comment: comment.filter(|comment| !comment.trim().is_empty()), } - None } -} -// 推断基础泛型替换器 -fn infer_substitutor_base_type( - semantic_model: &SemanticModel, - trigger_token: LuaSyntaxToken, -) -> Option { - let parent = trigger_token.parent()?; - match parent.kind().into() { - LuaSyntaxKind::LocalName => { - let target_local_name = LuaLocalName::cast(parent.clone())?; - let parent = parent.parent()?; - match parent.kind().into() { - LuaSyntaxKind::LocalStat => { - let local_stat = LuaLocalStat::cast(parent.clone())?; - let local_name_list = local_stat.get_local_name_list().collect::>(); - let value_expr_list = local_stat.get_value_exprs().collect::>(); - - for (index, name) in local_name_list.iter().enumerate() { - if target_local_name == *name { - let value_expr = value_expr_list.get(index)?; - return substitutor_form_expr(semantic_model, value_expr); - } + fn append_markdown(&self, content: &mut String) { + const LIMIT: usize = 80; + let inline_comment = self + .comment + .as_deref() + .filter(|comment| !comment.chars().any(|ch| ch == '\n' || ch == '\r')); + + match &self.signature { + MarkedString::String(s) => { + if let Some(comment) = inline_comment { + if s.chars().count() <= LIMIT { + content.push_str(&format!("\n{} -- {}\n", s, comment)); + } else { + content.push_str(&format!("\n{}\n-- {}\n", s, comment)); + } + } else { + content.push_str(&format!("\n{}\n", s)); + if let Some(comment) = self.comment.as_deref() { + content.push_str(&format!("\n{}\n", comment)); } } - _ => return None, } - } - _ => return None, - } - - None -} - -pub fn substitutor_form_expr( - semantic_model: &SemanticModel, - expr: &LuaExpr, -) -> Option { - if let LuaExpr::IndexExpr(index_expr) = expr { - let prefix_type = semantic_model - .infer_expr(index_expr.get_prefix_expr()?) - .ok()?; - let mut substitutor = TypeSubstitutor::new(); - if let LuaType::Generic(generic) = prefix_type { - for (i, param) in generic.get_params().iter().enumerate() { - substitutor.insert_type(GenericTplId::Type(i as u32), param.clone(), true); + MarkedString::LanguageString(s) => { + if let Some(comment) = inline_comment { + if s.value.chars().count() <= LIMIT { + content.push_str(&format!( + "\n```{}\n{} -- {}\n```\n", + s.language, s.value, comment + )); + } else { + content.push_str(&format!( + "\n```{}\n{}\n-- {}\n```\n", + s.language, s.value, comment + )); + } + } else { + content.push_str(&format!("\n```{}\n{}\n```\n", s.language, s.value)); + if let Some(comment) = self.comment.as_deref() { + content.push_str(&format!("\n{}\n", comment)); + } + } } - return Some(substitutor); - } else { - return None; } } - None } diff --git a/crates/emmylua_ls/src/handlers/hover/humanize_type_decl.rs b/crates/emmylua_ls/src/handlers/hover/humanize_type_decl.rs index 34e1048cf..60291f02d 100644 --- a/crates/emmylua_ls/src/handlers/hover/humanize_type_decl.rs +++ b/crates/emmylua_ls/src/handlers/hover/humanize_type_decl.rs @@ -1,6 +1,8 @@ use emmylua_code_analysis::{ - DbIndex, LuaSemanticDeclId, LuaType, LuaTypeDeclId, RenderLevel, humanize_type, + DbIndex, LuaSemanticDeclId, LuaType, LuaTypeDeclId, RenderLevel, + get_attribute_constructor_params, humanize_type, is_attribute_class, }; +use emmylua_parser::{LuaAstNode, LuaDocAttributeUse, LuaExpr}; use crate::handlers::hover::HoverBuilder; @@ -19,8 +21,8 @@ pub fn build_type_decl_hover( } } else if type_decl.is_enum() { format!("(enum) {}", type_decl.get_name()) - } else if type_decl.is_attribute() { - build_attribute(db, type_decl.get_name(), type_decl.get_attribute_type()) + } else if is_attribute_class(db, &type_decl_id) { + build_attribute(builder, db, type_decl.get_name(), &type_decl_id) } else { let humanize_text = humanize_type( db, @@ -35,16 +37,18 @@ pub fn build_type_decl_hover( Some(()) } -fn build_attribute(db: &DbIndex, attribute_name: &str, attribute_type: Option<&LuaType>) -> String { - let Some(LuaType::DocAttribute(attribute)) = attribute_type else { - return format!("(attribute) {}", attribute_name); - }; - let params = attribute - .get_params() - .iter() +fn build_attribute( + builder: &HoverBuilder, + db: &DbIndex, + attribute_name: &str, + type_decl_id: &LuaTypeDeclId, +) -> String { + let arg_types = get_hover_attribute_arg_types(builder); + let params = get_attribute_constructor_params(db, type_decl_id, &arg_types) + .into_iter() .map(|(name, typ)| match typ { Some(typ) => { - let type_name = humanize_type(db, typ, RenderLevel::Normal); + let type_name = humanize_type(db, &typ, RenderLevel::Normal); format!("{}: {}", name, type_name) } None => name.to_string(), @@ -52,8 +56,37 @@ fn build_attribute(db: &DbIndex, attribute_name: &str, attribute_type: Option<&L .collect::>(); if params.is_empty() { - format!("(attribute) {}", attribute_name) + format!("(class) {}", attribute_name) } else { - format!("(attribute) {}({})", attribute_name, params.join(", ")) + format!("(class) {}({})", attribute_name, params.join(", ")) } } + +fn get_hover_attribute_arg_types(builder: &HoverBuilder) -> Vec { + let Some(token) = builder.get_trigger_token() else { + return Vec::new(); + }; + + let mut node = token.parent(); + while let Some(current) = node { + if let Some(attribute_use) = LuaDocAttributeUse::cast(current.clone()) { + return attribute_use + .get_arg_list() + .map(|arg_list| { + arg_list + .get_args() + .map(|arg| { + builder + .semantic_model + .infer_expr(LuaExpr::LiteralExpr(arg)) + .unwrap_or(LuaType::Unknown) + }) + .collect() + }) + .unwrap_or_default(); + } + node = current.parent(); + } + + Vec::new() +} diff --git a/crates/emmylua_ls/src/handlers/hover/mod.rs b/crates/emmylua_ls/src/handlers/hover/mod.rs index c79ea8621..0ab483452 100644 --- a/crates/emmylua_ls/src/handlers/hover/mod.rs +++ b/crates/emmylua_ls/src/handlers/hover/mod.rs @@ -1,5 +1,5 @@ mod build_hover; -mod find_origin; +mod decl_context; mod function; mod hover_builder; mod humanize_type_decl; @@ -11,10 +11,10 @@ use crate::context::ServerContextSnapshot; use crate::util::{find_ref_at, resolve_ref_single}; pub use build_hover::build_hover_content_for_completion; use build_hover::build_semantic_info_hover; +pub(crate) use decl_context::{HoverDeclContext, HoverDeclInfo}; use emmylua_code_analysis::{EmmyLuaAnalysis, FileId, WorkspaceId}; use emmylua_parser::{LuaAstNode, LuaDocDescription, LuaTokenKind}; use emmylua_parser_desc::parse_ref_target; -pub use find_origin::{find_all_same_named_members, find_member_origin_owner}; pub use hover_builder::HoverBuilder; pub use humanize_types::infer_prefix_global_name; use keyword_hover::{hover_keyword, is_keyword}; @@ -101,7 +101,6 @@ pub fn hover(analysis: &EmmyLuaAnalysis, file_id: FileId, position: Position) -> let semantic_info = resolve_ref_single(db, file_id, &path, &detail)?; build_semantic_info_hover( - &analysis.compilation, &semantic_model, db, &document, @@ -120,7 +119,6 @@ pub fn hover(analysis: &EmmyLuaAnalysis, file_id: FileId, position: Position) -> let semantic_info = resolve_ref_single(db, file_id, &path, &doc_see)?; build_semantic_info_hover( - &analysis.compilation, &semantic_model, db, &document, @@ -135,15 +133,7 @@ pub fn hover(analysis: &EmmyLuaAnalysis, file_id: FileId, position: Position) -> let document = semantic_model.get_document(); let range = token.text_range(); - build_semantic_info_hover( - &analysis.compilation, - &semantic_model, - db, - &document, - token, - semantic_info, - range, - ) + build_semantic_info_hover(&semantic_model, db, &document, token, semantic_info, range) } } } diff --git a/crates/emmylua_ls/src/handlers/implementation/implementation_searcher.rs b/crates/emmylua_ls/src/handlers/implementation/implementation_searcher.rs index 1eb3b1052..9ec787b3a 100644 --- a/crates/emmylua_ls/src/handlers/implementation/implementation_searcher.rs +++ b/crates/emmylua_ls/src/handlers/implementation/implementation_searcher.rs @@ -10,7 +10,7 @@ use emmylua_parser::{ }; use lsp_types::Location; -use crate::handlers::hover::find_member_origin_owner; +use crate::handlers::common::find_member_origin_owner; pub fn search_implementations( semantic_model: &SemanticModel, @@ -57,7 +57,7 @@ pub fn search_member_implementations( let mut semantic_cache = HashMap::new(); - let property_owner = find_member_origin_owner(compilation, semantic_model, member_id) + let property_owner = find_member_origin_owner(semantic_model, member_id) .unwrap_or(LuaSemanticDeclId::Member(member_id)); for in_filed_syntax_id in index_references { let semantic_model = diff --git a/crates/emmylua_ls/src/handlers/mod.rs b/crates/emmylua_ls/src/handlers/mod.rs index 1042b0559..3f54e13ba 100644 --- a/crates/emmylua_ls/src/handlers/mod.rs +++ b/crates/emmylua_ls/src/handlers/mod.rs @@ -2,6 +2,7 @@ mod call_hierarchy; mod code_actions; mod code_lens; mod command; +mod common; mod completion; mod configuration; mod definition; diff --git a/crates/emmylua_ls/src/handlers/rename/rename_member.rs b/crates/emmylua_ls/src/handlers/rename/rename_member.rs index 6d0103b75..fef09e4f5 100644 --- a/crates/emmylua_ls/src/handlers/rename/rename_member.rs +++ b/crates/emmylua_ls/src/handlers/rename/rename_member.rs @@ -9,7 +9,7 @@ use emmylua_parser::{ }; use lsp_types::Uri; -use crate::handlers::hover::find_member_origin_owner; +use crate::handlers::common::find_member_origin_owner; #[allow(clippy::mutable_key_type)] pub fn rename_member_references( @@ -30,7 +30,7 @@ pub fn rename_member_references( .get_reference_index() .get_index_references(key)?; - let origin_property_owner = find_member_origin_owner(compilation, semantic_model, member_id) + let origin_property_owner = find_member_origin_owner(semantic_model, member_id) .unwrap_or(LuaSemanticDeclId::Member(member_id)); let property_owner = LuaSemanticDeclId::Member(member_id); let mut semantic_cache = HashMap::new(); diff --git a/crates/emmylua_ls/src/handlers/semantic_token/build_semantic_tokens.rs b/crates/emmylua_ls/src/handlers/semantic_token/build_semantic_tokens.rs index 8ea744c64..a12302748 100644 --- a/crates/emmylua_ls/src/handlers/semantic_token/build_semantic_tokens.rs +++ b/crates/emmylua_ls/src/handlers/semantic_token/build_semantic_tokens.rs @@ -11,9 +11,9 @@ use emmylua_code_analysis::{ }; use emmylua_parser::{ LuaAst, LuaAstNode, LuaAstToken, LuaCallArgList, LuaCallExpr, LuaComment, LuaDocFieldKey, - LuaDocGenericDecl, LuaDocGenericDeclList, LuaDocObjectFieldKey, LuaDocType, LuaExpr, - LuaGeneralToken, LuaKind, LuaLiteralToken, LuaNameToken, LuaSyntaxKind, LuaSyntaxNode, - LuaSyntaxToken, LuaTokenKind, LuaVarExpr, + LuaDocGenericDecl, LuaDocGenericDeclList, LuaDocObjectFieldKey, LuaExpr, LuaGeneralToken, + LuaKind, LuaLiteralToken, LuaNameToken, LuaSyntaxKind, LuaSyntaxNode, LuaSyntaxToken, + LuaTokenKind, LuaVarExpr, }; use emmylua_parser_desc::{CodeBlockHighlightKind, DescItem, DescItemKind}; use lsp_types::SemanticToken; @@ -202,7 +202,6 @@ fn build_tokens_semantic_token( | LuaTokenKind::TkTagReturnCast | LuaTokenKind::TkTagReturnOverload | LuaTokenKind::TkLanguage - | LuaTokenKind::TkTagAttribute | LuaTokenKind::TKTagSchema => { builder.push_with_modifier( token, @@ -816,22 +815,6 @@ fn build_node_semantic_token( } } } - LuaAst::LuaDocTagAttribute(tag_attribute) => { - if let Some(name) = tag_attribute.get_name_token() { - builder.push_with_modifier( - name.syntax(), - SemanticTokenTypeKind::Type, - SemanticTokenModifierKind::DECLARATION, - ); - } - if let Some(LuaDocType::Attribute(attribute)) = tag_attribute.get_type() { - for param in attribute.get_params() { - if let Some(name) = param.get_name_token() { - builder.push(name.syntax(), SemanticTokenTypeKind::Parameter); - } - } - } - } LuaAst::LuaDocInferType(infer_type) => { // 推断出的泛型定义 if let Some(gen_decl) = infer_type.get_generic_decl() { diff --git a/crates/emmylua_ls/src/handlers/signature_helper/build_signature_helper.rs b/crates/emmylua_ls/src/handlers/signature_helper/build_signature_helper.rs index 4e5cef58f..5dccd83d2 100644 --- a/crates/emmylua_ls/src/handlers/signature_helper/build_signature_helper.rs +++ b/crates/emmylua_ls/src/handlers/signature_helper/build_signature_helper.rs @@ -1,7 +1,7 @@ use emmylua_code_analysis::{ - DbIndex, InFiled, LuaCompilation, LuaFunctionType, LuaGenericType, LuaInstanceType, - LuaOperatorMetaMethod, LuaOperatorOwner, LuaSemanticDeclId, LuaSignatureId, LuaType, - LuaTypeDeclId, RenderLevel, SemanticModel, TypeSubstitutor, + DbIndex, InFiled, LuaFunctionType, LuaGenericType, LuaInstanceType, LuaOperatorMetaMethod, + LuaOperatorOwner, LuaSemanticDeclId, LuaSignatureId, LuaType, LuaTypeDeclId, RenderLevel, + SemanticModel, TypeSubstitutor, }; use emmylua_parser::{LuaAstNode, LuaCallExpr, LuaSyntaxToken, LuaTokenKind}; use lsp_types::{ @@ -16,13 +16,12 @@ use super::signature_helper_builder::SignatureHelperBuilder; pub fn build_signature_helper( semantic_model: &SemanticModel, - compilation: &LuaCompilation, call_expr: LuaCallExpr, token: LuaSyntaxToken, ) -> Option { let prefix_expr = call_expr.get_prefix_expr()?; let prefix_expr_type = semantic_model.infer_expr(prefix_expr.clone()).ok()?; - let builder = SignatureHelperBuilder::new(compilation, semantic_model, call_expr.clone()); + let builder = SignatureHelperBuilder::new(semantic_model, call_expr.clone()); let colon_call = call_expr.is_colon_call(); let current_idx = get_current_param_index(&call_expr, &token)?; let help = match prefix_expr_type { diff --git a/crates/emmylua_ls/src/handlers/signature_helper/mod.rs b/crates/emmylua_ls/src/handlers/signature_helper/mod.rs index 868ac93c6..5799238b4 100644 --- a/crates/emmylua_ls/src/handlers/signature_helper/mod.rs +++ b/crates/emmylua_ls/src/handlers/signature_helper/mod.rs @@ -63,7 +63,7 @@ pub fn signature_help( match node.kind().into() { LuaSyntaxKind::CallArgList => { let call_expr = LuaCallExpr::cast(node.parent()?)?; - build_signature_helper(&semantic_model, &analysis.compilation, call_expr, token) + build_signature_helper(&semantic_model, call_expr, token) } // todo LuaSyntaxKind::TypeGeneric | LuaSyntaxKind::DocTypeList => None, @@ -90,7 +90,7 @@ pub fn signature_help( match node.kind().into() { LuaSyntaxKind::CallArgList => { let call_expr = LuaCallExpr::cast(node.parent()?)?; - build_signature_helper(&semantic_model, &analysis.compilation, call_expr, token) + build_signature_helper(&semantic_model, call_expr, token) } // todo LuaSyntaxKind::TypeGeneric | LuaSyntaxKind::DocTypeList => None, diff --git a/crates/emmylua_ls/src/handlers/signature_helper/signature_helper_builder.rs b/crates/emmylua_ls/src/handlers/signature_helper/signature_helper_builder.rs index b508f9748..4bf26e501 100644 --- a/crates/emmylua_ls/src/handlers/signature_helper/signature_helper_builder.rs +++ b/crates/emmylua_ls/src/handlers/signature_helper/signature_helper_builder.rs @@ -1,18 +1,17 @@ use emmylua_code_analysis::{ - LuaCompilation, LuaMemberOwner, LuaSemanticDeclId, LuaType, SemanticDeclLevel, SemanticModel, + LuaMemberOwner, LuaSemanticDeclId, LuaType, SemanticDeclLevel, SemanticModel, }; use emmylua_parser::{LuaAstNode, LuaCallExpr, LuaExpr}; use lsp_types::{Documentation, MarkupContent, MarkupKind, ParameterInformation, ParameterLabel}; use rowan::NodeOrToken; -use crate::handlers::hover::{find_member_origin_owner, infer_prefix_global_name}; +use crate::handlers::{common::find_member_origin_owner, hover::infer_prefix_global_name}; use super::build_signature_helper::{build_function_label, generate_param_label}; #[derive(Debug)] pub struct SignatureHelperBuilder<'a> { pub semantic_model: &'a SemanticModel<'a>, - pub compilation: &'a LuaCompilation, pub call_expr: LuaCallExpr, pub prefix_name: Option, @@ -24,13 +23,8 @@ pub struct SignatureHelperBuilder<'a> { } impl<'a> SignatureHelperBuilder<'a> { - pub fn new( - compilation: &'a LuaCompilation, - semantic_model: &'a SemanticModel<'a>, - call_expr: LuaCallExpr, - ) -> Self { + pub fn new(semantic_model: &'a SemanticModel<'a>, call_expr: LuaCallExpr) -> Self { let mut builder = Self { - compilation, semantic_model, call_expr, prefix_name: None, @@ -72,8 +66,7 @@ impl<'a> SignatureHelperBuilder<'a> { // 推断为来源 semantic_decl = match semantic_decl { Some(LuaSemanticDeclId::Member(member_id)) => { - find_member_origin_owner(self.compilation, semantic_model, member_id) - .or(semantic_decl) + find_member_origin_owner(semantic_model, member_id).or(semantic_decl) } Some(LuaSemanticDeclId::LuaDecl(_)) => semantic_decl, _ => None, diff --git a/crates/emmylua_ls/src/handlers/test/completion_resolve_test.rs b/crates/emmylua_ls/src/handlers/test/completion_resolve_test.rs index 0516bb726..fe4a445b0 100644 --- a/crates/emmylua_ls/src/handlers/test/completion_resolve_test.rs +++ b/crates/emmylua_ls/src/handlers/test/completion_resolve_test.rs @@ -32,6 +32,7 @@ mod tests { )); Ok(()) } + #[gtest] fn test_2() -> Result<()> { let mut ws = ProviderVirtualWorkspace::new(); diff --git a/crates/emmylua_ls/src/handlers/test/completion_test.rs b/crates/emmylua_ls/src/handlers/test/completion_test.rs index 5ee0ad588..8b180d7f8 100644 --- a/crates/emmylua_ls/src/handlers/test/completion_test.rs +++ b/crates/emmylua_ls/src/handlers/test/completion_test.rs @@ -1194,7 +1194,13 @@ mod tests { #[gtest] fn test_index_key_alias() -> Result<()> { let mut ws = ProviderVirtualWorkspace::new(); - ws.def(" ---@attribute index_alias(name: string)"); + ws.def( + r#" + ---@class Attribute + ---@class index_alias: Attribute + ---@overload fun(name: string) + "#, + ); check!(ws.check_completion( r#" local export = { @@ -2322,8 +2328,6 @@ mod tests { r#" ---@alias std.RawGet unknown - ---@alias std.ConstTpl unknown - ---@generic T, K extends keyof T ---@param object T ---@param key K diff --git a/crates/emmylua_ls/src/handlers/test/hover_function_test.rs b/crates/emmylua_ls/src/handlers/test/hover_function_test.rs index 790c76784..8287d144f 100644 --- a/crates/emmylua_ls/src/handlers/test/hover_function_test.rs +++ b/crates/emmylua_ls/src/handlers/test/hover_function_test.rs @@ -122,14 +122,14 @@ mod tests { local event = test3.event "#, VirtualHoverResult { - value: "```lua\n(method) Test3:event(event: \"B\", key: string)\n```\n\n  in class `Hover.Test3`\n\n---\n\n---\n\n```lua\n(method) Test3:event(event: \"A\", key: string)\n```".to_string(), + value: "```lua\n(method) Test3:event(event: \"A\", key: string) (+1 overloads)\n```\n\n  in class `Hover.Test3`\n\n---\n\n---\n\n```lua\n(method) Test3:event(event: \"B\", key: string)\n```".to_string(), }, )); Ok(()) } #[gtest] - fn test_union_function() -> Result<()> { + fn test_mixed_class_field_and_real_definition() -> Result<()> { let mut ws = ProviderVirtualWorkspace::new(); check!(ws.check_hover( r#" @@ -149,10 +149,9 @@ mod tests { ---@class (partial) GameA ---@field event fun(self: self, event: "游戏-初始化"): Trigger ---@field event fun(self: self, event: "游戏-追帧完成"): Trigger - ---@field event fun(self: self, event: "游戏-逻辑不同步"): Trigger "#, VirtualHoverResult { - value: "```lua\n(method) GameA:event(event_type: EventTypeA, ...: any) -> Trigger\n```\n\n---\n\n注册引擎事件\n\n---\n\n```lua\n(method) GameA:event(event: \"游戏-初始化\") -> Trigger\n```\n\n```lua\n(method) GameA:event(event: \"游戏-追帧完成\") -> Trigger\n```\n\n```lua\n(method) GameA:event(event: \"游戏-逻辑不同步\") -> Trigger\n```".to_string(), + value: "```lua\n(method) GameA:event(event_type: EventTypeA, ...: any) -> Trigger (+2 overloads)\n```\n\n---\n\n注册引擎事件\n\n---\n\n```lua\n(method) GameA:event(event: \"游戏-初始化\") -> Trigger\n```\n\n```lua\n(method) GameA:event(event: \"游戏-追帧完成\") -> Trigger\n```".to_string(), }, )); Ok(()) @@ -192,7 +191,7 @@ mod tests { local alias = parse "#, VirtualHoverResult { - value: "```lua\nlocal function parse()\n -> true, integer\n -> false, string\n\n```" + value: "```lua\nlocal function parse() -> (true|false), (string|integer)\n```" .to_string(), }, ) @@ -213,7 +212,7 @@ mod tests { local alias = parse "#, VirtualHoverResult { - value: "```lua\nlocal function parse()\n -> true, integer\n -> false, string\n\n```\n\n---\n\n@*return_overload* #1 — success\n\n@*return_overload* #2 — failed".to_string(), + value: "```lua\nlocal function parse() -> (true|false), (string|integer)\n```\n\n---\n\n@*return_overload* #1 — success\n\n@*return_overload* #2 — failed".to_string(), }, )); Ok(()) @@ -221,6 +220,57 @@ mod tests { #[gtest] fn test_return_overload_call_hover() -> Result<()> { + let mut ws = ProviderVirtualWorkspace::new(); + check!( + ws.check_hover( + r#" + ---@class B + local B + + ---@generic T + ---@param x T + ---@return_overload true, T + ---@return_overload false, string + local function parse(x) + end + + parse(B) + "#, + VirtualHoverResult { + value: "```lua\nlocal function parse(x: B) -> (true|false), (B|string)\n```" + .to_string(), + }, + ) + ); + Ok(()) + } + + #[gtest] + fn test_return_overload_hover_short_row_keeps_nil() -> Result<()> { + let mut ws = ProviderVirtualWorkspace::new(); + check!( + ws.check_hover( + r#" + ---@param ok boolean + ---@return_overload true, integer + ---@return_overload false + local function maybe(ok) + end + + local alias = maybe + "#, + VirtualHoverResult { + value: + "```lua\nlocal function maybe(ok: boolean) -> (true|false), integer?\n```" + .to_string(), + }, + ) + ); + Ok(()) + } + + #[gtest] + fn test_return_overload_call_hover_short_generic_row_keeps_nil() -> Result<()> { let mut ws = ProviderVirtualWorkspace::new(); check!(ws.check_hover( r#" @@ -230,14 +280,14 @@ mod tests { ---@generic T ---@param x T ---@return_overload true, T - ---@return_overload false, string + ---@return_overload false local function parse(x) end parse(B) "#, VirtualHoverResult { - value: "```lua\nlocal function parse(x: B)\n -> true, B\n -> false, string\n\n```".to_string(), + value: "```lua\nlocal function parse(x: B) -> (true|false), B?\n```".to_string(), }, )); Ok(()) @@ -258,7 +308,7 @@ mod tests { local a, b = pcall(foo) "#, VirtualHoverResult { - value: "```lua\nfunction pcall(f: sync fun(a: string, b: table) -> ((false|true),((string,string)|string)), a: string, b: table)\n -> true, (false|true), ((string,string)|string)\n -> false, string\n\n```\n\n---\n\n\nCalls function `f` with the given arguments in *protected mode*. This\nmeans that any error inside `f` is not propagated; instead, `pcall` catches\nthe error and returns a status code. Its first result is the status code (a\nboolean), which is true if the call succeeds without errors. In such case,\n`pcall` also returns all results from the call, after this first result. In\ncase of any error, `pcall` returns **false** plus the error message.".to_string(), + value: "```lua\nfunction pcall(f: sync fun(a: string, b: table) -> ((false|true),((string,string)|string)), a: string, b: table) -> (true|false), (false|true|string), (((string,string)|string))?\n```\n\n---\n\n\nCalls function `f` with the given arguments in *protected mode*. This\nmeans that any error inside `f` is not propagated; instead, `pcall` catches\nthe error and returns a status code. Its first result is the status code (a\nboolean), which is true if the call succeeds without errors. In such case,\n`pcall` also returns all results from the call, after this first result. In\ncase of any error, `pcall` returns **false** plus the error message.".to_string(), }, )); Ok(()) @@ -343,7 +393,7 @@ mod tests { } "#, VirtualHoverResult { - value: "```lua\n(field) T.func(a: (string|number))\n```\n\n---\n\n注释1\n\n注释2\n\n---\n\n```lua\n(field) T.func(a: string)\n```\n\n```lua\n(field) T.func(a: number)\n```" + value: "```lua\n(field) T.func(a: string) (+1 overloads)\n```\n\n---\n\n注释1\n\n---\n\n```lua\n(field) T.func(a: number) -- 注释2\n```" .to_string(), }, )); @@ -360,10 +410,7 @@ mod tests { ---@field func fun(a:number) 注释2 ---@type T - local t = { - func = function(a) - end - } + local t t.func(1) "#, @@ -375,7 +422,7 @@ mod tests { } #[gtest] - fn test_origin_decl_1() -> Result<()> { + fn test_table_field_origin_decl() -> Result<()> { let mut ws = ProviderVirtualWorkspace::new(); check!(ws.check_hover( r#" @@ -391,7 +438,7 @@ mod tests { local abc = t.func "#, VirtualHoverResult { - value: "```lua\n(field) T.func(a: number)\n```\n\n---\n\n注释2\n\n注释1\n\n---\n\n```lua\n(field) T.func(a: string)\n```".to_string(), + value: "```lua\n(field) T.func(a: string) (+1 overloads)\n```\n\n---\n\n注释1\n\n---\n\n```lua\n(field) T.func(a: number) -- 注释2\n```".to_string(), }, )); Ok(()) @@ -651,6 +698,50 @@ mod tests { Ok(()) } + #[gtest] + fn test_call_hover_shows_all_overloads_when_no_match() -> Result<()> { + let mut ws = ProviderVirtualWorkspace::new(); + check!(ws.check_hover( + r#" + ---@overload fun(a: string): string + ---@overload fun(a: number): number + ---@param a table + function test(a) + end + + test(true) + "#, + VirtualHoverResult { + value: "```lua\nfunction test(a: table) (+2 overloads)\n```\n\n---\n\n---\n\n```lua\nfunction test(a: string) -> string\n```\n\n```lua\nfunction test(a: number) -> number\n```".to_string(), + }, + )); + Ok(()) + } + + #[gtest] + fn test_call_hover_shows_all_generic_overloads_when_no_match() -> Result<()> { + let mut ws = ProviderVirtualWorkspace::new(); + check!(ws.check_hover( + r#" + ---@generic T, U + ---@overload fun(value: string, fallback: T): T, U + ---@overload fun(value: number, fallback: T): T, U + ---@param value table + ---@param fallback T + ---@return T + ---@return U + function generic_test(value, fallback) + end + + generic_test(true, false) + "#, + VirtualHoverResult { + value: "```lua\nfunction generic_test(value: table, fallback: boolean) -> boolean, unknown (+2 overloads)\n```\n\n---\n\n---\n\n```lua\nfunction generic_test(value: string, fallback: boolean) -> boolean, unknown\n```\n\n```lua\nfunction generic_test(value: number, fallback: boolean) -> boolean, unknown\n```".to_string(), + }, + )); + Ok(()) + } + #[gtest] fn test_fix_method_1() -> Result<()> { let mut ws = ProviderVirtualWorkspace::new(); @@ -773,4 +864,187 @@ mod tests { )); Ok(()) } + + #[gtest] + fn test_regression_generic_table_field_should_be_function_owner() -> Result<()> { + let mut ws = ProviderVirtualWorkspace::new(); + ws.def( + r#" + ---@class ObserverParams + ---@field next fun(value: T): T # 测试 + + ---@generic T + ---@param params ObserverParams + function observe(params) + end + "#, + ); + check!( + ws.check_hover( + r#" + observe({ + ---@param value string + next = function(value) + return value + end + }) + "#, + VirtualHoverResult { + value: "```lua\n(field) ObserverParams.next(value: string) -> string\n```\n\n---\n\n测试" + .to_string(), + }, + ) + ); + Ok(()) + } + + #[gtest] + fn test_generic_table_field_value_without_inference_source() -> Result<()> { + let mut ws = ProviderVirtualWorkspace::new(); + ws.def( + r#" + ---@class ObserverParams + ---@field next fun(value: T): T # 测试 + + ---@generic T + ---@param params ObserverParams + function observe(params) + end + "#, + ); + check!( + ws.check_hover( + r#" + observe({ + next = 1 + }) + "#, + VirtualHoverResult { + value: "```lua\n(field) ObserverParams.next(value: unknown) -> unknown\n```\n\n---\n\n测试" + .to_string(), + }, + ) + ); + Ok(()) + } + + #[gtest] + fn test_generic_table_field_hover_filters_union_parent_without_field() -> Result<()> { + let mut ws = ProviderVirtualWorkspace::new(); + ws.def( + r#" + ---@class ObserverParams + ---@field next fun(value: T): T # 测试 + + ---@class OtherParams1 + ---@field other string + + ---@class OtherParams2 + ---@field wait fun(value: T): T # 测试2 + "#, + ); + check!( + ws.check_hover( + r#" + ---@type OtherParams2|ObserverParams|OtherParams1 + local params = { + next = function(value) + return value + end + } + "#, + VirtualHoverResult { + value: "```lua\n(field) ObserverParams.next(value: string) -> string\n```\n\n---\n\n测试" + .to_string(), + }, + ) + ); + Ok(()) + } + + #[gtest] + fn test_table_field_hover_keeps_same_owner_same_name_overloads() -> Result<()> { + let mut ws = ProviderVirtualWorkspace::new(); + ws.def( + r#" + ---@class OverloadedParams + ---@field next fun(value: string): string # 字符串 + ---@field next fun(value: number): number # 数字 + "#, + ); + check!( + ws.check_hover( + r#" + ---@type OverloadedParams + local params = { + next = function(value) + return value + end + } + "#, + VirtualHoverResult { + value: "```lua\n(field) OverloadedParams.next(value: string) -> string (+1 overloads)\n```\n\n---\n\n字符串\n\n---\n\n```lua\n(field) OverloadedParams.next(value: number) -> number -- 数字\n```" + .to_string(), + }, + ) + ); + Ok(()) + } + + #[gtest] + fn test_function_candidate_checks_all_origin_decls() -> Result<()> { + let mut ws = ProviderVirtualWorkspace::new(); + ws.def( + r#" + ---@class MixedOrigin + ---@field next string # 字符串 + ---@field next fun(): string # 函数 + "#, + ); + check!(ws.check_hover( + r#" + ---@type MixedOrigin + local params + local next = params.next + "#, + VirtualHoverResult { + value: + "```lua\n(field) MixedOrigin.next() -> string\n```\n\n---\n\n函数".to_string(), + }, + )); + Ok(()) + } + + #[gtest] + fn test_generic_table_field_uses_known_context_type() -> Result<()> { + let mut ws = ProviderVirtualWorkspace::new(); + ws.def( + r#" + ---@class ObserverParams + ---@field next fun(value: T): T # 测试 + + ---@generic T + ---@param value T + ---@param params ObserverParams + function observe(value, params) + end + "#, + ); + check!( + ws.check_hover( + r#" + observe("x", { + next = function(value) + return value + end + }) + "#, + VirtualHoverResult { + value: "```lua\n(field) ObserverParams.next(value: string) -> string\n```\n\n---\n\n测试" + .to_string(), + }, + ) + ); + Ok(()) + } } diff --git a/crates/emmylua_ls/src/handlers/test/hover_test.rs b/crates/emmylua_ls/src/handlers/test/hover_test.rs index 6acf90628..015e8f806 100644 --- a/crates/emmylua_ls/src/handlers/test/hover_test.rs +++ b/crates/emmylua_ls/src/handlers/test/hover_test.rs @@ -331,6 +331,25 @@ mod tests { Ok(()) } + #[gtest] + fn test_attribute_hover_uses_arg_types() -> Result<()> { + let mut ws = ProviderVirtualWorkspace::new(); + check!(ws.check_hover( + r#" + ---@class custom_attribute: Attribute + ---@overload fun(value: string) + ---@overload fun(value: integer) + + ---@[custom_attribute(1)] + local a + "#, + VirtualHoverResult { + value: "```lua\n(class) custom_attribute(value: integer)\n```".to_string(), + }, + )); + Ok(()) + } + #[gtest] fn test_alias_desc() -> Result<()> { let mut ws = ProviderVirtualWorkspace::new(); diff --git a/crates/emmylua_ls/src/handlers/test/inlay_hint_test.rs b/crates/emmylua_ls/src/handlers/test/inlay_hint_test.rs index 4f62af542..6721c7a25 100644 --- a/crates/emmylua_ls/src/handlers/test/inlay_hint_test.rs +++ b/crates/emmylua_ls/src/handlers/test/inlay_hint_test.rs @@ -179,7 +179,13 @@ mod tests { #[gtest] fn test_index_key_alias_hint() -> Result<()> { let mut ws = ProviderVirtualWorkspace::new(); - ws.def(" ---@attribute index_alias(name: string)"); + ws.def( + r#" + ---@class Attribute + ---@class index_alias: Attribute + ---@overload fun(name: string) + "#, + ); check!(ws.check_inlay_hint( r#" local export = { diff --git a/crates/emmylua_parser/locales/app.yml b/crates/emmylua_parser/locales/app.yml index 4be4e230b..dc4086e89 100644 --- a/crates/emmylua_parser/locales/app.yml +++ b/crates/emmylua_parser/locales/app.yml @@ -555,3 +555,7 @@ unfinished long comment: en: unfinished long comment zh_CN: 未完成的长注释 zh_HK: 未完成的長註釋 +Identifier expected. '%{reserved}' is a reserved word that cannot be used here.: + en: Identifier expected. '%{reserved}' is a reserved word that cannot be used here. + zh_CN: 应为标识符。'%{reserved}' 是保留字,不能在此处使用。 + zh_HK: 應為標識符。'%{reserved}' 是保留字,不能在此處使用。 diff --git a/crates/emmylua_parser/src/grammar/doc/tag.rs b/crates/emmylua_parser/src/grammar/doc/tag.rs index 563f5acf2..b59a7185d 100644 --- a/crates/emmylua_parser/src/grammar/doc/tag.rs +++ b/crates/emmylua_parser/src/grammar/doc/tag.rs @@ -8,7 +8,7 @@ use crate::{ use super::{ expect_token, if_token_bump, parse_description, - types::{parse_fun_type, parse_type, parse_type_list, parse_typed_param}, + types::{parse_fun_type, parse_type, parse_type_list}, }; pub fn parse_tag(p: &mut LuaDocParser) { @@ -57,7 +57,6 @@ fn parse_tag_detail(p: &mut LuaDocParser) -> DocParseResult { LuaTokenKind::TkTagUsing => parse_tag_using(p), LuaTokenKind::TkTagMeta => parse_tag_meta(p), LuaTokenKind::TkLanguage => parse_tag_language(p), - LuaTokenKind::TkTagAttribute => parse_tag_attribute(p), LuaTokenKind::TkDocAttributeUse => parse_tag_attribute_use(p, true), LuaTokenKind::TkCallGeneric => parse_tag_call_generic(p), LuaTokenKind::TKTagSchema => parse_tag_schema(p), @@ -151,6 +150,7 @@ pub(super) fn parse_generic_decl_list( // A = type fn parse_generic_param(p: &mut LuaDocParser) -> DocParseResult { let m = p.mark(LuaSyntaxKind::DocGenericParameter); + parse_generic_modifier(p)?; expect_token(p, LuaTokenKind::TkName)?; if p.current_token() == LuaTokenKind::TkDots { p.bump(); @@ -169,6 +169,25 @@ fn parse_generic_param(p: &mut LuaDocParser) -> DocParseResult { Ok(m.complete(p)) } +fn parse_generic_modifier(p: &mut LuaDocParser) -> Result<(), LuaParseError> { + if p.current_token() == LuaTokenKind::TkName && p.current_token_text() == "const" { + let range = p.current_token_range(); + p.set_current_token_kind(LuaTokenKind::TkDocConst); + p.bump(); + if p.current_token() != LuaTokenKind::TkName { + return Err(LuaParseError::doc_error_from( + &t!( + "Identifier expected. '%{reserved}' is a reserved word that cannot be used here.", + reserved = "const" + ), + range, + )); + } + } + + Ok(()) +} + // ---@enum A // ---@enum A : number fn parse_tag_enum(p: &mut LuaDocParser) -> DocParseResult { @@ -683,40 +702,6 @@ fn parse_tag_language(p: &mut LuaDocParser) -> DocParseResult { Ok(m.complete(p)) } -// ---@attribute 名称(参数列表) -fn parse_tag_attribute(p: &mut LuaDocParser) -> DocParseResult { - p.set_lexer_state(LuaDocLexerState::Normal); - let m = p.mark(LuaSyntaxKind::DocTagAttribute); - p.bump(); - - // 解析属性名称 - expect_token(p, LuaTokenKind::TkName)?; - - // 解析参数列表 - parse_type_attribute(p)?; - - p.set_lexer_state(LuaDocLexerState::Description); - parse_description(p); - Ok(m.complete(p)) -} - -// (param1: type1, param2: type2, ...) -fn parse_type_attribute(p: &mut LuaDocParser) -> DocParseResult { - let m = p.mark(LuaSyntaxKind::TypeAttribute); - expect_token(p, LuaTokenKind::TkLeftParen)?; - - if p.current_token() != LuaTokenKind::TkRightParen { - parse_typed_param(p)?; - while p.current_token() == LuaTokenKind::TkComma { - p.bump(); - parse_typed_param(p)?; - } - } - - expect_token(p, LuaTokenKind::TkRightParen)?; - Ok(m.complete(p)) -} - // ---@[attribute(arg1, arg2, ...)] // ---@[attribute] // ---@[attribute1, attribute2, ...] diff --git a/crates/emmylua_parser/src/grammar/doc/test.rs b/crates/emmylua_parser/src/grammar/doc/test.rs index 12ab4db84..eee8ca5ec 100644 --- a/crates/emmylua_parser/src/grammar/doc/test.rs +++ b/crates/emmylua_parser/src/grammar/doc/test.rs @@ -1,6 +1,6 @@ #[cfg(test)] mod tests { - use crate::{LuaParser, parser::ParserConfig}; + use crate::{LuaParseErrorKind, LuaParser, parser::ParserConfig}; macro_rules! assert_ast_eq { ($lua_code:expr, $expected:expr) => { @@ -1055,6 +1055,68 @@ Syntax(Chunk)@0..92 assert_ast_eq!(code, result); } + #[test] + fn test_generic_const_modifier_doc() { + let code = "---@class A\n---@generic const R\n---@alias B const\n"; + + let result = r#" +Syntax(Chunk)@0..62 + Syntax(Block)@0..62 + Syntax(Comment)@0..61 + Token(TkDocStart)@0..4 "---@" + Syntax(DocTagClass)@4..20 + Token(TkTagClass)@4..9 "class" + Token(TkWhitespace)@9..10 " " + Token(TkName)@10..11 "A" + Syntax(DocGenericDeclareList)@11..20 + Token(TkLt)@11..12 "<" + Syntax(DocGenericParameter)@12..19 + Token(TkDocConst)@12..17 "const" + Token(TkWhitespace)@17..18 " " + Token(TkName)@18..19 "T" + Token(TkGt)@19..20 ">" + Token(TkEndOfLine)@20..21 "\n" + Token(TkDocStart)@21..25 "---@" + Syntax(DocTagGeneric)@25..40 + Token(TkTagGeneric)@25..32 "generic" + Token(TkWhitespace)@32..33 " " + Syntax(DocGenericDeclareList)@33..40 + Syntax(DocGenericParameter)@33..40 + Token(TkDocConst)@33..38 "const" + Token(TkWhitespace)@38..39 " " + Token(TkName)@39..40 "R" + Token(TkEndOfLine)@40..41 "\n" + Token(TkDocStart)@41..45 "---@" + Syntax(DocTagAlias)@45..61 + Token(TkTagAlias)@45..50 "alias" + Token(TkWhitespace)@50..51 " " + Token(TkName)@51..52 "B" + Syntax(DocGenericDeclareList)@52..55 + Token(TkLt)@52..53 "<" + Syntax(DocGenericParameter)@53..54 + Token(TkName)@53..54 "T" + Token(TkGt)@54..55 ">" + Token(TkWhitespace)@55..56 " " + Syntax(TypeName)@56..61 + Token(TkName)@56..61 "const" + Token(TkEndOfLine)@61..62 "\n" + "#; + + assert_ast_eq!(code, result); + } + + #[test] + fn test_generic_const_modifier_requires_identifier() { + let tree = LuaParser::parse("---@class A\n", ParserConfig::default()); + let errors = tree.get_errors(); + + assert!(errors.iter().any(|error| { + error.kind == LuaParseErrorKind::DocError + && error.message + == "Identifier expected. 'const' is a reserved word that cannot be used here." + })); + } + #[test] fn test_diagnostic_doc() { let code = r#" @@ -2838,7 +2900,6 @@ Syntax(Chunk)@0..263 #[test] fn test_attribute_doc() { let code = r#" - ---@attribute check_point(x: string, y: number) ---@[Skip, check_point("a", 0)] "#; // print_ast(code); @@ -2847,58 +2908,34 @@ Syntax(Chunk)@0..263 // check_point("a", 0) // "#); let result = r#" -Syntax(Chunk)@0..105 - Syntax(Block)@0..105 +Syntax(Chunk)@0..49 + Syntax(Block)@0..49 Token(TkEndOfLine)@0..1 "\n" Token(TkWhitespace)@1..9 " " - Syntax(Comment)@9..96 + Syntax(Comment)@9..40 Token(TkDocStart)@9..13 "---@" - Syntax(DocTagAttribute)@13..56 - Token(TkTagAttribute)@13..22 "attribute" - Token(TkWhitespace)@22..23 " " - Token(TkName)@23..34 "check_point" - Syntax(TypeAttribute)@34..56 - Token(TkLeftParen)@34..35 "(" - Syntax(DocTypedParameter)@35..44 - Token(TkName)@35..36 "x" - Token(TkColon)@36..37 ":" - Token(TkWhitespace)@37..38 " " - Syntax(TypeName)@38..44 - Token(TkName)@38..44 "string" - Token(TkComma)@44..45 "," - Token(TkWhitespace)@45..46 " " - Syntax(DocTypedParameter)@46..55 - Token(TkName)@46..47 "y" - Token(TkColon)@47..48 ":" - Token(TkWhitespace)@48..49 " " - Syntax(TypeName)@49..55 - Token(TkName)@49..55 "number" - Token(TkRightParen)@55..56 ")" - Token(TkEndOfLine)@56..57 "\n" - Token(TkWhitespace)@57..65 " " - Token(TkDocStart)@65..69 "---@" - Syntax(DocTagAttributeUse)@69..96 - Token(TkDocAttributeUse)@69..70 "[" - Syntax(DocAttributeUse)@70..74 - Syntax(TypeName)@70..74 - Token(TkName)@70..74 "Skip" - Token(TkComma)@74..75 "," - Token(TkWhitespace)@75..76 " " - Syntax(DocAttributeUse)@76..95 - Syntax(TypeName)@76..87 - Token(TkName)@76..87 "check_point" - Syntax(DocAttributeCallArgList)@87..95 - Token(TkLeftParen)@87..88 "(" - Syntax(LiteralExpr)@88..91 - Token(TkString)@88..91 "\"a\"" - Token(TkComma)@91..92 "," - Token(TkWhitespace)@92..93 " " - Syntax(LiteralExpr)@93..94 - Token(TkInt)@93..94 "0" - Token(TkRightParen)@94..95 ")" - Token(TkRightBracket)@95..96 "]" - Token(TkEndOfLine)@96..97 "\n" - Token(TkWhitespace)@97..105 " " + Syntax(DocTagAttributeUse)@13..40 + Token(TkDocAttributeUse)@13..14 "[" + Syntax(DocAttributeUse)@14..18 + Syntax(TypeName)@14..18 + Token(TkName)@14..18 "Skip" + Token(TkComma)@18..19 "," + Token(TkWhitespace)@19..20 " " + Syntax(DocAttributeUse)@20..39 + Syntax(TypeName)@20..31 + Token(TkName)@20..31 "check_point" + Syntax(DocAttributeCallArgList)@31..39 + Token(TkLeftParen)@31..32 "(" + Syntax(LiteralExpr)@32..35 + Token(TkString)@32..35 "\"a\"" + Token(TkComma)@35..36 "," + Token(TkWhitespace)@36..37 " " + Syntax(LiteralExpr)@37..38 + Token(TkInt)@37..38 "0" + Token(TkRightParen)@38..39 ")" + Token(TkRightBracket)@39..40 "]" + Token(TkEndOfLine)@40..41 "\n" + Token(TkWhitespace)@41..49 " " "#; assert_ast_eq!(code, result); } diff --git a/crates/emmylua_parser/src/kind/lua_syntax_kind.rs b/crates/emmylua_parser/src/kind/lua_syntax_kind.rs index 7e4a5c825..68775a3ea 100644 --- a/crates/emmylua_parser/src/kind/lua_syntax_kind.rs +++ b/crates/emmylua_parser/src/kind/lua_syntax_kind.rs @@ -92,7 +92,6 @@ pub enum LuaSyntaxKind { DocTagReadonly, DocTagReturnCast, DocTagLanguage, - DocTagAttribute, DocTagAttributeUse, // '@[' DocTagCallGeneric, DocTagSchema, @@ -113,7 +112,6 @@ pub enum LuaSyntaxKind { TypeNullable, // ? TypeStringTemplate, // prefixName.`T` TypeMultiLineUnion, // | simple type # description - TypeAttribute, // declare. attribute<(paramList)> // follow donot support now TypeMatch, diff --git a/crates/emmylua_parser/src/kind/lua_token_kind.rs b/crates/emmylua_parser/src/kind/lua_token_kind.rs index e0a18a35c..a7fe47cb5 100644 --- a/crates/emmylua_parser/src/kind/lua_token_kind.rs +++ b/crates/emmylua_parser/src/kind/lua_token_kind.rs @@ -136,7 +136,6 @@ pub enum LuaTokenKind { TkTagReturnOverload, // return overload TkLanguage, // language TKTagSchema, // schema - TkTagAttribute, // attribute TkCallGeneric, // call generic. function_name--[[@]](...) TkDocOr, // | @@ -147,6 +146,7 @@ pub enum LuaTokenKind { TkDocAs, // as TkDocIn, // in TkDocInfer, // infer + TkDocConst, // const TkDocElse, // else (for return_cast) TkDocContinue, // --- TkDocContinueOr, // ---| or ---|+ or ---|> diff --git a/crates/emmylua_parser/src/lexer/lua_doc_lexer.rs b/crates/emmylua_parser/src/lexer/lua_doc_lexer.rs index 1fa54e80f..f2e367154 100644 --- a/crates/emmylua_parser/src/lexer/lua_doc_lexer.rs +++ b/crates/emmylua_parser/src/lexer/lua_doc_lexer.rs @@ -828,7 +828,6 @@ fn to_tag(text: &str) -> LuaTokenKind { "using" => LuaTokenKind::TkTagUsing, "source" => LuaTokenKind::TkTagSource, "language" => LuaTokenKind::TkLanguage, - "attribute" => LuaTokenKind::TkTagAttribute, "schema" => LuaTokenKind::TKTagSchema, _ => LuaTokenKind::TkTagOther, } diff --git a/crates/emmylua_parser/src/syntax/node/doc/mod.rs b/crates/emmylua_parser/src/syntax/node/doc/mod.rs index 04a5b43ad..a159b1bfe 100644 --- a/crates/emmylua_parser/src/syntax/node/doc/mod.rs +++ b/crates/emmylua_parser/src/syntax/node/doc/mod.rs @@ -365,6 +365,10 @@ impl LuaDocGenericDecl { pub fn is_variadic(&self) -> bool { self.token_by_kind(LuaTokenKind::TkDots).is_some() } + + pub fn has_const_modifier(&self) -> bool { + self.token_by_kind(LuaTokenKind::TkDocConst).is_some() + } } #[derive(Debug, Clone, PartialEq, Eq, Hash)] diff --git a/crates/emmylua_parser/src/syntax/node/doc/tag.rs b/crates/emmylua_parser/src/syntax/node/doc/tag.rs index 6149186de..f61dd4d46 100644 --- a/crates/emmylua_parser/src/syntax/node/doc/tag.rs +++ b/crates/emmylua_parser/src/syntax/node/doc/tag.rs @@ -17,7 +17,6 @@ pub enum LuaDocTag { Class(LuaDocTagClass), Enum(LuaDocTagEnum), Alias(LuaDocTagAlias), - Attribute(LuaDocTagAttribute), AttributeUse(LuaDocTagAttributeUse), Type(LuaDocTagType), Param(LuaDocTagParam), @@ -54,7 +53,6 @@ impl LuaAstNode for LuaDocTag { LuaDocTag::Class(it) => it.syntax(), LuaDocTag::Enum(it) => it.syntax(), LuaDocTag::Alias(it) => it.syntax(), - LuaDocTag::Attribute(it) => it.syntax(), LuaDocTag::Type(it) => it.syntax(), LuaDocTag::Param(it) => it.syntax(), LuaDocTag::Return(it) => it.syntax(), @@ -94,7 +92,6 @@ impl LuaAstNode for LuaDocTag { || kind == LuaSyntaxKind::DocTagEnum || kind == LuaSyntaxKind::DocTagAlias || kind == LuaSyntaxKind::DocTagType - || kind == LuaSyntaxKind::DocTagAttribute || kind == LuaSyntaxKind::DocTagParam || kind == LuaSyntaxKind::DocTagReturn || kind == LuaSyntaxKind::DocTagReturnOverload @@ -138,9 +135,6 @@ impl LuaAstNode for LuaDocTag { LuaSyntaxKind::DocTagAlias => { Some(LuaDocTag::Alias(LuaDocTagAlias::cast(syntax).unwrap())) } - LuaSyntaxKind::DocTagAttribute => Some(LuaDocTag::Attribute( - LuaDocTagAttribute::cast(syntax).unwrap(), - )), LuaSyntaxKind::DocTagAttributeUse => Some(LuaDocTag::AttributeUse( LuaDocTagAttributeUse::cast(syntax).unwrap(), )), @@ -1625,41 +1619,6 @@ impl LuaDocTagLanguage { } } -#[derive(Debug, Clone, PartialEq, Eq, Hash)] -pub struct LuaDocTagAttribute { - syntax: LuaSyntaxNode, -} - -impl LuaAstNode for LuaDocTagAttribute { - fn syntax(&self) -> &LuaSyntaxNode { - &self.syntax - } - - fn can_cast(kind: LuaSyntaxKind) -> bool { - kind == LuaSyntaxKind::DocTagAttribute - } - - fn cast(syntax: LuaSyntaxNode) -> Option { - if Self::can_cast(syntax.kind().into()) { - Some(Self { syntax }) - } else { - None - } - } -} - -impl LuaDocDescriptionOwner for LuaDocTagAttribute {} - -impl LuaDocTagAttribute { - pub fn get_name_token(&self) -> Option { - self.token() - } - - pub fn get_type(&self) -> Option { - self.child() - } -} - #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct LuaDocTagAttributeUse { syntax: LuaSyntaxNode, diff --git a/crates/emmylua_parser/src/syntax/node/doc/test.rs b/crates/emmylua_parser/src/syntax/node/doc/test.rs index 6854d7103..454e33526 100644 --- a/crates/emmylua_parser/src/syntax/node/doc/test.rs +++ b/crates/emmylua_parser/src/syntax/node/doc/test.rs @@ -206,4 +206,21 @@ mod test { "string" ); } + + #[test] + fn test_doc_generic_const_modifier_accessor() { + let tree = LuaParser::parse("---@class A\n", ParserConfig::default()); + let root = tree.get_chunk_node(); + let class = root.descendants::().next().unwrap(); + let generic_decl = class.get_generic_decl().unwrap(); + let mut params = generic_decl.get_generic_decl(); + + let const_param = params.next().unwrap(); + assert!(const_param.has_const_modifier()); + assert_eq!(const_param.get_name_token().unwrap().get_name_text(), "T"); + + let regular_param = params.next().unwrap(); + assert!(!regular_param.has_const_modifier()); + assert_eq!(regular_param.get_name_token().unwrap().get_name_text(), "U"); + } } diff --git a/crates/emmylua_parser/src/syntax/node/doc/types.rs b/crates/emmylua_parser/src/syntax/node/doc/types.rs index 80b894063..08bd25e13 100644 --- a/crates/emmylua_parser/src/syntax/node/doc/types.rs +++ b/crates/emmylua_parser/src/syntax/node/doc/types.rs @@ -25,7 +25,6 @@ pub enum LuaDocType { Generic(LuaDocGenericType), StrTpl(LuaDocStrTplType), MultiLineUnion(LuaDocMultiLineUnionType), - Attribute(LuaDocAttributeType), Mapped(LuaDocMappedType), IndexAccess(LuaDocIndexAccessType), } @@ -48,7 +47,6 @@ impl LuaAstNode for LuaDocType { LuaDocType::Generic(it) => it.syntax(), LuaDocType::StrTpl(it) => it.syntax(), LuaDocType::MultiLineUnion(it) => it.syntax(), - LuaDocType::Attribute(it) => it.syntax(), LuaDocType::Mapped(it) => it.syntax(), LuaDocType::IndexAccess(it) => it.syntax(), } @@ -75,7 +73,6 @@ impl LuaAstNode for LuaDocType { | LuaSyntaxKind::TypeGeneric | LuaSyntaxKind::TypeStringTemplate | LuaSyntaxKind::TypeMultiLineUnion - | LuaSyntaxKind::TypeAttribute | LuaSyntaxKind::TypeMapped | LuaSyntaxKind::TypeIndexAccess ) @@ -119,9 +116,6 @@ impl LuaAstNode for LuaDocType { LuaSyntaxKind::TypeMultiLineUnion => Some(LuaDocType::MultiLineUnion( LuaDocMultiLineUnionType::cast(syntax)?, )), - LuaSyntaxKind::TypeAttribute => { - Some(LuaDocType::Attribute(LuaDocAttributeType::cast(syntax)?)) - } _ => None, } } @@ -846,41 +840,6 @@ impl LuaDocOneLineField { } } -#[derive(Debug, Clone, PartialEq, Eq, Hash)] -pub struct LuaDocAttributeType { - syntax: LuaSyntaxNode, -} - -impl LuaAstNode for LuaDocAttributeType { - fn syntax(&self) -> &LuaSyntaxNode { - &self.syntax - } - - fn can_cast(kind: LuaSyntaxKind) -> bool - where - Self: Sized, - { - kind == LuaSyntaxKind::TypeAttribute - } - - fn cast(syntax: LuaSyntaxNode) -> Option - where - Self: Sized, - { - if Self::can_cast(syntax.kind().into()) { - Some(Self { syntax }) - } else { - None - } - } -} - -impl LuaDocAttributeType { - pub fn get_params(&self) -> LuaAstChildren { - self.children() - } -} - #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct LuaDocMappedType { syntax: LuaSyntaxNode, diff --git a/crates/emmylua_parser/src/syntax/node/mod.rs b/crates/emmylua_parser/src/syntax/node/mod.rs index 9b39ad5e0..fe12482fa 100644 --- a/crates/emmylua_parser/src/syntax/node/mod.rs +++ b/crates/emmylua_parser/src/syntax/node/mod.rs @@ -90,7 +90,6 @@ pub enum LuaAst { LuaDocTagAs(LuaDocTagAs), LuaDocTagReturnCast(LuaDocTagReturnCast), LuaDocTagLanguage(LuaDocTagLanguage), - LuaDocTagAttribute(LuaDocTagAttribute), LuaDocTagAttributeUse(LuaDocTagAttributeUse), // doc description LuaDocDescription(LuaDocDescription), @@ -181,7 +180,6 @@ impl LuaAstNode for LuaAst { LuaAst::LuaDocTagAsync(node) => node.syntax(), LuaAst::LuaDocTagAs(node) => node.syntax(), LuaAst::LuaDocTagReturnCast(node) => node.syntax(), - LuaAst::LuaDocTagAttribute(node) => node.syntax(), LuaAst::LuaDocTagAttributeUse(node) => node.syntax(), LuaAst::LuaDocTagLanguage(node) => node.syntax(), LuaAst::LuaDocDescription(node) => node.syntax(), @@ -370,9 +368,6 @@ impl LuaAstNode for LuaAst { LuaSyntaxKind::DocTagClass => LuaDocTagClass::cast(syntax).map(LuaAst::LuaDocTagClass), LuaSyntaxKind::DocTagEnum => LuaDocTagEnum::cast(syntax).map(LuaAst::LuaDocTagEnum), LuaSyntaxKind::DocTagAlias => LuaDocTagAlias::cast(syntax).map(LuaAst::LuaDocTagAlias), - LuaSyntaxKind::DocTagAttribute => { - LuaDocTagAttribute::cast(syntax).map(LuaAst::LuaDocTagAttribute) - } LuaSyntaxKind::DocTagType => LuaDocTagType::cast(syntax).map(LuaAst::LuaDocTagType), LuaSyntaxKind::DocTagParam => LuaDocTagParam::cast(syntax).map(LuaAst::LuaDocTagParam), LuaSyntaxKind::DocTagReturn => {