|
1 | | - |
| 1 | +# Internal function, used only for layers defined in this file. |
2 | 2 | _isactive(m, x) = isnothing(m.active) ? NNlib.within_gradient(x) : m.active |
3 | 3 |
|
4 | | -_dropout_shape(s, ::Colon) = size(s) |
5 | | -_dropout_shape(s, dims) = tuple((i ∉ dims ? 1 : si for (i, si) ∈ enumerate(size(s)))...) |
6 | | - |
7 | | -_dropout_kernel(y::T, p, q) where {T} = y > p ? T(1 / q) : T(0) |
8 | | - |
9 | | -""" |
10 | | - dropout([rng = rng_from_array(x)], x, p; dims=:, active=true) |
11 | | -
|
12 | | -The dropout function. If `active` is `true`, |
13 | | -for each input, either sets that input to `0` (with probability |
14 | | -`p`) or scales it by `1 / (1 - p)`. `dims` specifies the unbroadcasted dimensions, |
15 | | -e.g. `dims=1` applies dropout along columns and `dims=2` along rows. |
16 | | -If `active` is `false`, it just returns the input `x`. |
17 | | -
|
18 | | -Specify `rng` for custom RNGs instead of the default RNG. |
19 | | -Note that custom RNGs are only supported on the CPU. |
20 | | -
|
21 | | -Warning: when using this function, you have to manually manage the activation |
22 | | -state. Usually in fact, dropout is used while training |
23 | | -but is deactivated in the inference phase. This can be |
24 | | -automatically managed using the [`Dropout`](@ref) layer instead of the |
25 | | -`dropout` function. |
26 | | -
|
27 | | -The [`Dropout`](@ref) layer is what you should use in most scenarios. |
28 | | -""" |
29 | | -function dropout(rng, x, p; dims=:, active::Bool=true) |
30 | | - active || return x |
31 | | - y = dropout_mask(rng, x, p, dims=dims) |
32 | | - return x .* y |
33 | | -end |
34 | | -dropout(x, p; kwargs...) = dropout(rng_from_array(x), x, p; kwargs...) |
35 | | - |
36 | | -dropout_mask(rng::CUDA.RNG, x::CuArray, p; kwargs...) = _dropout_mask(rng, x, p; kwargs...) |
37 | | -dropout_mask(rng, x::CuArray, p; kwargs...) = |
38 | | - throw(ArgumentError("x isa CuArray, but rng isa $(typeof(rng)). dropout_mask only support CUDA.RNG for CuArrays.")) |
39 | | -dropout_mask(rng, x, p; kwargs...) = _dropout_mask(rng, x, p; kwargs...) |
40 | | -function _dropout_mask(rng, x, p; dims=:) |
41 | | - realfptype = float(real(eltype(x))) |
42 | | - y = rand!(rng, similar(x, realfptype, _dropout_shape(x, dims))) |
43 | | - y .= _dropout_kernel.(y, p, 1 - p) |
44 | | - return y |
45 | | -end |
46 | | - |
47 | | -# TODO move this to NNlib |
48 | | -ChainRulesCore.@non_differentiable dropout_mask(::Any, ::Any, ::Any) |
49 | | - |
50 | 4 | """ |
51 | | - Dropout(p; dims=:, rng = default_rng_value()) |
| 5 | + Dropout(p; [dims, rng]) |
52 | 6 |
|
53 | | -Dropout layer. |
| 7 | +Layer implementing [dropout](https://arxiv.org/abs/1207.0580) with the given probability. |
| 8 | +This is used as a regularisation, i.e. to reduce overfitting. |
54 | 9 |
|
55 | | -While training, for each input, this layer either sets that input to `0` (with probability |
56 | | -`p`) or scales it by `1 / (1 - p)`. To apply dropout along certain dimension(s), specify the |
57 | | -`dims` keyword. e.g. `Dropout(p; dims = 3)` will randomly zero out entire channels on WHCN input |
58 | | -(also called 2D dropout). This is used as a regularisation, i.e. it reduces overfitting during |
59 | | -training. |
| 10 | +While training, it sets each input to `0` (with probability `p`) |
| 11 | +or else scales it by `1 / (1 - p)`, using the [`NNlib.dropout`](@ref) function. |
| 12 | +While testing, it has no effect. |
60 | 13 |
|
61 | | -In the forward pass, this layer applies the [`Flux.dropout`](@ref) function. See that for more |
62 | | -details. |
| 14 | +By default the mode will switch automatically, but it can also |
| 15 | +be controlled manually via [`Flux.testmode!`](@ref). |
63 | 16 |
|
64 | | -Specify `rng` to use a custom RNG instead of the default. |
65 | | -Custom RNGs are only supported on the CPU. |
| 17 | +By default every input is treated independently. With the `dims` keyword, |
| 18 | +instead it takes a random choice only along that dimension. |
| 19 | +For example `Dropout(p; dims = 3)` will randomly zero out entire channels on WHCN input |
| 20 | +(also called 2D dropout). |
66 | 21 |
|
67 | | -Does nothing to the input once [`Flux.testmode!`](@ref) is `true`. |
| 22 | +Keyword `rng` lets you specify a custom random number generator. |
| 23 | +(Only supported on the CPU.) |
68 | 24 |
|
69 | 25 | # Examples |
70 | | -```jldoctest |
71 | | -julia> m = Chain(Dense(1 => 1), Dropout(1)); |
| 26 | +```julia |
| 27 | +julia> m = Chain(Dense(ones(3,2)), Dropout(0.4)) |
| 28 | +Chain( |
| 29 | + Dense(2 => 3), # 9 parameters |
| 30 | + Dropout(0.4), |
| 31 | +) |
72 | 32 |
|
73 | | -julia> Flux.trainmode!(m); |
| 33 | +julia> m(ones(2, 7)) # test mode, no effect |
| 34 | +3×7 Matrix{Float64}: |
| 35 | + 2.0 2.0 2.0 2.0 2.0 2.0 2.0 |
| 36 | + 2.0 2.0 2.0 2.0 2.0 2.0 2.0 |
| 37 | + 2.0 2.0 2.0 2.0 2.0 2.0 2.0 |
74 | 38 |
|
75 | | -julia> y = m([1]); |
| 39 | +julia> Flux.trainmode!(m); # would happen within gradient |
76 | 40 |
|
77 | | -julia> y == [0] |
78 | | -true |
| 41 | +julia> m(ones(2, 7)) |
| 42 | +3×7 Matrix{Float64}: |
| 43 | + 0.0 0.0 3.33333 0.0 0.0 0.0 0.0 |
| 44 | + 3.33333 0.0 3.33333 0.0 3.33333 0.0 3.33333 |
| 45 | + 3.33333 3.33333 0.0 3.33333 0.0 0.0 3.33333 |
79 | 46 |
|
80 | | -julia> m = Chain(Dense(1000 => 1000), Dropout(0.5)); |
| 47 | +julia> y = m(ones(2, 10_000)); |
81 | 48 |
|
82 | | -julia> Flux.trainmode!(m); |
| 49 | +julia> using Statistics |
83 | 50 |
|
84 | | -julia> y = m(ones(1000)); |
| 51 | +julia> mean(y) # is about 2.0, as for test mode |
| 52 | +1.9892222222222182 |
85 | 53 |
|
86 | | -julia> isapprox(count(==(0), y) / length(y), 0.5, atol=0.1) |
87 | | -true |
| 54 | +julia> mean(iszero, y) # is about 0.4 |
| 55 | +0.40323333333333333 |
88 | 56 | ``` |
89 | 57 | """ |
90 | | -mutable struct Dropout{F,D,R<:AbstractRNG} |
| 58 | +mutable struct Dropout{F<:Real,D,R<:AbstractRNG} |
91 | 59 | p::F |
92 | 60 | dims::D |
93 | 61 | active::Union{Bool, Nothing} |
94 | 62 | rng::R |
95 | 63 | end |
96 | | -Dropout(p, dims, active) = Dropout(p, dims, active, default_rng_value()) |
| 64 | +Dropout(p::Real, dims, active) = Dropout(p, dims, active, default_rng_value()) |
97 | 65 |
|
98 | | -function Dropout(p; dims=:, rng = default_rng_value()) |
99 | | - @assert 0 ≤ p ≤ 1 |
| 66 | +function Dropout(p::Real; dims=:, rng = default_rng_value()) |
| 67 | + 0 ≤ p ≤ 1 || throw(ArgumentError("Dropout expects 0 ≤ p ≤ 1, got p = $p")) |
100 | 68 | Dropout(p, dims, nothing, rng) |
101 | 69 | end |
102 | 70 |
|
103 | 71 | @functor Dropout |
104 | 72 | trainable(a::Dropout) = (;) |
105 | 73 |
|
106 | | -function (a::Dropout)(x) |
107 | | - _isactive(a, x) || return x |
108 | | - return dropout(a.rng, x, a.p; dims=a.dims, active=true) |
109 | | -end |
| 74 | +(a::Dropout)(x) = dropout(a.rng, x, a.p * _isactive(a, x); dims=a.dims) |
110 | 75 |
|
111 | 76 | testmode!(m::Dropout, mode=true) = |
112 | 77 | (m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode; m) |
|
0 commit comments