Adding static type checking to Julia in 100 lines of code
While in Julia type annotations are optional, the type system is designed in such a way that the types of variables in a function can be inferred for given inputs. By taking advantage of this we can add checks before a function ever gets called that the return types are as expected.
Given that Julia is dynamic (types can be added at any time) and generic by default (meaning a huge number of types would need to be checked in many cases) this kind of checks might not be very useful in practice. However, it's a good excuse to explore Julia's meta-programming and introspection tools.
We will first deal with the simplest case possible (a method annotated with concrete types) and then generalize to parametric methods.
Non-parametric methods
We will deal with the simplest case to begin with. We want to build a macro @checked
that will take a return type and a function definition as input and add a test that will check that the inferred return type matches the declared one after the function definition. In short, transforms the code:
Int16 function add(x::Int8, y::Int16) x+y end
Into:
function add(x::Int8, y::Int16) x+y end inferred(add, Tuple{Int8, Int16}) == Int16
In order to do this we need to:
- Extract the type annotation (
Int8
,Int16
) from the Abstract Syntax Tree (AST) - Write an
inferred
method that return the inferred return type - Generate the final expression
For 1. we will use the MLStyle package which adds pattern matching for ASTs, making this task very easy.
For example if we want to extract the body of a comprehension [i^2 for i=1:10]
using MLStyle's @match
macro, we simply need to replace each element by wildcards [$body for $var=$range]
:
using MLStyle match_signature(ast) = ast begin quote $(::LineNumberNode) [$body for $var=$range] end => body end match_signature( quote [i^2 for i=1:10] end )
Note that Julia adds line numbers in its AST, so we also need to include a line number node in the pattern.
We can match our function signature in a similar fashion, returning the matched elements as a named tuple:
match_signature(ast) = ast begin quote $(::LineNumberNode) function $name($(args...)); $(block...) end end => (name = name, args = args) end sig = match_signature(quote function add(x::Int8,y::Int16) x+y end end)
Similarly we define a method to extract type annotations (x::T
):
match_type_annotation(ast) = ast begin :($(arg)::$(ann)) => ann end match_type_annotation.(sig.args)
We can simply run type inference using Base.code_typed
:
function inferred(f,args) codeinfo = Base.code_typed(f,args)[1] codeinfo.second end inferred((x,y) -> x+y, Tuple{Int8,Int16}) == Int16
When adding two integers of different type, the result gets automatically promoted (here to Int16
). The promoted type can be computed using promote_type(T,K)
.
We can now write a macro that will perform type checking:
macro checked(out,f) sig = match_signature(quote $f end) args = match_type_annotation.(sig.args) checks = quote inferred($(sig.name),Tuple{$(args...)}) == $out end esc(quote $f $checks end) end Int16 function add(x::Int8,y::Int16) x+y end
Here we use @macroexpand
to see the code built by the macro. It first defines the function itself, and then adds a type check with the declared input and output types. We can now test that our macro work correctly:
Int16 function add(x::Int8,y::Int16) x+y end
If the inferred type is incorrect the macro will throw an AssertionError
when the function is defined:
using Test AssertionError Int8 function add(x::Int8,y::Int16) x+y end
Parametric methods
Next we will consider a simple parametric case, we will also replace the output type annotation by a predicate on the inferred type I
to allow for more general (and useful) checks. For example in the following, the inferred type will be compared to the type of the first argument:
I -> I == T function add(x::T, y::K) where {T<:Integer, K<:Integer} x + y end
We first generalise our match_signature
method to handle the parametric case, using an OR clause to handle functions without where
:
match_signature(ast) = ast begin quote $(::LineNumberNode) function $name($(args...)) where {$(cov...)}; $(block...) end end || quote $(::LineNumberNode) function $name($(args...)); $(block...) end end && Do(cov = []) => (name = name, args = args, cov = cov) end sig = match_signature(quote function add(x::T,y::K) where {T<:Integer, K<:Integer} x+y end end)
We need a small method to extract subtype relations (T <: Integer
):
match_subtype(ast) = ast begin :($(sub)<:$(sup)) => (sub=sub,sup=sup) end match_subtype.(sig.cov)
In order to test every combination of parameters under the constrains above we need to enumerate all concret subtypes of Integer
. This can be done by recursively calling InteractiveUtils.subtypes
:
import InteractiveUtils.subtypes function allsubtypes(T::Type) isconcretetype(T) && return [T] t = Type[] map(K->allsubtypes(K,t), subtypes(T)) t end allsubtypes(T,t) = isconcretetype(T) ? push!(t,T) : map(K->allsubtypes(K,t), subtypes(T)) allsubtypes(T::Symbol) = allsubtypes(Core.eval(Main,T))#convert Symbol to Type allsubtypes(Integer)
In order to enumerate every combination of concret types we can use Iterators.product
:
collect(Iterators.product([Bool,BigInt],[Bool,BigInt]))
We also need a method that will take and abstract type annotation ([
:T,:K]
) and replace it with the corresponding concrete subtypes:
function concretize!(args, concrete_args, idx) for (j,K) in enumerate(concrete_args) for i in idx[j] args[i] = Symbol("$K") end end end args = Symbol[:T, :K] concrete_args = (Bool, Bool) idx = [[1], [2]] concretize!(args, concrete_args, idx) args
Here idx
contains the indices of where the abstract types annotations occurs in the method signature (they can occur at several positions, e.g. f(x::T, y::T, z::K) where {T,K}
then idx = [[1,2],[3]]
).
We need to perform the same operation on our predicate (e.g. I -> I == T
), which can be an arbitrary expression. We will thus write a generic method that can replace a symbol by another in an expression:
#find all pair.first in expression and replace them by pair.second function ast_replace!(ex::Expr, pair::Pair) for i=1:length(ex.args) if ex.args[i] == pair.first ex.args[i] = pair.second else ast_replace!(ex.args[i], pair) end end ex end ast_replace!(ex::Symbol, pair) = nothing ast_replace!(ex::LineNumberNode, pair) = nothing ast_replace!(:(I -> I == T), :T => Int8)
We can now put all this together in a method that will generate the expression that performs the check for every combination of concrete parameters:
function check_types(fun, predicate, args, covs) #setup out = Expr[] covs = match_subtype.(covs) covs_symbols = [c.sub for c in covs] idx = [findall(args .== c) for c in covs_symbols] for p in Iterators.product([allsubtypes(c.sup) for c in covs]...) #build concrete arguments concrete_args = copy(args) concretize!(concrete_args, p, idx) #build concrete predicate concrete_predicate = copy(predicate) for pairs in [Pair(c,Symbol(t)) for (c,t) in zip(covs_symbols,p)] ast_replace!(concrete_predicate,pairs) end #build final expression t = :( Tuple{$(concrete_args...)} ) I = :( inferred($fun,$t) ) push!(out, quote report_type_check($concrete_predicate,$I,$fun,$concrete_args) end) end out end fun, args = :add, [:T,:K] covs = Any[:(T <: Integer), :(K <: Integer)] out_type = :(O -> O == T) check_types(fun, out_type, args, covs)[1:2]
In the generated code we call a helper function that will display the result of the type check instead of throwing an error:
function report_type_check(predicate,inferred,fun,types) passed = predicate(inferred) sig = join(types,",") status = passed ? "SUCCESS" : "FAILURE" println( "$status: $(fun)($sig) \t → $inferred" ) end
And finally we can update our checked
macro:
macro checked(out,f) sig = match_signature(quote $f end) args = match_type_annotation.(sig.args) checks = check_types(Symbol(sig.name),out,args,sig.cov) esc(quote $f $(checks...) end) end O -> O == promote_type(T,K) function add(x::T, y::K) where {T<: Integer, K <: Integer} x + y end
All tests succeeded but the first one (add(Bool,Bool) → Int64)
since adding two booleans returns an Int64
. This could be fixed by modifying a bit the predicate to handle this particular case.
Limitations
One limitation of the implementation above is that it doesn't handle parametric types (e.g. Vector{T} where T <: Number
), but it should be relatively easy to generalize it to handle such a case. More general ones however might get tricky (e.g. Array{T,N} where {T <: Union{Number, String}, N}
).
The sheer number type combinations to check can also become an issue, for example there's 16 concrete subtype of Number
in Base, so a function taking two numbers will generate 256 checks. There's 8998 concrete types in Base, so testing a function that has a single parameter annotated as Any
would take that many checks, which is not very realistic.
For that reason this kind of tool might be better tailored for automatic tests, ran only during testing for some specific cases.
Final code
using MLStyle import InteractiveUtils.subtypes match_signature(ast) = ast begin quote $(::LineNumberNode) function $name($(args...)) where {$(cov...)}; $(block...) end end || quote $(::LineNumberNode) function $name($(args...)); $(block...) end end && Do(cov = []) => (name = name, args = args, cov = cov) end match_type_annotation(ast) = ast begin :($(arg)::$(ann)) => ann end match_subtype(ast) = ast begin :($(sub)<:$(sup)) => (sub=sub,sup=sup) end function inferred(f,args) codeinfo = Base.code_typed(f,args)[1] codeinfo.second end function allsubtypes(T::Type) isconcretetype(T) && return [T] t = Type[] map(K->allsubtypes(K,t), subtypes(T)) t end allsubtypes(T,t) = isconcretetype(T) ? push!(t,T) : map(K->allsubtypes(K,t), subtypes(T)) allsubtypes(T::Symbol) = allsubtypes(Core.eval(Main,T))#convert Symbol to Type function concretize!(args, concrete_args, idx) for (j,K) in enumerate(concrete_args) for i in idx[j] args[i] = Symbol("$K") end end end function ast_replace!(ex::Expr, pair::Pair) for i=1:length(ex.args) if ex.args[i] == pair.first ex.args[i] = pair.second else ast_replace!(ex.args[i], pair) end end ex end ast_replace!(ex::Symbol, pair) = nothing ast_replace!(ex::LineNumberNode, pair) = nothing function check_types(fun, predicate, args, covs) #setup out = Expr[] covs = match_subtype.(covs) covs_symbols = [c.sub for c in covs] idx = [findall(args .== c) for c in covs_symbols] for p in Iterators.product([allsubtypes(c.sup) for c in covs]...) #build concrete arguments concrete_args = copy(args) concretize!(concrete_args, p, idx) #build concrete predicate concrete_predicate = copy(predicate) for pairs in [Pair(c,Symbol(t)) for (c,t) in zip(covs_symbols,p)] ast_replace!(concrete_predicate,pairs) end #build final expression t = :( Tuple{$(concrete_args...)} ) I = :( inferred($fun,$t) ) push!(out, quote report_type_check($concrete_predicate,$I,$fun,$concrete_args) end) end out end function report_type_check(predicate,inferred,fun,types) passed = predicate(inferred) sig = join(types,",") status = passed ? "SUCCESS" : "FAILURE" println( "$status: $(fun)($sig) \t → $inferred" ) end macro checked(out,f) sig = match_signature(quote $f end) args = match_type_annotation.(sig.args) checks = check_types(Symbol(sig.name),out,args,sig.cov) esc(quote $f $(checks...) end) end O -> O == promote_type(T,K) function add(x::T, y::K) where {T<: Integer, K <: Integer} x + y end