diff --git a/src/bias_act.jl b/src/bias_act.jl index ef7fb29d..935a5023 100644 --- a/src/bias_act.jl +++ b/src/bias_act.jl @@ -8,7 +8,7 @@ const RCR = RuleConfig{>:HasReverseMode} @inline only_derivative(y,f::F,x) where F = only(only(ChainRulesCore.derivatives_given_output(y, f, x))) # This has no methods, used for testing whether `derivatives_given_output(Ω, f, x)` -# is independent of `x`, as `_return_type` says `Union{}` when calling is an error. +# is independent of `x`, as `return_type` says `Union{}` when calling is an error. struct NotaNumber <: Real end """ @@ -57,7 +57,7 @@ function ChainRulesCore.rrule(cfg::RCR, ::typeof(bias_act!), σ::F, x::AbstractA end # Fast path: it is now safe to overwrite x, since this is not needed for gradient of σ - if isconcretetype(Core.Compiler._return_type(only_derivative, Tuple{T, F, NotaNumber})) + if isconcretetype(Core.Compiler.return_type(only_derivative, Tuple{T, F, NotaNumber})) Ω = bias_act!(σ, x, b) # now x === Ω, when x isa StridedArray{<:AbstractFloat} function bias_act!_fastback(Δ) # Tempting to overwrite x again, but only safe if you call pullback at most once, @@ -70,7 +70,7 @@ function ChainRulesCore.rrule(cfg::RCR, ::typeof(bias_act!), σ::F, x::AbstractA # # Slower path: can't overwrite x, but can use derivatives_given_output # # This case is WRONG and tests fail, but not sure why - # elseif isconcretetype(Core.Compiler._return_type(only_derivative, Tuple{T, F, T})) + # elseif isconcretetype(Core.Compiler.return_type(only_derivative, Tuple{T, F, T})) # Ω2 = fast_act(σ, x).(x) .+ b # @show σ b # function bias_act!_back2(Δ)