Adding static type checking to Julia in 100 lines of code

While in Julia type annotations are optional, the type system is designed in such a way that the types of variables in a function can be inferred for given inputs. By taking advantage of this we can add checks before a function ever gets called that the return types are as expected.

Given that Julia is dynamic (types can be added at any time) and generic by default (meaning a huge number of types would need to be checked in many cases) this kind of checks might not be very useful in practice. However, it's a good excuse to explore Julia's meta-programming and introspection tools.

We will first deal with the simplest case possible (a method annotated with concrete types) and then generalize to parametric methods.

Non-parametric methods

We will deal with the simplest case to begin with. We want to build a macro @checked that will take a return type and a function definition as input and add a test that will check that the inferred return type matches the declared one after the function definition. In short, transforms the code:

@checked Int16 function add(x::Int8, y::Int16)    
  x+y
end
Julia

Into:

function add(x::Int8, y::Int16)    
  x+y
end
@assert inferred(add, Tuple{Int8, Int16}) == Int16
Julia

In order to do this we need to:

  1. Extract the type annotation (Int8, Int16) from the Abstract Syntax Tree (AST)
  2. Write an inferred method that return the inferred return type
  3. Generate the final expression

For 1. we will use the MLStyle package which adds pattern matching for ASTs, making this task very easy.

For example if we want to extract the body of a comprehension [i^2 for i=1:10] using MLStyle's @match macro, we simply need to replace each element by wildcards [$body for $var=$range]:

using MLStyle

match_signature(ast) = @match ast begin
    quote
        $(::LineNumberNode)
        [$body for $var=$range]
    end => body
end

match_signature( quote [i^2 for i=1:10] end )
:(i ^ 2)

Note that Julia adds line numbers in its AST, so we also need to include a line number node in the pattern.

We can match our function signature in a similar fashion, returning the matched elements as a named tuple:

match_signature(ast) = @match ast begin
    quote
        $(::LineNumberNode)
        function $name($(args...)); $(block...) end
    end =>
    (name = name,
     args = args)
end

sig = match_signature(quote 
    function add(x::Int8,y::Int16)
        x+y
    end
end)
(name = :add, args = Any[:(x::Int8), :(y::Int16)])

Similarly we define a method to extract type annotations (x::T):

match_type_annotation(ast) = @match ast begin :($(arg)::$(ann)) => ann end
match_type_annotation.(sig.args) 
2-element Array{Symbol,1}: :Int8 :Int16

We can simply run type inference using Base.code_typed:

function inferred(f,args)
    codeinfo = Base.code_typed(f,args)[1]
    codeinfo.second
end

@assert inferred((x,y) -> x+y, Tuple{Int8,Int16}) == Int16

When adding two integers of different type, the result gets automatically promoted (here to Int16). The promoted type can be computed using promote_type(T,K).

We can now write a macro that will perform type checking:

macro checked(out,f)
    sig  = match_signature(quote $f end)
    args = match_type_annotation.(sig.args)
  
  	checks = quote
        @assert inferred($(sig.name),Tuple{$(args...)}) == $out
    end
    
    esc(quote
        $f
        $checks
    end)
end

@macroexpand @checked Int16 function add(x::Int8,y::Int16)
    x+y
end
quote #= cell:10 =# function add(x::Int8, y::Int16) #= cell:16 =# x + y end #= cell:11 =# begin #= cell:6 =# if inferred(add, Tuple{Int8, Int16}) == Int16 nothing else (Base.throw)((Base.AssertionError)("inferred(add, Tuple{Int8, Int16}) == Int16")) end end end

Here we use @macroexpand to see the code built by the macro. It first defines the function itself, and then adds a type check with the declared input and output types. We can now test that our macro work correctly:

@checked Int16 function add(x::Int8,y::Int16)
    x+y
end

If the inferred type is incorrect the macro will throw an AssertionError when the function is defined:

using Test
@test_throws AssertionError @checked Int8 function add(x::Int8,y::Int16)
    x+y
end
Test Passed Thrown: AssertionError

Parametric methods

Next we will consider a simple parametric case, we will also replace the output type annotation by a predicate on the inferred type I to allow for more general (and useful) checks. For example in the following, the inferred type will be compared to the type of the first argument:

@checked I -> I == T function add(x::T, y::K) where {T<:Integer, K<:Integer} 
    x + y
end
Julia

We first generalise our match_signature method to handle the parametric case, using an OR clause to handle functions without where:

match_signature(ast) = @match ast begin
    quote
        $(::LineNumberNode)
        function $name($(args...)) where {$(cov...)}; $(block...) end
    end ||
    quote
        $(::LineNumberNode)
        function $name($(args...)); $(block...) end
    end && Do(cov = []) =>
    (name = name,
     args = args,
     cov  = cov)
end

sig = match_signature(quote 
    function add(x::T,y::K) where {T<:Integer, K<:Integer}
        x+y
    end
end)
(name = :add, args = Any[:(x::T), :(y::K)], cov = Any[:(T <: Integer), :(K <: Integer)])

We need a small method to extract subtype relations (T <: Integer):

match_subtype(ast) = @match ast begin :($(sub)<:$(sup)) => (sub=sub,sup=sup) end
match_subtype.(sig.cov)
2-element Array{NamedTuple{(:sub, :sup),Tuple{Symbol,Symbol}},1}: (sub = :T, sup = :Integer) (sub = :K, sup = :Integer)

In order to test every combination of parameters under the constrains above we need to enumerate all concret subtypes of Integer. This can be done by recursively calling InteractiveUtils.subtypes:

import InteractiveUtils.subtypes

function allsubtypes(T::Type)
    isconcretetype(T) && return [T]
    t = Type[]
    map(K->allsubtypes(K,t), subtypes(T))
    t
end
allsubtypes(T,t) = isconcretetype(T) ? push!(t,T) : map(K->allsubtypes(K,t), subtypes(T))
allsubtypes(T::Symbol) = allsubtypes(Core.eval(Main,T))#convert Symbol to Type

allsubtypes(Integer)
12-element Array{Type,1}: Bool BigInt Int128 Int16 Int32 Int64 Int8 UInt128 UInt16 UInt32 UInt64 UInt8

In order to enumerate every combination of concret types we can use Iterators.product:

collect(Iterators.product([Bool,BigInt],[Bool,BigInt]))
2×2 Array{Tuple{DataType,DataType},2}: (Bool, Bool) (Bool, BigInt) (BigInt, Bool) (BigInt, BigInt)

We also need a method that will take and abstract type annotation ([:T,:K]) and replace it with the corresponding concrete subtypes:

function concretize!(args, concrete_args, idx)
    for (j,K) in enumerate(concrete_args)
        for i in idx[j]
            args[i] = Symbol("$K")
        end
    end
end

args = Symbol[:T, :K]
concrete_args = (Bool, Bool)
idx = [[1], [2]]

concretize!(args, concrete_args, idx)
args
2-element Array{Symbol,1}: :Bool :Bool

Here idx contains the indices of where the abstract types annotations occurs in the method signature (they can occur at several positions, e.g. f(x::T, y::T, z::K) where {T,K} then idx = [[1,2],[3]]).

We need to perform the same operation on our predicate (e.g. I -> I == T), which can be an arbitrary expression. We will thus write a generic method that can replace a symbol by another in an expression:

#find all pair.first in expression and replace them by pair.second
function ast_replace!(ex::Expr, pair::Pair)
    for i=1:length(ex.args)
        if ex.args[i] == pair.first
            ex.args[i] = pair.second
        else  
            ast_replace!(ex.args[i], pair) 
        end
    end
    ex
end 
ast_replace!(ex::Symbol, pair) = nothing 
ast_replace!(ex::LineNumberNode, pair) = nothing

ast_replace!(:(I -> I == T), :T => Int8)
:(I->begin #= cell:15 =# I == Int8 end)

We can now put all this together in a method that will generate the expression that performs the check for every combination of concrete parameters:

2.5s
MLStyle (Julia)
function check_types(fun, predicate, args, covs)
  	#setup
    out = Expr[]
    covs = match_subtype.(covs)
    covs_symbols = [c.sub for c in covs]
    idx  = [findall(args .== c) for c in covs_symbols]
    
    for p in Iterators.product([allsubtypes(c.sup) for c in covs]...) 
    
    		#build concrete arguments
        concrete_args = copy(args)                
			  concretize!(concrete_args, p, idx)
       
        #build concrete predicate
        concrete_predicate = copy(predicate)
        for pairs in [Pair(c,Symbol(t)) for (c,t) in zip(covs_symbols,p)]
            ast_replace!(concrete_predicate,pairs)
        end
        
    		#build final expression
        t = :( Tuple{$(concrete_args...)} )
        I = :( inferred($fun,$t) )
        push!(out, quote 
            report_type_check($concrete_predicate,$I,$fun,$concrete_args)
        end)
    end
    out
end

fun, args = :add, [:T,:K]
covs = Any[:(T <: Integer), :(K <: Integer)]
out_type = :(O -> O == T)

check_types(fun, out_type, args, covs)[1:2]
2-element Array{Expr,1}: quote #= cell:24 =# report_type_check((O->begin #= cell:32 =# O == Bool end), inferred(add, Tuple{Bool, Bool}), add, Symbol[:Bool, :Bool]) end quote #= cell:24 =# report_type_check((O->begin #= cell:32 =# O == BigInt end), inferred(add, Tuple{BigInt, Bool}), add, Symbol[:BigInt, :Bool]) end

In the generated code we call a helper function that will display the result of the type check instead of throwing an error:

function report_type_check(predicate,inferred,fun,types)
    passed = predicate(inferred)
    sig = join(types,",")
    status = passed ? "SUCCESS" : "FAILURE"
    println( "$status: $(fun)($sig) \t → $inferred" ) 
end
report_type_check (generic function with 1 method)

And finally we can update our checked macro:

macro checked(out,f)
    sig = match_signature(quote $f end)
    args = match_type_annotation.(sig.args)
        
    checks = check_types(Symbol(sig.name),out,args,sig.cov)
    
    esc(quote
        $f
        $(checks...)
    end)
end

@checked O -> O == promote_type(T,K) function add(x::T, y::K) where {T<: Integer, K <: Integer} 
    x + y
end

All tests succeeded but the first one (add(Bool,Bool) → Int64) since adding two booleans returns an Int64. This could be fixed by modifying a bit the predicate to handle this particular case.

Limitations

One limitation of the implementation above is that it doesn't handle parametric types (e.g. Vector{T} where T <: Number), but it should be relatively easy to generalize it to handle such a case. More general ones however might get tricky (e.g. Array{T,N} where {T <: Union{Number, String}, N}).

The sheer number type combinations to check can also become an issue, for example there's 16 concrete subtype of Number in Base, so a function taking two numbers will generate 256 checks. There's 8998 concrete types in Base, so testing a function that has a single parameter annotated as Any would take that many checks, which is not very realistic.

For that reason this kind of tool might be better tailored for automatic tests, ran only during testing for some specific cases.

Final code

using MLStyle
import InteractiveUtils.subtypes

match_signature(ast) = @match ast begin
    quote
        $(::LineNumberNode)
        function $name($(args...)) where {$(cov...)}; $(block...) end
    end ||
    quote
        $(::LineNumberNode)
        function $name($(args...)); $(block...) end
    end && Do(cov = []) =>
    (name = name,
     args = args,
     cov  = cov)
end

match_type_annotation(ast) = @match ast begin :($(arg)::$(ann)) => ann end
match_subtype(ast) = @match ast begin :($(sub)<:$(sup)) => (sub=sub,sup=sup) end

function inferred(f,args)
    codeinfo = Base.code_typed(f,args)[1]
    codeinfo.second
end

function allsubtypes(T::Type)
    isconcretetype(T) && return [T]
    t = Type[]
    map(K->allsubtypes(K,t), subtypes(T))
    t
end
allsubtypes(T,t) = isconcretetype(T) ? push!(t,T) : map(K->allsubtypes(K,t), subtypes(T))
allsubtypes(T::Symbol) = allsubtypes(Core.eval(Main,T))#convert Symbol to Type

function concretize!(args, concrete_args, idx)
    for (j,K) in enumerate(concrete_args)
        for i in idx[j]
            args[i] = Symbol("$K")
        end
    end
end

function ast_replace!(ex::Expr, pair::Pair)
    for i=1:length(ex.args)
        if ex.args[i] == pair.first
            ex.args[i] = pair.second
        else  
            ast_replace!(ex.args[i], pair) 
        end
    end
    ex
end 
ast_replace!(ex::Symbol, pair) = nothing 
ast_replace!(ex::LineNumberNode, pair) = nothing

function check_types(fun, predicate, args, covs)
  	#setup
    out = Expr[]
    covs = match_subtype.(covs)
    covs_symbols = [c.sub for c in covs]
    idx  = [findall(args .== c) for c in covs_symbols]
    
    for p in Iterators.product([allsubtypes(c.sup) for c in covs]...) 
    
    		#build concrete arguments
        concrete_args = copy(args)                
			  concretize!(concrete_args, p, idx)
       
        #build concrete predicate
        concrete_predicate = copy(predicate)
        for pairs in [Pair(c,Symbol(t)) for (c,t) in zip(covs_symbols,p)]
            ast_replace!(concrete_predicate,pairs)
        end
        
    		#build final expression
        t = :( Tuple{$(concrete_args...)} )
        I = :( inferred($fun,$t) )
        push!(out, quote 
            report_type_check($concrete_predicate,$I,$fun,$concrete_args)
        end)
    end
    out
end

function report_type_check(predicate,inferred,fun,types)
    passed = predicate(inferred)
    sig = join(types,",")
    status = passed ? "SUCCESS" : "FAILURE"
    println( "$status: $(fun)($sig) \t → $inferred" ) 
end

macro checked(out,f)
    sig = match_signature(quote $f end)
    args = match_type_annotation.(sig.args)
        
    checks = check_types(Symbol(sig.name),out,args,sig.cov)
    
    esc(quote
        $f
        $(checks...)
    end)
end

@checked O -> O == promote_type(T,K) function add(x::T, y::K) where {T<: Integer, K <: Integer} 
    x + y
end
Julia