From c3abf4db8f001f79e32a7dfacfd984f02d0a6678 Mon Sep 17 00:00:00 2001 From: Rojan Date: Thu, 5 Mar 2020 12:28:38 +0100 Subject: [PATCH 1/3] API for NN usage. --- FastText.NetWrapper/FastTextWrapper.Api.cs | 5 +++- FastText.NetWrapper/FastTextWrapper.cs | 34 ++++++++++++++++++++-- TestUtil/Program.cs | 21 ++++++++++--- 3 files changed, 52 insertions(+), 8 deletions(-) diff --git a/FastText.NetWrapper/FastTextWrapper.Api.cs b/FastText.NetWrapper/FastTextWrapper.Api.cs index 0028f8e..8feabc8 100644 --- a/FastText.NetWrapper/FastTextWrapper.Api.cs +++ b/FastText.NetWrapper/FastTextWrapper.Api.cs @@ -73,10 +73,13 @@ private struct TrainingArgsStruct private static extern void LoadModel(IntPtr hPtr, string path); [DllImport(FastTextDll)] - private static extern int GetMaxLabelLenght(IntPtr hPtr); + private static extern int GetMaxLabelLength(IntPtr hPtr); [DllImport(FastTextDll)] private static extern int GetLabels(IntPtr hPtr, IntPtr labels); + + [DllImport(FastTextDll)] + private static extern int GetNN(IntPtr hPtr, byte[] input, IntPtr predictedLabels, float[] predictedProbabilities, int n); [DllImport(FastTextDll)] private static extern float PredictSingle(IntPtr hPtr, byte[] input, IntPtr predicted); diff --git a/FastText.NetWrapper/FastTextWrapper.cs b/FastText.NetWrapper/FastTextWrapper.cs index 499ba7d..8d0437b 100644 --- a/FastText.NetWrapper/FastTextWrapper.cs +++ b/FastText.NetWrapper/FastTextWrapper.cs @@ -47,7 +47,7 @@ public FastTextWrapper() public void LoadModel(string path) { LoadModel(_fastText, path); - _maxLabelLen = GetMaxLabelLenght(_fastText); + _maxLabelLen = GetMaxLabelLength(_fastText); _modelLoaded = true; } @@ -74,6 +74,34 @@ public unsafe string[] GetLabels() return result; } + /// + /// Calculate nearest neighbor from input text. + /// + /// Text to calculating the nearest neighbor from. + /// Number of neighbors. + /// Nearest neighbor predictions. + public unsafe Prediction[] GetNN(string text, int number) + { + CheckModelLoaded(); + + var probs = new float[number]; + IntPtr labelsPtr; + + int cnt = GetNN(_fastText, _utf8.GetBytes(text), new IntPtr(&labelsPtr), probs, number); + var result = new Prediction[cnt]; + + for (int i = 0; i < cnt; i++) + { + var ptr = Marshal.ReadIntPtr(labelsPtr, i * IntPtr.Size); + string label = _utf8.GetString(GetStringBytes(ptr)); + result[i] = new Prediction(probs[i], label); + } + + DestroyStrings(labelsPtr, cnt); + + return result; + } + /// /// Predicts a single label from input text. /// @@ -169,7 +197,7 @@ public void Train(string inputPath, string outputPath, SupervisedArgs args) }; TrainSupervised(_fastText, inputPath, outputPath, argsStruct, args.LabelPrefix); - _maxLabelLen = GetMaxLabelLenght(_fastText); + _maxLabelLen = GetMaxLabelLength(_fastText); _modelLoaded = true; } @@ -214,7 +242,7 @@ public void Train(string inputPath, string outputPath, FastTextArgs args) }; Train(_fastText, inputPath, outputPath, argsStruct, args.LabelPrefix, args.PretrainedVectors); - _maxLabelLen = GetMaxLabelLenght(_fastText); + _maxLabelLen = GetMaxLabelLength(_fastText); _modelLoaded = true; } diff --git a/TestUtil/Program.cs b/TestUtil/Program.cs index 4769a02..9a1dcea 100644 --- a/TestUtil/Program.cs +++ b/TestUtil/Program.cs @@ -10,11 +10,11 @@ namespace TestUtil { class Program { - private const string Usage = "Usage: tesutil [train|trainlowlevel|load] train_file model_file"; + private static string Usage = $"Usage: tesutil [train|trainlowlevel|load] train_file model_file{Environment.NewLine}Usage: tesutil nn model_file"; static void Main(string[] args) { - if (args.Length < 3) + if ((args.FirstOrDefault() == "nn" && args.Length < 2) || (args.FirstOrDefault() != "nn" && args.Length < 3)) { Console.WriteLine(Usage); return; @@ -34,8 +34,16 @@ static void Main(string[] args) fastText.LoadModel(args[2]); break; } - - Test(fastText); + + if (args[0] != "nn") + { + Test(fastText); + } + else + { + fastText.LoadModel(args[1]); + TestNN(fastText); + } } } @@ -67,5 +75,10 @@ private static void Test(FastTextWrapper fastText) var predictions = fastText.PredictMultiple("Can I use a larger crockpot than the recipe calls for?", 4); var vector = fastText.GetSentenceVector("Can I use a larger crockpot than the recipe calls for?"); } + + private static void TestNN(FastTextWrapper fastText) + { + fastText.GetNN("train", 5); + } } } \ No newline at end of file From ebe71ae64c283b3064823aec4786e36383c436bc Mon Sep 17 00:00:00 2001 From: Rojan Date: Thu, 5 Mar 2020 18:18:15 +0100 Subject: [PATCH 2/3] This makes it possible to load a model from memory as a byte array. --- FastText.NetWrapper/FastTextWrapper.Api.cs | 3 +++ FastText.NetWrapper/FastTextWrapper.cs | 11 +++++++++++ TestUtil/Program.cs | 2 +- 3 files changed, 15 insertions(+), 1 deletion(-) diff --git a/FastText.NetWrapper/FastTextWrapper.Api.cs b/FastText.NetWrapper/FastTextWrapper.Api.cs index 8feabc8..d671926 100644 --- a/FastText.NetWrapper/FastTextWrapper.Api.cs +++ b/FastText.NetWrapper/FastTextWrapper.Api.cs @@ -71,6 +71,9 @@ private struct TrainingArgsStruct [DllImport(FastTextDll)] private static extern void LoadModel(IntPtr hPtr, string path); + + [DllImport(FastTextDll)] + private static extern void LoadModelData(IntPtr hPtr, byte[] data, long length); [DllImport(FastTextDll)] private static extern int GetMaxLabelLength(IntPtr hPtr); diff --git a/FastText.NetWrapper/FastTextWrapper.cs b/FastText.NetWrapper/FastTextWrapper.cs index 8d0437b..78a5f88 100644 --- a/FastText.NetWrapper/FastTextWrapper.cs +++ b/FastText.NetWrapper/FastTextWrapper.cs @@ -40,6 +40,17 @@ public FastTextWrapper() _fastText = CreateFastText(); } + /// + /// Loads a trained model from a byte array. + /// + /// Bytes array containing the model (.bin file). + public void LoadModel(byte[] bytes) + { + LoadModelData(_fastText, bytes, bytes.Length); + _maxLabelLen = GetMaxLabelLength(_fastText); + _modelLoaded = true; + } + /// /// Loads a trained model from a file. /// diff --git a/TestUtil/Program.cs b/TestUtil/Program.cs index 9a1dcea..39d153c 100644 --- a/TestUtil/Program.cs +++ b/TestUtil/Program.cs @@ -41,7 +41,7 @@ static void Main(string[] args) } else { - fastText.LoadModel(args[1]); + fastText.LoadModel(File.ReadAllBytes(args[1])); TestNN(fastText); } } From 83b77372c0ca1b13020c26c2227fe70e0f16e7e1 Mon Sep 17 00:00:00 2001 From: Oleg Tarasov Date: Mon, 9 Mar 2020 18:47:49 +0300 Subject: [PATCH 3/3] Updated nuget packages, fixed some comments --- FastText.NetWrapper/FastText.NetWrapper.csproj | 10 +++++----- FastText.NetWrapper/FastTextWrapper.cs | 4 ++-- TestUtil/Program.cs | 5 +++-- TestUtil/TestUtil.csproj | 2 +- 4 files changed, 11 insertions(+), 10 deletions(-) diff --git a/FastText.NetWrapper/FastText.NetWrapper.csproj b/FastText.NetWrapper/FastText.NetWrapper.csproj index 9660c17..5e81a71 100644 --- a/FastText.NetWrapper/FastText.NetWrapper.csproj +++ b/FastText.NetWrapper/FastText.NetWrapper.csproj @@ -19,11 +19,11 @@ - - - - - + + + + + \ No newline at end of file diff --git a/FastText.NetWrapper/FastTextWrapper.cs b/FastText.NetWrapper/FastTextWrapper.cs index 78a5f88..af273f9 100644 --- a/FastText.NetWrapper/FastTextWrapper.cs +++ b/FastText.NetWrapper/FastTextWrapper.cs @@ -86,9 +86,9 @@ public unsafe string[] GetLabels() } /// - /// Calculate nearest neighbor from input text. + /// Calculate nearest neighbors from input text. /// - /// Text to calculating the nearest neighbor from. + /// Text to calculate nearest neighbors from. /// Number of neighbors. /// Nearest neighbor predictions. public unsafe Prediction[] GetNN(string text, int number) diff --git a/TestUtil/Program.cs b/TestUtil/Program.cs index 39d153c..f407de4 100644 --- a/TestUtil/Program.cs +++ b/TestUtil/Program.cs @@ -10,7 +10,8 @@ namespace TestUtil { class Program { - private static string Usage = $"Usage: tesutil [train|trainlowlevel|load] train_file model_file{Environment.NewLine}Usage: tesutil nn model_file"; + private static string Usage = "Usage: tesutil [train|trainlowlevel|load] train_file model_file\n" + + "Usage: testutil nn model_file"; static void Main(string[] args) { @@ -78,7 +79,7 @@ private static void Test(FastTextWrapper fastText) private static void TestNN(FastTextWrapper fastText) { - fastText.GetNN("train", 5); + var predictions = fastText.GetNN("train", 5); } } } \ No newline at end of file diff --git a/TestUtil/TestUtil.csproj b/TestUtil/TestUtil.csproj index 835e72d..4a42bf9 100644 --- a/TestUtil/TestUtil.csproj +++ b/TestUtil/TestUtil.csproj @@ -2,7 +2,7 @@ Exe - netcoreapp2.2 + netcoreapp3.1