- Hands-On Meta Learning with Python
- Sudharsan Ravichandiran
- 670字
- 2021-07-02 14:29:17
Optimization as a model for few-shot learning
We know that, in few-shot learning, we learn from lesser data points, but how can we apply gradient descent in a few-shot learning setting? In a few-shot learning setting, gradient descent fails abruptly due to very few data points. Gradient descent optimization requires more data points to reach the convergence and minimize loss. So, we need a better optimization technique in the few-shot regime. Let's say we have a model parameterized by some parameter
. We initialize this parameter
with some random values and try to find the optimal value using gradient descent. Let's recall the update equation of our gradient descent:

In the previous equation, the following applies:
is the updated parameter
is the parameter value at previous time step
is the learning rate
is the gradient of loss function with respect to
Doesn't the update equation of gradient descent look familiar? Yes, you guessed it right: it resembles the cell state update equation of LSTM and it can be written as follows:

We can totally relate our LSTM cell update equation with gradient descent as, let's say = 1, then the following applies:



So, instead of using gradient descent as an optimizer in the few-shot learning regime, we can use LSTM as an optimizer. Our meta learner is the LSTM, which learns the update rule for training our model. So we use two networks: one, our base learner, which learns to perform a task, and the other, the meta learner, which tries to find the optimal parameter. But how does this work?
We know that, in LSTM, we use a forget gate for discarding information that is not required in the memory, and it can be represented as follows:

How can this forget gate be useful in our optimization setting? Let's say we are in a position where the loss is high, and the gradient is close to zero. How can we escape from this position? In this case, we can shrink the parameters of our model and forget some parts of its previous value. So, we can use our forget gate to do that and it takes a current parameter value , current loss
, current gradient
and the previous forget gate as the input; it can be represented as follows:

Now let's come to the input gate. We know that the input gate in LSTM is used for deciding what value to update, and it can be represented as follows:

In our few-shot learning setting, we can use this input gate to tune our learning rate to learn quickly while preventing it from divergence:

So, our meta learner learns the optimal value of and
after several updates.
But still, how does this work?
Let's say we have a base network parameterized by
and our LSTM meta learner
parameterized by
. Assume that we have a dataset
. We split our dataset as
and
for training and testing respectively. First, we randomly initialize our meta learner parameter
.
For some T number of iterations, we randomly sample data points from , calculate the loss, and then we calculate the gradients of loss with respect to our model parameter
. Now we feed this gradient, loss, and meta learner parameter
to our meta learner. Our meta learner
will return a cell state
and then we update our base network
parameter
at a time t as
. We repeat this for some N number of times, as shown in the following diagram:

So, after T iterations, we will have an optimal parameter . But how can we check the performance of
and how can we update our meta learner parameter? We take the test set and compute the loss on our test set with parameter
. Then, we calculate the gradients of the loss with respect to our meta learner parameter
and then we update
, as shown here:

We do this for some n number of iterations and update our meta learner. The overall algorithm is shown here:

- 數據要素安全流通
- PyTorch深度學習實戰:從新手小白到數據科學家
- Visual Studio 2015 Cookbook(Second Edition)
- Voice Application Development for Android
- 商業分析思維與實踐:用數據分析解決商業問題
- 數據結構與算法(C語言版)
- Microsoft Power BI數據可視化與數據分析
- 智能數據時代:企業大數據戰略與實戰
- SQL優化最佳實踐:構建高效率Oracle數據庫的方法與技巧
- 數據科學實戰指南
- 一本書講透Elasticsearch:原理、進階與工程實踐
- Spark分布式處理實戰
- 數字IC設計入門(微課視頻版)
- SAS金融數據挖掘與建模:系統方法與案例解析
- 數據挖掘與機器學習-WEKA應用技術與實踐(第二版)