diff --git a/Project.toml b/Project.toml index 74a4ad7..6b3292e 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "MLUtils" uuid = "f1d291b0-491e-4a28-83b9-f70985020b54" authors = ["Carlo Lucibello and contributors"] -version = "0.4.4" +version = "0.4.5" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/src/utils.jl b/src/utils.jl index f7a4afa..b9c25b3 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -237,7 +237,8 @@ end @non_differentiable _partition_idxs(::Any...) # Similar to ∇eachslice https://github.com/JuliaDiff/ChainRules.jl/blob/8108a77a96af5d4b0c460aac393e44f8943f3c5e/src/rulesets/Base/indexing.jl#L77 -function ∇chunk(dys, x, idxs, vd::Val{dim}) where {dim} +function ∇chunk(dys_raw, x, idxs, vd::Val{dim}) where {dim} + dys = unthunk.(unthunk(dys_raw)) # https://github.com/FluxML/Zygote.jl/pull/966#issuecomment-2569227272 i1 = findfirst(dy -> !(dy isa AbstractZero), dys) if i1 === nothing # all slices are Zero! return _zero_fill!(similar(x, float(eltype(x))))