diff --git a/CoreMLBert.xcodeproj/project.pbxproj b/CoreMLBert.xcodeproj/project.pbxproj index a9c8b35..f331df4 100644 --- a/CoreMLBert.xcodeproj/project.pbxproj +++ b/CoreMLBert.xcodeproj/project.pbxproj @@ -68,6 +68,8 @@ 79F2CCA022C666C7009F8551 /* question_tokens.json in Resources */ = {isa = PBXBuildFile; fileRef = 79F2CC9F22C666C7009F8551 /* question_tokens.json */; }; 79F2CCA222C6717E009F8551 /* LoaderView.swift in Sources */ = {isa = PBXBuildFile; fileRef = 79F2CCA122C6717D009F8551 /* LoaderView.swift */; }; 79F7060E22EA0CA900C4432C /* BERTSQUADFP16.mlmodel in Sources */ = {isa = PBXBuildFile; fileRef = 79F2CC9022C5590C009F8551 /* BERTSQUADFP16.mlmodel */; }; + DA3628A82989BB04007A3BE6 /* CLIPTokenizer.swift in Sources */ = {isa = PBXBuildFile; fileRef = DA3628A72989BB04007A3BE6 /* CLIPTokenizer.swift */; }; + DA3628A92989BB04007A3BE6 /* CLIPTokenizer.swift in Sources */ = {isa = PBXBuildFile; fileRef = DA3628A72989BB04007A3BE6 /* CLIPTokenizer.swift */; }; /* End PBXBuildFile section */ /* Begin PBXContainerItemProxy section */ @@ -136,6 +138,7 @@ 79F2CC9D22C57825009F8551 /* BertForQATests.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = BertForQATests.swift; sourceTree = ""; }; 79F2CC9F22C666C7009F8551 /* question_tokens.json */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = text.json; path = question_tokens.json; sourceTree = ""; }; 79F2CCA122C6717D009F8551 /* LoaderView.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = LoaderView.swift; sourceTree = ""; }; + DA3628A72989BB04007A3BE6 /* CLIPTokenizer.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = CLIPTokenizer.swift; sourceTree = ""; }; /* End PBXFileReference section */ /* Begin PBXFrameworksBuildPhase section */ @@ -273,6 +276,7 @@ 79F2CC8A22C549C1009F8551 /* Utils.swift */, 79F2CC9622C56890009F8551 /* Math.swift */, 79F2CC9922C57132009F8551 /* MLMultiArray+Utils.swift */, + DA3628A72989BB04007A3BE6 /* CLIPTokenizer.swift */, 796DF50F22E0EB1D00140C02 /* GPT2Tokenizer.swift */, 796DF51922E0FF7A00140C02 /* GPT2ByteEncoder.swift */, 796DF58622E2727000140C02 /* GPT2.swift */, @@ -457,6 +461,7 @@ buildActionMask = 2147483647; files = ( 796DF55422E1026700140C02 /* ViewController.swift in Sources */, + DA3628A92989BB04007A3BE6 /* CLIPTokenizer.swift in Sources */, 796DF58722E2727000140C02 /* GPT2.swift in Sources */, 796DF56F22E1039C00140C02 /* SquadDataset.swift in Sources */, 796DF55022E1026700140C02 /* AppDelegate.swift in Sources */, @@ -486,6 +491,7 @@ buildActionMask = 2147483647; files = ( 79F2CC9A22C57132009F8551 /* MLMultiArray+Utils.swift in Sources */, + DA3628A82989BB04007A3BE6 /* CLIPTokenizer.swift in Sources */, 79F2CCA222C6717E009F8551 /* LoaderView.swift in Sources */, 79F2CC6022C50078009F8551 /* ViewController.swift in Sources */, 79F2CC8122C5041C009F8551 /* SquadDataset.swift in Sources */, diff --git a/Sources/CLIPTokenizer.swift b/Sources/CLIPTokenizer.swift new file mode 100644 index 0000000..30122c1 --- /dev/null +++ b/Sources/CLIPTokenizer.swift @@ -0,0 +1,132 @@ +// +// CLIPTokenizer.swift +// CoreMLBert +// +// Created by Matthew Waller on 1/31/23. +// Copyright © 2023 Hugging Face. All rights reserved. +// + +import Foundation + +class CLIPTokenizer { + let bpeRanks: Dictionary + private let encoder: [String: Int] + private let decoder: [Int: String] + + init() { + let url = Bundle.main.url(forResource: "merges", withExtension: "txt")! + let bpeMergesTxt = try! String(contentsOf: url) + let arr = bpeMergesTxt.split(separator: "\n").map { String($0) } + var bpeRanks: Dictionary = [:] + for i in 1.. [String] { + let RE = "<\\|startoftext\\|>|<\\|endoftext\\|>|'s|'t|'re|'ve|'m|'ll|'d|[\\p{L}]+|[\\p{N}]|[^\\s\\p{L}\\p{N}]+" + let tokens = text.ranges(of: RE).map { String(text[$0]) } + return tokens.map { (token) -> String in + return Array(token.utf8).map { byteEncoder[$0]! }.joined() + } + } + + private func getPairs(word: [String]) -> Set { + var s = Set() + for i in 0.. String { + if token.count <= 1 { + return token + "" + } + + var word = Array(token).map { String($0)} + let last = (word.last ?? "") + "" + word.removeLast() + word.append(last) + var pairs = Array(getPairs(word: word)) + if pairs.isEmpty { + return token + "" + } + + while true { + let bigrams = pairs.filter { (bp) -> Bool in bpeRanks[bp] != nil } + if bigrams.count == 0 { + break + } + let bigram = bigrams.min { (bp1, bp2) -> Bool in + return bpeRanks[bp1]! < bpeRanks[bp2]! + }! + let first = bigram.a + let second = bigram.b + var newWord: [String] = [] + var i = 0 + while i < word.count { + if let j = word[i.. [String] { + var tokens: [String] = [] + let lowercased = text.lowercased() + for token in self.byteEncode(text: lowercased) { + let xx = self.bpe(token: token).split(separator: " ").map { String($0) } + tokens.append(contentsOf: xx) + } + return tokens + } + + /// Main entry point + func encode(text: String) -> [Int] { + return tokenize(text: text).map { encoder[$0]! } + } + + /// Decode + func decode(tokens: [Int]) -> String { + let text = tokens.map { decoder[$0]! }.joined(separator: "") + let utfCodepoints = text.map { byteDecoder[String($0)]! } + return String(decoding: utfCodepoints, as: UTF8.self) + } +}