Efficient and Modular Implicit Differentiation (Machine Learning Research Paper Explained)
Вставка
- Опубліковано 31 тра 2024
- #implicitfunction #jax #autodiff
Many problems in Machine Learning involve loops of inner and outer optimization. Finding update steps for the outer loop is usually difficult, because of the.need to differentiate through the inner loop's procedure over multiple steps. Such loop unrolling is very limited and constrained to very few steps. Other papers have found solutions around unrolling in very specific, individual problems. This paper proposes a unified framework for implicit differentiation of inner optimization procedures without unrolling and provides implementations that integrate seamlessly into JAX.
OUTLINE:
0:00 - Intro & Overview
2:05 - Automatic Differentiation of Inner Optimizations
4:30 - Example: Meta-Learning
7:45 - Unrolling Optimization
13:00 - Unified Framework Overview & Pseudocode
21:10 - Implicit Function Theorem
25:45 - More Technicalities
28:45 - Experiments
ERRATA:
- Dataset Distillation is done with respect to the training set, not the validation or test set.
Paper: arxiv.org/abs/2105.15183
Code coming soon
Abstract:
Automatic differentiation (autodiff) has revolutionized machine learning. It allows expressing complex computations by composing elementary ones in creative ways and removes the burden of computing their derivatives by hand. More recently, differentiation of optimization problem solutions has attracted widespread attention with applications such as optimization as a layer, and in bi-level problems such as hyper-parameter optimization and meta-learning. However, the formulas for these derivatives often involve case-by-case tedious mathematical derivations. In this paper, we propose a unified, efficient and modular approach for implicit differentiation of optimization problems. In our approach, the user defines (in Python in the case of our implementation) a function F capturing the optimality conditions of the problem to be differentiated. Once this is done, we leverage autodiff of F and implicit differentiation to automatically differentiate the optimization problem. Our approach thus combines the benefits of implicit differentiation and autodiff. It is efficient as it can be added on top of any state-of-the-art solver and modular as the optimality condition specification is decoupled from the implicit differentiation mechanism. We show that seemingly simple principles allow to recover many recently proposed implicit differentiation methods and create new ones easily. We demonstrate the ease of formulating and solving bi-level optimization problems using our framework. We also showcase an application to the sensitivity analysis of molecular dynamics.
Authors: Mathieu Blondel, Quentin Berthet, Marco Cuturi, Roy Frostig, Stephan Hoyer, Felipe Llinares-López, Fabian Pedregosa, Jean-Philippe Vert
Links:
TabNine Code Completion (Referral): bit.ly/tabnine-yannick
UA-cam: / yannickilcher
Twitter: / ykilcher
Discord: / discord
BitChute: www.bitchute.com/channel/yann...
Minds: www.minds.com/ykilcher
Parler: parler.com/profile/YannicKilcher
LinkedIn: / yannic-kilcher-488534136
BiliBili: space.bilibili.com/1824646584
If you want to support me, the best thing to do is to share out the content :)
If you want to support me financially (completely optional and voluntary, but a lot of people have asked for this):
SubscribeStar: www.subscribestar.com/yannick...
Patreon: / yannickilcher
Bitcoin (BTC): bc1q49lsw3q325tr58ygf8sudx2dqfguclvngvy2cq
Ethereum (ETH): 0x7ad3513E3B8f66799f507Aa7874b1B0eBC7F85e2
Litecoin (LTC): LQW2TRyKYetVC8WjFkhpPhtpbDM4Vw7r9m
Monero (XMR): 4ACL8AGrEo5hAir8A9CeVrW8pEauWvnp1WnSDZxW7tziCDLhZAGsgzhRQABDnFy8yuM9fWJDviJPHKRjV4FWt19CJZN9D4n - Наука та технологія
OUTLINE:
0:00 - Intro & Overview
2:05 - Automatic Differentiation of Inner Optimizations
4:30 - Example: Meta-Learning
7:45 - Unrolling Optimization
13:00 - Unified Framework Overview & Pseudocode
21:10 - Implicit Function Theorem
25:45 - More Technicalities
28:45 - Experiments
ERRATA:
- Dataset Distillation is done with respect to the training set, not the validation or test set.
I've been loving the new speed at which Dr. Kilcher is putting all of these videos out! So much to learn!
!!!
Dr lightspeed, you are a hero
Now I want to learn jax
I have the feeling this one is going to be big
Nice to give credit to an underrated area
This is really cool, but if I understood it correctly, the applications are still quite limited. The problem is that you need to compute a d*d jacobian matrix (and then solve a linear system involving this matrix), d being the dimension of the output of your inner optimization algorithm and input of the optimality condition function.
So, for any application involving neural networks for example, unless I'm wrong your little d would be the number of parameters of your neural network. Before even talking about solving the linear system, you would need to store a matrix of size "the number of parameters of the neural network SQUARED"
Thank you for introducing interesting paper!
18:00 the equation of ridge_solver is the same to the optimality condition F=0. I'm very little bit confused...
23:45 instead of optimality procedure has to be differentiable, only the optimality condition now needs to be differentiable. best summary ever.
Recurrent neural networks also struggled with the problem of computing gradients by unrolling the recurrence. I wonder if this technique could be applied in that instance as well.
hmm. Interesting idea, but what would be the optimality condition for the inner loop of an RNN layer?
@@chuby35 I’d have to think about it
Powerful stuff 💪🏼
In your expression at 8:22 I think you've got which terms are transposed reversed. Notice in the code, the terms all have `_tr` at the end. It should be $(X X^t + \lambda I) w = X y^t$. The way it's written in the video, it looks like you're solving for what weights should be applied to every *point,* but instead we want weights applies to each *component* (i.e., a linear function).
20:00 What does the custom_root decorator do? What does it return?
.
20:15 I dont understand, are we using the ridge solver as an approximation for the inner SGD? Why?
.
I dont understand what is going on in this paper, but if it works, its gonna be big.
In general decorators in python are a way to apply properties to a function by redefining it internally, for example you could have a decorator that add a timer to the function you apply it to so you can benchmark things easily.
In this case, the custom_root decorator takes in the differentiable optimality conditions this paper talks about, so a function F that is differentiable and is zero when we found the right solution. When applied to a function like the ridge_solver, it redefines its gradient in terms of lambda by instead using this paper method, rather than JAX builtin autograd.
The ridge solver is just an example that doesn't have anything to do with SGD, this method however provides you with the gradient to find the optimal value for lambda.
Essentially, the ridge regression function solves the problem of finding w minimizing this loss: ||w^T * X - y|| + λ ||w||, but we'd also like to find the optimal value of λ that minimizes this loss even further! This entire paper solves the problem **WITHOUT** us having to look inside the details of the ridge regression solver, internally we could've used whatever crazy method as long as it satisfies that minimization task, and by specifying a differentiable F we get for free the gradient with respect to λ, which allows us to minimize the loss even further.
A way to imagine this, a Yannic mentions, is to think of λ as a hyperparameter we have to search, but instead of just doing a black box search we actually get gradients efficiently. Imagine if instead of minimizing ||w^T * X - y|| + λ ||w||, you were to minimize ||NN(x) - y|| + λ ||w||, where NN is a neural network (Again, regression but with a regularization term dependant on lambda). Then the solver would be all the gradients step we take for training our model for a fixed λ, **which might internally be so complicated that backpropagating through it wrt. λ is simply impossible**. But with this method you only care that you have a routine, which internally may use SGD, that optimizes the weights of the network you defined, and you still get the gradient with respect to the hyperparameters by defining the optimization as a root of another differentiable function F!
Even nicer, this method allows to directly solve optimization problems by just stating the optimal conditions, which before required entire papers (like OptNet) to derive.
is it now implmented in JAX library ?
Dude, I wanna go to sleep... Damn it :D
GradSlam uses unrolling, perhaps it can utilize this!
Im a bit confused by the toy example, couldn’t you differentiate ||wTX -y|| +theta||w|| as your loss function, treating theta as another weight?
Does this not work because the norm is nonlinear or something?
The two losses are with respect to different datasets. The outer optimization is over the validation set and conditional on having solved the inner problem to completion
@@YannicKilcher ah so now you’d want 3 datasets so you can validate your hyper parameter training
I'd love to see the optimal dataset for ffhq with some classifier, but I don't want to learn jax just for that. :) I hope someone will create that just for the laughs. :)
Now we can have meta-meta-...-meta learning
If X is N x p, where N is the number of data points and p is the dimension of the features in X, and similarly y is NX1, then you want X w = y, not w' X = y, so rows have to equal N rows on both sides. Also, it's the L2 norms SQUARED, not just the L2 norms, at least for Tikhonov ridge regression.
This method seems interesting, I need to look at the proximal gradient stuff.
Does this framework also apply to training GANs ? Or it is a tri-level optimization problem if hyperparameter-optimization is involved🤣.
Yes it does. And yes, this actually supports any depth of inner loops 😁
@@YannicKilcher it’s so cool to actually optimize the theoretical max-min problem
I need to define an alias for jax.jacobian so I can just write jax.ocbian.
This is just another example of plagarizing Schmidhuber
Could you point put the JS paper(s) pls? Much appreciated.
For those who don't understand, this is a joke.
frey_squinting.jpg - Not sure if sarcasm or historically accurate, but odds are historically accurate.
It hurts my grad(init) headbutt too.
But why autodiff wasn't used before?
This sounds like it could be a big deal. Does this primarily make multi-level optimization things easier to code, or does it also make these things notably faster when running?
(I guess rather than "notably faster", what I really mean is like, a better big O time, compared to have they would usually have been implemented previously)
It still has to compute the inner optimizations I suppose..
Does this make some computations tasks that were previously infeasible, now feasible?
(feasible in the sense of "we know how to write a program within a reasonable amount of time and effort, which will run within a reasonable amount of time, and produce the answer")
Not that if the answer is no that this wouldn't still be important,
just, trying to tell if, if I understood it better, whether it would seem *very* important, or just, not quite that important but still quite cool.
Doing 2-level optim through, say, SGD with enough steps has been practically impossible for methods that do naive auto-diff then GD. This makes it possible.
@@aspergale9836 Thanks!
4:52 ok so, no one noticed it.
Are we getting closer to the free lunch theorem ?
"My nips are NP Hard!"
Doesn't this mean that you can now feasibly backpropagate spiking neural networks?
I'm not sure, is the inner loop run to optimum?
@@YannicKilcher Ah, I did not realise this attempted to optimise two separate problems simultaneously. I was thinking for the spiking network you could just solve the dictionary example where the dictionary elements are the connected node pairs. Perhaps this has some good applications in reinforcement learning!
From what I saw in another comment this doesn’t really work for non-differentiable optimization like discrete node communication. Afaik that is theoretically impossible unless you produce a differentiable approximation like a value function in RL.
In implementation we passed theta=10.0, so it like we pass Weight Initialization in normal deep learning. And we would get new optimized thetha at the end.
First
"I'm not saying inner and outer loop, but 'Inner and Outer Loop'"
Can we officially name this a Bayesian optimization killer?