diff --git a/vyper_lsp/analyzer/AstAnalyzer.py b/vyper_lsp/analyzer/AstAnalyzer.py index a9dbbe5..f9dcfac 100644 --- a/vyper_lsp/analyzer/AstAnalyzer.py +++ b/vyper_lsp/analyzer/AstAnalyzer.py @@ -1,6 +1,6 @@ import logging import re -from typing import List, Optional +from typing import List, Optional, Tuple import warnings from packaging.version import Version from lsprotocol.types import ( @@ -26,6 +26,7 @@ diagnostic_from_exception, is_internal_fn, is_state_var, + range_from_node, ) from lsprotocol.types import ( CompletionItem, @@ -124,7 +125,6 @@ def get_completions_in_doc( if element == "self": for fn in self.ast.get_internal_functions(): items.append(CompletionItem(label=fn)) - # TODO: This should exclude constants and immutables for var in self.ast.get_state_variables(): items.append(CompletionItem(label=var)) else: @@ -194,7 +194,9 @@ def _format_fn_signature(self, node: nodes.FunctionDef) -> str: function_def = match.group() return f"(Internal Function) {function_def}" - def hover_info(self, document: Document, pos: Position) -> Optional[str]: + def hover_info( + self, document: Document, pos: Position + ) -> Optional[Tuple[str, Range]]: if len(document.lines) < pos.line: return None @@ -204,34 +206,33 @@ def hover_info(self, document: Document, pos: Position) -> Optional[str]: if is_internal_fn(full_word): node = self.ast.find_function_declaration_node_for_name(word) - return node and self._format_fn_signature(node) + return node and (self._format_fn_signature(node), range_from_node(node)) if is_state_var(full_word): node = self.ast.find_state_variable_declaration_node_for_name(word) - if not node: - return None - variable_type = node.annotation.id - return f"(State Variable) **{word}** : **{variable_type}**" + return node and ( + f"(State Variable) **{word}** : **{node.annotation.id}**", + range_from_node(node), + ) if word in self.ast.get_structs(): node = self.ast.find_type_declaration_node_for_name(word) - return node and f"(Struct) **{word}**" + return node and (f"(Struct) **{word}**", range_from_node(node)) if word in self.ast.get_enums(): node = self.ast.find_type_declaration_node_for_name(word) - return node and f"(Enum) **{word}**" + return node and (f"(Enum) **{word}**", range_from_node(node)) if word in self.ast.get_events(): node = self.ast.find_type_declaration_node_for_name(word) - return node and f"(Event) **{word}**" + return node and (f"(Event) **{word}**", range_from_node(node)) if word in self.ast.get_constants(): node = self.ast.find_state_variable_declaration_node_for_name(word) - if not node: - return None - - variable_type = node.annotation.id - return f"(Constant) **{word}** : **{variable_type}**" + return node and ( + f"(Constant) **{word}** : **{node.annotation.id}**", + range_from_node(node), + ) return None diff --git a/vyper_lsp/ast.py b/vyper_lsp/ast.py index ad2c66c..5f8d658 100644 --- a/vyper_lsp/ast.py +++ b/vyper_lsp/ast.py @@ -13,7 +13,12 @@ class AST: ast_data_folded = None ast_data_unfolded = None - custom_type_node_types = (nodes.StructDef, nodes.EnumDef) + custom_type_node_types = ( + nodes.StructDef, + nodes.EnumDef, + nodes.InterfaceDef, + nodes.EventDef, + ) @classmethod def from_node(cls, node: VyperNode): @@ -67,16 +72,21 @@ def get_top_level_nodes(self, *args, **kwargs): return self.best_ast.get_children(*args, **kwargs) def get_enums(self) -> List[str]: - return [node.name for node in self.get_descendants(nodes.EnumDef)] + return [node.name for node in self.get_top_level_nodes(nodes.EnumDef)] def get_structs(self) -> List[str]: - return [node.name for node in self.get_descendants(nodes.StructDef)] + return [node.name for node in self.get_top_level_nodes(nodes.StructDef)] def get_events(self) -> List[str]: - return [node.name for node in self.get_descendants(nodes.EventDef)] + return [node.name for node in self.get_top_level_nodes(nodes.EventDef)] + + def get_interfaces(self): + return [node.name for node in self.get_top_level_nodes(nodes.InterfaceDef)] def get_user_defined_types(self): - return [node.name for node in self.get_descendants(self.custom_type_node_types)] + return [ + node.name for node in self.get_top_level_nodes(self.custom_type_node_types) + ] def get_constants(self): # NOTE: Constants should be fetched from self.ast_data, they are missing @@ -86,10 +96,29 @@ def get_constants(self): return [ node.target.id - for node in self.ast_data.get_children(nodes.VariableDecl) + for node in self.get_top_level_nodes(nodes.VariableDecl) if node.is_constant ] + def get_immutables(self): + return [ + node.target.id + for node in self.get_top_level_nodes(nodes.VariableDecl) + if node.is_immutable + ] + + def get_state_variables(self): + # NOTE: The state variables should be fetched from self.ast_data, they are + # missing from self.ast_data_unfolded and self.ast_data_folded when constants + if self.ast_data is None: + return [] + + return [ + node.target.id + for node in self.get_top_level_nodes(nodes.VariableDecl) + if not node.is_constant and not node.is_immutable + ] + def get_enum_variants(self, enum: str): enum_node = self.find_type_declaration_node_for_name(enum) if enum_node is None: @@ -104,16 +133,6 @@ def get_struct_fields(self, struct: str): return [node.target.id for node in struct_node.get_children(nodes.AnnAssign)] - def get_state_variables(self): - # NOTE: The state variables should be fetched from self.ast_data, they are - # missing from self.ast_data_unfolded and self.ast_data_folded when constants - if self.ast_data is None: - return [] - - return [ - node.target.id for node in self.ast_data.get_descendants(nodes.VariableDecl) - ] - def get_internal_function_nodes(self): function_nodes = self.get_descendants(nodes.FunctionDef) internal_nodes = [] @@ -138,14 +157,20 @@ def find_nodes_referencing_state_variable(self, variable: str): nodes.Attribute, {"value.id": "self", "attr": variable} ) - def find_nodes_referencing_constant(self, constant: str): - name_nodes = self.get_descendants(nodes.Name, {"id": constant}) + def find_nodes_referencing_constant_or_immutable(self, name: str): + name_nodes = self.get_descendants(nodes.Name, {"id": name}) return [ node for node in name_nodes if not isinstance(node.get_ancestor(), nodes.VariableDecl) ] + def find_nodes_referencing_constant(self, constant: str): + return self.find_nodes_referencing_constant_or_immutable(constant) + + def find_nodes_referencing_immutable(self, immutable: str): + return self.find_nodes_referencing_constant_or_immutable(immutable) + def get_attributes_for_symbol(self, symbol: str): node = self.find_type_declaration_node_for_name(symbol) if node is None: @@ -159,12 +184,8 @@ def get_attributes_for_symbol(self, symbol: str): return [] def find_function_declaration_node_for_name(self, function: str): - for node in self.get_descendants(nodes.FunctionDef): - name_match = node.name == function - not_interface_declaration = not isinstance( - node.get_ancestor(), nodes.InterfaceDef - ) - if name_match and not_interface_declaration: + for node in self.get_top_level_nodes(nodes.FunctionDef): + if node.name == function: return node return None @@ -175,15 +196,15 @@ def find_state_variable_declaration_node_for_name(self, variable: str): if self.ast_data is None: return None - for node in self.ast_data.get_descendants(nodes.VariableDecl): + for node in self.get_top_level_nodes(nodes.VariableDecl): if node.target.id == variable: return node return None def find_type_declaration_node_for_name(self, symbol: str): - searchable_types = self.custom_type_node_types + (nodes.EventDef,) - for node in self.get_descendants(searchable_types): + searchable_types = self.custom_type_node_types + for node in self.get_top_level_nodes(searchable_types): if node.name == symbol: return node if isinstance(node, nodes.EnumDef): @@ -193,17 +214,44 @@ def find_type_declaration_node_for_name(self, symbol: str): return None - def find_nodes_referencing_enum(self, enum: str): + def find_nodes_referencing_type(self, type_name: str): return_nodes = [] - for node in self.get_descendants(nodes.AnnAssign, {"annotation.id": enum}): - return_nodes.append(node) - for node in self.get_descendants(nodes.Attribute, {"value.id": enum}): - return_nodes.append(node) - for node in self.get_descendants(nodes.VariableDecl, {"annotation.id": enum}): - return_nodes.append(node) - for node in self.get_descendants(nodes.FunctionDef, {"returns.id": enum}): - return_nodes.append(node) + type_expressions = set() + + for node in self.get_descendants(): + if hasattr(node, "annotation"): + type_expressions.add(node.annotation) + elif hasattr(node, "returns") and node.returns: + type_expressions.add(node.returns) + + # TODO cover more builtin + for node in self.get_descendants(nodes.Call, {"func.id": "empty"}): + type_expressions.add(node.args[0]) + + for node in type_expressions: + for subnode in node.get_descendants(include_self=True): + if isinstance(subnode, nodes.Name) and subnode.id == type_name: + return_nodes.append(subnode) + + return return_nodes + + def find_nodes_referencing_callable_type(self, type_name: str): + return_nodes = self.find_nodes_referencing_type(type_name) + + for node in self.get_descendants(nodes.Call, {"func.id": type_name}): + # ERC20(foo) + # my_struct({x:0}) + return_nodes.append(node.func) + + return return_nodes + + def find_nodes_referencing_enum(self, type_name: str): + return_nodes = self.find_nodes_referencing_type(type_name) + + for node in self.get_descendants(nodes.Attribute, {"value.id": type_name}): + # A.o + return_nodes.append(node.value) return return_nodes @@ -212,19 +260,11 @@ def find_nodes_referencing_enum_variant(self, enum: str, variant: str): nodes.Attribute, {"attr": variant, "value.id": enum} ) - def find_nodes_referencing_struct(self, struct: str): - return_nodes = [] + def find_nodes_referencing_struct(self, type_name: str): + return self.find_nodes_referencing_callable_type(type_name) - for node in self.get_descendants(nodes.AnnAssign, {"annotation.id": struct}): - return_nodes.append(node) - for node in self.get_descendants(nodes.Call, {"func.id": struct}): - return_nodes.append(node) - for node in self.get_descendants(nodes.VariableDecl, {"annotation.id": struct}): - return_nodes.append(node) - for node in self.get_descendants(nodes.FunctionDef, {"returns.id": struct}): - return_nodes.append(node) - - return return_nodes + def find_nodes_referencing_interfaces(self, type_name: str): + return self.find_nodes_referencing_callable_type(type_name) def find_top_level_node_at_pos(self, pos: Position) -> Optional[VyperNode]: for node in self.get_top_level_nodes(): diff --git a/vyper_lsp/main.py b/vyper_lsp/main.py index 88c1af2..e9ef081 100755 --- a/vyper_lsp/main.py +++ b/vyper_lsp/main.py @@ -143,7 +143,8 @@ def hover(ls: LanguageServer, params: HoverParams): document = ls.workspace.get_text_document(params.text_document.uri) hover_info = ast_analyzer.hover_info(document, params.position) if hover_info: - return Hover(contents=hover_info, range=None) + hover_content, range = hover_info + return Hover(contents=hover_content, range=range) @server.feature( diff --git a/vyper_lsp/navigation.py b/vyper_lsp/navigation.py index ea9c3a1..ea1c113 100644 --- a/vyper_lsp/navigation.py +++ b/vyper_lsp/navigation.py @@ -61,8 +61,14 @@ def _is_state_var_decl(self, line, word): return is_top_level and is_state_variable def _is_constant_decl(self, line, word): + is_top_level = not line[0].isspace() is_constant = "constant(" in line - return is_constant and self._is_state_var_decl(line, word) + return is_top_level and is_constant and word in self.ast.get_constants() + + def _is_immutable_decl(self, line, word): + is_top_level = not line[0].isspace() + is_immutable = "immutable(" in line + return is_top_level and is_immutable and word in self.ast.get_immutables() def _is_internal_fn(self, line, word, expression): is_def = line.startswith("def") @@ -89,12 +95,18 @@ def finalize(refs): if word in self.ast.get_structs() or word in self.ast.get_events(): return finalize(self.ast.find_nodes_referencing_struct(word)) + if word in self.ast.get_interfaces(): + return finalize(self.ast.find_nodes_referencing_interfaces(word)) + if self._is_internal_fn(og_line, word, expression): return finalize(self.ast.find_nodes_referencing_internal_function(word)) if self._is_constant_decl(og_line, word): return finalize(self.ast.find_nodes_referencing_constant(word)) + if self._is_immutable_decl(og_line, word): + return finalize(self.ast.find_nodes_referencing_immutable(word)) + if self._is_state_var_decl(og_line, word): return finalize(self.ast.find_nodes_referencing_state_variable(word)) @@ -139,9 +151,7 @@ def find_declaration(self, document: Document, pos: Position) -> Optional[Range] return self._find_state_variable_declaration(word) elif word in self.ast.get_user_defined_types(): return self.find_type_declaration(word) - elif word in self.ast.get_events(): - return self.find_type_declaration(word) - elif word in self.ast.get_constants(): + elif word in self.ast.get_constants() or word in self.ast.get_immutables(): return self._find_state_variable_declaration(word) elif isinstance(top_level_node, FunctionDef): range_ = self._find_variable_declaration_under_node(top_level_node, word)