- Machine Learning with Swift
- Alexander Sosnovshchenko
- 387字
- 2021-06-24 18:55:04
Implementing KNN in Swift
The KNN classifier works with virtually any type of data since you define distance metric for your data points. That's why we define it as a generic structure parameterized with types for features and labels. Labels should conform to a Hashable protocol, as we're going to use them for dictionary keys:
struct kNN<X, Y> where Y: Hashable { ... }
KNN has two hyperparameters: k—the number of neighbors var k: Int, and distance metric. We'll define it elsewhere, and pass during the initialization. Metric is a function, returning double distance for any two samples x1 and x2:
var distanceMetric: (_ x1: X, _ x2: X) -> Double
During the initialization, we just record the hyperparameters inside our structure. The definition of init looks like this:
init (k: Int, distanceMetric: @escaping (_ x1: X, _ x2: X) -> Double) { self.k = k self.distanceMetric = distanceMetric }
KNN stores all its training data points. We are using the array of pairs (features, label) for this purposes:
private var data: [(X, Y)] = []
As usual with supervised learning models, we'll stick to the interface with two methods, train and predict, which reflect the two phases of a supervised algorithm's life. The train method in the case of KNN just saves the data points to use them later in the predict method:
mutating func train(X: [X], y: [Y]) { data.append(contentsOf: zip(X, y)) }
The predict method takes the data point and predicts the label for it:
func predict(x: X) -> Y? { assert(data.count > 0, "Please, use method train() at first to provide training data.") assert(k > 0, "Error, k must be greater then 0.")
For this, we iterate through all samples in the training dataset, and compare them with the input sample x. We use (distance, label) tuples to keep track of distances to each of the training samples. After this, we sort all the samples descending by distances, and take the (prefix) first k elements:
let tuples = data .map { (distanceMetric(x, $0.0), $0.1) } .sorted { $0.0 < $1.0 } .prefix(upTo: k)
Now we arrange majority voting among top k samples. We count the frequency of each label, and sort them from descending:
let countedSet = NSCountedSet(array: tuples.map{$0.1}) let result = countedSet.allObjects.sorted { countedSet.count(for: $0) > countedSet.count(for: $1) }.first return result as? Y }
The result variable holds a predicted class label.
- 嵌入式技術(shù)基礎(chǔ)與實(shí)踐(第5版)
- Getting Started with Qt 5
- 平衡掌控者:游戲數(shù)值經(jīng)濟(jì)設(shè)計(jì)
- scikit-learn:Machine Learning Simplified
- Building 3D Models with modo 701
- 單片機(jī)系統(tǒng)設(shè)計(jì)與開發(fā)教程
- 基于Proteus仿真的51單片機(jī)應(yīng)用
- 超大流量分布式系統(tǒng)架構(gòu)解決方案:人人都是架構(gòu)師2.0
- Neural Network Programming with Java(Second Edition)
- Learning Less.js
- The Reinforcement Learning Workshop
- PIC系列單片機(jī)的流碼編程
- Applied Deep Learning with Keras
- 超炫的35個Arduino制作項(xiàng)目
- CPU設(shè)計(jì)實(shí)戰(zhàn):LoongArch版