From a60be5f215216d1e1c8065f604d1ed9caa284375 Mon Sep 17 00:00:00 2001 From: Eric StJohn Date: Thu, 21 Dec 2023 10:08:50 -0800 Subject: [PATCH] Rename NameEntity to NamedEntity (#6917) --- .../NasBert/BertTaskType.cs | 6 ++- .../NasBert/NasBertTrainer.cs | 12 ++--- .../NasBert/NerTrainer.cs | 8 ++-- .../TorchSharpCatalog.cs | 47 +++++++++++++++++-- test/Microsoft.ML.Tests/NerTests.cs | 2 +- 5 files changed, 59 insertions(+), 16 deletions(-) diff --git a/src/Microsoft.ML.TorchSharp/NasBert/BertTaskType.cs b/src/Microsoft.ML.TorchSharp/NasBert/BertTaskType.cs index d362be7e3f..dc2597e606 100644 --- a/src/Microsoft.ML.TorchSharp/NasBert/BertTaskType.cs +++ b/src/Microsoft.ML.TorchSharp/NasBert/BertTaskType.cs @@ -4,6 +4,7 @@ using System; using System.Collections.Generic; +using System.ComponentModel; using System.Text; namespace Microsoft.ML.TorchSharp.NasBert @@ -17,7 +18,10 @@ public enum BertTaskType MaskedLM = 1, TextClassification = 2, SentenceRegression = 3, - NameEntityRecognition = 4, + NamedEntityRecognition = 4, + [Obsolete("Please use NamedEntityRecognition instead", false)] + [EditorBrowsable(EditorBrowsableState.Never)] + NameEntityRecognition = NamedEntityRecognition, QuestionAnswering = 5 } } diff --git a/src/Microsoft.ML.TorchSharp/NasBert/NasBertTrainer.cs b/src/Microsoft.ML.TorchSharp/NasBert/NasBertTrainer.cs index 7c044a5c0c..535b228d85 100644 --- a/src/Microsoft.ML.TorchSharp/NasBert/NasBertTrainer.cs +++ b/src/Microsoft.ML.TorchSharp/NasBert/NasBertTrainer.cs @@ -204,7 +204,7 @@ private protected override Module CreateModule(IChannel ch, IDataView input) EnglishRoberta tokenizerModel = Tokenizer.RobertaModel(); NasBertModel model; - if (Parent.BertOptions.TaskType == BertTaskType.NameEntityRecognition) + if (Parent.BertOptions.TaskType == BertTaskType.NamedEntityRecognition) model = new NerModel(Parent.BertOptions, tokenizerModel.PadIndex, tokenizerModel.SymbolsCount, Parent.Option.NumberOfClasses); else model = new ModelForPrediction(Parent.BertOptions, tokenizerModel.PadIndex, tokenizerModel.SymbolsCount, Parent.Option.NumberOfClasses); @@ -268,7 +268,7 @@ private protected override torch.Tensor PrepareRowTensor() private protected override void RunModelAndBackPropagate(ref List inputTensors, ref Tensor targetsTensor) { Tensor logits = default; - if (Parent.BertOptions.TaskType == BertTaskType.NameEntityRecognition) + if (Parent.BertOptions.TaskType == BertTaskType.NamedEntityRecognition) { int[,] lengthArray = new int[inputTensors.Count, 1]; for (int i = 0; i < inputTensors.Count; i++) @@ -293,7 +293,7 @@ private protected override void RunModelAndBackPropagate(ref List inputT torch.Tensor loss; if (Parent.BertOptions.TaskType == BertTaskType.TextClassification) loss = torch.nn.CrossEntropyLoss(reduction: Parent.BertOptions.Reduction).forward(logits, targetsTensor); - else if (Parent.BertOptions.TaskType == BertTaskType.NameEntityRecognition) + else if (Parent.BertOptions.TaskType == BertTaskType.NamedEntityRecognition) { targetsTensor = targetsTensor.@long().view(-1); logits = logits.view(-1, logits.size(-1)); @@ -338,7 +338,7 @@ public override SchemaShape GetOutputSchema(SchemaShape inputSchema) outColumns[Option.ScoreColumnName] = new SchemaShape.Column(Option.ScoreColumnName, SchemaShape.Column.VectorKind.Vector, NumberDataViewType.Single, false, new SchemaShape(AnnotationUtils.AnnotationsForMulticlassScoreColumn(labelCol))); } - else if (BertOptions.TaskType == BertTaskType.NameEntityRecognition) + else if (BertOptions.TaskType == BertTaskType.NamedEntityRecognition) { var metadata = new List(); metadata.Add(new SchemaShape.Column(AnnotationUtils.Kinds.KeyValues, SchemaShape.Column.VectorKind.Vector, @@ -387,7 +387,7 @@ private protected override void CheckInputSchema(SchemaShape inputSchema) TextDataViewType.Instance.ToString(), sentenceCol2.GetTypeString()); } } - else if (BertOptions.TaskType == BertTaskType.NameEntityRecognition) + else if (BertOptions.TaskType == BertTaskType.NamedEntityRecognition) { if (labelCol.ItemType != NumberDataViewType.UInt32) throw Host.ExceptSchemaMismatch(nameof(inputSchema), "label", Option.LabelColumnName, @@ -535,7 +535,7 @@ protected override DataViewSchema.DetachedColumn[] GetOutputColumnsCore() info[1] = new DataViewSchema.DetachedColumn(Parent.Options.ScoreColumnName, new VectorDataViewType(NumberDataViewType.Single, Parent.Options.NumberOfClasses), meta.ToAnnotations()); return info; } - else if (Parent.BertOptions.TaskType == BertTaskType.NameEntityRecognition) + else if (Parent.BertOptions.TaskType == BertTaskType.NamedEntityRecognition) { var info = new DataViewSchema.DetachedColumn[1]; var keyType = Parent.LabelColumn.Annotations.Schema.GetColumnOrNull(AnnotationUtils.Kinds.KeyValues)?.Type as VectorDataViewType; diff --git a/src/Microsoft.ML.TorchSharp/NasBert/NerTrainer.cs b/src/Microsoft.ML.TorchSharp/NasBert/NerTrainer.cs index 71cdcdef02..7f9c6901a9 100644 --- a/src/Microsoft.ML.TorchSharp/NasBert/NerTrainer.cs +++ b/src/Microsoft.ML.TorchSharp/NasBert/NerTrainer.cs @@ -35,7 +35,7 @@ namespace Microsoft.ML.TorchSharp.NasBert /// /// /// . @@ -54,7 +54,7 @@ namespace Microsoft.ML.TorchSharp.NasBert /// | Exportable to ONNX | No | /// /// ### Training Algorithm Details - /// Trains a Deep Neural Network(DNN) by leveraging an existing pre-trained NAS-BERT roBERTa model for the purpose of name entity recognition. + /// Trains a Deep Neural Network(DNN) by leveraging an existing pre-trained NAS-BERT roBERTa model for the purpose of named entity recognition. /// ]]> /// /// @@ -93,7 +93,7 @@ internal NerTrainer(IHostEnvironment env, BatchSize = batchSize, MaxEpoch = maxEpochs, ValidationSet = validationSet, - TaskType = BertTaskType.NameEntityRecognition + TaskType = BertTaskType.NamedEntityRecognition }) { } @@ -295,7 +295,7 @@ private static NerTransformer Create(IHostEnvironment env, ModelLoadContext ctx) options.Sentence1ColumnName = ctx.LoadString(); options.Sentence2ColumnName = ctx.LoadStringOrNull(); - options.TaskType = BertTaskType.NameEntityRecognition; + options.TaskType = BertTaskType.NamedEntityRecognition; BinarySaver saver = new BinarySaver(env, new BinarySaver.Arguments()); DataViewType type; diff --git a/src/Microsoft.ML.TorchSharp/TorchSharpCatalog.cs b/src/Microsoft.ML.TorchSharp/TorchSharpCatalog.cs index a0d9970ea1..7451f7b82b 100644 --- a/src/Microsoft.ML.TorchSharp/TorchSharpCatalog.cs +++ b/src/Microsoft.ML.TorchSharp/TorchSharpCatalog.cs @@ -4,6 +4,7 @@ using System; using System.Collections.Generic; +using System.ComponentModel; using System.Text; using Microsoft.ML.Data; using Microsoft.ML.TorchSharp.AutoFormerV2; @@ -161,7 +162,45 @@ public static ObjectDetectionMetrics EvaluateObjectDetection( } /// - /// Fine tune a NAS-BERT model for Name Entity Recognition. The limit for any sentence is 512 tokens. Each word typically + /// Obsolete: please use the method instead + /// + /// The transform's catalog. + /// Name of the label column. Column should be a key type. + /// Name of the output column. It will be a key type. It is the predicted label. + /// Name of the column for the first sentence. + /// Number of rows in the batch. + /// Maximum number of times to loop through your training set. + /// Architecture for the model. Defaults to Roberta. + /// The validation set used while training to improve model quality. + /// + [Obsolete("Please use NamedEntityRecognition method instead", false)] + [EditorBrowsable(EditorBrowsableState.Never)] + public static NerTrainer NameEntityRecognition( + this MulticlassClassificationCatalog.MulticlassClassificationTrainers catalog, + string labelColumnName = DefaultColumnNames.Label, + string outputColumnName = DefaultColumnNames.PredictedLabel, + string sentence1ColumnName = "Sentence", + int batchSize = 32, + int maxEpochs = 10, + BertArchitecture architecture = BertArchitecture.Roberta, + IDataView validationSet = null) + => NamedEntityRecognition(catalog, labelColumnName, outputColumnName, sentence1ColumnName, batchSize, maxEpochs, architecture, validationSet); + + /// + /// Obsolete: please use the method instead + /// + /// The transform's catalog. + /// The full set of advanced options. + /// + [Obsolete("Please use NamedEntityRecognition method instead", false)] + [EditorBrowsable(EditorBrowsableState.Never)] + public static NerTrainer NameEntityRecognition( + this MulticlassClassificationCatalog.MulticlassClassificationTrainers catalog, + NerTrainer.NerOptions options) + => NamedEntityRecognition(catalog, options); + + /// + /// Fine tune a NAS-BERT model for Named Entity Recognition. The limit for any sentence is 512 tokens. Each word typically /// will map to a single token, and we automatically add 2 specical tokens (a start token and a separator token) /// so in general this limit will be 510 words for all sentences. /// @@ -174,7 +213,7 @@ public static ObjectDetectionMetrics EvaluateObjectDetection( /// Architecture for the model. Defaults to Roberta. /// The validation set used while training to improve model quality. /// - public static NerTrainer NameEntityRecognition( + public static NerTrainer NamedEntityRecognition( this MulticlassClassificationCatalog.MulticlassClassificationTrainers catalog, string labelColumnName = DefaultColumnNames.Label, string outputColumnName = DefaultColumnNames.PredictedLabel, @@ -186,12 +225,12 @@ public static NerTrainer NameEntityRecognition( => new NerTrainer(CatalogUtils.GetEnvironment(catalog), labelColumnName, outputColumnName, sentence1ColumnName, batchSize, maxEpochs, validationSet, architecture); /// - /// Fine tune a Name Entity Recognition model. + /// Fine tune a Named Entity Recognition model. /// /// The transform's catalog. /// The full set of advanced options. /// - public static NerTrainer NameEntityRecognition( + public static NerTrainer NamedEntityRecognition( this MulticlassClassificationCatalog.MulticlassClassificationTrainers catalog, NerTrainer.NerOptions options) => new NerTrainer(CatalogUtils.GetEnvironment(catalog), options); diff --git a/test/Microsoft.ML.Tests/NerTests.cs b/test/Microsoft.ML.Tests/NerTests.cs index a8aee92f87..4e402a2132 100644 --- a/test/Microsoft.ML.Tests/NerTests.cs +++ b/test/Microsoft.ML.Tests/NerTests.cs @@ -54,7 +54,7 @@ public void TestSimpleNer() })); var chain = new EstimatorChain(); var estimator = chain.Append(ML.Transforms.Conversion.MapValueToKey("Label", keyData: labels)) - .Append(ML.MulticlassClassification.Trainers.NameEntityRecognition(outputColumnName: "outputColumn")) + .Append(ML.MulticlassClassification.Trainers.NamedEntityRecognition(outputColumnName: "outputColumn")) .Append(ML.Transforms.Conversion.MapKeyToValue("outputColumn")); var estimatorSchema = estimator.GetOutputSchema(SchemaShape.Create(dataView.Schema));