Skip to content

Commit d1aa910

Browse files
oschulzToucheSirpxl-th
authored
Utilize ChainRulesCore thunks (#966)
* Don't force unthunking of ChainRulesCore thunks Introduces @_adjoint_keepthunks to mark adjoints that should pass chunks through. * Use @_adjoint_keepthunks where appropriate * Use wrap_chainrules_output in unthunk_tangent * Fix unthunk_tangent for array of thunks * Don't unthunk explicitly in unbroadcast * Define unthunk_tangent for IdDict to support Params * Make unthunk_tangent for IdDict non-differentiable Co-authored-by: Brian Chen <[email protected]> * Revert "Don't unthunk explicitly in unbroadcast" This reverts commit 34865ea. * Fix problems related to unthunk_tangent for IdDict Co-authored-by: Brian Chen <[email protected]> * Resolve duplicate rrule for unthunk_tangent with IdDict * Make unthunk_tangent recurse into arrays * Fix tests * Unthunk in collect(::Generator) * Update deps * Disable thunks for 2nd order AD * Temporary use fork of CRC * Remove hook * Fix * Cleanup * Up deps * Cleanup * Remove extra unthunk_tangent --------- Co-authored-by: Brian Chen <[email protected]> Co-authored-by: Anton Smirnov <[email protected]>
1 parent bc6cd09 commit d1aa910

File tree

10 files changed

+82
-42
lines changed

10 files changed

+82
-42
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,8 @@ ZygoteTrackerExt = "Tracker"
3838

3939
[compat]
4040
AbstractFFTs = "1.3.1"
41-
ChainRules = "1.44.1"
42-
ChainRulesCore = "1.9"
41+
ChainRules = "1.72.2"
42+
ChainRulesCore = "1.25.1"
4343
ChainRulesTestUtils = "1"
4444
Colors = "0.12, 0.13"
4545
DiffRules = "1.4"

src/Zygote.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,12 @@ module Zygote
33
using LinearAlgebra, Statistics
44
using LinearAlgebra: copytri!, AbstractTriangular
55

6+
import ZygoteRules
67
import ZygoteRules: @adjoint, @adjoint!, AContext, adjoint, _pullback, pullback,
78
literal_getproperty, literal_getfield, unthunk_tangent
89

910
using ChainRulesCore
10-
using ChainRules: ChainRules, rrule, unthunk, canonicalize
11+
using ChainRules: ChainRules, AbstractThunk, rrule, unthunk, canonicalize
1112
using IRTools
1213
using MacroTools, Requires
1314
using MacroTools: @forward

src/compiler/chainrules.jl

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,14 @@
1+
# ToDo: Move some of this to ZygoteRules, or move unthunk_tangent for Tuple and NamedTuple from
2+
# Zygote rules here?
3+
function unthunk_tangent end
4+
@inline unthunk_tangent(x::AbstractThunk) = wrap_chainrules_output(unthunk(x))
5+
@inline unthunk_tangent(x::NTuple{N,<:Number}) where N = x
6+
@inline unthunk_tangent(x::AbstractArray{<:Number,N}) where N = x
7+
@inline unthunk_tangent(x::AbstractArray) = map(unthunk_tangent, x)
8+
unthunk_tangent(d::IdDict) = IdDict([unthunk_tangent(k) => unthunk_tangent(v) for (k, v) in d])
9+
@non_differentiable unthunk_tangent(::IdDict)
10+
11+
112
struct ZygoteRuleConfig{CTX<:AContext} <: RuleConfig{Union{HasReverseMode,NoForwardsMode}}
213
context::CTX
314
end
@@ -107,7 +118,6 @@ is_kwfunc(k, ::Type{<:NamedTuple}, f, args...) = k===Core.kwftype(f)
107118
Convert `x` from the differentials types ChainRules uses to the format Zygote uses internally.
108119
"""
109120
@inline wrap_chainrules_output(x) = x
110-
@inline wrap_chainrules_output(x::AbstractThunk) = wrap_chainrules_output(unthunk(x)) # For now we are just not going to deal with thunks
111121
@inline wrap_chainrules_output(x::Tuple) = map(wrap_chainrules_output, x)
112122
# Zygote convention: even if many AbstractZero partials (i.e. multi-input function), make just 1 nothing.
113123
@inline wrap_chainrules_output(x::Tuple{Vararg{ChainRules.AbstractZero}}) = nothing
@@ -261,7 +271,9 @@ function ChainRulesCore.rrule_via_ad(config::ZygoteRuleConfig, f_args...; kwargs
261271
_pullback(config.context, f_args...)
262272
end
263273

264-
ad_pullback(Δ) = zygote2differential(pb(wrap_chainrules_output(Δ)), f_args)
274+
ad_pullback(Δ) = zygote2differential(
275+
pb(wrap_chainrules_output(unthunk_tangent(Δ))),
276+
f_args)
265277
return y, ad_pullback
266278
end
267279

src/compiler/interface.jl

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,13 @@ end
3737
_pullback(f, args...) = _pullback(Context(), f, args...)
3838

3939
tailmemaybe(::Nothing) = nothing
40-
tailmemaybe(x::Tuple) = Base.tail(x)
40+
tailmemaybe(x::Tuple) = unthunk_tangent(Base.tail(x))
41+
42+
# unthunking is essentially an identity operation on a lazy value, but
43+
# `@adjoint unthunk_tangent(x) = unthunk_tangent(x), ȳ -> (ȳ,)` is not enough to make
44+
# nested AD work, so define
45+
@adjoint tailmemaybe(xs::Tuple) = tailmemaybe(xs), x̄s -> ((nothing, x̄s...),)
46+
4147

4248
"""
4349
pullback(f, args...)
@@ -351,6 +357,9 @@ function copy!(x::AbstractVector, ps::Params)
351357
x
352358
end
353359

360+
_maybe_unthunk(x::AbstractThunk) = unthunk(x)
361+
_maybe_unthunk(x) = x
362+
354363
"""
355364
Grads(...)
356365
@@ -385,7 +394,7 @@ end
385394

386395
function Base.getindex(gs::Grads, x)
387396
isbits(x) && error("Only reference types can be differentiated with `Params`.")
388-
return gs.grads[x]
397+
return _maybe_unthunk(gs.grads[x])
389398
end
390399

391400
"""
@@ -468,7 +477,7 @@ function pullback(f, ps::Params)
468477
cache(cx)[p] = nothing
469478
end
470479
back(Δ)
471-
Grads(cx.cache, ps) # TODO make a copy
480+
Grads(_maybe_unthunk(cx.cache), ps)
472481
end
473482
end
474483

src/compiler/reverse.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,18 @@ using IRTools: IR, Variable, Pipe, xcall, var, prewalk, postwalk,
33
insertafter!, finish, expand!, prune!, substitute!, substitute,
44
block, block!, branch!, return!, stmt, meta
55

6+
7+
# TODO: Temporary, to be removed when ChainRulesCore rrules are required to
8+
# support thunks as an input and all instances of _adjoint_keepthunks in
9+
# Zygote have been replaces by rrules:
10+
macro _adjoint_keepthunks(ex)
11+
ZygoteRules.gradm(ex, false, true)
12+
end
13+
macro _adjoint_keepthunks!(ex)
14+
ZygoteRules.gradm(ex, true, true)
15+
end
16+
17+
618
@inline tuple_va(N, xs) = xs
719
@inline tuple_va(N, x, xs...) = (x, tuple_va(N, xs...)...)
820
@inline tuple_va(::Val{N}, ::Nothing) where N = ntuple(_ -> nothing, Val(N))

src/lib/array.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,7 @@ function _pullback(cx::AContext, ::typeof(collect), g::Base.Generator)
237237
= reconstruct_if_dict(x̄, _keys) # return a dictionary if needed
238238
(nothing, (f = f̄, iter = x̄),)
239239
end
240-
y, collect_pullback
240+
y, collect_pullback unthunk_tangent
241241
end
242242

243243
collect_if_dict(x::Dict) = collect(x), collect(keys(x))

src/lib/broadcast.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,8 @@ function Base.reducedim_init(::typeof(identity), ::typeof(accum), A::AbstractArr
5353
Base.reducedim_initarray(A, region, nothing, Union{Nothing,eltype(A)})
5454
end
5555

56-
function unbroadcast(x::AbstractArray, x̄)
56+
function unbroadcast(x::AbstractArray, maybethunked_x̄)
57+
= unthunk_tangent(maybethunked_x̄)
5758
N = ndims(x̄)
5859
if length(x) == length(x̄)
5960
_project(x, x̄) # ProjectTo handles reshape, offsets, structured matrices, row vectors

src/lib/lib.jl

Lines changed: 29 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -37,16 +37,23 @@ function accum(x::RefValue, y::RefValue)
3737
return x
3838
end
3939

40+
accum(x::NamedTuple, y::ChainRulesCore.Tangent) = accum(x, wrap_chainrules_output(y))
41+
accum(x::ChainRulesCore.Tangent, y::NamedTuple) = accum(wrap_chainrules_output(x), y)
42+
43+
accum(x, y::AbstractThunk) = @thunk(accum(x, unthunk(y)))
44+
accum(x::AbstractThunk, y) = @thunk(accum(unthunk(x), y))
45+
accum(x::AbstractThunk, y::AbstractThunk) = @thunk(accum(unthunk(x), unthunk(y)))
46+
4047
# Core functions
41-
@adjoint deepcopy(x) = deepcopy(x), ȳ -> (ȳ,)
48+
@_adjoint_keepthunks deepcopy(x) = deepcopy(x), ȳ -> (ȳ,)
4249

43-
@adjoint (::Type{V})(x...) where V<:Val = V(x...), _ -> nothing
50+
@_adjoint_keepthunks (::Type{V})(x...) where V<:Val = V(x...), _ -> nothing
4451

45-
@adjoint ifelse(cond::Bool, t, f) =
52+
@_adjoint_keepthunks ifelse(cond::Bool, t, f) =
4653
ifelse(cond, t, f),
4754
Δ -> cond ? (nothing, Δ, zero(Δ)) : (nothing, zero(Δ), Δ)
4855

49-
@adjoint Base.typeassert(x, T) = Base.typeassert(x, T), Δ -> (Δ, nothing)
56+
@_adjoint_keepthunks Base.typeassert(x, T) = Base.typeassert(x, T), Δ -> (Δ, nothing)
5057

5158
accum_param(::Context{false}, _, Δ) = Δ
5259
@generated function accum_param(cx::Context, x, Δ)
@@ -70,11 +77,11 @@ end
7077

7178
unwrap(x) = x
7279

73-
@adjoint unwrap(x) = unwrap(x), x̄ -> (accum_param(__context__, x, x̄),)
80+
@_adjoint_keepthunks unwrap(x) = unwrap(x), x̄ -> (accum_param(__context__, x, x̄),)
7481

7582
unwrap(ref, x) = x
7683

77-
@adjoint unwrap(ref, x) = unwrap(x), function (x̄)
84+
@_adjoint_keepthunks unwrap(ref, x) = unwrap(x), function (x̄)
7885
accum_global(__context__, ref, x̄)
7986
(accum_param(__context__, x, x̄),)
8087
end
@@ -88,7 +95,7 @@ function global_set(ref, val)
8895
end
8996
end
9097

91-
@adjoint! function global_set(ref, x)
98+
@_adjoint_keepthunks! function global_set(ref, x)
9299
global_set(ref, x), function (x̄)
93100
gs = cache(__context__)
94101
= accum(get(gs, ref, nothing), x̄)
@@ -101,9 +108,9 @@ end
101108

102109
using Base: tail
103110

104-
@adjoint tuple(xs...) = xs, identity
111+
@_adjoint_keepthunks tuple(xs...) = xs, identity
105112

106-
@adjoint function literal_getindex(xs::NTuple{N,Any}, ::Val{i}) where {N,i}
113+
@_adjoint_keepthunks function literal_getindex(xs::NTuple{N,Any}, ::Val{i}) where {N,i}
107114
val = xs[i]
108115
function back(Δ)
109116
accum_param(__context__, val, Δ) === nothing && return
@@ -112,7 +119,7 @@ using Base: tail
112119
val, back
113120
end
114121

115-
@adjoint function getindex(xs::NTuple{N,Any}, i::Integer) where N
122+
@_adjoint_keepthunks function getindex(xs::NTuple{N,Any}, i::Integer) where N
116123
val = xs[i]
117124
function back(Δ)
118125
accum_param(__context__, val, Δ) === nothing && return
@@ -121,10 +128,10 @@ end
121128
return val, back
122129
end
123130

124-
@adjoint getindex(xs::NTuple{N,Any}, r::AbstractUnitRange) where N =
131+
@_adjoint_keepthunks getindex(xs::NTuple{N,Any}, r::AbstractUnitRange) where N =
125132
(xs[r], Δ -> (ntuple(j -> j in r ? Δ[findfirst(isequal(j), r)] : nothing, Val(N)), nothing))
126133

127-
@adjoint function getindex(xs::NTuple{N,Any}, r::AbstractVector) where N
134+
@_adjoint_keepthunks function getindex(xs::NTuple{N,Any}, r::AbstractVector) where N
128135
val = xs[r]
129136
function back(Δ)
130137
dxs = ntuple(Val(length(xs))) do x
@@ -155,18 +162,18 @@ function _pullback(cx::AContext, ::typeof(literal_indexed_iterate), xs::Tuple, :
155162
end
156163

157164
# Needed for iteration lowering
158-
@adjoint Core.getfield(xs::NTuple{N,Any}, i::Int) where N =
165+
@_adjoint_keepthunks Core.getfield(xs::NTuple{N,Any}, i::Int) where N =
159166
(xs[i], Δ -> (ntuple(j -> i == j ? Δ : nothing, Val(N)), nothing))
160167

161-
@adjoint Core.getfield(xs::NamedTuple{K,<:NTuple{N,Any}}, i::Int) where {K,N} =
168+
@_adjoint_keepthunks Core.getfield(xs::NamedTuple{K,<:NTuple{N,Any}}, i::Int) where {K,N} =
162169
(xs[i], Δ -> (NamedTuple{K}(ntuple(j -> i == j ? Δ : nothing, Val(N))), nothing))
163170

164-
@adjoint function Base.first(xs::Tuple)
171+
@_adjoint_keepthunks function Base.first(xs::Tuple)
165172
drest = map(_->nothing, tail(xs))
166173
first(xs), Δ -> ((Δ, drest...),)
167174
end
168175

169-
@adjoint Base.tail(xs::Tuple) = tail(xs), x̄s -> ((nothing, x̄s...),)
176+
@_adjoint_keepthunks Base.tail(xs::Tuple) = tail(xs), x̄s -> ((nothing, x̄s...),)
170177

171178
_empty(x) = length(x)
172179
_empty(x::Union{Tuple,NamedTuple}) = map(_->nothing, x)
@@ -188,7 +195,7 @@ end
188195

189196
unapply(t, xs) = _unapply(t, xs)[1]
190197

191-
@adjoint! function Core._apply(f, args...)
198+
@_adjoint_keepthunks! function Core._apply(f, args...)
192199
y, back = Core._apply(_pullback, (__context__, f), args...)
193200
st = map(_empty, args)
194201
y, function (Δ)
@@ -198,7 +205,7 @@ unapply(t, xs) = _unapply(t, xs)[1]
198205
end
199206
end
200207

201-
@adjoint! function Core._apply_iterate(::typeof(iterate), f, args...)
208+
@_adjoint_keepthunks! function Core._apply_iterate(::typeof(iterate), f, args...)
202209
y, back = Core._apply(_pullback, (__context__, f), args...)
203210
st = map(_empty, args)
204211
y, function (Δ)
@@ -223,7 +230,7 @@ end
223230
@generated pair(::Val{k}, v, _=nothing) where k = :($k = v,)
224231
@generated pair(::Val{k}, v, ::NamedTuple{keys}) where {k,keys} = k isa Int ? :($(getfield(keys, k)) = v,) : :($k = v,)
225232

226-
@adjoint function literal_getfield(x, ::Val{f}) where f
233+
@_adjoint_keepthunks function literal_getfield(x, ::Val{f}) where f
227234
val = getfield(x, f)
228235
function back(Δ)
229236
accum_param(__context__, val, Δ) === nothing && return
@@ -273,8 +280,7 @@ function _get!(default::Base.Callable, ch, x)
273280
end
274281
end
275282

276-
277-
@adjoint! function setfield!(x, f, val)
283+
@_adjoint_keepthunks! function setfield!(x, f, val)
278284
y = setfield!(x, f, val)
279285
g = grad_mut(__context__, x)
280286
y, function (_)
@@ -290,13 +296,13 @@ end
290296

291297
Jnew{T}(g) where T = Jnew{T,typeof(g)}(g)
292298

293-
@adjoint! function __new__(T, args...)
299+
@_adjoint_keepthunks! function __new__(T, args...)
294300
x = __new__(T, args...)
295301
g = !ismutabletype(T) || fieldcount(T) == 0 ? nothing : grad_mut(__context__, x)
296302
x, Jnew{T,typeof(g),false}(g)
297303
end
298304

299-
@adjoint! function __splatnew__(T, args)
305+
@_adjoint_keepthunks! function __splatnew__(T, args)
300306
x = __splatnew__(T, args)
301307
g = !ismutabletype(T) || fieldcount(T) == 0 ? nothing : grad_mut(__context__, x)
302308
x, Jnew{T,typeof(g),true}(g)

test/gradcheck.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,14 @@ function ngradient(f, xs::AbstractArray...)
2020
return grads
2121
end
2222

23-
function gradcheck(f, xs...)
23+
function gradcheck(f, xs...; rtol = 1e-5, atol = 1e-5)
2424
grad_zygote = gradient(f, xs...)
2525
grad_finite_difference = ngradient(f, xs...)
26-
return all(isapprox.(grad_zygote, grad_finite_difference; rtol = 1e-5, atol = 1e-5))
26+
return all(isapprox.(grad_zygote, grad_finite_difference; rtol = rtol, atol = atol))
2727
end
2828

29-
gradtest(f, xs::AbstractArray...) = gradcheck((xs...) -> sum(sin.(f(xs...))), xs...)
30-
gradtest(f, dims...) = gradtest(f, rand.(Float64, dims)...)
29+
gradtest(f, xs::AbstractArray...; kwargs...) = gradcheck((xs...) -> sum(sin.(f(xs...))), xs...; kwargs...)
30+
gradtest(f, dims...; kwargs...) = gradtest(f, rand.(Float64, dims)...; kwargs...)
3131

3232
# utilities for using gradcheck with complex matrices
3333
_splitreim(A) = (real(A),)
@@ -160,8 +160,8 @@ end
160160
@test gradient(y, x, z) == ([1, 1, 2], nothing)
161161

162162
# https://github.com/FluxML/Zygote.jl/issues/376
163-
_, back = Zygote._pullback(x->x[1]*im, randn(2))
164-
@test back(1.0)[2] == real([-im, 0]) == [0, 0]
163+
_, back = Zygote.pullback(x -> x[1] * im, randn(2))
164+
@test back(1.0)[1] == real([-im, 0]) == [0, 0]
165165

166166
# _droplike
167167
@test gradient(x -> sum(inv, x[1, :]'), ones(2, 2)) == ([-1 -1; 0 0],)
@@ -949,8 +949,8 @@ end
949949
_hermsymtype(::Type{<:Symmetric}) = Symmetric
950950
_hermsymtype(::Type{<:Hermitian}) = Hermitian
951951

952-
function _gradtest_hermsym(f, ST, A)
953-
gradtest(_splitreim(collect(A))...) do (args...)
952+
function _gradtest_hermsym(f, ST, A; kwargs...)
953+
gradtest(_splitreim(collect(A))...; kwargs...) do (args...)
954954
B = f(ST(_joinreim(_dropimaggrad.(args)...)))
955955
return sum(_splitreim(B))
956956
end

test/interface.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -269,5 +269,4 @@ end
269269
@test sgs[d.b] fill(1.f0, size(d.b))
270270
end
271271

272-
273272
end

0 commit comments

Comments
 (0)