Dubstep
This module uses Cassette.jl to modify programs by overdubbing their executions in a context.
TraceCtx
Builds hierarchical runtime value traces by running the program you pass it. You can change the metadata. You can change out the metadata that you pass in order to collect different information. The default is Any[].
LPCtx
Replaces all calls to norm(x,p)
which norm(x,ctx.metadata[p])
so you can change the norms that a code uses to compute.
Example
Here is an example of changing an internal component of a mathematical operation using cassette to rewrite the norm function.
First we define a function that uses norm, and another function that calls it.
subg(x,y) = norm([x x x]/6 - [y y y]/2, 2)
function g()
a = 5+7
b = 3+4
c = subg(a,b)
return c
end
We use the Dubstep.LPCtx which is shown here.
Cassette.@context LPCtx
function Cassette.execute(ctx::LPCtx, args...)
if Cassette.canoverdub(ctx, args...)
newctx = Cassette.similarcontext(ctx, metadata = ctx.metadata)
return Cassette.overdub(newctx, args...)
else
return Cassette.fallback(ctx, args...)
end
end
using LinearAlgebra
function Cassette.execute(ctx::LPCtx, f::typeof(norm), arg, power)
return f(arg, ctx.metadata[power])
end
Note the method definition of Cassette.execute
for LPCtx when called with the function LinearAlgebra.norm
.
We then construct an instance of the context that configures how we want to do the substitution.
@testset "LP" begin
@test 2.5980 < g() < 2.599
ctx = Dubstep.LPCtx(metadata=Dict(1=>2, 2=>1, Inf=>1
@test Cassette.overdub(ctx, g) == 4.5
And just like that, we can control the execution of a program without rewriting it at the lexical level.
Transformations
You can also transform model by executing it in a context that changes the function calls. Eventually we will support writing compiler passes for modifying models at the expression level, but for now function calls are a good entry point.
Example: Perturbations
This example comes from the unit tests test/transform/ode.jl
.
The first step is to define a context for solving models.
module ODEXform
using DifferentialEquations
using Cassette
using SemanticModels.Dubstep
Cassette.@context SolverCtx
function Cassette.execute(ctx::SolverCtx, args...)
if Cassette.canoverdub(ctx, args...)
#newctx = Cassette.similarcontext(ctx, metadata = ctx.metadata)
return Cassette.overdub(ctx, args...)
else
return Cassette.fallback(ctx, args...)
end
end
function Cassette.execute(ctx::SolverCtx, f::typeof(Base.vect), args...)
@info "constructing a vector length $(length(args))"
return Cassette.fallback(ctx, f, args...)
end
# We don't need to overdub basic math. this hopefully makes execution faster.
# if these overloads don't actually make it faster, they can be deleted.
function Cassette.execute(ctx::SolverCtx, f::typeof(+), args...)
return Cassette.fallback(ctx, f, args...)
end
function Cassette.execute(ctx::SolverCtx, f::typeof(-), args...)
return Cassette.fallback(ctx, f, args...)
end
function Cassette.execute(ctx::SolverCtx, f::typeof(*), args...)
return Cassette.fallback(ctx, f, args...)
end
function Cassette.execute(ctx::SolverCtx, f::typeof(/), args...)
return Cassette.fallback(ctx, f, args...)
end
end #module
Then we define our RHS of the differential equation that is du/dt = sir_ode(du, u, p, t)
. This function needs to be defined before we define the method for Cassette.execute
with the signature: Cassette.execute(ctx::ODEXform.SolverCtx, f::typeof(sir_ode), args...)
because we need to have the function we want to overdub defined before we can specify how to overdub it.
using LinearAlgebra
using Test
using Cassette
using DifferentialEquations
using SemanticModels.Dubstep
""" sir_ode(du,u,p,t)
computes the du/dt array for the SIR system. parameters p is b,g = beta,gamma.
"""
sir_ode(du,u,p,t) = begin
S,I,R = u
b,g = p
du[1] = -b*S*I
du[2] = b*S*I-g*I
du[3] = g*I
end
function Cassette.execute(ctx::ODEXform.SolverCtx, f::typeof(sir_ode), args...)
y = Cassette.fallback(ctx, f, args...)
# add a lagniappe of infection
extra = args[1][1] * ctx.metadata.factor
push!(ctx.metadata.extras, extra)
args[1][1] += extra
args[1][2] -= extra
return y
end
The key thing is that we define the execute method by specifying that we want to execute sir_ode
then compute the extra amount (the lagniappe) and add that extra amount to the dS/dt
. The SIR model has an invariant that dI/dt = -dS/dt + dR/dt
so we adjust the dI/dt
accordingly.
The rest of this code runs the model in the context.
function g()
parms = [0.1,0.05]
init = [0.99,0.01,0.0]
tspan = (0.0,200.0)
sir_prob = Dubstep.construct(ODEProblem,sir_ode,init,tspan,parms)
return sir_prob
end
function h()
prob = g()
return solve(prob, alg=Vern7())
end
#precompile
@time sol1 = h()
#timeit
@time sol1 = h()
We define a perturbation function that handles setting up the context and collecting the results. Note that we store the extras in the context.metadata using a modifying operator push!.
""" perturb(f, factor)
run the function f with a perturbation specified by factor.
"""
function perturb(f, factor)
t = (factor=factor,extras=Float64[])
ctx = ODEXform.SolverCtx(metadata = t)
val = Cassette.overdub(ctx, f)
return val, t
end
We collect the traces t
and solutions s
in order to quantify the effect of our perturbation on the answer computed by solve
. We test to make sure that the bigger the perturbation, the bigger the error.
traces = Any[]
solns = Any[]
for f in [0.0, 0.01, 0.05, 0.10]
val, t = perturb(h, f)
push!(traces, t)
push!(solns, val)
end
for (i, s) in enumerate(solns)
@show s(100)
@show traces[i].factor
@show traces[i].extras[5]
@show sum(traces[i].extras)/length(traces[i].extras)
end
@testset "ODE perturbation"
@test norm(sol1(100) .- solns[1](100),2) < 1e-6
@test norm(sol1(100) .- solns[2](100),2) > 1e-6
@test norm(solns[1](100) .- solns[2](100),2) < norm(solns[1](100) .- solns[3](100),2)
@test norm(solns[1](100) .- solns[2](100),2) < norm(solns[1](100) .- solns[4](100),2)
end
This example illustrates how you can use a Cassette.Context to highjack the execution of a scientific model in order to change the execution in a meaningful way. We also see how the execution allows use to example the sensitivity of the solution with respect to the derivative. This technique allows scientists to answer counterfactual questions about the execution of codes, such as "what if the model had a slightly different RHS?"
Reference
SemanticModels.Dubstep.LPCtx
— Type.LPCtx
replaces all calls to LinearAlgebra.norm
with a different p
.
This context is useful for modifying statistical codes or machine learning regularizers.
SemanticModels.Dubstep.TraceCtx
— Type.Context{N<:Cassette.AbstractContextName,
M<:Any,
P<:Cassette.AbstractPass,
T<:Union{Nothing,Cassette.Tag},
B<:Union{Nothing,Cassette.BindingMetaDictCache}}
A type representing a Cassette execution context. This type is normally interacted with through type aliases constructed via Cassette.@context
:
julia> Cassette.@context MyCtx
Cassette.Context{nametype(MyCtx),M,P,T,B} where B<:Union{Nothing,IdDict{Module,Dict{Symbol,BindingMeta}}}
where P<:Cassette.AbstractPass
where T<:Union{Nothing,Tag}
where M
Constructors
Given a context type alias named e.g. MyCtx
, an instance of the type can be constructed via:
MyCtx(; metadata = nothing, pass = Cassette.NoPass())
To construct a new context instance using an existing context instance as a template, see the similarcontext
function.
To enable contextual tagging for a given context instance, see the enabletagging
function.
Fields
name::N<:Cassette.AbstractContextName
: a parameter used to disambiguate different contexts for overloading purposes (e.g. distinguishesMyCtx
from otherContext
type aliases).metadata::M<:Any
: trace-local metadata as provided to the context constructorpass::P<:Cassette.AbstractPass
: the Cassette pass that will be applied to all method bodies encountered during contextual execution (see the@pass
macro for details).tag::T<:Union{Nothing,Tag}
: the tag object that is attached to values when they are tagged w.r.t. the context instancebindingscache::B<:Union{Nothing,BindingMetaDictCache}}
: storage for metadata associated with tagged module bindings
SemanticModels.Dubstep.TracedRun
— Type.TracedRun{T,V}
captures the dataflow of a code execution. We store the trace and the value.
see also trace
.
SemanticModels.Dubstep.replacenorm
— Method.replacenorm(f::Function, d::AbstractDict)
run f, but replace every call to norm using the mapping in d.
SemanticModels.Dubstep.trace
— Method.trace(f)
run the function f and return a TracedRun containing the trace and the output.