FluxTraining.jl gives you an endlessly extensible training loop for deep learning

Overview

FluxTraining.jl

Docs (master)

A powerful, extensible neural net training library.

FluxTraining.jl gives you an endlessly extensible training loop for deep learning inspired by fastai's training loop. It is the training backend for FastAI.jl.

It exposes a small set of extensible interfaces and uses them to implement

  • hyperparameter scheduling
  • metrics
  • logging
  • training history; and
  • model checkpointing

Install using ]add FluxTraining.

Read getting started first and the user guide if you want to know more.

Comments
  • TagBot trigger issue

    TagBot trigger issue

    This issue is used to trigger TagBot; feel free to unsubscribe.

    If you haven't already, you should update your TagBot.yml to include issue comment triggers. Please see this post on Discourse for instructions and more details.

    opened by JuliaTagBot 15
  • Use SnoopPrecompile.jl

    Use SnoopPrecompile.jl

    This adds a basic precompile statement using SnoopPrecompile.jl.

    This reduces the Time-to-first-fit! by

    Measurements:

    • using FluxTraining: 21s (this PR), 19s (master) -> 2s slower
    • fit!(testlearner(), 1): 14.5s (this PR), 30s (master) -> 15s faster
    • both: 35.5s (this PR), 49s (master) -> 13.5s/40% faster

    This seems like a clear win for me, except for the longer precompilation time which will only occur once for regular package usage. Has anyone tried using SnoopPrecompile.jl for other packages in the FluxML org?

    opened by lorenzoh 8
  • `Scheduler` causes cycle in execution DAG?

    `Scheduler` causes cycle in execution DAG?

    I have the following script:

    lossfn = Flux.Losses.logitcrossentropy
    
    # define schedule and optimizer
    initial_lr = 0.1
    schedule = Step(initial_lr, 0.5, 20)
    optim = Flux.Optimiser(Momentum(initial_lr), WeightDecay(1e-3))
    
    # callbacks
    logger = TensorBoardBackend("tblogs")
    schcb = Scheduler(LearningRate => schedule)
    hlogcb = LogHyperParams(logger)
    mlogcb = LogMetrics(logger)
    valcb = Metrics(Metric(accuracy; phase = TrainingPhase, name = "train_acc"),
                    Metric(accuracy; phase = ValidationPhase, name = "val_acc"))
    
    # setup learner object
    learner = Learner(m, lossfn;
                      data = (trainloader, valloader),
                      optimizer = optim,
                      callbacks = [ToGPU(), mlogcb, valcb])
    

    Any time I add schcb to the list of callbacks passed to the Learner, I get an error from FluxTraining that there is a cycle in the DAG. This did not happen in previous versions of FluxTraining (though I haven't been able to bisect the change yet).

    opened by darsnack 4
  • Use Optimisers.jl

    Use Optimisers.jl

    With Flux.jl 0.13 moving to use the explicit optimisers in Optimisers.jl, I think FluxTraining.jl should also use those as a default.

    This would also allow easier integration with alternative ADs like, PyCallChainRules.jl, see https://github.com/rejuvyesh/PyCallChainRules.jl/issues/19.

    @ToucheSir can this be done in a backward-compatible way, i.e. supporting Flux v0.12 and below or does Optimisers.jl depend on Flux v0.13?

    opened by lorenzoh 4
  • Break out Schedule

    Break out Schedule

    Does it make sense to break out Schedule from FluxTraining.jl? It seems like you hit upon a cool Julia package to use for scheduling, and we could use it as the base for implementing several common LR schedules. Would be nice to be able to write something like Cyclic(period = n) etc.

    I can give this a shot in a repo if you think it makes sense.

    opened by darsnack 4
  • Record time trained, training loss, validation loss and performance

    Record time trained, training loss, validation loss and performance

    For my application, I would love to be able to record the time trained, training loss, validation loss and classification performance at a given time-interval in the training loop. But currently, the History seems only able to store number of epochs, steps, and steps in current epoch.

    Would there be a way to make the History for extendable, so that users can record anything they want?

    A final detail would be that I want to record these stats only after a factor increase in training time, so that when I plot e.g. training loss again a logarithmic time scale, I get somewhat evenly distributed numbers. I am not sure how to make that happen, and I do not expect it to be built in functionality. I am just mentioning it in case it would be simple enough to implement.

    opened by KronosTheLate 3
  • Add SanityCheck callback

    Add SanityCheck callback

    Adds a SanityCheck callback as discussed on Zulip.

    If some checks don't pass the output will look like this:

    1/4 sanity checks failed:
    ---
    1: Model and loss function compatible with data (ERROR)
    
    To perform the optimization step, model and loss function need
    to be compatible with the data. This means the following must work:
    
    - `(x, y), _ = iterate(learner.data.training)`
    - `ŷ = learner.model(x)`
    - `loss = learner.lossfunction(learner.model(x), y)
    
    opened by lorenzoh 3
  • How not to have printing callbacks?

    How not to have printing callbacks?

    Is there a way of constructing a Learner without certain callbacks?

    julia> Learner(predict, lossfn; callbacks = [Metrics(accuracy)]).callbacks
    FluxTraining.Callbacks(FluxTraining.SafeCallback[Metrics(Loss(), Metric(Accuracy)), ProgressPrinter(), MetricsPrinter(), StopOnNaNLoss(), Recorder()], FluxTraining.LinearRunner(), {5, 5} directed simple Int64 graph, false)
    
    julia> Learner(predict, lossfn).callbacks
    FluxTraining.Callbacks(FluxTraining.SafeCallback[ProgressPrinter(), MetricsPrinter(), StopOnNaNLoss(), Recorder(), Metrics(Loss())], FluxTraining.LinearRunner(), {5, 5} directed simple Int64 graph, false)
    

    Or at least, to remove callbacks after construction? image

    opened by KronosTheLate 2
  • docs oddities

    docs oddities

    At the very top of this doc page https://fluxml.ai/FluxTraining.jl/dev/i/?id=documents%2Fdocs%2Fcallbacks%2Fusage.md

    using FluxTraining
    using FluxTraining: Callback, Read, Write, stateaccess
    model, data, lossfn = nothing, (), nothing, nothing
    

    Is that intended?

    Also, if from that page I follow a couple of links and land to https://fluxml.ai/FluxTraining.jl/dev/i/?id=documents%2FREADME.md&id=documents%2Fdocs%2Fcallbacks%2Fcustom.md&id=documents%2Fdocs%2Ftutorials%2Ftraining.md but then i don't see any buttons for closing all those panes or for going back. Even the browser back button doesn't have any effect.

    One last thing is that I don't see a link when visualizing the docstring of types/methods to jump to the source code.

    opened by CarloLucibello 2
  • (Documentation) Document `Learner` components

    (Documentation) Document `Learner` components

    This adds better documentation for the components of a Learner, i.e. model, data iterator, optimizer and loss function.

    Also makes the README clearer and fixes some broken links.

    opened by lorenzoh 2
  • Error displaying EarlyStopper

    Error displaying EarlyStopper

    Show method for EarlyStopping throws an error:

    julia> using FluxTraining
    
    julia> FluxTraining.EarlyStopping(1)
    Error showing value of type EarlyStopping:
    ERROR: type EarlyStopping has no field stopper
    Stacktrace:
      [1] getproperty(x::EarlyStopping, f::Symbol)
        @ Base ./Base.jl:33
      [2] show(io::IOContext{Base.TTY}, cb::EarlyStopping)
        @ FluxTraining ~/.julia/packages/FluxTraining/LfCE3/src/callbacks/earlystopping.jl:56
      [3] show(io::IOContext{Base.TTY}, #unused#::MIME{Symbol("text/plain")}, x::EarlyStopping)
        @ Base.Multimedia ./multimedia.jl:47
      [4] (::REPL.var"#38#39"{REPL.REPLDisplay{REPL.LineEditREPL}, MIME{Symbol("text/plain")}, Base.RefValue{Any}})(io::Any)
        @ REPL /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.6/REPL/src/REPL.jl:220
      [5] with_repl_linfo(f::Any, repl::REPL.LineEditREPL)
        @ REPL /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.6/REPL/src/REPL.jl:462
      [6] display(d::REPL.REPLDisplay, mime::MIME{Symbol("text/plain")}, x::Any)
        @ REPL /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.6/REPL/src/REPL.jl:213
      [7] display(d::REPL.REPLDisplay, x::Any)
        @ REPL /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.6/REPL/src/REPL.jl:225
      [8] display(x::Any)
        @ Base.Multimedia ./multimedia.jl:328
      [9] (::Media.var"#15#16"{EarlyStopping})()
        @ Media ~/.julia/packages/Media/ItEPc/src/compat.jl:28
     [10] hookless(f::Media.var"#15#16"{EarlyStopping})
        @ Media ~/.julia/packages/Media/ItEPc/src/compat.jl:14
     [11] render(#unused#::Media.NoDisplay, x::EarlyStopping)
        @ Media ~/.julia/packages/Media/ItEPc/src/compat.jl:27
     [12] render(x::EarlyStopping)
        @ Media ~/.julia/packages/Media/ItEPc/src/system.jl:160
     [13] display(#unused#::Media.DisplayHook, x::EarlyStopping)
        @ Media ~/.julia/packages/Media/ItEPc/src/compat.jl:9
     [14] display(x::Any)
        @ Base.Multimedia ./multimedia.jl:328
     [15] #invokelatest#2
        @ ./essentials.jl:708 [inlined]
     [16] invokelatest
        @ ./essentials.jl:706 [inlined]
     [17] print_response(errio::IO, response::Any, show_value::Bool, have_color::Bool, specialdisplay::Union{Nothing, AbstractDisplay})
        @ REPL /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.6/REPL/src/REPL.jl:247
     [18] (::REPL.var"#40#41"{REPL.LineEditREPL, Pair{Any, Bool}, Bool, Bool})(io::Any)
        @ REPL /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.6/REPL/src/REPL.jl:231
     [19] with_repl_linfo(f::Any, repl::REPL.LineEditREPL)
        @ REPL /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.6/REPL/src/REPL.jl:462
     [20] print_response(repl::REPL.AbstractREPL, response::Any, show_value::Bool, have_color::Bool)
        @ REPL /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.6/REPL/src/REPL.jl:229
     [21] (::REPL.var"#do_respond#61"{Bool, Bool, REPL.var"#72#82"{REPL.LineEditREPL, REPL.REPLHistoryProvider}, REPL.LineEditREPL, REPL.LineEdit.Prompt})(s::REPL.LineEdit.MIState, buf::Any, ok::Bool)
        @ REPL /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.6/REPL/src/REPL.jl:798
     [22] #invokelatest#2
        @ ./essentials.jl:708 [inlined]
     [23] invokelatest
        @ ./essentials.jl:706 [inlined]
     [24] run_interface(terminal::REPL.Terminals.TextTerminal, m::REPL.LineEdit.ModalInterface, s::REPL.LineEdit.MIState)
        @ REPL.LineEdit /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.6/REPL/src/LineEdit.jl:2441
     [25] run_frontend(repl::REPL.LineEditREPL, backend::REPL.REPLBackendRef)
        @ REPL /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.6/REPL/src/REPL.jl:1126
     [26] (::REPL.var"#44#49"{REPL.LineEditREPL, REPL.REPLBackendRef})()
        @ REPL ./task.jl:411
    
    

    Thanks for the awesome package! Should I open a PR to fix this?

    opened by awadell1 2
  • Question regarding ProgressPrinter

    Question regarding ProgressPrinter

    Hi first off: wonderful package :)

    I have some issues with the ProgressPrinter not showing up even when using the defaultcallbacks.

    learner = Learner(model, loss; optimizer=opt, callbacks=[ToGPU()], usedefaultcallbacks=true)
    FluxTraining.fit!(learner, epochs, (dl, val_dl)) # where dl, dl_val are both Flux.DataLoader objects
    

    Do I need to do something specific when constructing the Learner which I have missed? From the code it seems like I would need to give it a Progress object, do I have to construct that myself? What requirements does my data-iterator have to fullfill to show up with the defaultcallbacks?

    opened by hv10 5
  • CompatHelper: bump compat for PrettyTables to 2, (keep existing compat)

    CompatHelper: bump compat for PrettyTables to 2, (keep existing compat)

    This pull request changes the compat entry for the PrettyTables package from 1, 1.1, 1.2 to 1, 1.1, 1.2, 2. This keeps the compat entries for earlier versions.

    Note: I have not tested your package with this new compat entry. It is your responsibility to make sure that your package tests pass before you merge this pull request.

    opened by github-actions[bot] 0
  • `accuracy` doesn't work with `DataLoader`

    `accuracy` doesn't work with `DataLoader`

    since there's an example of using accuracy in Metric, it probably:

    1. should be documented
    2. should be made working with either learner or (model, dataloader)
    opened by Moelf 0
  • Improve printing during training

    Improve printing during training

    With the single metric accuracy, the output-tables (which I love) look like this:

    Epoch 11 TrainingPhase(): 100%|███████████████████████████████████████████| Time: 0:00:00
    ┌───────────────┬───────┬─────────┬──────────┐
    │         Phase │ Epoch │    Loss │ Accuracy │
    ├───────────────┼───────┼─────────┼──────────┤
    │ TrainingPhase │  11.0 │ 0.25969 │  0.92827 │
    └───────────────┴───────┴─────────┴──────────┘
    ┌─────────────────┬───────┬─────────┬──────────┐
    │           Phase │ Epoch │    Loss │ Accuracy │
    ├─────────────────┼───────┼─────────┼──────────┤
    │ ValidationPhase │  11.0 │ 0.26323 │  0.92731 │
    └─────────────────┴───────┴─────────┴──────────┘
    

    I suggest putting them into the same table, and making the Epoch vector of element type Int64, to make it look like this:

    Epoch 11 TrainingPhase(): 100%|███████████████████████████████████████████| Time: 0:00:00
    ┌────────--───────┬───────┬─────────┬──────────┐
    │         Phase   │ Epoch │    Loss │ Accuracy │
    ├──────────--─────┼───────┼─────────┼──────────┤
    │ TrainingPhase   │  11   │ 0.25969 │  0.92827 │
    │ ValidationPhase │  11   │ 0.26323 │  0.92731 │
    └─────────────────┴───────┴─────────┴──────────┘
    
    opened by KronosTheLate 1
  • Quickstart tutorial broken

    Quickstart tutorial broken

    The example Training an image classifier currently uses the following code:

    xs, ys = (
        # convert each image into h*w*1 array of floats 
        [Float32.(reshape(img, 28, 28, 1)) for img in Flux.Data.MNIST.images()],
        # one-hot encode the labels
        [Float32.(Flux.onehot(y, 0:9)) for y in Flux.Data.MNIST.labels()],
    )
    

    However,

    (Project) pkg> st Flux
          Status `C:\Users\Dennis Bal\ProjectFolder\Project.toml`
      [587475ba] Flux v0.13.0
    
    julia> using Flux
    
    julia> Flux.Data.MNIST
    ERROR: UndefVarError: MNIST not defined
    Stacktrace:
     [1] getproperty(x::Module, f::Symbol)
       @ Base .\Base.jl:35
     [2] top-level scope
       @ REPL[16]:1
    
    

    So the example is broken. As a side note, I think the example would do great by using MLUtils instead of DataLoaders.jl and MLDataPattern. Also, Flux imports DataLoader so no need to explicitly import it.

    But I take a look at the docs and try to get started. So I make the following code, that works with Flux's base capacities:

    julia> using Flux
    
    julia> using Flux: onehotbatch, onecold
    
    julia> using FluxTraining
    
    julia> using MLUtils: flatten, unsqueeze
    
    julia> using MLDatasets
    
    julia> labels = 0:9
    0:9
    
    julia> traindata = MNIST.traindata(Float32) |> x->(unsqueeze(x[1], 3), onehotbatch(x[2], labels));
    
    julia> size.(traindata)
    ((28, 28, 1, 60000), (10, 60000))
    
    julia> trainloader = DataLoader(traindata, batchsize=128);
    
    julia> validdata = MNIST.testdata(Float32) |> x->(unsqueeze(x[1], 3), onehotbatch(x[2], labels)); 
    
    julia> size.(validdata)
    ((28, 28, 1, 10000), (10, 10000))
    
    julia> validloader = DataLoader(validdata, batchsize=128);
    
    julia> predict = Chain(flatten, Dense(28^2, 10))
    Chain(
      MLUtils.flatten,
      Dense(784 => 10),                     # 7_850 parameters
    )
    
    julia> lossfunc(x, y) = Flux.Losses.logitbinarycrossentropy(predict(x), y)
    lossfunc (generic function with 1 method)
    
    julia> optimizer=ADAM()
    ADAM(0.001, (0.9, 0.999), 1.0e-8, IdDict{Any, Any}())
    
    julia> callbacks = [Metrics(accuracy)]
    1-element Vector{Metrics}:
     Metrics(Loss(), Metric(Accuracy))
    
    julia> learner = Learner(predict, lossfunc; optimizer, callbacks)
    Learner()
    
    

    At this point, I start checking loss and training with Flux's train!:

    julia> lossfunc(validdata...)
    0.7624986f0
    
    julia> Flux.train!(lossfunc, Flux.params(predict), trainloader, optimizer)
    
    julia> lossfunc(validdata...)
    0.11266354f0
    
    julia> Flux.train!(lossfunc, Flux.params(predict), trainloader, optimizer)
    
    julia> lossfunc(validdata...)
    0.08880948f0
    
    julia> Flux.train!(lossfunc, Flux.params(predict), trainloader, optimizer)
    
    julia> lossfunc(validdata...)
    0.0801171f0
    

    Training no problem. However, when I try to train my learner, it seems like a single float is passed to predict, and not an array:

    julia> fit!(learner, 1, (traindata, validdata))
    Epoch 1 TrainingPhase() ...
    ERROR: MethodError: no method matching flatten(::Float32)
    Closest candidates are:
      flatten(::AbstractArray) at C:\Users\usrname\.julia\packages\MLUtils\QTRw7\src\utils.jl:424  
    Stacktrace:
      [1] macro expansion
        @ C:\Users\usrname\.julia\packages\Zygote\Y6SC4\src\compiler\interface2.jl:0 [inlined]     
      [2] _pullback(ctx::Zygote.Context, f::typeof(flatten), args::Float32)
        @ Zygote C:\Users\usrname\.julia\packages\Zygote\Y6SC4\src\compiler\interface2.jl:9        
      [3] macro expansion
        @ C:\Users\usrname\.julia\packages\Flux\18YZE\src\layers\basic.jl:53 [inlined]
      [4] _pullback
        @ C:\Users\usrname\.julia\packages\Flux\18YZE\src\layers\basic.jl:53 [inlined]
      [5] _pullback(::Zygote.Context, ::typeof(Flux.applychain), ::Tuple{typeof(flatten), Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}, ::Float32)
        @ Zygote C:\Users\usrname\.julia\packages\Zygote\Y6SC4\src\compiler\interface2.jl:0        
      [6] _pullback
        @ C:\Users\usrname\.julia\packages\Flux\18YZE\src\layers\basic.jl:51 [inlined]
      [7] _pullback(ctx::Zygote.Context, f::Chain{Tuple{typeof(flatten), Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, args::Float32)
        @ Zygote C:\Users\usrname\.julia\packages\Zygote\Y6SC4\src\compiler\interface2.jl:0        
      [8] _pullback
        @ C:\Users\usrname\.julia\packages\FluxTraining\iBFSd\src\training.jl:54 [inlined]
      [9] _pullback(ctx::Zygote.Context, f::FluxTraining.var"#70#72"{FluxTraining.var"#handlefn#78"{Learner, TrainingPhase}, FluxTraining.PropDict{Any}, Learner}, args::Chain{Tuple{typeof(flatten), Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}})
        @ Zygote C:\Users\usrname\.julia\packages\Zygote\Y6SC4\src\compiler\interface2.jl:0        
     [10] _pullback
        @ C:\Users\usrname\.julia\packages\FluxTraining\iBFSd\src\training.jl:70 [inlined]
     [11] _pullback(::Zygote.Context, ::FluxTraining.var"#73#74"{FluxTraining.var"#70#72"{FluxTraining.var"#handlefn#78"{Learner, TrainingPhase}, FluxTraining.PropDict{Any}, Learner}, Chain{Tuple{typeof(flatten), Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}})
        @ Zygote C:\Users\usrname\.julia\packages\Zygote\Y6SC4\src\compiler\interface2.jl:0        
     [12] pullback(f::Function, ps::Zygote.Params{Zygote.Buffer{Any, Vector{Any}}})
        @ Zygote C:\Users\usrname\.julia\packages\Zygote\Y6SC4\src\compiler\interface.jl:352       
     [13] gradient(f::Function, args::Zygote.Params{Zygote.Buffer{Any, Vector{Any}}})
        @ Zygote C:\Users\usrname\.julia\packages\Zygote\Y6SC4\src\compiler\interface.jl:75        
     [14] _gradient(f::FluxTraining.var"#70#72"{FluxTraining.var"#handlefn#78"{Learner, TrainingPhase}, FluxTraining.PropDict{Any}, Learner}, #unused#::ADAM, m::Chain{Tuple{typeof(flatten), Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, ps::Zygote.Params{Zygote.Buffer{Any, Vector{Any}}})
        @ FluxTraining C:\Users\usrname\.julia\packages\FluxTraining\iBFSd\src\training.jl:70      
     [15] (::FluxTraining.var"#69#71"{Learner})(handle::FluxTraining.var"#handlefn#78"{Learner, TrainingPhase}, state::FluxTraining.PropDict{Any})
        @ FluxTraining C:\Users\usrname\.julia\packages\FluxTraining\iBFSd\src\training.jl:53      
     [16] runstep(stepfn::FluxTraining.var"#69#71"{Learner}, learner::Learner, phase::TrainingPhase, initialstate::NamedTuple{(:xs, :ys), Tuple{Float32, Float32}})
        @ FluxTraining C:\Users\usrname\.julia\packages\FluxTraining\iBFSd\src\training.jl:133     
     [17] step!
        @ C:\Users\usrname\.julia\packages\FluxTraining\iBFSd\src\training.jl:51 [inlined]
     [18] (::FluxTraining.var"#67#68"{Learner, TrainingPhase, Tuple{Array{Float32, 4}, Flux.OneHotArray{UInt32, 10, 1, 2, Vector{UInt32}}}})(#unused#::Function)
        @ FluxTraining C:\Users\usrname\.julia\packages\FluxTraining\iBFSd\src\training.jl:24      
     [19] runepoch(epochfn::FluxTraining.var"#67#68"{Learner, TrainingPhase, Tuple{Array{Float32, 4}, Flux.OneHotArray{UInt32, 10, 1, 2, Vector{UInt32}}}}, learner::Learner, phase::TrainingPhase)     
        @ FluxTraining C:\Users\usrname\.julia\packages\FluxTraining\iBFSd\src\training.jl:105     
     [20] epoch!
        @ C:\Users\usrname\.julia\packages\FluxTraining\iBFSd\src\training.jl:22 [inlined]
     [21] fit!(learner::Learner, nepochs::Int64, ::Tuple{Tuple{Array{Float32, 4}, Flux.OneHotArray{UInt32, 10, 1, 2, Vector{UInt32}}}, Tuple{Array{Float32, 4}, Flux.OneHotArray{UInt32, 10, 1, 2, Vector{UInt32}}}})
        @ FluxTraining C:\Users\usrname\.julia\packages\FluxTraining\iBFSd\src\training.jl:168     
     [22] top-level scope
        @ REPL[51]:1
    

    I am completely stuck as to what goes wrong. Pointers in that regard would be appreciated, but the main issue is making the example functional, and updating the packages used to load data and the utility functions that I take from MLUtils.

    To improve the reliability of this package, could doc testing be used to ensure that in the future, the documentation examples actually run?

    opened by KronosTheLate 8
  • Add metadata field to `Learner`

    Add metadata field to `Learner`

    This adds a "metadata" PropDict to Learner for storing information that is required for training but extraneous to the training state or callback state. This is useful for unconventional training methods (issue that I am currently dealing with). In the same way that the loss function is a "parameter" that needs to be specified to standard supervised training, the metadata field holds parameters that need to be specified for unconventional training. Of course, we can't know what these parameters will be like standard training, so instead of explicit names, we provide a container to hold them.

    opened by darsnack 7
Releases(v0.3.5)
  • v0.3.5(Dec 17, 2022)

  • v0.3.4(Oct 22, 2022)

  • v0.3.3(Sep 14, 2022)

    FluxTraining v0.3.3

    Diff since v0.3.2

    Closed issues:

    • docs oddities (#113)
    • Record time trained, training loss, validation loss and performance (#119)
    • How not to have printing callbacks? (#126)
    • ignore(f) is deprecated (#128)

    Merged pull requests:

    • How not to print docstring (#127) (@KronosTheLate)
    • Fix deprecate warnings (#129) (@yuehhua)
    • Fix the quickstart tutorial (#130) (@christiangnrd)
    • fix typo in docs (#133) (@Moelf)
    • Fix README links (#134) (@lorenzoh)
    • Add ability to record and log arbitrary learner values (#136) (@darsnack)
    • Use show method for phase in MetricsPrinter (#137) (@darsnack)
    Source code(tar.gz)
    Source code(zip)
  • v0.3.2(Jun 10, 2022)

  • v0.3.1(May 27, 2022)

    FluxTraining v0.3.1

    Diff since v0.3.0

    Closed issues:

    • Simpler Learner API (#104)
    • Use Optimisers.jl (#112)
    • Add callback support for Optimisers.jl (#115)

    Merged pull requests:

    • Update Pollen documentation to new PkgTemplate workflow (#111) (@lorenzoh)
    • Add support for Optimisers.jl (#114) (@lorenzoh)
    • CompatHelper: add new compat entry for Optimisers at version 0.2, (keep existing compat) (#116) (@github-actions[bot])
    • Add callback support for Optimisers.jl (#117) (@lorenzoh)
    • CompatHelper: add new compat entry for Setfield at version 0.8, (keep existing compat) (#118) (@github-actions[bot])
    Source code(tar.gz)
    Source code(zip)
  • v0.3.0(Apr 16, 2022)

    FluxTraining v0.3.0

    Diff since v0.2.4

    Closed issues:

    • Switch to ParameterSchedulers.jl (#106)

    Merged pull requests:

    • Add more convenient Learner method (#105) (@lorenzoh)
    • Switch to ParameterSchedulers.jl (#107) (@rejuvyesh)
    • Prepare v0.3.0 (#108) (@lorenzoh)
    Source code(tar.gz)
    Source code(zip)
  • v0.2.4(Mar 9, 2022)

    FluxTraining v0.2.4

    Diff since v0.2.3

    Closed issues:

    • Allow restricting phases during which a Metric runs (#84)

    Merged pull requests:

    • CompatHelper: bump compat for EarlyStopping to 0.2, (keep existing compat) (#92) (@github-actions[bot])
    • CompatHelper: bump compat for EarlyStopping to 0.3, (keep existing compat) (#94) (@github-actions[bot])
    • doc: fix link to TensorBoardLogger.jl (#95) (@visr)
    • Move documentation system to Pollen.jl (#96) (@lorenzoh)
    • Use ReTest.jl to run tests (#97) (@lorenzoh)
    • Add and improve a lot of docstrings and a few test cases (#99) (@lorenzoh)
    • Add phase argument to Metric (#100) (@lorenzoh)
    • CompatHelper: add new compat entry for InlineTest at version 0.2, (keep existing compat) (#101) (@github-actions[bot])
    • (Documentation) Document Learner components (#102) (@lorenzoh)
    • Add Flux 0.13 compatibility (#103) (@lorenzoh)
    Source code(tar.gz)
    Source code(zip)
  • v0.2.3(Oct 21, 2021)

  • v0.2.2(Oct 12, 2021)

  • v0.2.1(Jul 26, 2021)

  • v0.2.0(Jul 3, 2021)

    FluxTraining v0.2.0

    Diff since v0.1.3

    Added

    Changed

    • Batch* renamed to Step*:
      • events: BatchBegin now StepBegin, BatchEnd now StepEnd
      • CancelBatchException now CancelStepException.
      • field Learner.batch now Learner.step
    • Learner.step/batch is no longer a special struct but now a PropDict, allowing you to set arbitrary fields.
    • Learner.model can now be a NamedTuple/Tuple of models for use in custom training loops. Likewise, learner.params now resembles the structure of learner.model, allowing separate access to parameters of different models.
    • Callbacks
      • Added init! method for callback initilization, replacing the Init event which required a Phase to implement.
      • Scheduler now has internal step counter and no longer relies on Recorder's history. This makes it easier to replace the scheduler without needing to offset the new schedules.
      • EarlyStopping callback now uses criteria from EarlyStopping.jl

    Removed

    • Removed old training API. Methods fitbatch!, fitbatchphase!, fitepoch!, fitepochphase! have all been removed.

    Closed issues:

    • Scheduler applies schedules per batch by default (#68)
    • Recorder does not work with models with non-Array inputs. (#80)

    Merged pull requests:

    • CompatHelper: bump compat for "BSON" to "0.3" (#69) (@github-actions[bot])
    • use EarlyStopping.jl for stopping criteria (#72) (@lorenzoh)
    • CompatHelper: bump compat for "PrettyTables" to "0.12" (#73) (@github-actions[bot])
    • Move documentation to Pollen.jl (#77) (@lorenzoh)
    • Revert onecycle (#78) (@lorenzoh)
    • Remove samplesfield from History (#81) (@lorenzoh)
    • New training API and QoL improvements (v0.2.0) (#83) (@lorenzoh)
    Source code(tar.gz)
    Source code(zip)
  • v0.1.3(Feb 21, 2021)

    FluxTraining v0.1.3

    Diff since v0.1.2

    Closed issues:

    • Metrics wraps Metric(s) in Metric(s) (#66)
    • Unnecessary softmax in accuracy? (#67)

    Merged pull requests:

    • Internal callback changes (#65) (@lorenzoh)
    • fix #66 make Metric subtype AbstractMetric (#70) (@lorenzoh)
    • Lorenzoh/fix/67 (#71) (@lorenzoh)
    Source code(tar.gz)
    Source code(zip)
  • v0.1.2(Jan 28, 2021)

  • v0.1.1(Jan 18, 2021)

    FluxTraining v0.1.1

    Diff since v0.1.0

    Closed issues:

    • Link to Loss broken (#40)
    • Early stopping link broken (#41)
    • BatchEnd link broken (#42)
    • Break out Schedule (#44)
    • How to verify GPU is working? (#47)
    • What is the role of the test data? (#48)
    • What do you think of Data Modules? (#53)

    Merged pull requests:

    • Add testing as CI step (#36) (@lorenzoh)
    • Better docstrings (#37) (@lorenzoh)
    • Typo (#43) (@drozzy)
    • WRONG order of arguments to the Learner. (#45) (@drozzy)
    • Fix issue #42 - missing docstrings (#49) (@lorenzoh)
    • add EarlyStopping docstring (#50) (@lorenzoh)
    • Update callback reference section on metrics (#51) (@lorenzoh)
    • Remove no longer needed dependencies from Project.toml (#52) (@lorenzoh)
    • Add SanityCheck callback (#56) (@lorenzoh)
    • Fix GarbageCollect callback (#57) (@lorenzoh)
    • Fix TensorBoard image serialization (#58) (@lorenzoh)
    • Sanitycheck (#60) (@lorenzoh)
    • update compat bounds (#61) (@lorenzoh)
    Source code(tar.gz)
    Source code(zip)
  • v0.1.0(Nov 7, 2020)

    FluxTraining v0.1.0

    Closed issues:

    • Collaborating on a FastAI port? (#13)

    Merged pull requests:

    • Remove unused package LearnBase (#26) (@nrhodes)
    • Fix protected tests (#28) (@nrhodes)
    • Log model histograms in TensorBoard (#30) (@ToucheSir)
    Source code(tar.gz)
    Source code(zip)
Owner
CompSci student from Berlin, currently working on FastAI.jl
Extreme Rotation Estimation using Dense Correlation Volumes

Extreme Rotation Estimation using Dense Correlation Volumes This repository contains a PyTorch implementation of the paper: Extreme Rotation Estimatio

Ruojin Cai 29 Nov 18, 2022
AI-Fitness-Tracker - AI Fitness Tracker With Python

AI-Fitness-Tracker We have build a AI based Fitness Tracker using OpenCV and Pyt

Sharvari Mangale 5 Feb 09, 2022
Implementation of Bottleneck Transformer in Pytorch

Bottleneck Transformer - Pytorch Implementation of Bottleneck Transformer, SotA visual recognition model with convolution + attention that outperforms

Phil Wang 621 Jan 06, 2023
An official reimplementation of the method described in the INTERSPEECH 2021 paper - Speech Resynthesis from Discrete Disentangled Self-Supervised Representations.

Speech Resynthesis from Discrete Disentangled Self-Supervised Representations Implementation of the method described in the Speech Resynthesis from Di

Facebook Research 253 Jan 06, 2023
Normal Learning in Videos with Attention Prototype Network

Codes_APN Official codes of CVPR21 paper: Normal Learning in Videos with Attention Prototype Network (https://arxiv.org/abs/2108.11055) Overview of ou

11 Dec 13, 2022
(CVPR 2022) Energy-based Latent Aligner for Incremental Learning

Energy-based Latent Aligner for Incremental Learning Accepted to CVPR 2022 We illustrate an Incremental Learning model trained on a continuum of tasks

Joseph K J 37 Jan 03, 2023
Integrated physics-based and ligand-based modeling.

ComBind ComBind integrates data-driven modeling and physics-based docking for improved binding pose prediction and binding affinity prediction. Given

Dror Lab 44 Oct 26, 2022
This repository is based on Ultralytics/yolov5, with adjustments to enable rotate prediction boxes.

Rotate-Yolov5 This repository is based on Ultralytics/yolov5, with adjustments to enable rotate prediction boxes. Section I. Description The codes are

xinzelee 90 Dec 13, 2022
[CVPR 2022] CoTTA Code for our CVPR 2022 paper Continual Test-Time Domain Adaptation

CoTTA Code for our CVPR 2022 paper Continual Test-Time Domain Adaptation Prerequisite Please create and activate the following conda envrionment. To r

Qin Wang 87 Jan 08, 2023
(CVPR2021) Kaleido-BERT: Vision-Language Pre-training on Fashion Domain

Kaleido-BERT: Vision-Language Pre-training on Fashion Domain Mingchen Zhuge*, Dehong Gao*, Deng-Ping Fan#, Linbo Jin, Ben Chen, Haoming Zhou, Minghui

250 Jan 08, 2023
FwordCTF 2021 Infrastructure and Source code of Web/Bash challenges

FwordCTF 2021 You can find here the source code of the challenges I wrote (Web and Bash) in FwordCTF 2021 and the source code of the platform with our

Kahla 5 Nov 25, 2022
Apollo optimizer in tensorflow

Apollo Optimizer in Tensorflow 2.x Notes: Warmup is important with Apollo optimizer, so be sure to pass in a learning rate schedule vs. a constant lea

Evan Walters 1 Nov 09, 2021
Adversarial Attacks on Probabilistic Autoregressive Forecasting Models.

Attack-Probabilistic-Models This is the source code for Adversarial Attacks on Probabilistic Autoregressive Forecasting Models. This repository contai

SRI Lab, ETH Zurich 25 Sep 14, 2022
FairFuzz: AFL extension targeting rare branches

FairFuzz An AFL extension to increase code coverage by targeting rare branches. FairFuzz has a particular advantage on programs with highly nested str

Caroline Lemieux 222 Nov 16, 2022
With this package, you can generate mixed-integer linear programming (MIP) models of trained artificial neural networks (ANNs) using the rectified linear unit (ReLU) activation function

With this package, you can generate mixed-integer linear programming (MIP) models of trained artificial neural networks (ANNs) using the rectified linear unit (ReLU) activation function. At the momen

ChemEngAI 40 Dec 27, 2022
FPSAutomaticAiming——基于YOLOV5的FPS类游戏自动瞄准AI

FPSAutomaticAiming——基于YOLOV5的FPS类游戏自动瞄准AI 声明: 本项目仅限于学习交流,不可用于非法用途,包括但不限于:用于游戏外挂等,使用本项目产生的任何后果与本人无关! 简介 本项目基于yolov5,实现了一款FPS类游戏(CF、CSGO等)的自瞄AI,本项目旨在使用现

Fabian 246 Dec 28, 2022
Thermal Control of Laser Powder Bed Fusion using Deep Reinforcement Learning

This repository is the implementation of the paper "Thermal Control of Laser Powder Bed Fusion Using Deep Reinforcement Learning", linked here. The project makes use of the Deep Reinforcement Library

BaratiLab 11 Dec 27, 2022
A scanpy extension to analyse single-cell TCR and BCR data.

Scirpy: A Scanpy extension for analyzing single-cell immune-cell receptor sequencing data Scirpy is a scalable python-toolkit to analyse T cell recept

ICBI 145 Jan 03, 2023
Blender Add-On for slicing meshes with planes

MeshSlicer Blender Add-On for slicing meshes with multiple overlapping planes at once. This is a simple Blender addon to slice a silmple mesh with mul

52 Dec 12, 2022
Global-Local Attention for Emotion Recognition

Global-Local Attention for Emotion Recognition Requirements Python 3 Install tensorflow (or tensorflow-gpu) = 2.0.0 Install some other packages pip i

Minh Nhat Le 15 Apr 21, 2022