- Hands-On Meta Learning with Python
- Sudharsan Ravichandiran
- 669字
- 2021-07-02 14:29:20
Prototypical networks
Prototypical networks are yet another simple, efficient, few shot learning algorithm. Like siamese networks, a prototypical network tries to learn the metric space to perform classification. The basic idea of prototypical networks is to create a prototypical representation of each class and classify a query point (that is, a new point) based on the distance between the class prototype and the query point.
Let's say we have a support set comprising images of lions, elephants, and dogs, as shown in the following diagram:

So, we have three classes: {lion, elephant, dog}. Now we need to create a prototypical representation for each of these three class. How can we build the prototype of these three classes? First, we will learn the embeddings of each data point using an embedding function. The embedding function, , can be any function that can be used to extract features. Since our input is an image, we can use the convolutional network as our embedding function, which will extract features from the input image:

Once we learn the embeddings of each data point, we take the mean embeddings of data points in each class and form the class prototype, as shown in the following diagram. So, a class prototype is basically the mean embeddings of data points in a class:

Similarly, when a new data point comes in, that is, a query point for which we want to predict the label, we will generate the embeddings for this new data point using the same embedding function that we used to create the class prototype—that is, we generate the embeddings for our query point using the convolutional network:

Once we have the embedding for our query point, we compare the distance between class prototype and query point embeddings to find which class the query point belongs to. We can use Euclidean distance as a measure for finding the distance between the class prototype and query points embeddings, as shown here:

After finding the distance between the class prototype and query point embeddings, we apply softmax to this distance and get the probabilities. Since we have three classes, that is, lion, elephant and dog, we will get three probabilities. So, the class that has the highest probability will be the class of our query point.
Since we want our network to learn from a few data points, that is, we want to perform few-shot learning, we train our network in the same way. So, we use episodic training—for each episode, we randomly sample a few data points from each class in our dataset and we call that a support set and train the network using only the support set, instead of the whole dataset. Similarly, we randomly sample a point from the dataset as a query point and try to predict its class. So, in this way, our network is trained how to learn from a smaller set of data points.
The overall flow of our prototypical network is shown in the following diagram. As you can see, first, we will generate the embeddings for all of the data points in our support set and build the class prototype by taking the mean embeddings of data points in a class. We also generate the embeddings for our query point. Then, we compute the distance between class prototype and query point embeddings. We use Euclidean distance as a distance measure. Then, we apply softmax to this distance and get the probabilities. As you can see in the following diagram since our query point is a lion, the probability for lion is high—0.9:

Prototypical networks are not only used for one-shot/few-shot learning but are also used in zero-shot learning. Consider the case where you have no data points per class, but you have the meta information containing a high-level description of each class. So, in those cases, we learn the embeddings from the meta information of each class to form the class prototype and then perform classification with the class prototype.
- 復雜性思考:復雜性科學和計算模型(原書第2版)
- 數據庫應用基礎教程(Visual FoxPro 9.0)
- OracleDBA實戰攻略:運維管理、診斷優化、高可用與最佳實踐
- The Game Jam Survival Guide
- 數據挖掘原理與SPSS Clementine應用寶典
- Power BI商業數據分析完全自學教程
- 大數據架構商業之路:從業務需求到技術方案
- Access數據庫開發從入門到精通
- 數據庫應用系統技術
- 改進的群智能算法及其應用
- 數據庫原理與設計實驗教程(MySQL版)
- Microsoft Dynamics NAV 2015 Professional Reporting
- Oracle 11g數據庫管理員指南
- 數字化轉型實踐:構建云原生大數據平臺
- 碼上行動:利用Python與ChatGPT高效搞定Excel數據分析