From 42d803df74c2a3222c7f018c42d87926b444f342 Mon Sep 17 00:00:00 2001 From: ZacNugent Date: Mon, 20 Apr 2020 10:38:24 +0100 Subject: [PATCH 1/9] add inference by getfield use --- src/StaticLint.jl | 6 +++ src/imports.jl | 43 +++++++--------- src/server.jl | 4 +- src/type_inf.jl | 126 ++++++++++++++++++++++++++++++++++++++++++++++ src/utils.jl | 3 +- test/runtests.jl | 73 +++++++++++++++++++++++++++ 6 files changed, 228 insertions(+), 27 deletions(-) diff --git a/src/StaticLint.jl b/src/StaticLint.jl index aaf97c6b..cf81094a 100644 --- a/src/StaticLint.jl +++ b/src/StaticLint.jl @@ -69,6 +69,12 @@ function (state::State)(x::EXPR) # return to previous states state.scope != s0 && (state.scope = s0) + + if hasscope(x) && scopeof(x) !== state.scope && typof(x) !== CSTParser.ModuleH && typof(x) !== CSTParser.BareModule && typof(x) !== CSTParser.FileH && !CSTParser.defines_datatype(x) + for (n,b) in scopeof(x).names + infer_type_by_getfield_calls(b, state.server) + end + end state.delayed = delayed return state.scope end diff --git a/src/imports.jl b/src/imports.jl index 961402d4..8f8d54df 100644 --- a/src/imports.jl +++ b/src/imports.jl @@ -50,7 +50,15 @@ function resolve_import(x, state::State) end end -function _mark_import_arg(arg, par, state, u) +function add_to_imported_modules(scope::Scope, name::Symbol, val) + if scope.modules isa Dict + scope.modules[name] = val + else + modules = Dict(name => val) + end +end + +function _mark_import_arg(arg, par, state, usinged) if par !== nothing && (typof(arg) === IDENTIFIER || typof(arg) === MacroName) if par isa Binding # mark reference to binding push!(par.refs, arg) @@ -65,31 +73,16 @@ function _mark_import_arg(arg, par, state, u) end arg.meta.binding = Binding(arg, par, _typeof(par, state), [], nothing, nothing) end - if u && par isa SymbolServer.ModuleStore - if state.scope.modules isa Dict - state.scope.modules[Symbol(valof(arg))] = par - else - state.scope.modules = Dict(Symbol(valof(arg)) => par) - end - elseif u && par isa Binding && par.val isa SymbolServer.ModuleStore - if state.scope.modules isa Dict - state.scope.modules[Symbol(valof(arg))] = par.val - else - state.scope.modules = Dict(Symbol(valof(arg)) => par.val) - end - elseif u && par isa Binding && par.val isa EXPR && (typof(par.val) === CSTParser.ModuleH || typof(par.val) === CSTParser.BareModule) - if state.scope.modules isa Dict - state.scope.modules[Symbol(valof(arg))] = scopeof(par.val) - else - state.scope.modules = Dict(Symbol(valof(arg)) => scopeof(par.val)) - end - elseif u && par isa Binding && par.val isa Binding && par.val.val isa EXPR && (typof(par.val.val) === CSTParser.ModuleH || typof(par.val.val) === CSTParser.BareModule) - if state.scope.modules isa Dict - state.scope.modules[Symbol(valof(arg))] = scopeof(par.val.val) - else - state.scope.modules = Dict(Symbol(valof(arg)) => scopeof(par.val.val)) + if usinged + if par isa SymbolServer.ModuleStore + add_to_imported_modules(state.scope, Symbol(valof(arg)), par) + elseif par isa Binding && par.val isa SymbolServer.ModuleStore + add_to_imported_modules(state.scope, Symbol(valof(arg)), par.val) + elseif par isa Binding && par.val isa EXPR && (typof(par.val) === CSTParser.ModuleH || typof(par.val) === CSTParser.BareModule) + add_to_imported_modules(state.scope, Symbol(valof(arg)), scopeof(par.val)) + elseif par isa Binding && par.val isa Binding && par.val.val isa EXPR && (typof(par.val.val) === CSTParser.ModuleH || typof(par.val.val) === CSTParser.BareModule) + add_to_imported_modules(state.scope, Symbol(valof(arg)), scopeof(par.val.val)) end - end end end diff --git a/src/server.jl b/src/server.jl index a6e72481..349bf34b 100644 --- a/src/server.jl +++ b/src/server.jl @@ -14,8 +14,9 @@ mutable struct FileServer <: AbstractServer roots::Set{File} symbolserver::SymbolServer.EnvStore symbol_extends::Dict{SymbolServer.VarRef, Vector{SymbolServer.VarRef}} + symbol_fieldtypemap::Dict{Symbol, Vector{SymbolServer.VarRef}} end -FileServer() = FileServer(Dict{String,File}(), Set{File}(), deepcopy(SymbolServer.stdlibs), SymbolServer.collect_extended_methods(SymbolServer.stdlibs)) +FileServer() = FileServer(Dict{String,File}(), Set{File}(), deepcopy(SymbolServer.stdlibs), SymbolServer.collect_extended_methods(SymbolServer.stdlibs), fieldname_type_map(SymbolServer.stdlibs)) # Interface spec. # AbstractServer :-> (has/canload/load/set/get)file, getsymbolserver, getsymbolextends @@ -37,6 +38,7 @@ function loadfile(server::FileServer, path::String) end getsymbolserver(server::FileServer) = server.symbolserver getsymbolextendeds(server::FileServer) = server.symbol_extends +getsymbolfieldtypemap(server::FileServer) = server.symbol_fieldtypemap function scopepass(file, target = nothing) server = file.server diff --git a/src/type_inf.jl b/src/type_inf.jl index ca5b253c..abe914e8 100644 --- a/src/type_inf.jl +++ b/src/type_inf.jl @@ -58,3 +58,129 @@ function infer_type(binding::Binding, scope, state) end end end + +""" + is_getfield_lhs(x::EXPR) +x the `a` in `a.b` +""" +is_getfield_lhs(x::EXPR) = is_getfield(parentof(x)) && x === parentof(x)[1] + +""" + is_getfield_lhs_as_chain(x::EXPR) +x the `b` in `a.b.c` +""" +is_getfield_lhs_as_chain(x::EXPR) = parentof(x) isa EXPR && typof(parentof(x)) === CSTParser.Quotenode && StaticLint.is_getfield(parentof(parentof(x))) && StaticLint.is_getfield(parentof(parentof(parentof(x)))) && x === parentof(parentof(x))[3][1] + +function get_struct_fieldname(x::EXPR) + if _binary_assert(x, CSTParser.Tokens.DECLARATION) + return get_struct_fieldname(x[1]) + elseif typof(x) === CSTParser.InvisBrackets && length(x) == 3 + return get_struct_fieldname(x[2]) + elseif isidentifier(x) + return x + else + end + return nothing +end + +function cst_struct_fieldnames(x::EXPR) + fns = Symbol[] + if CSTParser.defines_mutable(x) + body = x[4] + elseif CSTParser.defines_struct(x) + body = x[3] + else + return fns + end + for arg in body + field_name = get_struct_fieldname(arg) + if field_name isa EXPR && isidentifier(field_name) + push!(fns, Symbol(CSTParser.str_value(field_name))) + end + end + return fns +end + +fieldname_type_map(s, server, l = Dict()) = l # fallback +function fieldname_type_map(s::Scope, server, l = Dict()) + for (n,b) in s.names + b = get_root_method(b, server) + if b isa Binding && b.val isa EXPR && CSTParser.defines_datatype(b.val) + for f in cst_struct_fieldnames(b.val) + f = Symbol(f) + if haskey(l, f) + push!(l[f], b) + else + l[f] = [b] + end + end + end + end + l +end + +function fieldname_type_map(cache::SymbolServer.ModuleStore, l) + for (n,v) in cache.vals + if v isa SymbolServer.DataTypeStore + for f in v.fieldnames + if haskey(l, f) + push!(l[f], v.name.name) + else + l[f] = [v.name.name] + end + end + elseif v isa SymbolServer.ModuleStore + fieldname_type_map(v, l) + end + end + l +end + +function fieldname_type_map(cache::SymbolServer.EnvStore, l = Dict()) + for (_,m) in cache + fieldname_type_map(m, l) + end + l +end + +function infer_type_by_getfield_calls(b::Binding, server) + b.type !== nothing && return # b already has a type + user_datatypes = fieldname_type_map(retrieve_toplevel_scope(b.val), server) + possibletypes = [] + for ref in b.refs + ref isa EXPR || continue # skip non-EXPR (i.e. used for handling of globals) + if is_getfield_lhs(ref) && typof(parentof(ref)[3]) === CSTParser.Quotenode + rhs = parentof(ref)[3][1] + elseif is_getfield_lhs_as_chain(ref) + rhs = parentof(parentof(parentof(ref)))[3][1] + else + continue + end + + if isidentifier(rhs) + rhs_sym = Symbol(CSTParser.str_value(rhs)) + new_possibles = [get(getsymbolfieldtypemap(server), rhs_sym, [])..., get(user_datatypes, rhs_sym, [])...] + + # @info new_possibles + if isempty(possibletypes) + possibletypes = new_possibles + elseif !isempty(new_possibles) + possibletypes = intersect(possibletypes, new_possibles) + end + if isempty(possibletypes) + return + end + end + end + if length(possibletypes) == 1 + type = first(possibletypes) + if type isa Binding + b.type = type + elseif type isa SymbolServer.VarRef + b.type = SymbolServer._lookup(type, getsymbolserver(server)) # could be nothing + else + end + end +end + + diff --git a/src/utils.jl b/src/utils.jl index 34dbb3c9..f2aa0891 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -276,7 +276,8 @@ isexportedby(k::String, m::SymbolServer.ModuleStore) = isexportedby(Symbol(k), m isexportedby(x::EXPR, m::SymbolServer.ModuleStore) = isexportedby(valof(x), m) isexportedby(k, m::SymbolServer.ModuleStore) = false -function retrieve_toplevel_scope(x) +function retrieve_toplevel_scope(x) end +function retrieve_toplevel_scope(x::EXPR) if scopeof(x) !== nothing && (typof(x) === CSTParser.ModuleH || typof(x) === CSTParser.BareModule || typof(x) === CSTParser.FileH) return scopeof(x) elseif parentof(x) isa EXPR diff --git a/test/runtests.jl b/test/runtests.jl index 674ec8b5..b42e317c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -735,4 +735,77 @@ end end end + +@testset "expr fieldnames" begin + let cst = parse_and_pass(""" + struct T + end + + struct T + a + end + + struct T + a + b + end + + struct T + a::S + b::S + end + + mutable struct T + a::S + b::S + end + """) + @test StaticLint.cst_struct_fieldnames(cst[1]) == [] + @test StaticLint.cst_struct_fieldnames(cst[2]) == [:a] + @test StaticLint.cst_struct_fieldnames(cst[3]) == [:a, :b] + @test StaticLint.cst_struct_fieldnames(cst[4]) == [:a, :b] + @test StaticLint.cst_struct_fieldnames(cst[5]) == [:a, :b] + end +end + + + +@testset "fieldname inference" begin + let cst = parse_and_pass(""" + struct T + fieldname1 + end + struct S + fieldname2 + end + function f(arg1, arg2, arg3) + arg1.fieldname1 + arg2.fieldname2 + arg3.fieldname1 + arg3.fieldname2 + end + """) + @test bindingof(cst[3][2][3]).type !== nothing + @test bindingof(cst[3][2][5]).type !== nothing + @test bindingof(cst[3][2][7]).type === nothing + end + let cst = parse_and_pass(""" + struct T + fieldname1 + end + struct S + fieldname2 + end + function f(arg1) + if arg1 isa T + arg1.fieldname1 + elseif arg1 isa S + arg1.fieldname2 + end + end + """) + @test bindingof(cst[3][2][3]).type === nothing + end +end + end From eeb2bace2896dd86b6d32663adf77b38bdbf90fe Mon Sep 17 00:00:00 2001 From: ZacNugent Date: Mon, 20 Apr 2020 14:05:14 +0100 Subject: [PATCH 2/9] only run in target file --- src/StaticLint.jl | 2 +- src/type_inf.jl | 2 -- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/src/StaticLint.jl b/src/StaticLint.jl index cf81094a..5bacb144 100644 --- a/src/StaticLint.jl +++ b/src/StaticLint.jl @@ -70,7 +70,7 @@ function (state::State)(x::EXPR) # return to previous states state.scope != s0 && (state.scope = s0) - if hasscope(x) && scopeof(x) !== state.scope && typof(x) !== CSTParser.ModuleH && typof(x) !== CSTParser.BareModule && typof(x) !== CSTParser.FileH && !CSTParser.defines_datatype(x) + if state.file == state.targetfile && hasscope(x) && scopeof(x) !== state.scope && typof(x) !== CSTParser.ModuleH && typof(x) !== CSTParser.BareModule && typof(x) !== CSTParser.FileH && !CSTParser.defines_datatype(x) for (n,b) in scopeof(x).names infer_type_by_getfield_calls(b, state.server) end diff --git a/src/type_inf.jl b/src/type_inf.jl index d024032f..b5c8af71 100644 --- a/src/type_inf.jl +++ b/src/type_inf.jl @@ -183,5 +183,3 @@ function infer_type_by_getfield_calls(b::Binding, server) end end end - - From 2d6c16a0feaa5245979002774ca5d7cc9005c478 Mon Sep 17 00:00:00 2001 From: ZacNugent Date: Fri, 24 Apr 2020 22:41:15 +0100 Subject: [PATCH 3/9] fix order in resolve_getindex --- src/references.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/references.jl b/src/references.jl index 78983716..bbec529a 100644 --- a/src/references.jl +++ b/src/references.jl @@ -177,12 +177,12 @@ function resolve_getindex(x::EXPR, b::Binding, state::State)::Bool resolved = resolve_getindex(x, b.type.val, state) elseif b.val isa SymbolServer.ModuleStore resolved = resolve_getindex(x, b.val, state) - elseif b.type isa SymbolServer.DataTypeStore - resolved = resolve_getindex(x, b.type, state) elseif b.val isa EXPR && (typof(b.val) === ModuleH || typof(b.val) === BareModule) resolved = resolve_getindex(x, b.val, state) elseif b.val isa Binding && b.val.val isa EXPR && (typof(b.val.val) === ModuleH || typof(b.val.val) === BareModule) resolved = resolve_getindex(x, b.val.val, state) + elseif b.type isa SymbolServer.DataTypeStore + resolved = resolve_getindex(x, b.type, state) end return resolved end From 89a9d36223d98bf9419782d8d24c06153cf64473 Mon Sep 17 00:00:00 2001 From: ZacNugent Date: Fri, 24 Apr 2020 22:42:15 +0100 Subject: [PATCH 4/9] add get_last_method --- src/utils.jl | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/utils.jl b/src/utils.jl index f2aa0891..665d63c5 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -119,6 +119,18 @@ function get_root_method(b::Binding, server, b1 = nothing, visited_bindings = Bi end end +function get_last_method(b::Binding, server, visited_bindings = Binding[]) + if b.next === nothing || b == b.next || !(b.next isa Binding) || b in visited_bindings + return b + end + push!(visited_bindings, b) + if b.type == b.next.type == CoreTypes.Function + return get_last_method(b.next, server, visited_bindings) + else + return b + end +end + function retrieve_delayed_scope(x) if (CSTParser.defines_function(x) || CSTParser.defines_macro(x)) && scopeof(x) !== nothing if parentof(scopeof(x)) !== nothing From 9da29b82546c48f23fa74e6443d2fb1f81cdea3f Mon Sep 17 00:00:00 2001 From: ZacNugent Date: Sat, 25 Apr 2020 09:13:49 +0100 Subject: [PATCH 5/9] Add inference through func arg use, add tests --- src/StaticLint.jl | 2 +- src/type_inf.jl | 205 ++++++++++++++++++++++++++++++++++++++++------ test/runtests.jl | 43 +--------- test/type_inf.jl | 138 +++++++++++++++++++++++++++++++ 4 files changed, 321 insertions(+), 67 deletions(-) create mode 100644 test/type_inf.jl diff --git a/src/StaticLint.jl b/src/StaticLint.jl index 5bacb144..7f369cd2 100644 --- a/src/StaticLint.jl +++ b/src/StaticLint.jl @@ -72,7 +72,7 @@ function (state::State)(x::EXPR) if state.file == state.targetfile && hasscope(x) && scopeof(x) !== state.scope && typof(x) !== CSTParser.ModuleH && typof(x) !== CSTParser.BareModule && typof(x) !== CSTParser.FileH && !CSTParser.defines_datatype(x) for (n,b) in scopeof(x).names - infer_type_by_getfield_calls(b, state.server) + infer_type_by_use(b, state.server) end end state.delayed = delayed diff --git a/src/type_inf.jl b/src/type_inf.jl index b5c8af71..c92091b5 100644 --- a/src/type_inf.jl +++ b/src/type_inf.jl @@ -56,6 +56,8 @@ function infer_type(binding::Binding, scope, state) binding.type = refof(t) end end + elseif binding.val isa EXPR && parentof(binding.val) isa EXPR && typof(parentof(binding.val)) === CSTParser.WhereOpCall + binding.type = CoreTypes.DataType end end end @@ -72,6 +74,8 @@ x the `b` in `a.b.c` """ is_getfield_lhs_as_chain(x::EXPR) = parentof(x) isa EXPR && typof(parentof(x)) === CSTParser.Quotenode && StaticLint.is_getfield(parentof(parentof(x))) && StaticLint.is_getfield(parentof(parentof(parentof(x)))) && x === parentof(parentof(x))[3][1] +isemptyvect(x::EXPR) = typof(x) === CSTParser.Vect && length(x) == 2 + function get_struct_fieldname(x::EXPR) if _binary_assert(x, CSTParser.Tokens.DECLARATION) return get_struct_fieldname(x[1]) @@ -102,10 +106,18 @@ function cst_struct_fieldnames(x::EXPR) return fns end -fieldname_type_map(s, server, l = Dict()) = l # fallback + +""" + fieldname_type_map(s::Union{Scope,ModuleStore,EnvStore}, server, l = Dict()) + +Returns a Dict where a fieldname (key) points to a collection of types that +have that field. +""" +fieldname_type_map(s, server, l = Dict{Symbol,Any}()) = l # fallback function fieldname_type_map(s::Scope, server, l = Dict()) for (n,b) in s.names b = get_root_method(b, server) + # Todo: Allow for const rebindings of datatypes (i.e. `const dt = DataType`) if b isa Binding && b.val isa EXPR && CSTParser.defines_datatype(b.val) for f in cst_struct_fieldnames(b.val) f = Symbol(f) @@ -117,10 +129,10 @@ function fieldname_type_map(s::Scope, server, l = Dict()) end end end - l + return l end -function fieldname_type_map(cache::SymbolServer.ModuleStore, l) +function fieldname_type_map(cache::SymbolServer.ModuleStore, l = Dict{Symbol,Any}()) for (n,v) in cache.vals if v isa SymbolServer.DataTypeStore for f in v.fieldnames @@ -134,52 +146,195 @@ function fieldname_type_map(cache::SymbolServer.ModuleStore, l) fieldname_type_map(v, l) end end - l + return l end -function fieldname_type_map(cache::SymbolServer.EnvStore, l = Dict()) +function fieldname_type_map(cache::SymbolServer.EnvStore, l = Dict{Symbol,Any}()) for (_,m) in cache fieldname_type_map(m, l) end - l + return l +end + +""" + check_ref_against_fieldnames(ref, user_datatypes, new_possibles, server) + +Tries to infer the type of `ref` by looking at how getfield is used against it +and comparing these instances against the fields of all known datatypes. These +are pre-cached for packages in the server's EnvStore (`getsymbolfieldtypemap(server)`). +""" +function check_ref_against_fieldnames(ref, user_datatypes, new_possibles, server) + if is_getfield_lhs(ref) && typof(parentof(ref)[3]) === CSTParser.Quotenode + rhs = parentof(ref)[3][1] + elseif is_getfield_lhs_as_chain(ref) + rhs = parentof(parentof(parentof(ref)))[3][1] + else + return + end + if isidentifier(rhs) + rhs_sym = Symbol(CSTParser.str_value(rhs)) + for t in get(getsymbolfieldtypemap(server), rhs_sym, []) + push!(new_possibles, t) + end + for t in get(user_datatypes, rhs_sym, []) + push!(new_possibles, t) + end + end +end + +""" + is_arg_of_resolved_call(x) + +Checks whether x is the argument of a function call. +""" +is_arg_of_resolved_call(x::EXPR) = parentof(x) isa EXPR && typof(parentof(x)) === Call && parentof(x)[1] !== x && +(hasref(parentof(x)[1]) || (is_getfield(parentof(x)[1]) && typof(parentof(x)[1][3]) === CSTParser.Quotenode && hasref(parentof(x)[1][3][1]))) + + +""" + get_arg_position_in_call(call, arg) + get_arg_position_in_call(arg) + +Returns the position of `arg` in `call` ignoring the function name and punctuation. +The single argument method assumes `parentof(arg) == call` +""" +function get_arg_position_in_call(call::EXPR, arg) + for (i,a) in enumerate(call) + a == arg && return div(i-1, 2) + end +end + +function get_arg_position_in_call(arg) + get_arg_position_in_call(parentof(arg), arg) +end + + +""" + get_arg_type_at_position(f, argi, types) + +Pushes to `types` the argument type (if not `Core.Any`) of a function +at position `argi`. +""" +function get_arg_type_at_position(f, argi, types) end + +function get_arg_type_at_position(b::Binding, argi, types) + argi1 = argi*2 + 1 + if b.val isa EXPR + sig = CSTParser.get_sig(b.val) + if sig !== nothing && + argi1 < length(sig) && + hasbinding(sig[argi1]) && + (argb = bindingof(sig[argi1]); argb isa Binding && argb.type !== nothing) && + !(argb.type in types) + push!(types, argb.type) + return + end + elseif b.val isa SymbolServer.SymStore + return get_arg_type_at_position(b.val, argi, types) + end + return +end + +function get_arg_type_at_position(f::T, argi, types) where T <: Union{SymbolServer.DataTypeStore,SymbolServer.FunctionStore} + for m in f.methods + get_arg_type_at_position(m, argi, types) + end +end + +function get_arg_type_at_position(m::SymbolServer.MethodStore, argi, types) + if length(m.sig) >= argi && m.sig[argi][2] != SymbolServer.VarRef(SymbolServer.VarRef(nothing, :Core), :Any) && !(m.sig[argi][2] in types) + push!(types, m.sig[argi][2]) + end +end + +""" + check_ref_against_calls(x, visitedmethods, new_possibles, server) + +Pushes to `new_possibles` +""" +function check_ref_against_calls(x, visitedmethods, new_possibles, server) + if is_arg_of_resolved_call(x) + # x is argument of function call (func) and we know what that function is + if CSTParser.isidentifier(parentof(x)[1]) + func = refof(parentof(x)[1]) + else + func = refof(parentof(x)[1][3][1]) + end + # make sure we've got the last binding for func + if func isa Binding + func = get_last_method(func, server) + end + # what slot does ref sit in? + argi = get_arg_position_in_call(x) + tls = retrieve_toplevel_scope(x) + while (func isa Binding && func.type == CoreTypes.Function) || func isa SymbolServer.SymStore + !(func in visitedmethods) ? push!(visitedmethods, func) : return # check whether we've been here before + if func isa Binding + get_arg_type_at_position(func, argi, new_possibles) + func = prev_method(func) + else + tls === nothing && return + iterate_over_ss_methods(func, tls, server, m->(get_arg_type_at_position(m, argi, new_possibles);false)) + return + end + end + end end -function infer_type_by_getfield_calls(b::Binding, server) +""" + infer_type_by_use(b::Binding, server) + +Tries to infer the type of Binding `b` by looking at how it is used. +""" +function infer_type_by_use(b::Binding, server) b.type !== nothing && return # b already has a type user_datatypes = fieldname_type_map(retrieve_toplevel_scope(b.val), server) possibletypes = [] + visitedmethods = [] for ref in b.refs + new_possibles = [] ref isa EXPR || continue # skip non-EXPR (i.e. used for handling of globals) - if is_getfield_lhs(ref) && typof(parentof(ref)[3]) === CSTParser.Quotenode - rhs = parentof(ref)[3][1] - elseif is_getfield_lhs_as_chain(ref) - rhs = parentof(parentof(parentof(ref)))[3][1] - else - continue - end - - if isidentifier(rhs) - rhs_sym = Symbol(CSTParser.str_value(rhs)) - new_possibles = [get(getsymbolfieldtypemap(server), rhs_sym, [])..., get(user_datatypes, rhs_sym, [])...] + check_ref_against_fieldnames(ref, user_datatypes, new_possibles, server) + check_ref_against_calls(ref, visitedmethods, new_possibles, server) - # @info new_possibles - if isempty(possibletypes) - possibletypes = new_possibles - elseif !isempty(new_possibles) - possibletypes = intersect(possibletypes, new_possibles) - end + if isempty(possibletypes) + possibletypes = new_possibles + elseif !isempty(new_possibles) + possibletypes = intersect(possibletypes, new_possibles) if isempty(possibletypes) return end end end + # Only do something if we're left with a set of 1 at the end. if length(possibletypes) == 1 type = first(possibletypes) if type isa Binding b.type = type + elseif type isa SymbolServer.DataTypeStore + b.type = type elseif type isa SymbolServer.VarRef b.type = SymbolServer._lookup(type, getsymbolserver(server)) # could be nothing - else + elseif type isa SymbolServer.FakeTypeName && isempty(type.parameters) + b.type = SymbolServer._lookup(type.name, getsymbolserver(server)) # could be nothing end end end + +""" + isrebinding(b::Binding) + +Does `b` simply rebind another binding? +""" +function isrebinding(b::Binding) + b.val isa EXPR && CSTParser.is_assignment(b.val) && + b.val[1] == b.name && CSTParser.isidentifier(b.val[3]) && + hasbinding(b.val[3]) +end + +""" + getrebound(b::Binding) + +Assumes `isrebinding(b) == true` and gets the source binding (recursively). +""" +getrebound(b::Binding) = isrebinding(bindingof(b.val[3])) ? getrebound(bindingof(b.val[3])) : bindingof(b.val[3]) diff --git a/test/runtests.jl b/test/runtests.jl index b42e317c..87d36dea 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -20,7 +20,7 @@ function parse_and_pass(s) f = StaticLint.File("", s, CSTParser.parse(s, true), nothing, server) StaticLint.setroot(f, f) StaticLint.setfile(server, "", f) - StaticLint.scopepass(f) + StaticLint.scopepass(f, f) return f.cst end @@ -768,44 +768,5 @@ end end end - - -@testset "fieldname inference" begin - let cst = parse_and_pass(""" - struct T - fieldname1 - end - struct S - fieldname2 - end - function f(arg1, arg2, arg3) - arg1.fieldname1 - arg2.fieldname2 - arg3.fieldname1 - arg3.fieldname2 - end - """) - @test bindingof(cst[3][2][3]).type !== nothing - @test bindingof(cst[3][2][5]).type !== nothing - @test bindingof(cst[3][2][7]).type === nothing - end - let cst = parse_and_pass(""" - struct T - fieldname1 - end - struct S - fieldname2 - end - function f(arg1) - if arg1 isa T - arg1.fieldname1 - elseif arg1 isa S - arg1.fieldname2 - end - end - """) - @test bindingof(cst[3][2][3]).type === nothing - end -end - +include("type_inf.jl") end diff --git a/test/type_inf.jl b/test/type_inf.jl new file mode 100644 index 00000000..2df345ee --- /dev/null +++ b/test/type_inf.jl @@ -0,0 +1,138 @@ +@testset "fieldname inference" begin +# arg1 is inferred as T -> only a single (user defined) +# datatype has the field `fieldname1` +let cst = parse_and_pass(""" + struct T + fieldname1 + end + function f(arg1) + arg1.fieldname1 + end + """) + @test cst[2].meta.scope.names["arg1"].type === cst.meta.scope.names["T"] +end + +# arg1 inferred as above +# arg2 as above but for `S` +# arg3 field use is conflicting -> no type assigned +let cst = parse_and_pass(""" + struct T + fieldname1 + end + struct S + fieldname2 + end + function f(arg1, arg2, arg3) + arg1.fieldname1 + arg2.fieldname2 + arg3.fieldname1 + arg3.fieldname2 + end + """) + @test cst[3].meta.scope.names["arg1"].type === cst.meta.scope.names["T"] + @test cst[3].meta.scope.names["arg2"].type === cst.meta.scope.names["S"] + @test cst[3].meta.scope.names["arg3"].type === nothing +end + +# arg1 type inferred as above +# arg2 type not inferred as `sig` is also the fieldname of +# `Method` exported by Core. +let cst = parse_and_pass(""" + struct T + fieldname1 + sig + end + function f(arg1, arg2) + arg1.fieldname1 + arg2.sig + end + """) + @test cst[2].meta.scope.names["arg1"].type === cst.meta.scope.names["T"] + @test cst[2].meta.scope.names["arg2"].type === nothing +end + +let cst = parse_and_pass(""" + struct T + fieldname1 + end + struct S + fieldname2 + end + function f(arg1) + if arg1 isa T + arg1.fieldname1 + elseif arg1 isa S + arg1.fieldname2 + end + end + """) + @test cst[3].meta.scope.names["arg1"].type === nothing +end +end + +@testset "inference by use as function argument" begin +# single method function with user defined datatype +let cst = parse_and_pass(""" + struct T end + function f(arg::T) end + function g(arg) end + let arg1 = unknownvalue, arg2 = unknownvalue + f(arg1) + g(arg1) + end + """) + @test cst[4].meta.scope.names["arg1"].type === cst.meta.scope.names["T"] + @test cst[4].meta.scope.names["arg2"].type === nothing +end + +# as above against imported (symbolserver) types +let cst = parse_and_pass(""" + function f(arg::Int) end + let arg = unknownvalue + f(arg) + end + """) + @test cst[2].meta.scope.names["arg"].type.name == SymbolServer.FakeTypeName(Int) +end + +# 2 methods, conflicting types so no inference +let cst = parse_and_pass(""" + function f(arg::Int) end + function f(arg::Float64) end + let arg = unknownvalue + f(arg) + end + """) + @test cst[3].meta.scope.names["arg"].type === nothing +end + +# 2 functions, 1 with two methods. +let cst = parse_and_pass(""" + function f(arg::Int) end + function f(arg::Float64) end + function g(arg::Int) end + let arg = unknownvalue + f(arg) + g(arg) + end + """) + @test cst[4].meta.scope.names["arg"].type.name == SymbolServer.FakeTypeName(Int) +end + +# SymServer function w/ single method +let cst = parse_and_pass(""" + let arg = unknownvalue + dirname(arg) + end + """) + @test cst[1].meta.scope.names["arg"].type.name == SymbolServer.FakeTypeName(AbstractString) +end +# As above but qualified name for function. +let cst = parse_and_pass(""" + let arg = unknownvalue + Base.dirname(arg) + end + """) + @test cst[1].meta.scope.names["arg"].type.name == SymbolServer.FakeTypeName(AbstractString) +end +end \ No newline at end of file From 3780f60e04875506a28ef499abb7e435aa570dff Mon Sep 17 00:00:00 2001 From: ZacNugent Date: Sun, 26 Apr 2020 18:06:26 +0100 Subject: [PATCH 6/9] search for user defined `get_property` fields --- src/type_inf.jl | 72 ++++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 65 insertions(+), 7 deletions(-) diff --git a/src/type_inf.jl b/src/type_inf.jl index c92091b5..a0bc8a84 100644 --- a/src/type_inf.jl +++ b/src/type_inf.jl @@ -118,13 +118,36 @@ function fieldname_type_map(s::Scope, server, l = Dict()) for (n,b) in s.names b = get_root_method(b, server) # Todo: Allow for const rebindings of datatypes (i.e. `const dt = DataType`) - if b isa Binding && b.val isa EXPR && CSTParser.defines_datatype(b.val) - for f in cst_struct_fieldnames(b.val) - f = Symbol(f) - if haskey(l, f) - push!(l[f], b) - else - l[f] = [b] + if b isa Binding && b.val isa EXPR + if CSTParser.defines_datatype(b.val) + for f in cst_struct_fieldnames(b.val) + f = Symbol(f) + if haskey(l, f) + push!(l[f], b) + else + l[f] = [b] + end + end + elseif CSTParser.defines_function(b.val) && n == "get_property" + # need to check this overwrites Base.get_property + # need to iterate over all methods + sig = CSTParser.get_sig(b.val) + if length(sig) > 5 && _binary_assert(sig[3], CSTParser.Tokens.DECLARATION) && hasref(sig[3][3]) + t_binding = refof(sig[3][3]) + if t_binding isa Binding + if t_binding.type !== CoreTypes.DataType + t_binding = get_root_method(t_binding, server) + t_binding.type !== CoreTypes.DataType && continue + end + for f in get_property_shadow_fields(b.val) + f = Symbol(f) + if haskey(l, f) + push!(l[f], t_binding) + else + l[f] = [t_binding] + end + end + end end end end @@ -132,6 +155,41 @@ function fieldname_type_map(s::Scope, server, l = Dict()) return l end +""" + get_property_shadow_fields(func) + +Assumes `func` is the definition of a function for `get_property`. Searches for +comparisons within the body between the second argument of the function and +symbols, returning a list of these symbols. + +e.g. +``` +function get_property(x::SomeType, f::Symbol) + if f === :asdf + elseif f == :sdgs + end +end +``` + +-> [:asdf, :sdgs] +""" +function get_property_shadow_fields(func) + # Get the argname for 2nd argument of get_property + str_sname = CSTParser.str_value(CSTParser.rem_decl(CSTParser.rem_where_decl(CSTParser.get_sig(func))[5])) + str_sname isa String || return [] + function trav(x, out = []) + if (_binary_assert(x, CSTParser.Tokens.EQEQEQ) || _binary_assert(x, CSTParser.Tokens.EQEQ)) && CSTParser.valof(x[1]) == str_sname && + CSTParser.typof(x[3]) === CSTParser.Quotenode && length(x[3]) ==2 && CSTParser.is_colon(x[3][1]) && CSTParser.isidentifier(x[3][2]) + push!(out, Expr(x[3][2])) + end + for a in x + trav(a, out) + end + out + end + trav(func) +end + function fieldname_type_map(cache::SymbolServer.ModuleStore, l = Dict{Symbol,Any}()) for (n,v) in cache.vals if v isa SymbolServer.DataTypeStore From 1b14634667bab69fa1c0ff6529f239a703d9938b Mon Sep 17 00:00:00 2001 From: ZacNugent Date: Sun, 26 Apr 2020 18:13:44 +0100 Subject: [PATCH 7/9] fix test --- test/runtests.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/test/runtests.jl b/test/runtests.jl index 2e450d3c..b13c6abf 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -792,3 +792,4 @@ end include("type_inf.jl") end +end From d7470083c12df41769cf21d3b5649ecb940a61ba Mon Sep 17 00:00:00 2001 From: ZacNugent Date: Thu, 30 Apr 2020 18:15:56 +0100 Subject: [PATCH 8/9] order fix --- src/references.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/references.jl b/src/references.jl index 663519a5..d63d595b 100644 --- a/src/references.jl +++ b/src/references.jl @@ -201,12 +201,12 @@ function resolve_getfield(x::EXPR, b::Binding, state::State)::Bool resolved = resolve_getfield(x, b.type.val, state) elseif b.val isa SymbolServer.ModuleStore resolved = resolve_getfield(x, b.val, state) - elseif b.type isa SymbolServer.DataTypeStore - resolved = resolve_getfield(x, b.type, state) elseif b.val isa EXPR && CSTParser.defines_module(b.val) resolved = resolve_getfield(x, b.val, state) elseif b.val isa Binding && b.val.val isa EXPR && CSTParser.defines_module(b.val.val) resolved = resolve_getfield(x, b.val.val, state) + elseif b.type isa SymbolServer.DataTypeStore + resolved = resolve_getfield(x, b.type, state) end return resolved end From 6d487d5645ee807f3df147e0a913c72e89a690d6 Mon Sep 17 00:00:00 2001 From: ZacNugent Date: Thu, 7 May 2020 19:17:10 +0100 Subject: [PATCH 9/9] fix --- src/StaticLint.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/StaticLint.jl b/src/StaticLint.jl index feeb0a83..69fa0594 100644 --- a/src/StaticLint.jl +++ b/src/StaticLint.jl @@ -76,7 +76,7 @@ function (state::Toplevel)(x::EXPR) infer_type_by_use(b, state.server) end end - state.delayed = delayed + return state.scope end @@ -95,6 +95,7 @@ function (state::Delayed)(x::EXPR) traverse(x, state) + # needs to call to add infer_type_by_use state.scope != s0 && (state.scope = s0) return state.scope end