Skip to content

Commit

Permalink
Rename NameEntity to NamedEntity (#6917)
Browse files Browse the repository at this point in the history
  • Loading branch information
ericstj authored Dec 21, 2023
1 parent b8f71b9 commit a60be5f
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 16 deletions.
6 changes: 5 additions & 1 deletion src/Microsoft.ML.TorchSharp/NasBert/BertTaskType.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

using System;
using System.Collections.Generic;
using System.ComponentModel;
using System.Text;

namespace Microsoft.ML.TorchSharp.NasBert
Expand All @@ -17,7 +18,10 @@ public enum BertTaskType
MaskedLM = 1,
TextClassification = 2,
SentenceRegression = 3,
NameEntityRecognition = 4,
NamedEntityRecognition = 4,
[Obsolete("Please use NamedEntityRecognition instead", false)]
[EditorBrowsable(EditorBrowsableState.Never)]
NameEntityRecognition = NamedEntityRecognition,
QuestionAnswering = 5
}
}
12 changes: 6 additions & 6 deletions src/Microsoft.ML.TorchSharp/NasBert/NasBertTrainer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ private protected override Module CreateModule(IChannel ch, IDataView input)
EnglishRoberta tokenizerModel = Tokenizer.RobertaModel();

NasBertModel model;
if (Parent.BertOptions.TaskType == BertTaskType.NameEntityRecognition)
if (Parent.BertOptions.TaskType == BertTaskType.NamedEntityRecognition)
model = new NerModel(Parent.BertOptions, tokenizerModel.PadIndex, tokenizerModel.SymbolsCount, Parent.Option.NumberOfClasses);
else
model = new ModelForPrediction(Parent.BertOptions, tokenizerModel.PadIndex, tokenizerModel.SymbolsCount, Parent.Option.NumberOfClasses);
Expand Down Expand Up @@ -268,7 +268,7 @@ private protected override torch.Tensor PrepareRowTensor()
private protected override void RunModelAndBackPropagate(ref List<Tensor> inputTensors, ref Tensor targetsTensor)
{
Tensor logits = default;
if (Parent.BertOptions.TaskType == BertTaskType.NameEntityRecognition)
if (Parent.BertOptions.TaskType == BertTaskType.NamedEntityRecognition)
{
int[,] lengthArray = new int[inputTensors.Count, 1];
for (int i = 0; i < inputTensors.Count; i++)
Expand All @@ -293,7 +293,7 @@ private protected override void RunModelAndBackPropagate(ref List<Tensor> inputT
torch.Tensor loss;
if (Parent.BertOptions.TaskType == BertTaskType.TextClassification)
loss = torch.nn.CrossEntropyLoss(reduction: Parent.BertOptions.Reduction).forward(logits, targetsTensor);
else if (Parent.BertOptions.TaskType == BertTaskType.NameEntityRecognition)
else if (Parent.BertOptions.TaskType == BertTaskType.NamedEntityRecognition)
{
targetsTensor = targetsTensor.@long().view(-1);
logits = logits.view(-1, logits.size(-1));
Expand Down Expand Up @@ -338,7 +338,7 @@ public override SchemaShape GetOutputSchema(SchemaShape inputSchema)
outColumns[Option.ScoreColumnName] = new SchemaShape.Column(Option.ScoreColumnName, SchemaShape.Column.VectorKind.Vector,
NumberDataViewType.Single, false, new SchemaShape(AnnotationUtils.AnnotationsForMulticlassScoreColumn(labelCol)));
}
else if (BertOptions.TaskType == BertTaskType.NameEntityRecognition)
else if (BertOptions.TaskType == BertTaskType.NamedEntityRecognition)
{
var metadata = new List<SchemaShape.Column>();
metadata.Add(new SchemaShape.Column(AnnotationUtils.Kinds.KeyValues, SchemaShape.Column.VectorKind.Vector,
Expand Down Expand Up @@ -387,7 +387,7 @@ private protected override void CheckInputSchema(SchemaShape inputSchema)
TextDataViewType.Instance.ToString(), sentenceCol2.GetTypeString());
}
}
else if (BertOptions.TaskType == BertTaskType.NameEntityRecognition)
else if (BertOptions.TaskType == BertTaskType.NamedEntityRecognition)
{
if (labelCol.ItemType != NumberDataViewType.UInt32)
throw Host.ExceptSchemaMismatch(nameof(inputSchema), "label", Option.LabelColumnName,
Expand Down Expand Up @@ -535,7 +535,7 @@ protected override DataViewSchema.DetachedColumn[] GetOutputColumnsCore()
info[1] = new DataViewSchema.DetachedColumn(Parent.Options.ScoreColumnName, new VectorDataViewType(NumberDataViewType.Single, Parent.Options.NumberOfClasses), meta.ToAnnotations());
return info;
}
else if (Parent.BertOptions.TaskType == BertTaskType.NameEntityRecognition)
else if (Parent.BertOptions.TaskType == BertTaskType.NamedEntityRecognition)
{
var info = new DataViewSchema.DetachedColumn[1];
var keyType = Parent.LabelColumn.Annotations.Schema.GetColumnOrNull(AnnotationUtils.Kinds.KeyValues)?.Type as VectorDataViewType;
Expand Down
8 changes: 4 additions & 4 deletions src/Microsoft.ML.TorchSharp/NasBert/NerTrainer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ namespace Microsoft.ML.TorchSharp.NasBert
/// </summary>
/// <remarks>
/// <format type="text/markdown"><![CDATA[
/// To create this trainer, use [NER](xref:Microsoft.ML.TorchSharpCatalog.NameEntityRecognition(Microsoft.ML.MulticlassClassificationCatalog.MulticlassClassificationTrainers,System.String,System.String,System.String,Int32,Int32,Int32,Microsoft.ML.TorchSharp.NasBert.BertArchitecture,Microsoft.ML.IDataView)).
/// To create this trainer, use [NER](xref:Microsoft.ML.TorchSharpCatalog.NamedEntityRecognition(Microsoft.ML.MulticlassClassificationCatalog.MulticlassClassificationTrainers,System.String,System.String,System.String,Int32,Int32,Int32,Microsoft.ML.TorchSharp.NasBert.BertArchitecture,Microsoft.ML.IDataView)).
///
/// ### Input and Output Columns
/// The input label column data must be a Vector of [string](xref:Microsoft.ML.Data.TextDataViewType) type and the sentence columns must be of type<xref:Microsoft.ML.Data.TextDataViewType>.
Expand All @@ -54,7 +54,7 @@ namespace Microsoft.ML.TorchSharp.NasBert
/// | Exportable to ONNX | No |
///
/// ### Training Algorithm Details
/// Trains a Deep Neural Network(DNN) by leveraging an existing pre-trained NAS-BERT roBERTa model for the purpose of name entity recognition.
/// Trains a Deep Neural Network(DNN) by leveraging an existing pre-trained NAS-BERT roBERTa model for the purpose of named entity recognition.
/// ]]>
/// </format>
/// </remarks>
Expand Down Expand Up @@ -93,7 +93,7 @@ internal NerTrainer(IHostEnvironment env,
BatchSize = batchSize,
MaxEpoch = maxEpochs,
ValidationSet = validationSet,
TaskType = BertTaskType.NameEntityRecognition
TaskType = BertTaskType.NamedEntityRecognition
})
{
}
Expand Down Expand Up @@ -295,7 +295,7 @@ private static NerTransformer Create(IHostEnvironment env, ModelLoadContext ctx)

options.Sentence1ColumnName = ctx.LoadString();
options.Sentence2ColumnName = ctx.LoadStringOrNull();
options.TaskType = BertTaskType.NameEntityRecognition;
options.TaskType = BertTaskType.NamedEntityRecognition;

BinarySaver saver = new BinarySaver(env, new BinarySaver.Arguments());
DataViewType type;
Expand Down
47 changes: 43 additions & 4 deletions src/Microsoft.ML.TorchSharp/TorchSharpCatalog.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

using System;
using System.Collections.Generic;
using System.ComponentModel;
using System.Text;
using Microsoft.ML.Data;
using Microsoft.ML.TorchSharp.AutoFormerV2;
Expand Down Expand Up @@ -161,7 +162,45 @@ public static ObjectDetectionMetrics EvaluateObjectDetection(
}

/// <summary>
/// Fine tune a NAS-BERT model for Name Entity Recognition. The limit for any sentence is 512 tokens. Each word typically
/// Obsolete: please use the <see cref="NamedEntityRecognition(MulticlassClassificationCatalog.MulticlassClassificationTrainers, string, string, string, int, int, BertArchitecture, IDataView)"/> method instead
/// </summary>
/// <param name="catalog">The transform's catalog.</param>
/// <param name="labelColumnName">Name of the label column. Column should be a key type.</param>
/// <param name="outputColumnName">Name of the output column. It will be a key type. It is the predicted label.</param>
/// <param name="sentence1ColumnName">Name of the column for the first sentence.</param>
/// <param name="batchSize">Number of rows in the batch.</param>
/// <param name="maxEpochs">Maximum number of times to loop through your training set.</param>
/// <param name="architecture">Architecture for the model. Defaults to Roberta.</param>
/// <param name="validationSet">The validation set used while training to improve model quality.</param>
/// <returns></returns>
[Obsolete("Please use NamedEntityRecognition method instead", false)]
[EditorBrowsable(EditorBrowsableState.Never)]
public static NerTrainer NameEntityRecognition(
this MulticlassClassificationCatalog.MulticlassClassificationTrainers catalog,
string labelColumnName = DefaultColumnNames.Label,
string outputColumnName = DefaultColumnNames.PredictedLabel,
string sentence1ColumnName = "Sentence",
int batchSize = 32,
int maxEpochs = 10,
BertArchitecture architecture = BertArchitecture.Roberta,
IDataView validationSet = null)
=> NamedEntityRecognition(catalog, labelColumnName, outputColumnName, sentence1ColumnName, batchSize, maxEpochs, architecture, validationSet);

/// <summary>
/// Obsolete: please use the <see cref="NamedEntityRecognition(MulticlassClassificationCatalog.MulticlassClassificationTrainers, NerTrainer.NerOptions)"/> method instead
/// </summary>
/// <param name="catalog">The transform's catalog.</param>
/// <param name="options">The full set of advanced options.</param>
/// <returns></returns>
[Obsolete("Please use NamedEntityRecognition method instead", false)]
[EditorBrowsable(EditorBrowsableState.Never)]
public static NerTrainer NameEntityRecognition(
this MulticlassClassificationCatalog.MulticlassClassificationTrainers catalog,
NerTrainer.NerOptions options)
=> NamedEntityRecognition(catalog, options);

/// <summary>
/// Fine tune a NAS-BERT model for Named Entity Recognition. The limit for any sentence is 512 tokens. Each word typically
/// will map to a single token, and we automatically add 2 specical tokens (a start token and a separator token)
/// so in general this limit will be 510 words for all sentences.
/// </summary>
Expand All @@ -174,7 +213,7 @@ public static ObjectDetectionMetrics EvaluateObjectDetection(
/// <param name="architecture">Architecture for the model. Defaults to Roberta.</param>
/// <param name="validationSet">The validation set used while training to improve model quality.</param>
/// <returns></returns>
public static NerTrainer NameEntityRecognition(
public static NerTrainer NamedEntityRecognition(
this MulticlassClassificationCatalog.MulticlassClassificationTrainers catalog,
string labelColumnName = DefaultColumnNames.Label,
string outputColumnName = DefaultColumnNames.PredictedLabel,
Expand All @@ -186,12 +225,12 @@ public static NerTrainer NameEntityRecognition(
=> new NerTrainer(CatalogUtils.GetEnvironment(catalog), labelColumnName, outputColumnName, sentence1ColumnName, batchSize, maxEpochs, validationSet, architecture);

/// <summary>
/// Fine tune a Name Entity Recognition model.
/// Fine tune a Named Entity Recognition model.
/// </summary>
/// <param name="catalog">The transform's catalog.</param>
/// <param name="options">The full set of advanced options.</param>
/// <returns></returns>
public static NerTrainer NameEntityRecognition(
public static NerTrainer NamedEntityRecognition(
this MulticlassClassificationCatalog.MulticlassClassificationTrainers catalog,
NerTrainer.NerOptions options)
=> new NerTrainer(CatalogUtils.GetEnvironment(catalog), options);
Expand Down
2 changes: 1 addition & 1 deletion test/Microsoft.ML.Tests/NerTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ public void TestSimpleNer()
}));
var chain = new EstimatorChain<ITransformer>();
var estimator = chain.Append(ML.Transforms.Conversion.MapValueToKey("Label", keyData: labels))
.Append(ML.MulticlassClassification.Trainers.NameEntityRecognition(outputColumnName: "outputColumn"))
.Append(ML.MulticlassClassification.Trainers.NamedEntityRecognition(outputColumnName: "outputColumn"))
.Append(ML.Transforms.Conversion.MapKeyToValue("outputColumn"));

var estimatorSchema = estimator.GetOutputSchema(SchemaShape.Create(dataView.Schema));
Expand Down

0 comments on commit a60be5f

Please sign in to comment.