Efficient and Modular Implicit Differentiation (Machine Learning Research Paper Explained)

Поділитися
Вставка
  • Опубліковано 31 тра 2024
  • #implicitfunction #jax #autodiff
    Many problems in Machine Learning involve loops of inner and outer optimization. Finding update steps for the outer loop is usually difficult, because of the.need to differentiate through the inner loop's procedure over multiple steps. Such loop unrolling is very limited and constrained to very few steps. Other papers have found solutions around unrolling in very specific, individual problems. This paper proposes a unified framework for implicit differentiation of inner optimization procedures without unrolling and provides implementations that integrate seamlessly into JAX.
    OUTLINE:
    0:00 - Intro & Overview
    2:05 - Automatic Differentiation of Inner Optimizations
    4:30 - Example: Meta-Learning
    7:45 - Unrolling Optimization
    13:00 - Unified Framework Overview & Pseudocode
    21:10 - Implicit Function Theorem
    25:45 - More Technicalities
    28:45 - Experiments
    ERRATA:
    - Dataset Distillation is done with respect to the training set, not the validation or test set.
    Paper: arxiv.org/abs/2105.15183
    Code coming soon
    Abstract:
    Automatic differentiation (autodiff) has revolutionized machine learning. It allows expressing complex computations by composing elementary ones in creative ways and removes the burden of computing their derivatives by hand. More recently, differentiation of optimization problem solutions has attracted widespread attention with applications such as optimization as a layer, and in bi-level problems such as hyper-parameter optimization and meta-learning. However, the formulas for these derivatives often involve case-by-case tedious mathematical derivations. In this paper, we propose a unified, efficient and modular approach for implicit differentiation of optimization problems. In our approach, the user defines (in Python in the case of our implementation) a function F capturing the optimality conditions of the problem to be differentiated. Once this is done, we leverage autodiff of F and implicit differentiation to automatically differentiate the optimization problem. Our approach thus combines the benefits of implicit differentiation and autodiff. It is efficient as it can be added on top of any state-of-the-art solver and modular as the optimality condition specification is decoupled from the implicit differentiation mechanism. We show that seemingly simple principles allow to recover many recently proposed implicit differentiation methods and create new ones easily. We demonstrate the ease of formulating and solving bi-level optimization problems using our framework. We also showcase an application to the sensitivity analysis of molecular dynamics.
    Authors: Mathieu Blondel, Quentin Berthet, Marco Cuturi, Roy Frostig, Stephan Hoyer, Felipe Llinares-López, Fabian Pedregosa, Jean-Philippe Vert
    Links:
    TabNine Code Completion (Referral): bit.ly/tabnine-yannick
    UA-cam: / yannickilcher
    Twitter: / ykilcher
    Discord: / discord
    BitChute: www.bitchute.com/channel/yann...
    Minds: www.minds.com/ykilcher
    Parler: parler.com/profile/YannicKilcher
    LinkedIn: / yannic-kilcher-488534136
    BiliBili: space.bilibili.com/1824646584
    If you want to support me, the best thing to do is to share out the content :)
    If you want to support me financially (completely optional and voluntary, but a lot of people have asked for this):
    SubscribeStar: www.subscribestar.com/yannick...
    Patreon: / yannickilcher
    Bitcoin (BTC): bc1q49lsw3q325tr58ygf8sudx2dqfguclvngvy2cq
    Ethereum (ETH): 0x7ad3513E3B8f66799f507Aa7874b1B0eBC7F85e2
    Litecoin (LTC): LQW2TRyKYetVC8WjFkhpPhtpbDM4Vw7r9m
    Monero (XMR): 4ACL8AGrEo5hAir8A9CeVrW8pEauWvnp1WnSDZxW7tziCDLhZAGsgzhRQABDnFy8yuM9fWJDviJPHKRjV4FWt19CJZN9D4n
  • Наука та технологія

КОМЕНТАРІ • 49

  • @YannicKilcher
    @YannicKilcher  3 роки тому +5

    OUTLINE:
    0:00 - Intro & Overview
    2:05 - Automatic Differentiation of Inner Optimizations
    4:30 - Example: Meta-Learning
    7:45 - Unrolling Optimization
    13:00 - Unified Framework Overview & Pseudocode
    21:10 - Implicit Function Theorem
    25:45 - More Technicalities
    28:45 - Experiments
    ERRATA:
    - Dataset Distillation is done with respect to the training set, not the validation or test set.

  • @ChaiTimeDataScience
    @ChaiTimeDataScience 3 роки тому +43

    I've been loving the new speed at which Dr. Kilcher is putting all of these videos out! So much to learn!

  • @paxdriver
    @paxdriver 3 роки тому +15

    Dr lightspeed, you are a hero

  • @theoboyer3812
    @theoboyer3812 3 роки тому +17

    Now I want to learn jax

  • @joedalton77
    @joedalton77 3 роки тому +4

    I have the feeling this one is going to be big

  • @MrMIB983
    @MrMIB983 3 роки тому +3

    Nice to give credit to an underrated area

  • @theoboyer3812
    @theoboyer3812 2 роки тому +1

    This is really cool, but if I understood it correctly, the applications are still quite limited. The problem is that you need to compute a d*d jacobian matrix (and then solve a linear system involving this matrix), d being the dimension of the output of your inner optimization algorithm and input of the optimality condition function.
    So, for any application involving neural networks for example, unless I'm wrong your little d would be the number of parameters of your neural network. Before even talking about solving the linear system, you would need to store a matrix of size "the number of parameters of the neural network SQUARED"

  • @kimchi_taco
    @kimchi_taco 3 роки тому +1

    Thank you for introducing interesting paper!
    18:00 the equation of ridge_solver is the same to the optimality condition F=0. I'm very little bit confused...
    23:45 instead of optimality procedure has to be differentiable, only the optimality condition now needs to be differentiable. best summary ever.

  • @victorrielly4588
    @victorrielly4588 3 роки тому +5

    Recurrent neural networks also struggled with the problem of computing gradients by unrolling the recurrence. I wonder if this technique could be applied in that instance as well.

    • @chuby35
      @chuby35 3 роки тому +2

      hmm. Interesting idea, but what would be the optimality condition for the inner loop of an RNN layer?

    • @victorrielly4588
      @victorrielly4588 3 роки тому +1

      @@chuby35 I’d have to think about it

  • @JTMoustache
    @JTMoustache 3 роки тому

    Powerful stuff 💪🏼

  • @tchlux
    @tchlux 3 роки тому

    In your expression at 8:22 I think you've got which terms are transposed reversed. Notice in the code, the terms all have `_tr` at the end. It should be $(X X^t + \lambda I) w = X y^t$. The way it's written in the video, it looks like you're solving for what weights should be applied to every *point,* but instead we want weights applies to each *component* (i.e., a linear function).

  • @herp_derpingson
    @herp_derpingson 3 роки тому +4

    20:00 What does the custom_root decorator do? What does it return?
    .
    20:15 I dont understand, are we using the ridge solver as an approximation for the inner SGD? Why?
    .
    I dont understand what is going on in this paper, but if it works, its gonna be big.

    • @mgostIH
      @mgostIH 2 роки тому +2

      In general decorators in python are a way to apply properties to a function by redefining it internally, for example you could have a decorator that add a timer to the function you apply it to so you can benchmark things easily.
      In this case, the custom_root decorator takes in the differentiable optimality conditions this paper talks about, so a function F that is differentiable and is zero when we found the right solution. When applied to a function like the ridge_solver, it redefines its gradient in terms of lambda by instead using this paper method, rather than JAX builtin autograd.
      The ridge solver is just an example that doesn't have anything to do with SGD, this method however provides you with the gradient to find the optimal value for lambda.
      Essentially, the ridge regression function solves the problem of finding w minimizing this loss: ||w^T * X - y|| + λ ||w||, but we'd also like to find the optimal value of λ that minimizes this loss even further! This entire paper solves the problem **WITHOUT** us having to look inside the details of the ridge regression solver, internally we could've used whatever crazy method as long as it satisfies that minimization task, and by specifying a differentiable F we get for free the gradient with respect to λ, which allows us to minimize the loss even further.
      A way to imagine this, a Yannic mentions, is to think of λ as a hyperparameter we have to search, but instead of just doing a black box search we actually get gradients efficiently. Imagine if instead of minimizing ||w^T * X - y|| + λ ||w||, you were to minimize ||NN(x) - y|| + λ ||w||, where NN is a neural network (Again, regression but with a regularization term dependant on lambda). Then the solver would be all the gradients step we take for training our model for a fixed λ, **which might internally be so complicated that backpropagating through it wrt. λ is simply impossible**. But with this method you only care that you have a routine, which internally may use SGD, that optimizes the weights of the network you defined, and you still get the gradient with respect to the hyperparameters by defining the optimization as a root of another differentiable function F!
      Even nicer, this method allows to directly solve optimization problems by just stating the optimal conditions, which before required entire papers (like OptNet) to derive.

  • @piotr780
    @piotr780 2 роки тому

    is it now implmented in JAX library ?

  • @Phenix66
    @Phenix66 3 роки тому +12

    Dude, I wanna go to sleep... Damn it :D

  • @nx6803
    @nx6803 3 роки тому +1

    GradSlam uses unrolling, perhaps it can utilize this!

  • @paulcurry8383
    @paulcurry8383 3 роки тому +1

    Im a bit confused by the toy example, couldn’t you differentiate ||wTX -y|| +theta||w|| as your loss function, treating theta as another weight?
    Does this not work because the norm is nonlinear or something?

    • @YannicKilcher
      @YannicKilcher  3 роки тому +1

      The two losses are with respect to different datasets. The outer optimization is over the validation set and conditional on having solved the inner problem to completion

    • @paulcurry8383
      @paulcurry8383 3 роки тому

      @@YannicKilcher ah so now you’d want 3 datasets so you can validate your hyper parameter training

  • @chuby35
    @chuby35 3 роки тому +1

    I'd love to see the optimal dataset for ffhq with some classifier, but I don't want to learn jax just for that. :) I hope someone will create that just for the laughs. :)

  • @mohammadmahdirahimi7452
    @mohammadmahdirahimi7452 3 роки тому +2

    Now we can have meta-meta-...-meta learning

  • @scottmiller2591
    @scottmiller2591 3 роки тому +1

    If X is N x p, where N is the number of data points and p is the dimension of the features in X, and similarly y is NX1, then you want X w = y, not w' X = y, so rows have to equal N rows on both sides. Also, it's the L2 norms SQUARED, not just the L2 norms, at least for Tikhonov ridge regression.
    This method seems interesting, I need to look at the proximal gradient stuff.

  • @MaeglinLiao
    @MaeglinLiao 3 роки тому +6

    Does this framework also apply to training GANs ? Or it is a tri-level optimization problem if hyperparameter-optimization is involved🤣.

    • @YannicKilcher
      @YannicKilcher  3 роки тому +9

      Yes it does. And yes, this actually supports any depth of inner loops 😁

    • @MaeglinLiao
      @MaeglinLiao 3 роки тому +4

      @@YannicKilcher it’s so cool to actually optimize the theoretical max-min problem

  • @scottmiller2591
    @scottmiller2591 3 роки тому +3

    I need to define an alias for jax.jacobian so I can just write jax.ocbian.

  • @barberb
    @barberb 3 роки тому +19

    This is just another example of plagarizing Schmidhuber

    • @nilsrethmeier8280
      @nilsrethmeier8280 3 роки тому +1

      Could you point put the JS paper(s) pls? Much appreciated.

    • @barberb
      @barberb 3 роки тому

      For those who don't understand, this is a joke.

    • @scottmiller2591
      @scottmiller2591 3 роки тому +2

      frey_squinting.jpg - Not sure if sarcasm or historically accurate, but odds are historically accurate.

    • @G12GilbertProduction
      @G12GilbertProduction 3 роки тому

      It hurts my grad(init) headbutt too.

  • @XOPOIIIO
    @XOPOIIIO 3 роки тому +1

    But why autodiff wasn't used before?

  • @drdca8263
    @drdca8263 3 роки тому +4

    This sounds like it could be a big deal. Does this primarily make multi-level optimization things easier to code, or does it also make these things notably faster when running?
    (I guess rather than "notably faster", what I really mean is like, a better big O time, compared to have they would usually have been implemented previously)
    It still has to compute the inner optimizations I suppose..
    Does this make some computations tasks that were previously infeasible, now feasible?
    (feasible in the sense of "we know how to write a program within a reasonable amount of time and effort, which will run within a reasonable amount of time, and produce the answer")
    Not that if the answer is no that this wouldn't still be important,
    just, trying to tell if, if I understood it better, whether it would seem *very* important, or just, not quite that important but still quite cool.

    • @aspergale9836
      @aspergale9836 3 роки тому +1

      Doing 2-level optim through, say, SGD with enough steps has been practically impossible for methods that do naive auto-diff then GD. This makes it possible.

    • @drdca8263
      @drdca8263 3 роки тому

      @@aspergale9836 Thanks!

  • @MadlipzMarathi
    @MadlipzMarathi 3 роки тому +1

    4:52 ok so, no one noticed it.

  • @al8-.W
    @al8-.W 3 роки тому

    Are we getting closer to the free lunch theorem ?

  • @surf8168
    @surf8168 2 роки тому

    "My nips are NP Hard!"

  • @sieyk
    @sieyk 2 роки тому

    Doesn't this mean that you can now feasibly backpropagate spiking neural networks?

    • @YannicKilcher
      @YannicKilcher  2 роки тому

      I'm not sure, is the inner loop run to optimum?

    • @sieyk
      @sieyk 2 роки тому

      @@YannicKilcher Ah, I did not realise this attempted to optimise two separate problems simultaneously. I was thinking for the spiking network you could just solve the dictionary example where the dictionary elements are the connected node pairs. Perhaps this has some good applications in reinforcement learning!

    • @joshuasmith5782
      @joshuasmith5782 2 роки тому +1

      From what I saw in another comment this doesn’t really work for non-differentiable optimization like discrete node communication. Afaik that is theoretically impossible unless you produce a differentiable approximation like a value function in RL.

  • @ankitaharwal5886
    @ankitaharwal5886 3 роки тому

    In implementation we passed theta=10.0, so it like we pass Weight Initialization in normal deep learning. And we would get new optimized thetha at the end.

  • @Pmaisterify
    @Pmaisterify 3 роки тому +1

    First

  • @scottmiller2591
    @scottmiller2591 3 роки тому

    "I'm not saying inner and outer loop, but 'Inner and Outer Loop'"

  • @dhruvpatel4948
    @dhruvpatel4948 3 роки тому +1

    Can we officially name this a Bayesian optimization killer?