@@ -37,16 +37,23 @@ function accum(x::RefValue, y::RefValue)
3737 return x
3838end
3939
40+ accum (x:: NamedTuple , y:: ChainRulesCore.Tangent ) = accum (x, wrap_chainrules_output (y))
41+ accum (x:: ChainRulesCore.Tangent , y:: NamedTuple ) = accum (wrap_chainrules_output (x), y)
42+
43+ accum (x, y:: AbstractThunk ) = @thunk (accum (x, unthunk (y)))
44+ accum (x:: AbstractThunk , y) = @thunk (accum (unthunk (x), y))
45+ accum (x:: AbstractThunk , y:: AbstractThunk ) = @thunk (accum (unthunk (x), unthunk (y)))
46+
4047# Core functions
41- @adjoint deepcopy (x) = deepcopy (x), ȳ -> (ȳ,)
48+ @_adjoint_keepthunks deepcopy (x) = deepcopy (x), ȳ -> (ȳ,)
4249
43- @adjoint (:: Type{V} )(x... ) where V<: Val = V (x... ), _ -> nothing
50+ @_adjoint_keepthunks (:: Type{V} )(x... ) where V<: Val = V (x... ), _ -> nothing
4451
45- @adjoint ifelse (cond:: Bool , t, f) =
52+ @_adjoint_keepthunks ifelse (cond:: Bool , t, f) =
4653 ifelse (cond, t, f),
4754 Δ -> cond ? (nothing , Δ, zero (Δ)) : (nothing , zero (Δ), Δ)
4855
49- @adjoint Base. typeassert (x, T) = Base. typeassert (x, T), Δ -> (Δ, nothing )
56+ @_adjoint_keepthunks Base. typeassert (x, T) = Base. typeassert (x, T), Δ -> (Δ, nothing )
5057
5158accum_param (:: Context{false} , _, Δ) = Δ
5259@generated function accum_param (cx:: Context , x, Δ)
7077
7178unwrap (x) = x
7279
73- @adjoint unwrap (x) = unwrap (x), x̄ -> (accum_param (__context__, x, x̄),)
80+ @_adjoint_keepthunks unwrap (x) = unwrap (x), x̄ -> (accum_param (__context__, x, x̄),)
7481
7582unwrap (ref, x) = x
7683
77- @adjoint unwrap (ref, x) = unwrap (x), function (x̄)
84+ @_adjoint_keepthunks unwrap (ref, x) = unwrap (x), function (x̄)
7885 accum_global (__context__, ref, x̄)
7986 (accum_param (__context__, x, x̄),)
8087end
@@ -88,7 +95,7 @@ function global_set(ref, val)
8895 end
8996end
9097
91- @adjoint ! function global_set (ref, x)
98+ @_adjoint_keepthunks ! function global_set (ref, x)
9299 global_set (ref, x), function (x̄)
93100 gs = cache (__context__)
94101 x̄ = accum (get (gs, ref, nothing ), x̄)
101108
102109using Base: tail
103110
104- @adjoint tuple (xs... ) = xs, identity
111+ @_adjoint_keepthunks tuple (xs... ) = xs, identity
105112
106- @adjoint function literal_getindex (xs:: NTuple{N,Any} , :: Val{i} ) where {N,i}
113+ @_adjoint_keepthunks function literal_getindex (xs:: NTuple{N,Any} , :: Val{i} ) where {N,i}
107114 val = xs[i]
108115 function back (Δ)
109116 accum_param (__context__, val, Δ) === nothing && return
@@ -112,7 +119,7 @@ using Base: tail
112119 val, back
113120end
114121
115- @adjoint function getindex (xs:: NTuple{N,Any} , i:: Integer ) where N
122+ @_adjoint_keepthunks function getindex (xs:: NTuple{N,Any} , i:: Integer ) where N
116123 val = xs[i]
117124 function back (Δ)
118125 accum_param (__context__, val, Δ) === nothing && return
@@ -121,10 +128,10 @@ end
121128 return val, back
122129end
123130
124- @adjoint getindex (xs:: NTuple{N,Any} , r:: AbstractUnitRange ) where N =
131+ @_adjoint_keepthunks getindex (xs:: NTuple{N,Any} , r:: AbstractUnitRange ) where N =
125132 (xs[r], Δ -> (ntuple (j -> j in r ? Δ[findfirst (isequal (j), r)] : nothing , Val (N)), nothing ))
126133
127- @adjoint function getindex (xs:: NTuple{N,Any} , r:: AbstractVector ) where N
134+ @_adjoint_keepthunks function getindex (xs:: NTuple{N,Any} , r:: AbstractVector ) where N
128135 val = xs[r]
129136 function back (Δ)
130137 dxs = ntuple (Val (length (xs))) do x
@@ -155,18 +162,18 @@ function _pullback(cx::AContext, ::typeof(literal_indexed_iterate), xs::Tuple, :
155162end
156163
157164# Needed for iteration lowering
158- @adjoint Core. getfield (xs:: NTuple{N,Any} , i:: Int ) where N =
165+ @_adjoint_keepthunks Core. getfield (xs:: NTuple{N,Any} , i:: Int ) where N =
159166 (xs[i], Δ -> (ntuple (j -> i == j ? Δ : nothing , Val (N)), nothing ))
160167
161- @adjoint Core. getfield (xs:: NamedTuple{K,<:NTuple{N,Any}} , i:: Int ) where {K,N} =
168+ @_adjoint_keepthunks Core. getfield (xs:: NamedTuple{K,<:NTuple{N,Any}} , i:: Int ) where {K,N} =
162169 (xs[i], Δ -> (NamedTuple {K} (ntuple (j -> i == j ? Δ : nothing , Val (N))), nothing ))
163170
164- @adjoint function Base. first (xs:: Tuple )
171+ @_adjoint_keepthunks function Base. first (xs:: Tuple )
165172 drest = map (_-> nothing , tail (xs))
166173 first (xs), Δ -> ((Δ, drest... ),)
167174end
168175
169- @adjoint Base. tail (xs:: Tuple ) = tail (xs), x̄s -> ((nothing , x̄s... ),)
176+ @_adjoint_keepthunks Base. tail (xs:: Tuple ) = tail (xs), x̄s -> ((nothing , x̄s... ),)
170177
171178_empty (x) = length (x)
172179_empty (x:: Union{Tuple,NamedTuple} ) = map (_-> nothing , x)
188195
189196unapply (t, xs) = _unapply (t, xs)[1 ]
190197
191- @adjoint ! function Core. _apply (f, args... )
198+ @_adjoint_keepthunks ! function Core. _apply (f, args... )
192199 y, back = Core. _apply (_pullback, (__context__, f), args... )
193200 st = map (_empty, args)
194201 y, function (Δ)
@@ -198,7 +205,7 @@ unapply(t, xs) = _unapply(t, xs)[1]
198205 end
199206end
200207
201- @adjoint ! function Core. _apply_iterate (:: typeof (iterate), f, args... )
208+ @_adjoint_keepthunks ! function Core. _apply_iterate (:: typeof (iterate), f, args... )
202209 y, back = Core. _apply (_pullback, (__context__, f), args... )
203210 st = map (_empty, args)
204211 y, function (Δ)
223230@generated pair (:: Val{k} , v, _= nothing ) where k = :($ k = v,)
224231@generated pair (:: Val{k} , v, :: NamedTuple{keys} ) where {k,keys} = k isa Int ? :($ (getfield (keys, k)) = v,) : :($ k = v,)
225232
226- @adjoint function literal_getfield (x, :: Val{f} ) where f
233+ @_adjoint_keepthunks function literal_getfield (x, :: Val{f} ) where f
227234 val = getfield (x, f)
228235 function back (Δ)
229236 accum_param (__context__, val, Δ) === nothing && return
@@ -273,8 +280,7 @@ function _get!(default::Base.Callable, ch, x)
273280 end
274281end
275282
276-
277- @adjoint! function setfield! (x, f, val)
283+ @_adjoint_keepthunks! function setfield! (x, f, val)
278284 y = setfield! (x, f, val)
279285 g = grad_mut (__context__, x)
280286 y, function (_)
@@ -290,13 +296,13 @@ end
290296
291297Jnew {T} (g) where T = Jnew {T,typeof(g)} (g)
292298
293- @adjoint ! function __new__ (T, args... )
299+ @_adjoint_keepthunks ! function __new__ (T, args... )
294300 x = __new__ (T, args... )
295301 g = ! ismutabletype (T) || fieldcount (T) == 0 ? nothing : grad_mut (__context__, x)
296302 x, Jnew {T,typeof(g),false} (g)
297303end
298304
299- @adjoint ! function __splatnew__ (T, args)
305+ @_adjoint_keepthunks ! function __splatnew__ (T, args)
300306 x = __splatnew__ (T, args)
301307 g = ! ismutabletype (T) || fieldcount (T) == 0 ? nothing : grad_mut (__context__, x)
302308 x, Jnew {T,typeof(g),true} (g)
0 commit comments