Dubstep

Dubstep

This module uses Cassette.jl to modify programs by overdubbing their executions in a context.

TraceCtx

Builds hierarchical runtime value traces by running the program you pass it. You can change the metadata. You can change out the metadata that you pass in order to collect different information. The default is Any[].

LPCtx

Replaces all calls to norm(x,p) which norm(x,ctx.metadata[p]) so you can change the norms that a code uses to compute.

Example

Here is an example of changing an internal component of a mathematical operation using cassette to rewrite the norm function.

First we define a function that uses norm, and another function that calls it.

subg(x,y) = norm([x x x]/6 - [y y y]/2, 2)
function g()
    a = 5+7
    b = 3+4
    c = subg(a,b)
    return c
end

We use the Dubstep.LPCtx which is shown here.

Cassette.@context LPCtx

function Cassette.execute(ctx::LPCtx, args...)
    if Cassette.canoverdub(ctx, args...)
        newctx = Cassette.similarcontext(ctx, metadata = ctx.metadata)
        return Cassette.overdub(newctx, args...)
    else
        return Cassette.fallback(ctx, args...)
    end
end

using LinearAlgebra
function Cassette.execute(ctx::LPCtx, f::typeof(norm), arg, power)
    return f(arg, ctx.metadata[power])
end

Note the method definition of Cassette.execute for LPCtx when called with the function LinearAlgebra.norm.

We then construct an instance of the context that configures how we want to do the substitution.

@testset "LP" begin 
@test 2.5980 < g() < 2.599
ctx = Dubstep.LPCtx(metadata=Dict(1=>2, 2=>1, Inf=>1
@test Cassette.overdub(ctx, g) == 4.5

And just like that, we can control the execution of a program without rewriting it at the lexical level.

Transformations

You can also transform model by executing it in a context that changes the function calls. Eventually we will support writing compiler passes for modifying models at the expression level, but for now function calls are a good entry point.

Example: Perturbations

This example comes from the unit tests test/transform/ode.jl.

The first step is to define a context for solving models.

module ODEXform
using DifferentialEquations
using Cassette
using SemanticModels.Dubstep

Cassette.@context SolverCtx
function Cassette.execute(ctx::SolverCtx, args...)
    if Cassette.canoverdub(ctx, args...)
        #newctx = Cassette.similarcontext(ctx, metadata = ctx.metadata)
        return Cassette.overdub(ctx, args...)
    else
        return Cassette.fallback(ctx, args...)
    end
end

function Cassette.execute(ctx::SolverCtx, f::typeof(Base.vect), args...)
    @info "constructing a vector length $(length(args))"
    return Cassette.fallback(ctx, f, args...)
end

# We don't need to overdub basic math. this hopefully makes execution faster.
# if these overloads don't actually make it faster, they can be deleted.
function Cassette.execute(ctx::SolverCtx, f::typeof(+), args...)
    return Cassette.fallback(ctx, f, args...)
end
function Cassette.execute(ctx::SolverCtx, f::typeof(-), args...)
    return Cassette.fallback(ctx, f, args...)
end
function Cassette.execute(ctx::SolverCtx, f::typeof(*), args...)
    return Cassette.fallback(ctx, f, args...)
end
function Cassette.execute(ctx::SolverCtx, f::typeof(/), args...)
    return Cassette.fallback(ctx, f, args...)
end
end #module

Then we define our RHS of the differential equation that is du/dt = sir_ode(du, u, p, t). This function needs to be defined before we define the method for Cassette.execute with the signature: Cassette.execute(ctx::ODEXform.SolverCtx, f::typeof(sir_ode), args...) because we need to have the function we want to overdub defined before we can specify how to overdub it.

using LinearAlgebra
using Test
using Cassette
using DifferentialEquations
using SemanticModels.Dubstep

"""   sir_ode(du,u,p,t)

computes the du/dt array for the SIR system. parameters p is b,g = beta,gamma.
"""
sir_ode(du,u,p,t) = begin
    S,I,R = u
    b,g = p
    du[1] = -b*S*I
    du[2] = b*S*I-g*I
    du[3] = g*I
end

function Cassette.execute(ctx::ODEXform.SolverCtx, f::typeof(sir_ode), args...)
    y = Cassette.fallback(ctx, f, args...)
    # add a lagniappe of infection
    extra = args[1][1] * ctx.metadata.factor
    push!(ctx.metadata.extras, extra)
    args[1][1] += extra
    args[1][2] -= extra
    return y
end

The key thing is that we define the execute method by specifying that we want to execute sir_ode then compute the extra amount (the lagniappe) and add that extra amount to the dS/dt. The SIR model has an invariant that dI/dt = -dS/dt + dR/dt so we adjust the dI/dt accordingly.

The rest of this code runs the model in the context.

function g()
    parms = [0.1,0.05]
    init = [0.99,0.01,0.0]
    tspan = (0.0,200.0)
    sir_prob = Dubstep.construct(ODEProblem,sir_ode,init,tspan,parms)
    return sir_prob
end

function h()
    prob = g()
    return solve(prob, alg=Vern7())
end

#precompile
@time sol1 = h()
#timeit
@time sol1 = h()

We define a perturbation function that handles setting up the context and collecting the results. Note that we store the extras in the context.metadata using a modifying operator push!.

"""    perturb(f, factor)

run the function f with a perturbation specified by factor.
"""
function perturb(f, factor)
    t = (factor=factor,extras=Float64[])
    ctx = ODEXform.SolverCtx(metadata = t)
    val = Cassette.overdub(ctx, f)
    return val, t
end

We collect the traces t and solutions s in order to quantify the effect of our perturbation on the answer computed by solve. We test to make sure that the bigger the perturbation, the bigger the error.

traces = Any[]
solns = Any[]
for f in [0.0, 0.01, 0.05, 0.10]
    val, t = perturb(h, f)
    push!(traces, t)
    push!(solns, val)
end

for (i, s) in enumerate(solns)
    @show s(100)
    @show traces[i].factor
    @show traces[i].extras[5]
    @show sum(traces[i].extras)/length(traces[i].extras)
end

@testset "ODE perturbation"

@test norm(sol1(100) .- solns[1](100),2) < 1e-6
@test norm(sol1(100) .- solns[2](100),2) > 1e-6
@test norm(solns[1](100) .- solns[2](100),2) < norm(solns[1](100) .- solns[3](100),2)
@test norm(solns[1](100) .- solns[2](100),2) < norm(solns[1](100) .- solns[4](100),2)

end

This example illustrates how you can use a Cassette.Context to highjack the execution of a scientific model in order to change the execution in a meaningful way. We also see how the execution allows use to example the sensitivity of the solution with respect to the derivative. This technique allows scientists to answer counterfactual questions about the execution of codes, such as "what if the model had a slightly different RHS?"

Reference

LPCtx

replaces all calls to LinearAlgebra.norm with a different p.

This context is useful for modifying statistical codes or machine learning regularizers.

source
Context{N<:Cassette.AbstractContextName,
        M<:Any,
        P<:Cassette.AbstractPass,
        T<:Union{Nothing,Cassette.Tag},
        B<:Union{Nothing,Cassette.BindingMetaDictCache}}

A type representing a Cassette execution context. This type is normally interacted with through type aliases constructed via Cassette.@context:

julia> Cassette.@context MyCtx
Cassette.Context{nametype(MyCtx),M,P,T,B} where B<:Union{Nothing,IdDict{Module,Dict{Symbol,BindingMeta}}}
                                          where P<:Cassette.AbstractPass
                                          where T<:Union{Nothing,Tag}
                                          where M

Constructors

Given a context type alias named e.g. MyCtx, an instance of the type can be constructed via:

MyCtx(; metadata = nothing, pass = Cassette.NoPass())

To construct a new context instance using an existing context instance as a template, see the similarcontext function.

To enable contextual tagging for a given context instance, see the enabletagging function.

Fields

  • name::N<:Cassette.AbstractContextName: a parameter used to disambiguate different contexts for overloading purposes (e.g. distinguishes MyCtx from other Context type aliases).

  • metadata::M<:Any: trace-local metadata as provided to the context constructor

  • pass::P<:Cassette.AbstractPass: the Cassette pass that will be applied to all method bodies encountered during contextual execution (see the @pass macro for details).

  • tag::T<:Union{Nothing,Tag}: the tag object that is attached to values when they are tagged w.r.t. the context instance

  • bindingscache::B<:Union{Nothing,BindingMetaDictCache}}: storage for metadata associated with tagged module bindings

TracedRun{T,V}

captures the dataflow of a code execution. We store the trace and the value.

see also trace.

source
replacenorm(f::Function, d::AbstractDict)

run f, but replace every call to norm using the mapping in d.

source
trace(f)

run the function f and return a TracedRun containing the trace and the output.

source