Skip to content

Commit

Permalink
Merge branch 'NoorDigitalAgency-master'
Browse files Browse the repository at this point in the history
  • Loading branch information
olegtarasov committed Mar 9, 2020
2 parents 63d5656 + 83b7737 commit ab971d0
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 14 deletions.
10 changes: 5 additions & 5 deletions FastText.NetWrapper/FastText.NetWrapper.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@
</PropertyGroup>

<ItemGroup>
<PackageReference Include="FastText.Native.Linux" Version="1.0.73" />
<PackageReference Include="FastText.Native.MacOs" Version="1.0.74" />
<PackageReference Include="FastText.Native.Windows" Version="1.0.72" />
<PackageReference Include="LibLog" Version="5.0.6" />
<PackageReference Include="NativeLibraryManager" Version="1.0.14" />
<PackageReference Include="FastText.Native.Linux" Version="1.0.84" />
<PackageReference Include="FastText.Native.MacOs" Version="1.0.84" />
<PackageReference Include="FastText.Native.Windows" Version="1.0.84" />
<PackageReference Include="LibLog" Version="5.0.8" />
<PackageReference Include="NativeLibraryManager" Version="1.0.18" />
</ItemGroup>

</Project>
8 changes: 7 additions & 1 deletion FastText.NetWrapper/FastTextWrapper.Api.cs
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,18 @@ private struct TrainingArgsStruct

[DllImport(FastTextDll)]
private static extern void LoadModel(IntPtr hPtr, string path);

[DllImport(FastTextDll)]
private static extern void LoadModelData(IntPtr hPtr, byte[] data, long length);

[DllImport(FastTextDll)]
private static extern int GetMaxLabelLenght(IntPtr hPtr);
private static extern int GetMaxLabelLength(IntPtr hPtr);

[DllImport(FastTextDll)]
private static extern int GetLabels(IntPtr hPtr, IntPtr labels);

[DllImport(FastTextDll)]
private static extern int GetNN(IntPtr hPtr, byte[] input, IntPtr predictedLabels, float[] predictedProbabilities, int n);

[DllImport(FastTextDll)]
private static extern float PredictSingle(IntPtr hPtr, byte[] input, IntPtr predicted);
Expand Down
45 changes: 42 additions & 3 deletions FastText.NetWrapper/FastTextWrapper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,25 @@ public FastTextWrapper()
_fastText = CreateFastText();
}

/// <summary>
/// Loads a trained model from a byte array.
/// </summary>
/// <param name="bytes">Bytes array containing the model (.bin file).</param>
public void LoadModel(byte[] bytes)
{
LoadModelData(_fastText, bytes, bytes.Length);
_maxLabelLen = GetMaxLabelLength(_fastText);
_modelLoaded = true;
}

/// <summary>
/// Loads a trained model from a file.
/// </summary>
/// <param name="path">Path to a model (.bin file).</param>
public void LoadModel(string path)
{
LoadModel(_fastText, path);
_maxLabelLen = GetMaxLabelLenght(_fastText);
_maxLabelLen = GetMaxLabelLength(_fastText);
_modelLoaded = true;
}

Expand All @@ -74,6 +85,34 @@ public unsafe string[] GetLabels()
return result;
}

/// <summary>
/// Calculate nearest neighbors from input text.
/// </summary>
/// <param name="text">Text to calculate nearest neighbors from.</param>
/// <param name="number">Number of neighbors.</param>
/// <returns>Nearest neighbor predictions.</returns>
public unsafe Prediction[] GetNN(string text, int number)
{
CheckModelLoaded();

var probs = new float[number];
IntPtr labelsPtr;

int cnt = GetNN(_fastText, _utf8.GetBytes(text), new IntPtr(&labelsPtr), probs, number);
var result = new Prediction[cnt];

for (int i = 0; i < cnt; i++)
{
var ptr = Marshal.ReadIntPtr(labelsPtr, i * IntPtr.Size);
string label = _utf8.GetString(GetStringBytes(ptr));
result[i] = new Prediction(probs[i], label);
}

DestroyStrings(labelsPtr, cnt);

return result;
}

/// <summary>
/// Predicts a single label from input text.
/// </summary>
Expand Down Expand Up @@ -169,7 +208,7 @@ public void Train(string inputPath, string outputPath, SupervisedArgs args)
};

TrainSupervised(_fastText, inputPath, outputPath, argsStruct, args.LabelPrefix);
_maxLabelLen = GetMaxLabelLenght(_fastText);
_maxLabelLen = GetMaxLabelLength(_fastText);
_modelLoaded = true;
}

Expand Down Expand Up @@ -214,7 +253,7 @@ public void Train(string inputPath, string outputPath, FastTextArgs args)
};

Train(_fastText, inputPath, outputPath, argsStruct, args.LabelPrefix, args.PretrainedVectors);
_maxLabelLen = GetMaxLabelLenght(_fastText);
_maxLabelLen = GetMaxLabelLength(_fastText);
_modelLoaded = true;
}

Expand Down
22 changes: 18 additions & 4 deletions TestUtil/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,12 @@ namespace TestUtil
{
class Program
{
private const string Usage = "Usage: tesutil [train|trainlowlevel|load] train_file model_file";
private static string Usage = "Usage: tesutil [train|trainlowlevel|load] train_file model_file\n" +
"Usage: testutil nn model_file";

static void Main(string[] args)
{
if (args.Length < 3)
if ((args.FirstOrDefault() == "nn" && args.Length < 2) || (args.FirstOrDefault() != "nn" && args.Length < 3))
{
Console.WriteLine(Usage);
return;
Expand All @@ -34,8 +35,16 @@ static void Main(string[] args)
fastText.LoadModel(args[2]);
break;
}

Test(fastText);

if (args[0] != "nn")
{
Test(fastText);
}
else
{
fastText.LoadModel(File.ReadAllBytes(args[1]));
TestNN(fastText);
}
}
}

Expand Down Expand Up @@ -67,5 +76,10 @@ private static void Test(FastTextWrapper fastText)
var predictions = fastText.PredictMultiple("Can I use a larger crockpot than the recipe calls for?", 4);
var vector = fastText.GetSentenceVector("Can I use a larger crockpot than the recipe calls for?");
}

private static void TestNN(FastTextWrapper fastText)
{
var predictions = fastText.GetNN("train", 5);
}
}
}
2 changes: 1 addition & 1 deletion TestUtil/TestUtil.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

<PropertyGroup>
<OutputType>Exe</OutputType>
<TargetFramework>netcoreapp2.2</TargetFramework>
<TargetFramework>netcoreapp3.1</TargetFramework>
</PropertyGroup>

<ItemGroup>
Expand Down

0 comments on commit ab971d0

Please sign in to comment.