FluxTraining.jl gives you an endlessly extensible training loop for deep learning

Overview

FluxTraining.jl

Docs (master)

A powerful, extensible neural net training library.

FluxTraining.jl gives you an endlessly extensible training loop for deep learning inspired by fastai's training loop. It is the training backend for FastAI.jl.

It exposes a small set of extensible interfaces and uses them to implement

  • hyperparameter scheduling
  • metrics
  • logging
  • training history; and
  • model checkpointing

Install using ]add FluxTraining.

Read getting started first and the user guide if you want to know more.

Comments
  • TagBot trigger issue

    TagBot trigger issue

    This issue is used to trigger TagBot; feel free to unsubscribe.

    If you haven't already, you should update your TagBot.yml to include issue comment triggers. Please see this post on Discourse for instructions and more details.

    opened by JuliaTagBot 15
  • Use SnoopPrecompile.jl

    Use SnoopPrecompile.jl

    This adds a basic precompile statement using SnoopPrecompile.jl.

    This reduces the Time-to-first-fit! by

    Measurements:

    • using FluxTraining: 21s (this PR), 19s (master) -> 2s slower
    • fit!(testlearner(), 1): 14.5s (this PR), 30s (master) -> 15s faster
    • both: 35.5s (this PR), 49s (master) -> 13.5s/40% faster

    This seems like a clear win for me, except for the longer precompilation time which will only occur once for regular package usage. Has anyone tried using SnoopPrecompile.jl for other packages in the FluxML org?

    opened by lorenzoh 8
  • `Scheduler` causes cycle in execution DAG?

    `Scheduler` causes cycle in execution DAG?

    I have the following script:

    lossfn = Flux.Losses.logitcrossentropy
    
    # define schedule and optimizer
    initial_lr = 0.1
    schedule = Step(initial_lr, 0.5, 20)
    optim = Flux.Optimiser(Momentum(initial_lr), WeightDecay(1e-3))
    
    # callbacks
    logger = TensorBoardBackend("tblogs")
    schcb = Scheduler(LearningRate => schedule)
    hlogcb = LogHyperParams(logger)
    mlogcb = LogMetrics(logger)
    valcb = Metrics(Metric(accuracy; phase = TrainingPhase, name = "train_acc"),
                    Metric(accuracy; phase = ValidationPhase, name = "val_acc"))
    
    # setup learner object
    learner = Learner(m, lossfn;
                      data = (trainloader, valloader),
                      optimizer = optim,
                      callbacks = [ToGPU(), mlogcb, valcb])
    

    Any time I add schcb to the list of callbacks passed to the Learner, I get an error from FluxTraining that there is a cycle in the DAG. This did not happen in previous versions of FluxTraining (though I haven't been able to bisect the change yet).

    opened by darsnack 4
  • Use Optimisers.jl

    Use Optimisers.jl

    With Flux.jl 0.13 moving to use the explicit optimisers in Optimisers.jl, I think FluxTraining.jl should also use those as a default.

    This would also allow easier integration with alternative ADs like, PyCallChainRules.jl, see https://github.com/rejuvyesh/PyCallChainRules.jl/issues/19.

    @ToucheSir can this be done in a backward-compatible way, i.e. supporting Flux v0.12 and below or does Optimisers.jl depend on Flux v0.13?

    opened by lorenzoh 4
  • Break out Schedule

    Break out Schedule

    Does it make sense to break out Schedule from FluxTraining.jl? It seems like you hit upon a cool Julia package to use for scheduling, and we could use it as the base for implementing several common LR schedules. Would be nice to be able to write something like Cyclic(period = n) etc.

    I can give this a shot in a repo if you think it makes sense.

    opened by darsnack 4
  • Record time trained, training loss, validation loss and performance

    Record time trained, training loss, validation loss and performance

    For my application, I would love to be able to record the time trained, training loss, validation loss and classification performance at a given time-interval in the training loop. But currently, the History seems only able to store number of epochs, steps, and steps in current epoch.

    Would there be a way to make the History for extendable, so that users can record anything they want?

    A final detail would be that I want to record these stats only after a factor increase in training time, so that when I plot e.g. training loss again a logarithmic time scale, I get somewhat evenly distributed numbers. I am not sure how to make that happen, and I do not expect it to be built in functionality. I am just mentioning it in case it would be simple enough to implement.

    opened by KronosTheLate 3
  • Add SanityCheck callback

    Add SanityCheck callback

    Adds a SanityCheck callback as discussed on Zulip.

    If some checks don't pass the output will look like this:

    1/4 sanity checks failed:
    ---
    1: Model and loss function compatible with data (ERROR)
    
    To perform the optimization step, model and loss function need
    to be compatible with the data. This means the following must work:
    
    - `(x, y), _ = iterate(learner.data.training)`
    - `ŷ = learner.model(x)`
    - `loss = learner.lossfunction(learner.model(x), y)
    
    opened by lorenzoh 3
  • How not to have printing callbacks?

    How not to have printing callbacks?

    Is there a way of constructing a Learner without certain callbacks?

    julia> Learner(predict, lossfn; callbacks = [Metrics(accuracy)]).callbacks
    FluxTraining.Callbacks(FluxTraining.SafeCallback[Metrics(Loss(), Metric(Accuracy)), ProgressPrinter(), MetricsPrinter(), StopOnNaNLoss(), Recorder()], FluxTraining.LinearRunner(), {5, 5} directed simple Int64 graph, false)
    
    julia> Learner(predict, lossfn).callbacks
    FluxTraining.Callbacks(FluxTraining.SafeCallback[ProgressPrinter(), MetricsPrinter(), StopOnNaNLoss(), Recorder(), Metrics(Loss())], FluxTraining.LinearRunner(), {5, 5} directed simple Int64 graph, false)
    

    Or at least, to remove callbacks after construction? image

    opened by KronosTheLate 2
  • docs oddities

    docs oddities

    At the very top of this doc page https://fluxml.ai/FluxTraining.jl/dev/i/?id=documents%2Fdocs%2Fcallbacks%2Fusage.md

    using FluxTraining
    using FluxTraining: Callback, Read, Write, stateaccess
    model, data, lossfn = nothing, (), nothing, nothing
    

    Is that intended?

    Also, if from that page I follow a couple of links and land to https://fluxml.ai/FluxTraining.jl/dev/i/?id=documents%2FREADME.md&id=documents%2Fdocs%2Fcallbacks%2Fcustom.md&id=documents%2Fdocs%2Ftutorials%2Ftraining.md but then i don't see any buttons for closing all those panes or for going back. Even the browser back button doesn't have any effect.

    One last thing is that I don't see a link when visualizing the docstring of types/methods to jump to the source code.

    opened by CarloLucibello 2
  • (Documentation) Document `Learner` components

    (Documentation) Document `Learner` components

    This adds better documentation for the components of a Learner, i.e. model, data iterator, optimizer and loss function.

    Also makes the README clearer and fixes some broken links.

    opened by lorenzoh 2
  • Error displaying EarlyStopper

    Error displaying EarlyStopper

    Show method for EarlyStopping throws an error:

    julia> using FluxTraining
    
    julia> FluxTraining.EarlyStopping(1)
    Error showing value of type EarlyStopping:
    ERROR: type EarlyStopping has no field stopper
    Stacktrace:
      [1] getproperty(x::EarlyStopping, f::Symbol)
        @ Base ./Base.jl:33
      [2] show(io::IOContext{Base.TTY}, cb::EarlyStopping)
        @ FluxTraining ~/.julia/packages/FluxTraining/LfCE3/src/callbacks/earlystopping.jl:56
      [3] show(io::IOContext{Base.TTY}, #unused#::MIME{Symbol("text/plain")}, x::EarlyStopping)
        @ Base.Multimedia ./multimedia.jl:47
      [4] (::REPL.var"#38#39"{REPL.REPLDisplay{REPL.LineEditREPL}, MIME{Symbol("text/plain")}, Base.RefValue{Any}})(io::Any)
        @ REPL /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.6/REPL/src/REPL.jl:220
      [5] with_repl_linfo(f::Any, repl::REPL.LineEditREPL)
        @ REPL /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.6/REPL/src/REPL.jl:462
      [6] display(d::REPL.REPLDisplay, mime::MIME{Symbol("text/plain")}, x::Any)
        @ REPL /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.6/REPL/src/REPL.jl:213
      [7] display(d::REPL.REPLDisplay, x::Any)
        @ REPL /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.6/REPL/src/REPL.jl:225
      [8] display(x::Any)
        @ Base.Multimedia ./multimedia.jl:328
      [9] (::Media.var"#15#16"{EarlyStopping})()
        @ Media ~/.julia/packages/Media/ItEPc/src/compat.jl:28
     [10] hookless(f::Media.var"#15#16"{EarlyStopping})
        @ Media ~/.julia/packages/Media/ItEPc/src/compat.jl:14
     [11] render(#unused#::Media.NoDisplay, x::EarlyStopping)
        @ Media ~/.julia/packages/Media/ItEPc/src/compat.jl:27
     [12] render(x::EarlyStopping)
        @ Media ~/.julia/packages/Media/ItEPc/src/system.jl:160
     [13] display(#unused#::Media.DisplayHook, x::EarlyStopping)
        @ Media ~/.julia/packages/Media/ItEPc/src/compat.jl:9
     [14] display(x::Any)
        @ Base.Multimedia ./multimedia.jl:328
     [15] #invokelatest#2
        @ ./essentials.jl:708 [inlined]
     [16] invokelatest
        @ ./essentials.jl:706 [inlined]
     [17] print_response(errio::IO, response::Any, show_value::Bool, have_color::Bool, specialdisplay::Union{Nothing, AbstractDisplay})
        @ REPL /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.6/REPL/src/REPL.jl:247
     [18] (::REPL.var"#40#41"{REPL.LineEditREPL, Pair{Any, Bool}, Bool, Bool})(io::Any)
        @ REPL /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.6/REPL/src/REPL.jl:231
     [19] with_repl_linfo(f::Any, repl::REPL.LineEditREPL)
        @ REPL /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.6/REPL/src/REPL.jl:462
     [20] print_response(repl::REPL.AbstractREPL, response::Any, show_value::Bool, have_color::Bool)
        @ REPL /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.6/REPL/src/REPL.jl:229
     [21] (::REPL.var"#do_respond#61"{Bool, Bool, REPL.var"#72#82"{REPL.LineEditREPL, REPL.REPLHistoryProvider}, REPL.LineEditREPL, REPL.LineEdit.Prompt})(s::REPL.LineEdit.MIState, buf::Any, ok::Bool)
        @ REPL /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.6/REPL/src/REPL.jl:798
     [22] #invokelatest#2
        @ ./essentials.jl:708 [inlined]
     [23] invokelatest
        @ ./essentials.jl:706 [inlined]
     [24] run_interface(terminal::REPL.Terminals.TextTerminal, m::REPL.LineEdit.ModalInterface, s::REPL.LineEdit.MIState)
        @ REPL.LineEdit /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.6/REPL/src/LineEdit.jl:2441
     [25] run_frontend(repl::REPL.LineEditREPL, backend::REPL.REPLBackendRef)
        @ REPL /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.6/REPL/src/REPL.jl:1126
     [26] (::REPL.var"#44#49"{REPL.LineEditREPL, REPL.REPLBackendRef})()
        @ REPL ./task.jl:411
    
    

    Thanks for the awesome package! Should I open a PR to fix this?

    opened by awadell1 2
  • Question regarding ProgressPrinter

    Question regarding ProgressPrinter

    Hi first off: wonderful package :)

    I have some issues with the ProgressPrinter not showing up even when using the defaultcallbacks.

    learner = Learner(model, loss; optimizer=opt, callbacks=[ToGPU()], usedefaultcallbacks=true)
    FluxTraining.fit!(learner, epochs, (dl, val_dl)) # where dl, dl_val are both Flux.DataLoader objects
    

    Do I need to do something specific when constructing the Learner which I have missed? From the code it seems like I would need to give it a Progress object, do I have to construct that myself? What requirements does my data-iterator have to fullfill to show up with the defaultcallbacks?

    opened by hv10 5
  • CompatHelper: bump compat for PrettyTables to 2, (keep existing compat)

    CompatHelper: bump compat for PrettyTables to 2, (keep existing compat)

    This pull request changes the compat entry for the PrettyTables package from 1, 1.1, 1.2 to 1, 1.1, 1.2, 2. This keeps the compat entries for earlier versions.

    Note: I have not tested your package with this new compat entry. It is your responsibility to make sure that your package tests pass before you merge this pull request.

    opened by github-actions[bot] 0
  • `accuracy` doesn't work with `DataLoader`

    `accuracy` doesn't work with `DataLoader`

    since there's an example of using accuracy in Metric, it probably:

    1. should be documented
    2. should be made working with either learner or (model, dataloader)
    opened by Moelf 0
  • Improve printing during training

    Improve printing during training

    With the single metric accuracy, the output-tables (which I love) look like this:

    Epoch 11 TrainingPhase(): 100%|███████████████████████████████████████████| Time: 0:00:00
    ┌───────────────┬───────┬─────────┬──────────┐
    │         Phase │ Epoch │    Loss │ Accuracy │
    ├───────────────┼───────┼─────────┼──────────┤
    │ TrainingPhase │  11.0 │ 0.25969 │  0.92827 │
    └───────────────┴───────┴─────────┴──────────┘
    ┌─────────────────┬───────┬─────────┬──────────┐
    │           Phase │ Epoch │    Loss │ Accuracy │
    ├─────────────────┼───────┼─────────┼──────────┤
    │ ValidationPhase │  11.0 │ 0.26323 │  0.92731 │
    └─────────────────┴───────┴─────────┴──────────┘
    

    I suggest putting them into the same table, and making the Epoch vector of element type Int64, to make it look like this:

    Epoch 11 TrainingPhase(): 100%|███████████████████████████████████████████| Time: 0:00:00
    ┌────────--───────┬───────┬─────────┬──────────┐
    │         Phase   │ Epoch │    Loss │ Accuracy │
    ├──────────--─────┼───────┼─────────┼──────────┤
    │ TrainingPhase   │  11   │ 0.25969 │  0.92827 │
    │ ValidationPhase │  11   │ 0.26323 │  0.92731 │
    └─────────────────┴───────┴─────────┴──────────┘
    
    opened by KronosTheLate 1
  • Quickstart tutorial broken

    Quickstart tutorial broken

    The example Training an image classifier currently uses the following code:

    xs, ys = (
        # convert each image into h*w*1 array of floats 
        [Float32.(reshape(img, 28, 28, 1)) for img in Flux.Data.MNIST.images()],
        # one-hot encode the labels
        [Float32.(Flux.onehot(y, 0:9)) for y in Flux.Data.MNIST.labels()],
    )
    

    However,

    (Project) pkg> st Flux
          Status `C:\Users\Dennis Bal\ProjectFolder\Project.toml`
      [587475ba] Flux v0.13.0
    
    julia> using Flux
    
    julia> Flux.Data.MNIST
    ERROR: UndefVarError: MNIST not defined
    Stacktrace:
     [1] getproperty(x::Module, f::Symbol)
       @ Base .\Base.jl:35
     [2] top-level scope
       @ REPL[16]:1
    
    

    So the example is broken. As a side note, I think the example would do great by using MLUtils instead of DataLoaders.jl and MLDataPattern. Also, Flux imports DataLoader so no need to explicitly import it.

    But I take a look at the docs and try to get started. So I make the following code, that works with Flux's base capacities:

    julia> using Flux
    
    julia> using Flux: onehotbatch, onecold
    
    julia> using FluxTraining
    
    julia> using MLUtils: flatten, unsqueeze
    
    julia> using MLDatasets
    
    julia> labels = 0:9
    0:9
    
    julia> traindata = MNIST.traindata(Float32) |> x->(unsqueeze(x[1], 3), onehotbatch(x[2], labels));
    
    julia> size.(traindata)
    ((28, 28, 1, 60000), (10, 60000))
    
    julia> trainloader = DataLoader(traindata, batchsize=128);
    
    julia> validdata = MNIST.testdata(Float32) |> x->(unsqueeze(x[1], 3), onehotbatch(x[2], labels)); 
    
    julia> size.(validdata)
    ((28, 28, 1, 10000), (10, 10000))
    
    julia> validloader = DataLoader(validdata, batchsize=128);
    
    julia> predict = Chain(flatten, Dense(28^2, 10))
    Chain(
      MLUtils.flatten,
      Dense(784 => 10),                     # 7_850 parameters
    )
    
    julia> lossfunc(x, y) = Flux.Losses.logitbinarycrossentropy(predict(x), y)
    lossfunc (generic function with 1 method)
    
    julia> optimizer=ADAM()
    ADAM(0.001, (0.9, 0.999), 1.0e-8, IdDict{Any, Any}())
    
    julia> callbacks = [Metrics(accuracy)]
    1-element Vector{Metrics}:
     Metrics(Loss(), Metric(Accuracy))
    
    julia> learner = Learner(predict, lossfunc; optimizer, callbacks)
    Learner()
    
    

    At this point, I start checking loss and training with Flux's train!:

    julia> lossfunc(validdata...)
    0.7624986f0
    
    julia> Flux.train!(lossfunc, Flux.params(predict), trainloader, optimizer)
    
    julia> lossfunc(validdata...)
    0.11266354f0
    
    julia> Flux.train!(lossfunc, Flux.params(predict), trainloader, optimizer)
    
    julia> lossfunc(validdata...)
    0.08880948f0
    
    julia> Flux.train!(lossfunc, Flux.params(predict), trainloader, optimizer)
    
    julia> lossfunc(validdata...)
    0.0801171f0
    

    Training no problem. However, when I try to train my learner, it seems like a single float is passed to predict, and not an array:

    julia> fit!(learner, 1, (traindata, validdata))
    Epoch 1 TrainingPhase() ...
    ERROR: MethodError: no method matching flatten(::Float32)
    Closest candidates are:
      flatten(::AbstractArray) at C:\Users\usrname\.julia\packages\MLUtils\QTRw7\src\utils.jl:424  
    Stacktrace:
      [1] macro expansion
        @ C:\Users\usrname\.julia\packages\Zygote\Y6SC4\src\compiler\interface2.jl:0 [inlined]     
      [2] _pullback(ctx::Zygote.Context, f::typeof(flatten), args::Float32)
        @ Zygote C:\Users\usrname\.julia\packages\Zygote\Y6SC4\src\compiler\interface2.jl:9        
      [3] macro expansion
        @ C:\Users\usrname\.julia\packages\Flux\18YZE\src\layers\basic.jl:53 [inlined]
      [4] _pullback
        @ C:\Users\usrname\.julia\packages\Flux\18YZE\src\layers\basic.jl:53 [inlined]
      [5] _pullback(::Zygote.Context, ::typeof(Flux.applychain), ::Tuple{typeof(flatten), Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}, ::Float32)
        @ Zygote C:\Users\usrname\.julia\packages\Zygote\Y6SC4\src\compiler\interface2.jl:0        
      [6] _pullback
        @ C:\Users\usrname\.julia\packages\Flux\18YZE\src\layers\basic.jl:51 [inlined]
      [7] _pullback(ctx::Zygote.Context, f::Chain{Tuple{typeof(flatten), Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, args::Float32)
        @ Zygote C:\Users\usrname\.julia\packages\Zygote\Y6SC4\src\compiler\interface2.jl:0        
      [8] _pullback
        @ C:\Users\usrname\.julia\packages\FluxTraining\iBFSd\src\training.jl:54 [inlined]
      [9] _pullback(ctx::Zygote.Context, f::FluxTraining.var"#70#72"{FluxTraining.var"#handlefn#78"{Learner, TrainingPhase}, FluxTraining.PropDict{Any}, Learner}, args::Chain{Tuple{typeof(flatten), Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}})
        @ Zygote C:\Users\usrname\.julia\packages\Zygote\Y6SC4\src\compiler\interface2.jl:0        
     [10] _pullback
        @ C:\Users\usrname\.julia\packages\FluxTraining\iBFSd\src\training.jl:70 [inlined]
     [11] _pullback(::Zygote.Context, ::FluxTraining.var"#73#74"{FluxTraining.var"#70#72"{FluxTraining.var"#handlefn#78"{Learner, TrainingPhase}, FluxTraining.PropDict{Any}, Learner}, Chain{Tuple{typeof(flatten), Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}})
        @ Zygote C:\Users\usrname\.julia\packages\Zygote\Y6SC4\src\compiler\interface2.jl:0        
     [12] pullback(f::Function, ps::Zygote.Params{Zygote.Buffer{Any, Vector{Any}}})
        @ Zygote C:\Users\usrname\.julia\packages\Zygote\Y6SC4\src\compiler\interface.jl:352       
     [13] gradient(f::Function, args::Zygote.Params{Zygote.Buffer{Any, Vector{Any}}})
        @ Zygote C:\Users\usrname\.julia\packages\Zygote\Y6SC4\src\compiler\interface.jl:75        
     [14] _gradient(f::FluxTraining.var"#70#72"{FluxTraining.var"#handlefn#78"{Learner, TrainingPhase}, FluxTraining.PropDict{Any}, Learner}, #unused#::ADAM, m::Chain{Tuple{typeof(flatten), Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, ps::Zygote.Params{Zygote.Buffer{Any, Vector{Any}}})
        @ FluxTraining C:\Users\usrname\.julia\packages\FluxTraining\iBFSd\src\training.jl:70      
     [15] (::FluxTraining.var"#69#71"{Learner})(handle::FluxTraining.var"#handlefn#78"{Learner, TrainingPhase}, state::FluxTraining.PropDict{Any})
        @ FluxTraining C:\Users\usrname\.julia\packages\FluxTraining\iBFSd\src\training.jl:53      
     [16] runstep(stepfn::FluxTraining.var"#69#71"{Learner}, learner::Learner, phase::TrainingPhase, initialstate::NamedTuple{(:xs, :ys), Tuple{Float32, Float32}})
        @ FluxTraining C:\Users\usrname\.julia\packages\FluxTraining\iBFSd\src\training.jl:133     
     [17] step!
        @ C:\Users\usrname\.julia\packages\FluxTraining\iBFSd\src\training.jl:51 [inlined]
     [18] (::FluxTraining.var"#67#68"{Learner, TrainingPhase, Tuple{Array{Float32, 4}, Flux.OneHotArray{UInt32, 10, 1, 2, Vector{UInt32}}}})(#unused#::Function)
        @ FluxTraining C:\Users\usrname\.julia\packages\FluxTraining\iBFSd\src\training.jl:24      
     [19] runepoch(epochfn::FluxTraining.var"#67#68"{Learner, TrainingPhase, Tuple{Array{Float32, 4}, Flux.OneHotArray{UInt32, 10, 1, 2, Vector{UInt32}}}}, learner::Learner, phase::TrainingPhase)     
        @ FluxTraining C:\Users\usrname\.julia\packages\FluxTraining\iBFSd\src\training.jl:105     
     [20] epoch!
        @ C:\Users\usrname\.julia\packages\FluxTraining\iBFSd\src\training.jl:22 [inlined]
     [21] fit!(learner::Learner, nepochs::Int64, ::Tuple{Tuple{Array{Float32, 4}, Flux.OneHotArray{UInt32, 10, 1, 2, Vector{UInt32}}}, Tuple{Array{Float32, 4}, Flux.OneHotArray{UInt32, 10, 1, 2, Vector{UInt32}}}})
        @ FluxTraining C:\Users\usrname\.julia\packages\FluxTraining\iBFSd\src\training.jl:168     
     [22] top-level scope
        @ REPL[51]:1
    

    I am completely stuck as to what goes wrong. Pointers in that regard would be appreciated, but the main issue is making the example functional, and updating the packages used to load data and the utility functions that I take from MLUtils.

    To improve the reliability of this package, could doc testing be used to ensure that in the future, the documentation examples actually run?

    opened by KronosTheLate 8
  • Add metadata field to `Learner`

    Add metadata field to `Learner`

    This adds a "metadata" PropDict to Learner for storing information that is required for training but extraneous to the training state or callback state. This is useful for unconventional training methods (issue that I am currently dealing with). In the same way that the loss function is a "parameter" that needs to be specified to standard supervised training, the metadata field holds parameters that need to be specified for unconventional training. Of course, we can't know what these parameters will be like standard training, so instead of explicit names, we provide a container to hold them.

    opened by darsnack 7
Releases(v0.3.5)
  • v0.3.5(Dec 17, 2022)

  • v0.3.4(Oct 22, 2022)

  • v0.3.3(Sep 14, 2022)

    FluxTraining v0.3.3

    Diff since v0.3.2

    Closed issues:

    • docs oddities (#113)
    • Record time trained, training loss, validation loss and performance (#119)
    • How not to have printing callbacks? (#126)
    • ignore(f) is deprecated (#128)

    Merged pull requests:

    • How not to print docstring (#127) (@KronosTheLate)
    • Fix deprecate warnings (#129) (@yuehhua)
    • Fix the quickstart tutorial (#130) (@christiangnrd)
    • fix typo in docs (#133) (@Moelf)
    • Fix README links (#134) (@lorenzoh)
    • Add ability to record and log arbitrary learner values (#136) (@darsnack)
    • Use show method for phase in MetricsPrinter (#137) (@darsnack)
    Source code(tar.gz)
    Source code(zip)
  • v0.3.2(Jun 10, 2022)

  • v0.3.1(May 27, 2022)

    FluxTraining v0.3.1

    Diff since v0.3.0

    Closed issues:

    • Simpler Learner API (#104)
    • Use Optimisers.jl (#112)
    • Add callback support for Optimisers.jl (#115)

    Merged pull requests:

    • Update Pollen documentation to new PkgTemplate workflow (#111) (@lorenzoh)
    • Add support for Optimisers.jl (#114) (@lorenzoh)
    • CompatHelper: add new compat entry for Optimisers at version 0.2, (keep existing compat) (#116) (@github-actions[bot])
    • Add callback support for Optimisers.jl (#117) (@lorenzoh)
    • CompatHelper: add new compat entry for Setfield at version 0.8, (keep existing compat) (#118) (@github-actions[bot])
    Source code(tar.gz)
    Source code(zip)
  • v0.3.0(Apr 16, 2022)

    FluxTraining v0.3.0

    Diff since v0.2.4

    Closed issues:

    • Switch to ParameterSchedulers.jl (#106)

    Merged pull requests:

    • Add more convenient Learner method (#105) (@lorenzoh)
    • Switch to ParameterSchedulers.jl (#107) (@rejuvyesh)
    • Prepare v0.3.0 (#108) (@lorenzoh)
    Source code(tar.gz)
    Source code(zip)
  • v0.2.4(Mar 9, 2022)

    FluxTraining v0.2.4

    Diff since v0.2.3

    Closed issues:

    • Allow restricting phases during which a Metric runs (#84)

    Merged pull requests:

    • CompatHelper: bump compat for EarlyStopping to 0.2, (keep existing compat) (#92) (@github-actions[bot])
    • CompatHelper: bump compat for EarlyStopping to 0.3, (keep existing compat) (#94) (@github-actions[bot])
    • doc: fix link to TensorBoardLogger.jl (#95) (@visr)
    • Move documentation system to Pollen.jl (#96) (@lorenzoh)
    • Use ReTest.jl to run tests (#97) (@lorenzoh)
    • Add and improve a lot of docstrings and a few test cases (#99) (@lorenzoh)
    • Add phase argument to Metric (#100) (@lorenzoh)
    • CompatHelper: add new compat entry for InlineTest at version 0.2, (keep existing compat) (#101) (@github-actions[bot])
    • (Documentation) Document Learner components (#102) (@lorenzoh)
    • Add Flux 0.13 compatibility (#103) (@lorenzoh)
    Source code(tar.gz)
    Source code(zip)
  • v0.2.3(Oct 21, 2021)

  • v0.2.2(Oct 12, 2021)

  • v0.2.1(Jul 26, 2021)

  • v0.2.0(Jul 3, 2021)

    FluxTraining v0.2.0

    Diff since v0.1.3

    Added

    Changed

    • Batch* renamed to Step*:
      • events: BatchBegin now StepBegin, BatchEnd now StepEnd
      • CancelBatchException now CancelStepException.
      • field Learner.batch now Learner.step
    • Learner.step/batch is no longer a special struct but now a PropDict, allowing you to set arbitrary fields.
    • Learner.model can now be a NamedTuple/Tuple of models for use in custom training loops. Likewise, learner.params now resembles the structure of learner.model, allowing separate access to parameters of different models.
    • Callbacks
      • Added init! method for callback initilization, replacing the Init event which required a Phase to implement.
      • Scheduler now has internal step counter and no longer relies on Recorder's history. This makes it easier to replace the scheduler without needing to offset the new schedules.
      • EarlyStopping callback now uses criteria from EarlyStopping.jl

    Removed

    • Removed old training API. Methods fitbatch!, fitbatchphase!, fitepoch!, fitepochphase! have all been removed.

    Closed issues:

    • Scheduler applies schedules per batch by default (#68)
    • Recorder does not work with models with non-Array inputs. (#80)

    Merged pull requests:

    • CompatHelper: bump compat for "BSON" to "0.3" (#69) (@github-actions[bot])
    • use EarlyStopping.jl for stopping criteria (#72) (@lorenzoh)
    • CompatHelper: bump compat for "PrettyTables" to "0.12" (#73) (@github-actions[bot])
    • Move documentation to Pollen.jl (#77) (@lorenzoh)
    • Revert onecycle (#78) (@lorenzoh)
    • Remove samplesfield from History (#81) (@lorenzoh)
    • New training API and QoL improvements (v0.2.0) (#83) (@lorenzoh)
    Source code(tar.gz)
    Source code(zip)
  • v0.1.3(Feb 21, 2021)

    FluxTraining v0.1.3

    Diff since v0.1.2

    Closed issues:

    • Metrics wraps Metric(s) in Metric(s) (#66)
    • Unnecessary softmax in accuracy? (#67)

    Merged pull requests:

    • Internal callback changes (#65) (@lorenzoh)
    • fix #66 make Metric subtype AbstractMetric (#70) (@lorenzoh)
    • Lorenzoh/fix/67 (#71) (@lorenzoh)
    Source code(tar.gz)
    Source code(zip)
  • v0.1.2(Jan 28, 2021)

  • v0.1.1(Jan 18, 2021)

    FluxTraining v0.1.1

    Diff since v0.1.0

    Closed issues:

    • Link to Loss broken (#40)
    • Early stopping link broken (#41)
    • BatchEnd link broken (#42)
    • Break out Schedule (#44)
    • How to verify GPU is working? (#47)
    • What is the role of the test data? (#48)
    • What do you think of Data Modules? (#53)

    Merged pull requests:

    • Add testing as CI step (#36) (@lorenzoh)
    • Better docstrings (#37) (@lorenzoh)
    • Typo (#43) (@drozzy)
    • WRONG order of arguments to the Learner. (#45) (@drozzy)
    • Fix issue #42 - missing docstrings (#49) (@lorenzoh)
    • add EarlyStopping docstring (#50) (@lorenzoh)
    • Update callback reference section on metrics (#51) (@lorenzoh)
    • Remove no longer needed dependencies from Project.toml (#52) (@lorenzoh)
    • Add SanityCheck callback (#56) (@lorenzoh)
    • Fix GarbageCollect callback (#57) (@lorenzoh)
    • Fix TensorBoard image serialization (#58) (@lorenzoh)
    • Sanitycheck (#60) (@lorenzoh)
    • update compat bounds (#61) (@lorenzoh)
    Source code(tar.gz)
    Source code(zip)
  • v0.1.0(Nov 7, 2020)

    FluxTraining v0.1.0

    Closed issues:

    • Collaborating on a FastAI port? (#13)

    Merged pull requests:

    • Remove unused package LearnBase (#26) (@nrhodes)
    • Fix protected tests (#28) (@nrhodes)
    • Log model histograms in TensorBoard (#30) (@ToucheSir)
    Source code(tar.gz)
    Source code(zip)
Owner
CompSci student from Berlin, currently working on FastAI.jl
[ACM MM 2021] TSA-Net: Tube Self-Attention Network for Action Quality Assessment

Tube Self-Attention Network (TSA-Net) This repository contains the PyTorch implementation for paper TSA-Net: Tube Self-Attention Network for Action Qu

ShunliWang 18 Dec 23, 2022
Python版OpenCVのTracking APIのサンプルです。DaSiamRPNアルゴリズムまで対応しています。

OpenCV-Object-Tracker-Sample Python版OpenCVのTracking APIのサンプルです。   Requirement opencv-contrib-python 4.5.3.56 or later Algorithm 2021/07/16時点でOpenCVには以

KazuhitoTakahashi 36 Jan 01, 2023
:fire: 2D and 3D Face alignment library build using pytorch

Face Recognition Detect facial landmarks from Python using the world's most accurate face alignment network, capable of detecting points in both 2D an

Adrian Bulat 6k Dec 31, 2022
Quickly and easily create / train a custom DeepDream model

Dream-Creator This project aims to simplify the process of creating a custom DeepDream model by using pretrained GoogleNet models and custom image dat

55 Dec 27, 2022
Implementation of Analyzing and Improving the Image Quality of StyleGAN (StyleGAN 2) in PyTorch

Implementation of Analyzing and Improving the Image Quality of StyleGAN (StyleGAN 2) in PyTorch

Kim Seonghyeon 2.2k Jan 01, 2023
Code used to generate the results appearing in "Train longer, generalize better: closing the generalization gap in large batch training of neural networks"

Train longer, generalize better - Big batch training This is a code repository used to generate the results appearing in "Train longer, generalize bet

Elad Hoffer 145 Sep 16, 2022
Contrastive Language-Image Pretraining

CLIP [Blog] [Paper] [Model Card] [Colab] CLIP (Contrastive Language-Image Pre-Training) is a neural network trained on a variety of (image, text) pair

OpenAI 11.5k Jan 08, 2023
MiraiML: asynchronous, autonomous and continuous Machine Learning in Python

MiraiML Mirai: future in japanese. MiraiML is an asynchronous engine for continuous & autonomous machine learning, built for real-time usage. Usage In

Arthur Paulino 25 Jul 27, 2022
Codes for NeurIPS 2021 paper "On the Equivalence between Neural Network and Support Vector Machine".

On the Equivalence between Neural Network and Support Vector Machine Codes for NeurIPS 2021 paper "On the Equivalence between Neural Network and Suppo

Leslie 8 Oct 25, 2022
PPLNN is a Primitive Library for Neural Network is a high-performance deep-learning inference engine for efficient AI inferencing

PPLNN is a Primitive Library for Neural Network is a high-performance deep-learning inference engine for efficient AI inferencing

943 Jan 07, 2023
PyTorch code for 'Efficient Single Image Super-Resolution Using Dual Path Connections with Multiple Scale Learning'

Efficient Single Image Super-Resolution Using Dual Path Connections with Multiple Scale Learning This repository is for EMSRDPN introduced in the foll

7 Feb 10, 2022
Code from Daniel Lemire, A Better Alternative to Piecewise Linear Time Series Segmentation

PiecewiseLinearTimeSeriesApproximation code from Daniel Lemire, A Better Alternative to Piecewise Linear Time Series Segmentation, SIAM Data Mining 20

Daniel Lemire 21 Oct 27, 2022
An end-to-end implementation of intent prediction with Metaflow and other cool tools

You Don't Need a Bigger Boat An end-to-end (Metaflow-based) implementation of an intent prediction flow for kids who can't MLOps good and wanna learn

Jacopo Tagliabue 614 Dec 31, 2022
CRNN With PyTorch

CRNN-PyTorch Implementation of https://arxiv.org/abs/1507.05717

Vadim 4 Sep 01, 2022
Fast Axiomatic Attribution for Neural Networks (NeurIPS*2021)

Fast Axiomatic Attribution for Neural Networks This is the official repository accompanying the NeurIPS 2021 paper: R. Hesse, S. Schaub-Meyer, and S.

Visual Inference Lab @TU Darmstadt 11 Nov 21, 2022
GAN encoders in PyTorch that could match PGGAN, StyleGAN v1/v2, and BigGAN. Code also integrates the implementation of these GANs.

MTV-TSA: Adaptable GAN Encoders for Image Reconstruction via Multi-type Latent Vectors with Two-scale Attentions. This is the official code release fo

owl 37 Dec 24, 2022
A task Provided by A respective Artenal Ai and Ml based Company to complete it

A task Provided by A respective Alternal Ai and Ml based Company to complete it .

Parth Madan 1 Jan 25, 2022
Research using Cirq!

ReCirq Research using Cirq! This project contains modules for running quantum computing applications and experiments through Cirq and Quantum Engine.

quantumlib 230 Dec 29, 2022
A fast model to compute optical flow between two input images.

DCVNet: Dilated Cost Volumes for Fast Optical Flow This repository contains our implementation of the paper: @InProceedings{jiang2021dcvnet, title={

Huaizu Jiang 8 Sep 27, 2021
RARA: Zero-shot Sim2Real Visual Navigation with Following Foreground Cues

RARA: Zero-shot Sim2Real Visual Navigation with Following Foreground Cues FGBG (foreground-background) pytorch package for defining and training model

Klaas Kelchtermans 1 Jun 02, 2022