From 9248f5fd580fae528c8bb83145adeb10c298483b Mon Sep 17 00:00:00 2001
From: Farrukh Masud
Date: Wed, 29 Jan 2025 13:11:21 -0800
Subject: [PATCH 1/8] Adding capability use Cognitive Service Language Service
asynchronously. The transformer calls the async service and poll for result.
The polling delay and max retry attempts is controlled by parameters. Request
creation for each task is extracted into separate trait to make code more
readable and manageable. There has been minimal changes in AnalyzeText class.
---
.../ml/services/language/AnalyzeText.scala | 100 +--
.../language/AnalyzeTextJobSchema.scala | 655 ++++++++++++++++++
.../AnalyzeTextJobServiceTraits.scala | 537 ++++++++++++++
.../AnalyzeTextLongRunningOperations.scala | 230 ++++++
.../language/AnalyzeTextLROSuite.scala | 630 +++++++++++++++++
5 files changed, 2110 insertions(+), 42 deletions(-)
create mode 100644 cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/language/AnalyzeTextJobSchema.scala
create mode 100644 cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/language/AnalyzeTextJobServiceTraits.scala
create mode 100644 cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/language/AnalyzeTextLongRunningOperations.scala
create mode 100644 cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/language/AnalyzeTextLROSuite.scala
diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/language/AnalyzeText.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/language/AnalyzeText.scala
index ecc7228406..d357765789 100644
--- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/language/AnalyzeText.scala
+++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/language/AnalyzeText.scala
@@ -3,33 +3,32 @@
package com.microsoft.azure.synapse.ml.services.language
-import com.microsoft.azure.synapse.ml.services._
-import com.microsoft.azure.synapse.ml.services.text.{TADocument, TextAnalyticsAutoBatch}
-import com.microsoft.azure.synapse.ml.logging.{FeatureNames, SynapseMLLogging}
+import com.microsoft.azure.synapse.ml.logging.{ FeatureNames, SynapseMLLogging }
import com.microsoft.azure.synapse.ml.param.ServiceParam
-import com.microsoft.azure.synapse.ml.stages.{FixedMiniBatchTransformer, FlattenBatch, HasBatchSize, UDFTransformer}
-import org.apache.http.entity.{AbstractHttpEntity, StringEntity}
+import com.microsoft.azure.synapse.ml.services._
+import com.microsoft.azure.synapse.ml.services.text.{ TADocument, TextAnalyticsAutoBatch }
+import com.microsoft.azure.synapse.ml.stages.{ FixedMiniBatchTransformer, FlattenBatch, HasBatchSize, UDFTransformer }
+import org.apache.http.entity.{ AbstractHttpEntity, StringEntity }
import org.apache.spark.injections.UDFUtils
-import org.apache.spark.ml.param.{Param, ParamValidators}
+import org.apache.spark.ml.{ ComplexParamsReadable, NamespaceInjections, PipelineModel }
+import org.apache.spark.ml.param.{ Param, ParamValidators }
import org.apache.spark.ml.util.Identifiable
-import org.apache.spark.ml.{ComplexParamsReadable, NamespaceInjections, PipelineModel}
import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.UserDefinedFunction
-import org.apache.spark.sql.types.{ArrayType, DataType, StructType}
-import spray.json.DefaultJsonProtocol._
+import org.apache.spark.sql.types.{ ArrayType, DataType, StructType }
import spray.json._
+import spray.json.DefaultJsonProtocol._
-trait AnalyzeTextTaskParameters extends HasServiceParams {
+trait HasAnalyzeTextServiceBaseParams extends HasServiceParams {
val modelVersion = new ServiceParam[String](
this, name = "modelVersion", "Version of the model")
+ def getModelVersion: String = getScalarParam(modelVersion)
def setModelVersion(v: String): this.type = setScalarParam(modelVersion, v)
-
+ def getModelVersionCol: String = getVectorParam(modelVersion)
def setModelVersionCol(v: String): this.type = setVectorParam(modelVersion, v)
- def getModelVersion: String = getScalarParam(modelVersion)
- def getModelVersionCol: String = getVectorParam(modelVersion)
val loggingOptOut = new ServiceParam[Boolean](
this, "loggingOptOut", "loggingOptOut for task"
@@ -44,13 +43,15 @@ trait AnalyzeTextTaskParameters extends HasServiceParams {
def getLoggingOptOutCol: String = getVectorParam(loggingOptOut)
val stringIndexType = new ServiceParam[String](this, "stringIndexType",
- "Specifies the method used to interpret string offsets. " +
- "Defaults to Text Elements (Graphemes) according to Unicode v8.0.0. " +
- "For additional information see https://aka.ms/text-analytics-offsets",
- isValid = {
- case Left(s) => Set("TextElements_v8", "UnicodeCodePoint", "Utf16CodeUnit")(s)
- case _ => true
- })
+ "Specifies the method used to interpret string offsets. " +
+ "Defaults to Text Elements(Graphemes) according to Unicode v8.0.0." +
+ "For more information see https://aka.ms/text-analytics-offsets",
+ isValid = {
+ case Left(s) => Set("TextElements_v8",
+ "UnicodeCodePoint",
+ "Utf16CodeUnit")(s)
+ case _ => true
+ })
def setStringIndexType(v: String): this.type = setScalarParam(stringIndexType, v)
@@ -60,6 +61,36 @@ trait AnalyzeTextTaskParameters extends HasServiceParams {
def getStringIndexTypeCol: String = getVectorParam(stringIndexType)
+ val showStats = new ServiceParam[Boolean](
+ this, name = "showStats", "Whether to include detailed statistics in the response",
+ isURLParam = true)
+
+ def setShowStats(v: Boolean): this.type = setScalarParam(showStats, v)
+
+ def getShowStats: Boolean = getScalarParam(showStats)
+
+ // We don't support setKindCol here because output schemas for different kind are different
+ val kind = new Param[String](
+ this, "kind", "Enumeration of supported Text Analysis tasks",
+ isValid = validKinds.contains(_)
+ )
+
+ protected def validKinds: Set[String]
+
+ def setKind(v: String): this.type = set(kind, v)
+
+ def getKind: String = $(kind)
+
+ setDefault(
+ showStats -> Left(false),
+ modelVersion -> Left("latest"),
+ loggingOptOut -> Left(false),
+ stringIndexType -> Left("TextElements_v8")
+ )
+}
+
+
+trait AnalyzeTextTaskParameters extends HasAnalyzeTextServiceBaseParams {
val opinionMining = new ServiceParam[Boolean](
this, name = "opinionMining", "opinionMining option for SentimentAnalysisTask")
@@ -98,9 +129,6 @@ trait AnalyzeTextTaskParameters extends HasServiceParams {
def getPiiCategoriesCol: String = getVectorParam(piiCategories)
setDefault(
- modelVersion -> Left("latest"),
- loggingOptOut -> Left(false),
- stringIndexType -> Left("TextElements_v8"),
opinionMining -> Left(false),
domain -> Left("none")
)
@@ -131,33 +159,21 @@ class AnalyzeText(override val uid: String) extends CognitiveServicesBase(uid)
def this() = this(Identifiable.randomUID("AnalyzeText"))
- val showStats = new ServiceParam[Boolean](
- this, name = "showStats", "Whether to include detailed statistics in the response",
- isURLParam = true)
-
- def setShowStats(v: Boolean): this.type = setScalarParam(showStats, v)
-
- def getShowStats: Boolean = getScalarParam(showStats)
+ override protected def validKinds: Set[String] = Set("EntityLinking",
+ "EntityRecognition",
+ "KeyPhraseExtraction",
+ "LanguageDetection",
+ "PiiEntityRecognition",
+ "SentimentAnalysis")
setDefault(
- apiVersion -> Left("2022-05-01"),
- showStats -> Left(false)
+ apiVersion -> Left("2022-05-01")
)
override def urlPath: String = "/language/:analyze-text"
override private[ml] def internalServiceType: String = "textanalytics"
- // We don't support setKindCol here because output schemas for different kind are different
- val kind = new Param[String](
- this, "kind", "Enumeration of supported Text Analysis tasks",
- isValid = ParamValidators.inArray(Array("EntityLinking", "EntityRecognition", "KeyPhraseExtraction",
- "LanguageDetection", "PiiEntityRecognition", "SentimentAnalysis"))
- )
-
- def setKind(v: String): this.type = set(kind, v)
-
- def getKind: String = $(kind)
override protected def shouldSkip(row: Row): Boolean = if (emptyParamData(row, text)) {
true
diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/language/AnalyzeTextJobSchema.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/language/AnalyzeTextJobSchema.scala
new file mode 100644
index 0000000000..a84d2f3dcc
--- /dev/null
+++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/language/AnalyzeTextJobSchema.scala
@@ -0,0 +1,655 @@
+package com.microsoft.azure.synapse.ml.services.language
+
+import com.microsoft.azure.synapse.ml.core.schema.SparkBindings
+import spray.json. RootJsonFormat
+
+case class DocumentWarning(code: String,
+ message: String,
+ targetRef: Option[String])
+
+object DocumentWarning extends SparkBindings[DocumentWarning]
+
+
+case class SummaryContext(offset: Int,
+ length: Int)
+
+object SummaryContext extends SparkBindings[SummaryContext]
+
+//------------------------------------------------------------------------------------------------------
+// Extractive Summarization
+//------------------------------------------------------------------------------------------------------
+case class ExtractiveSummarizationTaskParameters(loggingOptOut: Boolean,
+ modelVersion: String,
+ sentenceCount: Option[Int],
+ sortBy: Option[String],
+ stringIndexType: String)
+
+case class ExtractiveSummarizationLROTask(parameters: ExtractiveSummarizationTaskParameters,
+ taskName: Option[String],
+ kind: String)
+
+case class ExtractiveSummarizationJobsInput(displayName: Option[String],
+ analysisInput: MultiLanguageAnalysisInput,
+ tasks: Seq[ExtractiveSummarizationLROTask])
+
+case class ExtractedSummarySentence(text: String,
+ rankScore: Double,
+ offset: Int,
+ length: Int)
+
+object ExtractedSummarySentence extends SparkBindings[ExtractedSummarySentence]
+
+case class ExtractedSummaryDocumentResult(id: String,
+ warnings: Seq[DocumentWarning],
+ statistics: Option[RequestStatistics],
+ sentences: Seq[ExtractedSummarySentence])
+
+object ExtractedSummaryDocumentResult extends SparkBindings[ExtractedSummaryDocumentResult]
+
+case class ExtractiveSummarizationResult(errors: Seq[ATError],
+ statistics: Option[RequestStatistics],
+ modelVersion: String,
+ documents: Seq[ExtractedSummaryDocumentResult])
+
+object ExtractiveSummarizationResult extends SparkBindings[ExtractiveSummarizationResult]
+
+case class ExtractiveSummarizationLROResult(results: ExtractiveSummarizationResult,
+ lastUpdateDateTime: String,
+ status: String,
+ taskName: Option[String],
+ kind: String)
+
+object ExtractiveSummarizationLROResult extends SparkBindings[ExtractiveSummarizationLROResult]
+
+case class ExtractiveSummarizationTaskResult(completed: Int,
+ failed: Int,
+ inProgress: Int,
+ total: Int,
+ items: Option[Seq[ExtractiveSummarizationLROResult]])
+
+object ExtractiveSummarizationTaskResult extends SparkBindings[ExtractiveSummarizationTaskResult]
+
+case class ExtractiveSummarizationJobState(displayName: Option[String],
+ createdDateTime: String,
+ expirationDateTime: Option[String],
+ jobId: String,
+ lastUpdatedDateTime: String,
+ status: String,
+ errors: Option[Seq[String]],
+ nextLink: Option[String],
+ tasks: ExtractiveSummarizationTaskResult,
+ statistics: Option[RequestStatistics])
+
+object ExtractiveSummarizationJobState extends SparkBindings[ExtractiveSummarizationJobState]
+//------------------------------------------------------------------------------------------------------
+
+//------------------------------------------------------------------------------------------------------
+// Abstractive Summarization
+//------------------------------------------------------------------------------------------------------
+object SummaryLength extends Enumeration {
+ type SummaryLength = Value
+ val Short, Medium, Long = Value
+}
+
+case class AbstractiveSummarizationTaskParameters(loggingOptOut: Boolean,
+ modelVersion: String,
+ sentenceCount: Option[Int],
+ stringIndexType: String,
+ summaryLength: Option[String])
+
+case class AbstractiveSummarizationLROTask(parameters: AbstractiveSummarizationTaskParameters,
+ taskName: Option[String],
+ kind: String)
+
+case class AbstractiveSummarizationJobsInput(displayName: Option[String],
+ analysisInput: MultiLanguageAnalysisInput,
+ tasks: Seq[AbstractiveSummarizationLROTask])
+
+case class AbstractiveSummary(text: String,
+ contexts: Option[Seq[SummaryContext]])
+
+object AbstractiveSummary extends SparkBindings[AbstractiveSummary]
+
+case class AbstractiveSummaryDocumentResult(id: String,
+ warnings: Seq[DocumentWarning],
+ statistics: Option[RequestStatistics],
+ summaries: Seq[AbstractiveSummary])
+
+object AbstractiveSummaryDocumentResult extends SparkBindings[AbstractiveSummaryDocumentResult]
+
+case class AbstractiveSummarizationResult(errors: Seq[ATError],
+ statistics: Option[RequestStatistics],
+ modelVersion: String,
+ documents: Seq[AbstractiveSummaryDocumentResult])
+
+object AbstractiveSummarizationResult extends SparkBindings[AbstractiveSummarizationResult]
+
+case class AbstractiveSummarizationLROResult(results: AbstractiveSummarizationResult,
+ lastUpdateDateTime: String,
+ status: String,
+ taskName: Option[String],
+ kind: String)
+
+object AbstractiveSummarizationLROResult extends SparkBindings[AbstractiveSummarizationLROResult]
+
+
+case class AbstractiveSummarizationTaskResult(completed: Int,
+ failed: Int,
+ inProgress: Int,
+ total: Int,
+ items: Option[Seq[AbstractiveSummarizationLROResult]])
+
+object AbstractiveSummarizationTaskResult extends SparkBindings[AbstractiveSummarizationTaskResult]
+
+case class AbstractiveSummarizationJobState(displayName: Option[String],
+ createdDateTime: String,
+ expirationDateTime: Option[String],
+ jobId: String,
+ lastUpdatedDateTime: String,
+ status: String,
+ errors: Option[Seq[String]],
+ nextLink: Option[String],
+ tasks: AbstractiveSummarizationTaskResult,
+ statistics: Option[RequestStatistics])
+
+object AbstractiveSummarizationJobState extends SparkBindings[AbstractiveSummarizationJobState]
+//------------------------------------------------------------------------------------------------------
+
+//------------------------------------------------------------------------------------------------------
+// HealthCare
+//------------------------------------------------------------------------------------------------------
+case class HealthcareTaskParameters(loggingOptOut: Boolean,
+ modelVersion: String,
+ stringIndexType: String)
+
+case class HealthcareLROTask(parameters: HealthcareTaskParameters,
+ taskName: Option[String],
+ kind: String)
+
+case class HealthcareJobsInput(displayName: Option[String],
+ analysisInput: MultiLanguageAnalysisInput,
+ tasks: Seq[HealthcareLROTask])
+
+case class HealthcareAssertion(conditionality: Option[String],
+ certainty: Option[String],
+ association: Option[String],
+ temporality: Option[String])
+
+object HealthcareAssertion extends SparkBindings[HealthcareAssertion]
+
+case class HealthcareEntitiesDocumentResult(id: String,
+ warnings: Seq[DocumentWarning],
+ statistics: Option[RequestStatistics],
+ entities: Seq[HealthcareEntity],
+ relations: Seq[HealthcareRelation],
+ fhirBundle: Option[String])
+
+object HealthcareEntitiesDocumentResult extends SparkBindings[HealthcareEntitiesDocumentResult]
+
+case class HealthcareEntity(text: String,
+ category: String,
+ subcategory: Option[String],
+ offset: Int,
+ length: Int,
+ confidenceScore: Double,
+ assertion: Option[HealthcareAssertion],
+ name: Option[String],
+ links: Option[Seq[HealthcareEntityLink]])
+
+object HealthcareEntity extends SparkBindings[HealthcareEntity]
+
+case class HealthcareEntityLink(dataSource: String,
+ id: String)
+
+object HealthcareEntityLink extends SparkBindings[HealthcareEntityLink]
+
+case class HealthcareLROResult(results: HealthcareResult,
+ lastUpdateDateTime: String,
+ status: String,
+ taskName: Option[String],
+ kind: String)
+
+object HealthcareLROResult extends SparkBindings[HealthcareLROResult]
+
+
+case class HealthcareRelation(relationType: String,
+ entities: Seq[HealthcareRelationEntity],
+ confidenceScore: Option[Double])
+
+object HealthcareRelation extends SparkBindings[HealthcareRelation]
+
+case class HealthcareRelationEntity(ref: String,
+ role: String)
+
+object HealthcareRelationEntity extends SparkBindings[HealthcareRelationEntity]
+
+case class HealthcareResult(errors: Seq[DocumentError],
+ statistics: Option[RequestStatistics],
+ modelVersion: String,
+ documents: Seq[HealthcareEntitiesDocumentResult])
+
+object HealthcareResult extends SparkBindings[HealthcareResult]
+
+case class HealthcareTaskResult(completed: Int,
+ failed: Int,
+ inProgress: Int,
+ total: Int,
+ items: Option[Seq[HealthcareLROResult]])
+
+object HealthcareTaskResult extends SparkBindings[HealthcareTaskResult]
+
+case class HealthcareJobState(displayName: Option[String],
+ createdDateTime: String,
+ expirationDateTime: Option[String],
+ jobId: String,
+ lastUpdatedDateTime: String,
+ status: String,
+ errors: Option[Seq[String]],
+ nextLink: Option[String],
+ tasks: HealthcareTaskResult,
+ statistics: Option[RequestStatistics])
+
+object HealthcareJobState extends SparkBindings[HealthcareJobState]
+
+//------------------------------------------------------------------------------------------------------
+
+//------------------------------------------------------------------------------------------------------
+// Sentiment Analysis
+//------------------------------------------------------------------------------------------------------
+case class SentimentAnalysisLROTask(parameters: SentimentAnalysisTaskParameters,
+ taskName: Option[String],
+ kind: String)
+
+case class SentimentAnalysisJobsInput(displayName: Option[String],
+ analysisInput: MultiLanguageAnalysisInput,
+ tasks: Seq[SentimentAnalysisLROTask])
+
+case class SentimentAnalysisLROResult(results: SentimentResult,
+ lastUpdateDateTime: String,
+ status: String,
+ taskName: Option[String],
+ kind: String)
+
+object SentimentAnalysisLROResult extends SparkBindings[SentimentAnalysisLROResult]
+
+case class SentimentAnalysisTaskResult(completed: Int,
+ failed: Int,
+ inProgress: Int,
+ total: Int,
+ items: Option[Seq[SentimentAnalysisLROResult]])
+
+object SentimentAnalysisTaskResult extends SparkBindings[SentimentAnalysisTaskResult]
+
+case class SentimentAnalysisJobState(displayName: Option[String],
+ createdDateTime: String,
+ expirationDateTime: Option[String],
+ jobId: String,
+ lastUpdatedDateTime: String,
+ status: String,
+ errors: Option[Seq[String]],
+ nextLink: Option[String],
+ tasks: SentimentAnalysisTaskResult,
+ statistics: Option[RequestStatistics])
+
+object SentimentAnalysisJobState extends SparkBindings[SentimentAnalysisJobState]
+//------------------------------------------------------------------------------------------------------
+
+//------------------------------------------------------------------------------------------------------
+// Key Phrase Extraction
+//------------------------------------------------------------------------------------------------------
+case class KeyPhraseExtractionLROTask(parameters: KPnLDTaskParameters,
+ taskName: Option[String],
+ kind: String)
+
+case class KeyPhraseExtractionJobsInput(displayName: Option[String],
+ analysisInput: MultiLanguageAnalysisInput,
+ tasks: Seq[KeyPhraseExtractionLROTask])
+
+case class KeyPhraseExtractionLROResult(results: KeyPhraseExtractionResult,
+ lastUpdateDateTime: String,
+ status: String,
+ taskName: Option[String],
+ kind: String)
+
+object KeyPhraseExtractionLROResult extends SparkBindings[KeyPhraseExtractionLROResult]
+
+case class KeyPhraseExtractionTaskResult(completed: Int,
+ failed: Int,
+ inProgress: Int,
+ total: Int,
+ items: Option[Seq[KeyPhraseExtractionLROResult]])
+
+object KeyPhraseExtractionTaskResult extends SparkBindings[KeyPhraseExtractionTaskResult]
+
+case class KeyPhraseExtractionJobState(displayName: Option[String],
+ createdDateTime: String,
+ expirationDateTime: Option[String],
+ jobId: String,
+ lastUpdatedDateTime: String,
+ status: String,
+ errors: Option[Seq[String]],
+ nextLink: Option[String],
+ tasks: KeyPhraseExtractionTaskResult,
+ statistics: Option[RequestStatistics])
+
+object KeyPhraseExtractionJobState extends SparkBindings[KeyPhraseExtractionJobState]
+//------------------------------------------------------------------------------------------------------
+
+//------------------------------------------------------------------------------------------------------
+// PII Entity Recognition
+//------------------------------------------------------------------------------------------------------
+object PiiDomain extends Enumeration {
+ type PiiDomain = Value
+ val None, Phi = Value
+}
+
+case class PiiEntityRecognitionLROTask(parameters: PiiTaskParameters,
+ taskName: Option[String],
+ kind: String)
+
+case class PiiEntityRecognitionJobsInput(displayName: Option[String],
+ analysisInput: MultiLanguageAnalysisInput,
+ tasks: Seq[PiiEntityRecognitionLROTask])
+
+
+case class PiiEntityRecognitionLROResult(results: PIIResult,
+ lastUpdateDateTime: String,
+ status: String,
+ taskName: Option[String],
+ kind: String)
+
+object PiiEntityRecognitionLROResult extends SparkBindings[PiiEntityRecognitionLROResult]
+
+case class PiiEntityRecognitionTaskResult(completed: Int,
+ failed: Int,
+ inProgress: Int,
+ total: Int,
+ items: Option[Seq[PiiEntityRecognitionLROResult]])
+
+object PiiEntityRecognitionTaskResult extends SparkBindings[PiiEntityRecognitionTaskResult]
+
+case class PiiEntityRecognitionJobState(displayName: Option[String],
+ createdDateTime: String,
+ expirationDateTime: Option[String],
+ jobId: String,
+ lastUpdatedDateTime: String,
+ status: String,
+ errors: Option[Seq[String]],
+ nextLink: Option[String],
+ tasks: PiiEntityRecognitionTaskResult,
+ statistics: Option[RequestStatistics])
+
+object PiiEntityRecognitionJobState extends SparkBindings[PiiEntityRecognitionJobState]
+//------------------------------------------------------------------------------------------------------
+
+//------------------------------------------------------------------------------------------------------
+// Entity Linking
+//------------------------------------------------------------------------------------------------------
+case class EntityLinkingLROTask(parameters: EntityTaskParameters,
+ taskName: Option[String],
+ kind: String)
+
+
+case class EntityLinkingJobsInput(displayName: Option[String],
+ analysisInput: MultiLanguageAnalysisInput,
+ tasks: Seq[EntityLinkingLROTask])
+
+
+case class EntityLinkingLROResult(results: EntityLinkingResult,
+ lastUpdateDateTime: String,
+ status: String,
+ taskName: Option[String],
+ kind: String)
+
+object EntityLinkingLROResult extends SparkBindings[EntityLinkingLROResult]
+
+case class EntityLinkingTaskResult(completed: Int,
+ failed: Int,
+ inProgress: Int,
+ total: Int,
+ items: Option[Seq[EntityLinkingLROResult]])
+
+object EntityLinkingTaskResult extends SparkBindings[EntityLinkingTaskResult]
+
+case class EntityLinkingJobState(displayName: Option[String],
+ createdDateTime: String,
+ expirationDateTime: Option[String],
+ jobId: String,
+ lastUpdatedDateTime: String,
+ status: String,
+ errors: Option[Seq[String]],
+ nextLink: Option[String],
+ tasks: EntityLinkingTaskResult,
+ statistics: Option[RequestStatistics])
+
+object EntityLinkingJobState extends SparkBindings[EntityLinkingJobState]
+//------------------------------------------------------------------------------------------------------
+
+//------------------------------------------------------------------------------------------------------
+// Entity Recognition
+//------------------------------------------------------------------------------------------------------
+
+case class EntityRecognitionTaskParameters(loggingOptOut: Boolean,
+ modelVersion: String,
+ stringIndexType: String,
+ inclusionList: Option[Seq[String]],
+ exclusionList: Option[Seq[String]],
+ overlapPolicy: Option[EntityOverlapPolicy],
+ inferenceOptions: Option[EntityInferenceOptions])
+
+case class EntityOverlapPolicy(policyKind: String)
+
+case class EntityInferenceOptions(excludeNormalizedValues: Boolean)
+
+case class EntityRecognitionLROTask(parameters: EntityRecognitionTaskParameters,
+ taskName: Option[String],
+ kind: String)
+
+case class EntityRecognitionJobsInput(displayName: Option[String],
+ analysisInput: MultiLanguageAnalysisInput,
+ tasks: Seq[EntityRecognitionLROTask])
+
+case class EntityRecognitionLROResult(results: EntityRecognitionResult,
+ lastUpdateDateTime: String,
+ status: String,
+ taskName: Option[String],
+ kind: String)
+
+object EntityRecognitionLROResult extends SparkBindings[EntityRecognitionLROResult]
+
+case class EntityRecognitionTaskResult(completed: Int,
+ failed: Int,
+ inProgress: Int,
+ total: Int,
+ items: Option[Seq[EntityRecognitionLROResult]])
+
+object EntityRecognitionTaskResult extends SparkBindings[EntityRecognitionTaskResult]
+
+case class EntityRecognitionJobState(displayName: Option[String],
+ createdDateTime: String,
+ expirationDateTime: Option[String],
+ jobId: String,
+ lastUpdatedDateTime: String,
+ status: String,
+ errors: Option[Seq[String]],
+ nextLink: Option[String],
+ tasks: EntityRecognitionTaskResult,
+ statistics: Option[RequestStatistics])
+
+object EntityRecognitionJobState extends SparkBindings[EntityRecognitionJobState]
+//------------------------------------------------------------------------------------------------------
+
+//------------------------------------------------------------------------------------------------------
+// Custom Entity Recognoition
+//------------------------------------------------------------------------------------------------------
+case class CustomEntitiesTaskParameters(loggingOptOut: Boolean,
+ stringIndexType: String,
+ deploymentName: String,
+ projectName: String)
+
+case class CustomEntityRecognitionLROTask(parameters: CustomEntitiesTaskParameters,
+ taskName: Option[String],
+ kind: String)
+
+case class CustomEntitiesJobsInput(displayName: Option[String],
+ analysisInput: MultiLanguageAnalysisInput,
+ tasks: Seq[CustomEntityRecognitionLROTask])
+//------------------------------------------------------------------------------------------------------
+
+//------------------------------------------------------------------------------------------------------
+// Custom Label Classification
+//------------------------------------------------------------------------------------------------------
+case class CustomLabelTaskParameters(loggingOptOut: Boolean,
+ deploymentName: String,
+ projectName: String)
+
+case class CustomLabelLROTask(parameters: CustomLabelTaskParameters,
+ taskName: Option[String],
+ kind: String)
+
+case class CustomLabelJobsInput(displayName: Option[String],
+ analysisInput: MultiLanguageAnalysisInput,
+ tasks: Seq[CustomLabelLROTask])
+
+case class ClassificationDocumentResult(id: String,
+ warnings: Seq[DocumentWarning],
+ statistics: Option[RequestStatistics],
+ classes: Seq[ClassificationResult])
+
+object ClassificationDocumentResult extends SparkBindings[ClassificationDocumentResult]
+
+case class ClassificationResult(category: String,
+ confidenceScore: Double)
+
+object ClassificationResult extends SparkBindings[ClassificationResult]
+
+case class CustomLabelResult(errors: Seq[DocumentError],
+ statistics: Option[RequestStatistics],
+ modelVersion: String,
+ documents: Seq[ClassificationDocumentResult])
+
+object CustomLabelResult extends SparkBindings[CustomLabelResult]
+
+case class CustomLabelLROResult(results: CustomLabelResult,
+ lastUpdateDateTime: String,
+ status: String,
+ taskName: Option[String],
+ kind: String)
+
+object CustomLabelLROResult extends SparkBindings[CustomLabelLROResult]
+
+case class CustomLabelTaskResult(completed: Int,
+ failed: Int,
+ inProgress: Int,
+ total: Int,
+ items: Option[Seq[CustomLabelLROResult]])
+
+object CustomLabelTaskResult extends SparkBindings[CustomLabelTaskResult]
+
+case class CustomLabelJobState(displayName: Option[String],
+ createdDateTime: String,
+ expirationDateTime: Option[String],
+ jobId: String,
+ lastUpdatedDateTime: String,
+ status: String,
+ errors: Option[Seq[String]],
+ nextLink: Option[String],
+ tasks: CustomLabelTaskResult,
+ statistics: Option[RequestStatistics])
+
+object CustomLabelJobState extends SparkBindings[CustomLabelJobState]
+//------------------------------------------------------------------------------------------------------
+
+
+object ATLROJSONFormat {
+
+ import spray.json.DefaultJsonProtocol._
+ import ATJSONFormat._
+
+ implicit val DocumentWarningFormat: RootJsonFormat[DocumentWarning] =
+ jsonFormat3(DocumentWarning.apply)
+
+ implicit val ExtractiveSummarizationTaskParametersF: RootJsonFormat[ExtractiveSummarizationTaskParameters] =
+ jsonFormat5(ExtractiveSummarizationTaskParameters.apply)
+
+ implicit val ExtractiveSummarizationLROTaskF: RootJsonFormat[ExtractiveSummarizationLROTask] =
+ jsonFormat3(ExtractiveSummarizationLROTask.apply)
+
+ implicit val ExtractiveSummarizationJobsInputF: RootJsonFormat[ExtractiveSummarizationJobsInput] =
+ jsonFormat3(ExtractiveSummarizationJobsInput.apply)
+
+ implicit val AbstractiveSummarizationTaskParametersF: RootJsonFormat[AbstractiveSummarizationTaskParameters] =
+ jsonFormat5(AbstractiveSummarizationTaskParameters.apply)
+
+ implicit val AbstractiveSummarizationLROTaskF: RootJsonFormat[AbstractiveSummarizationLROTask] =
+ jsonFormat3(AbstractiveSummarizationLROTask.apply)
+
+ implicit val AbstractiveSummarizationJobsInputF: RootJsonFormat[AbstractiveSummarizationJobsInput] =
+ jsonFormat3(AbstractiveSummarizationJobsInput.apply)
+
+ implicit val HealthcareTaskParametersF: RootJsonFormat[HealthcareTaskParameters] =
+ jsonFormat3(HealthcareTaskParameters.apply)
+
+ implicit val HealthcareLROTaskF: RootJsonFormat[HealthcareLROTask] =
+ jsonFormat3(HealthcareLROTask.apply)
+
+ implicit val HealthcareJobsInputF: RootJsonFormat[HealthcareJobsInput] =
+ jsonFormat3(HealthcareJobsInput.apply)
+
+ implicit val SentimentAnalysisLROTaskF: RootJsonFormat[SentimentAnalysisLROTask] =
+ jsonFormat3(SentimentAnalysisLROTask.apply)
+
+ implicit val SentimentAnalysisJobsInputF: RootJsonFormat[SentimentAnalysisJobsInput] =
+ jsonFormat3(SentimentAnalysisJobsInput.apply)
+
+ implicit val KeyPhraseExtractionLROTaskF: RootJsonFormat[KeyPhraseExtractionLROTask] =
+ jsonFormat3(KeyPhraseExtractionLROTask.apply)
+
+ implicit val KeyPhraseExtractionJobsInputF: RootJsonFormat[KeyPhraseExtractionJobsInput] =
+ jsonFormat3(KeyPhraseExtractionJobsInput.apply)
+
+ implicit val PiiEntityRecognitionLROTaskF: RootJsonFormat[PiiEntityRecognitionLROTask] =
+ jsonFormat3(PiiEntityRecognitionLROTask.apply)
+
+ implicit val PiiEntityRecognitionJobsInputF: RootJsonFormat[PiiEntityRecognitionJobsInput] =
+ jsonFormat3(PiiEntityRecognitionJobsInput.apply)
+
+ implicit val EntityLinkingLROTaskF: RootJsonFormat[EntityLinkingLROTask] =
+ jsonFormat3(EntityLinkingLROTask.apply)
+
+ implicit val EntityLinkingJobsInputF: RootJsonFormat[EntityLinkingJobsInput] =
+ jsonFormat3(EntityLinkingJobsInput.apply)
+
+ implicit val EntityOverlapPolicyF: RootJsonFormat[EntityOverlapPolicy] =
+ jsonFormat1(EntityOverlapPolicy.apply)
+
+ implicit val EntityInferenceOptionsF: RootJsonFormat[EntityInferenceOptions] =
+ jsonFormat1(EntityInferenceOptions.apply)
+
+ implicit val EntityRecognitionTaskParametersF: RootJsonFormat[EntityRecognitionTaskParameters] =
+ jsonFormat7(EntityRecognitionTaskParameters.apply)
+
+ implicit val EntityRecognitionLROTaskF: RootJsonFormat[EntityRecognitionLROTask] =
+ jsonFormat3(EntityRecognitionLROTask.apply)
+
+ implicit val EntityRecognitionJobsInputF: RootJsonFormat[EntityRecognitionJobsInput] =
+ jsonFormat3(EntityRecognitionJobsInput.apply)
+
+ implicit val CustomEntitiesTaskParametersF: RootJsonFormat[CustomEntitiesTaskParameters] =
+ jsonFormat4(CustomEntitiesTaskParameters.apply)
+
+ implicit val CustomEntityRecognitionLROTaskF: RootJsonFormat[CustomEntityRecognitionLROTask] =
+ jsonFormat3(CustomEntityRecognitionLROTask.apply)
+
+ implicit val CustomEntitiesJobsInputF: RootJsonFormat[CustomEntitiesJobsInput] =
+ jsonFormat3(CustomEntitiesJobsInput.apply)
+
+ implicit val CustomSingleLabelTaskParametersF: RootJsonFormat[CustomLabelTaskParameters] =
+ jsonFormat3(CustomLabelTaskParameters.apply)
+
+ implicit val CustomSingleLabelLROTaskF: RootJsonFormat[CustomLabelLROTask] =
+ jsonFormat3(CustomLabelLROTask.apply)
+
+ implicit val CustomSingleLabelJobsInputF: RootJsonFormat[CustomLabelJobsInput] =
+ jsonFormat3(CustomLabelJobsInput.apply)
+}
diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/language/AnalyzeTextJobServiceTraits.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/language/AnalyzeTextJobServiceTraits.scala
new file mode 100644
index 0000000000..ce06e964ce
--- /dev/null
+++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/language/AnalyzeTextJobServiceTraits.scala
@@ -0,0 +1,537 @@
+package com.microsoft.azure.synapse.ml.services.language
+
+import com.microsoft.azure.synapse.ml.param.ServiceParam
+import com.microsoft.azure.synapse.ml.services.HasServiceParams
+import com.microsoft.azure.synapse.ml.services.language.ATLROJSONFormat._
+import com.microsoft.azure.synapse.ml.services.language.PiiDomain.PiiDomain
+import com.microsoft.azure.synapse.ml.services.language.SummaryLength.SummaryLength
+import org.apache.spark.ml.param.ParamValidators
+import org.apache.spark.sql.Row
+import spray.json.DefaultJsonProtocol._
+import spray.json.enrichAny
+
+object AnalysisTaskKind extends Enumeration {
+ type AnalysisTaskKind = Value
+ val SentimentAnalysis,
+ EntityRecognition,
+ PiiEntityRecognition,
+ KeyPhraseExtraction,
+ EntityLinking,
+ Healthcare,
+ CustomEntityRecognition,
+ CustomSingleLabelClassification,
+ CustomMultiLabelClassification,
+ ExtractiveSummarization,
+ AbstractiveSummarization = Value
+
+ def getKindFromString(kind: String): AnalysisTaskKind = {
+ AnalysisTaskKind.values.find(_.toString == kind).getOrElse(
+ throw new IllegalArgumentException(s"Invalid kind: $kind")
+ )
+ }
+}
+
+trait HasSummarizationBaseParameter extends HasServiceParams {
+ val sentenceCount = new ServiceParam[Int](
+ this,
+ name = "sentenceCount",
+ doc = "Specifies the number of sentences in the extracted summary.",
+ isValid = { case Left(value) => value >= 1 case Right(_) => true }
+ )
+
+ def getSentenceCount: Int = getScalarParam(sentenceCount)
+
+ def setSentenceCount(value: Int): this.type = setScalarParam(sentenceCount, value)
+
+ def getSentenceCountCol: String = getVectorParam(sentenceCount)
+
+ def setSentenceCountCol(value: String): this.type = setVectorParam(sentenceCount, value)
+}
+
+/**
+ * This trait is used to handle the extractive summarization request. It provides the necessary
+ * parameters to create the request and the method to create the request. There are two
+ * parameters for extractive summarization: sentenceCount and sortBy. Both of them are optional.
+ * If the user does not provide any value for sentenceCount, the service will return the default
+ * number of sentences in the summary. If the user does not provide any value for sortBy, the
+ * service will return the summary in the order of the sentences in the input text. The possible values
+ * for sortBy are "Rank" and "Offset". If the user provides an invalid value for sortBy, the service
+ * will return an error. SentenceCount is an integer value, and it should be greater than 0. This parameter
+ * specifies the number of sentences in the extracted summary. If the user provides an invalid value for
+ * sentenceCount, the service will return an error. For more details about the parameters, please refer to
+ * the documentation.
+ * [[https://learn.microsoft.com/en-us/azure/ai-services/language-service/summarization/overview]]
+ */
+trait HandleExtractiveSummarization extends HasServiceParams
+ with HasSummarizationBaseParameter {
+ val sortBy = new ServiceParam[String](
+ this,
+ name = "sortBy",
+ doc = "Specifies how to sort the extracted summaries. This can be either 'Rank' or 'Offset'.",
+ isValid = {
+ case Left(value) => ParamValidators.inArray(Array("Rank", "Offset"))(value)
+ case Right(_) => true
+ })
+
+ def getSortBy: String = getScalarParam(sortBy)
+
+ def setSortBy(value: String): this.type = setScalarParam(sortBy, value)
+
+ def getSortByCol: String = getVectorParam(sortBy)
+
+ def setSortByCol(value: String): this.type = setVectorParam(sortBy, value)
+
+ def createExtractiveSummarizationRequest(row: Row,
+ analysisInput: MultiLanguageAnalysisInput,
+ modelVersion: String,
+ stringIndexType: String,
+ loggingOptOut: Boolean): String = {
+ val taskParameter = ExtractiveSummarizationLROTask(
+ parameters = ExtractiveSummarizationTaskParameters(
+ loggingOptOut = loggingOptOut,
+ modelVersion = modelVersion,
+ sentenceCount = getValueOpt(row, sentenceCount),
+ sortBy = getValueOpt(row, sortBy),
+ stringIndexType = stringIndexType
+ ),
+ taskName = None,
+ kind = AnalysisTaskKind.ExtractiveSummarization.toString
+ )
+
+ ExtractiveSummarizationJobsInput(displayName = None,
+ analysisInput = analysisInput,
+ tasks = Seq(taskParameter)).toJson.compactPrint
+ }
+}
+
+/**
+ * This trait is used to handle the abstractive summarization request. It provides the necessary
+ * parameters to create the request and the method to create the request. There are two
+ * parameters for abstractive summarization: sentenceCount and summaryLength. Both of them are optional.
+ * It is recommended to use summaryLength over sentenceCount. Service may ignore sentenceCount parameter.
+ * SummaryLength is a string value, and it should be one of "short", "medium", or "long". This parameter
+ * controls the approximate length of the output summaries. If the user provides an invalid value for
+ * summaryLength, the service will return an error. For more details about the parameters, please refer to
+ * the documentation.
+ * [[https://learn.microsoft.com/en-us/azure/ai-services/language-service/summarization/overview]]
+ */
+trait HandleAbstractiveSummarization extends HasServiceParams with HasSummarizationBaseParameter {
+ val summaryLength = new ServiceParam[String](
+ this,
+ name = "summaryLength",
+ doc = "(NOTE: Recommended to use summaryLength over sentenceCount) Controls the"
+ + " approximate length of the output summaries.",
+ isValid = {
+ case Left(value) => ParamValidators.inArray(Array("short", "medium", "long"))(value)
+ case Right(_) => true
+ }
+
+ )
+
+ def getSummaryLength: String = getScalarParam(summaryLength)
+
+ def setSummaryLength(value: String): this.type = setScalarParam(summaryLength, value)
+
+ def setSummaryLength(value: SummaryLength): this.type = setScalarParam(summaryLength, value.toString.toLowerCase)
+
+ def getSummaryLengthCol: String = getVectorParam(summaryLength)
+
+ def setSummaryLengthCol(value: String): this.type = setVectorParam(summaryLength, value)
+
+ def createAbstractiveSummarizationRequest(row: Row,
+ analysisInput: MultiLanguageAnalysisInput,
+ modelVersion: String,
+ stringIndexType: String,
+ loggingOptOut: Boolean): String = {
+ val paramerter = AbstractiveSummarizationLROTask(
+ parameters = AbstractiveSummarizationTaskParameters(
+ sentenceCount = getValueOpt(row, sentenceCount),
+ summaryLength = getValueOpt(row, summaryLength),
+ loggingOptOut = loggingOptOut,
+ modelVersion = modelVersion,
+ stringIndexType = stringIndexType),
+ taskName = None,
+ kind = AnalysisTaskKind.AbstractiveSummarization.toString
+ )
+ AbstractiveSummarizationJobsInput(displayName = None,
+ analysisInput = analysisInput,
+ tasks = Seq(paramerter)).toJson.compactPrint
+ }
+}
+
+/**
+ * This trait is used to handle the healthcare text analytics request. It provides the necessary parameters
+ * to create the request and the method to create the request. There are three parameters for healthcare text
+ * analytics: modelVersion, stringIndexType, and loggingOptOut. All of them are optional. For more details about
+ * the parameters, please refer to the documentation.
+ * [[https://learn.microsoft.com/en-us/azure/ai-services/language-service/text-analytics-for-health/overview]]
+ */
+trait HandleHealthcareTextAnalystics extends HasServiceParams {
+ def createHealthcareTextAnalyticsRequest(row: Row,
+ analysisInput: MultiLanguageAnalysisInput,
+ modelVersion: String,
+ stringIndexType: String,
+ loggingOptOut: Boolean): String = {
+ val taskParameter = HealthcareLROTask(
+ parameters = HealthcareTaskParameters(
+ loggingOptOut = loggingOptOut,
+ modelVersion = modelVersion,
+ stringIndexType = stringIndexType
+ ),
+ taskName = None,
+ kind = AnalysisTaskKind.Healthcare.toString
+ )
+ HealthcareJobsInput(displayName = None,
+ analysisInput = analysisInput,
+ tasks = Seq(taskParameter)).toJson.compactPrint
+ }
+}
+
+/**
+ * This trait is used to handle the text analytics request. It provides the necessary parameters to create
+ * the request and the method to create the request. There are three parameters for text analytics: modelVersion,
+ * stringIndexType, and loggingOptOut. All of them are optional. For more details about the parameters, please refer
+ * to the documentation.
+ * [[https://learn.microsoft.com/en-us/azure/ai-services/language-service/sentiment-opinion-mining/overview]]
+ */
+trait HandleSentimentAnalysis extends HasServiceParams {
+ val opinionMining = new ServiceParam[Boolean](
+ this,
+ name = "opinionMining",
+ doc = "Whether to use opinion mining in the request or not."
+ )
+
+ def getOpinionMining: Boolean = getScalarParam(opinionMining)
+
+ def setOpinionMining(value: Boolean): this.type = setScalarParam(opinionMining, value)
+
+ def getOpinionMiningCol: String = getVectorParam(opinionMining)
+
+ def setOpinionMiningCol(value: String): this.type = setVectorParam(opinionMining, value)
+
+ setDefault(
+ opinionMining -> Left(false)
+ )
+
+ def createSentimentAnalysisRequest(row: Row,
+ analysisInput: MultiLanguageAnalysisInput,
+ modelVersion: String,
+ stringIndexType: String,
+ loggingOptOut: Boolean): String = {
+ val taskParameter = SentimentAnalysisLROTask(
+ parameters = SentimentAnalysisTaskParameters(
+ loggingOptOut = loggingOptOut,
+ modelVersion = modelVersion,
+ opinionMining = getValue(row, opinionMining),
+ stringIndexType = stringIndexType
+ ),
+ taskName = None,
+ kind = AnalysisTaskKind.SentimentAnalysis.toString
+ )
+ SentimentAnalysisJobsInput(displayName = None,
+ analysisInput = analysisInput,
+ tasks = Seq(taskParameter)).toJson.compactPrint
+ }
+}
+
+/**
+ * This trait is used to handle the key phrase extraction request. It provides the necessary parameters to create
+ * the request and the method to create the request. There are two parameters for key phrase extraction: modelVersion
+ * and loggingOptOut. Both of them are optional. For more details about the parameters,
+ * please refer to the documentation.
+ * [[https://learn.microsoft.com/en-us/azure/ai-services/language-service/key-phrase-extraction/overview]]
+ */
+trait HandleKeyPhraseExtraction extends HasServiceParams {
+ def createKeyPhraseExtractionRequest(row: Row,
+ analysisInput: MultiLanguageAnalysisInput,
+ modelVersion: String,
+ // This parameter is not used and only exists for compatibility
+ stringIndexType: String,
+ loggingOptOut: Boolean): String = {
+ val taskParameter = KeyPhraseExtractionLROTask(
+ parameters = KPnLDTaskParameters(
+ loggingOptOut = loggingOptOut,
+ modelVersion = modelVersion
+ ),
+ taskName = None,
+ kind = AnalysisTaskKind.KeyPhraseExtraction.toString
+ )
+ KeyPhraseExtractionJobsInput(displayName = None,
+ analysisInput = analysisInput,
+ tasks = Seq(taskParameter)).toJson.compactPrint
+ }
+}
+
+trait HandleEntityLinking extends HasServiceParams {
+ def createEntityLinkingRequest(row: Row,
+ analysisInput: MultiLanguageAnalysisInput,
+ modelVersion: String,
+ stringIndexType: String,
+ loggingOptOut: Boolean): String = {
+ val taskParameter = EntityLinkingLROTask(
+ parameters = EntityTaskParameters(
+ loggingOptOut = loggingOptOut,
+ modelVersion = modelVersion,
+ stringIndexType = stringIndexType
+ ),
+ taskName = None,
+ kind = AnalysisTaskKind.EntityLinking.toString
+ )
+ EntityLinkingJobsInput(displayName = None,
+ analysisInput = analysisInput,
+ tasks = Seq(taskParameter)).toJson.compactPrint
+ }
+}
+
+/**
+ * This trait is used to handle the PII entity recognition request. It provides the necessary parameters to create
+ * the request and the method to create the request. There are three parameters for PII entity recognition: domain,
+ * piiCategories, and loggingOptOut. All of them are optional. For more details about the parameters, please refer to
+ * the documentation.
+ * [[https://learn.microsoft.com/en-us/azure/ai-services/language-service/personally-identifiable-information/overview]]
+ */
+trait HandlePiiEntityRecognition extends HasServiceParams {
+ val domain = new ServiceParam[String](
+ this,
+ name = "domain",
+ doc = "The domain of the PII entity recognition request.",
+ isValid = {
+ case Left(value) => PiiDomain.values.map(_.toString.toLowerCase).contains(value)
+ case Right(_) => true
+ }
+ )
+
+ def getDomain: String = getScalarParam(domain)
+
+ def setDomain(value: String): this.type = setScalarParam(domain, value)
+
+ def setDomain(value: PiiDomain): this.type = setScalarParam(domain, value.toString.toLowerCase)
+
+ def getDomainCol: String = getVectorParam(domain)
+
+ def setDomainCol(value: String): this.type = setVectorParam(domain, value)
+
+ val piiCategories = new ServiceParam[Seq[String]](this, "piiCategories",
+ "describes the PII categories to return")
+
+ def setPiiCategories(v: Seq[String]): this.type = setScalarParam(piiCategories, v)
+
+ def getPiiCategories: Seq[String] = getScalarParam(piiCategories)
+
+ def setPiiCategoriesCol(v: String): this.type = setVectorParam(piiCategories, v)
+
+ def getPiiCategoriesCol: String = getVectorParam(piiCategories)
+
+ setDefault(
+ domain -> Left("none"),
+ )
+
+ def createPiiEntityRecognitionRequest(row: Row,
+ analysisInput: MultiLanguageAnalysisInput,
+ modelVersion: String,
+ stringIndexType: String,
+ loggingOptOut: Boolean): String = {
+ val taskParameter = PiiEntityRecognitionLROTask(
+ parameters = PiiTaskParameters(
+ domain = getValue(row, domain),
+ loggingOptOut = loggingOptOut,
+ modelVersion = modelVersion,
+ piiCategories = getValueOpt(row, piiCategories),
+ stringIndexType = stringIndexType
+ ),
+ taskName = None,
+ kind = AnalysisTaskKind.PiiEntityRecognition.toString
+ )
+ PiiEntityRecognitionJobsInput(displayName = None,
+ analysisInput = analysisInput,
+ tasks = Seq(taskParameter)).toJson.compactPrint
+ }
+}
+
+/**
+ * This trait is used to handle the entity recognition request. It provides the necessary parameters to create
+ * the request and the method to create the request. There are five parameters for entity recognition: inclusionList,
+ * exclusionList, overlapPolicy, excludeNormalizedValues, and loggingOptOut. All of them are optional. For more details
+ * about the parameters, please refer to the documentation.
+ * [[https://learn.microsoft.com/en-us/azure/ai-services/language-service/named-entity-recognition/overview]]
+ */
+trait HandleEntityRecognition extends HasServiceParams {
+ val inclusionList = new ServiceParam[Seq[String]](
+ this,
+ name = "inclusionList",
+ doc = "(Optional) request parameter that limits the output to the requested entity"
+ + " types included in this list. We will apply inclusionList before"
+ + " exclusionList"
+ )
+
+ def getInclusionList: Seq[String] = getScalarParam(inclusionList)
+
+ def setInclusionList(value: Seq[String]): this.type = setScalarParam(inclusionList, value)
+
+ def getInclusionListCol: String = getVectorParam(inclusionList)
+
+ def setInclusionListCol(value: String): this.type = setVectorParam(inclusionList, value)
+
+ val exclusionList = new ServiceParam[Seq[String]](
+ this,
+ name = "exclusionList",
+ doc = "(Optional) request parameter that filters out any entities that are"
+ + " included the excludeList. When a user specifies an excludeList, they cannot"
+ + " get a prediction returned with an entity in that list. We will apply"
+ + " inclusionList before exclusionList"
+ )
+
+ def getExclusionList: Seq[String] = getScalarParam(exclusionList)
+
+ def setExclusionList(value: Seq[String]): this.type = setScalarParam(exclusionList, value)
+
+ def getExclusionListCol: String = getVectorParam(exclusionList)
+
+ def setExclusionListCol(value: String): this.type = setVectorParam(exclusionList, value)
+
+ val overlapPolicy = new ServiceParam[String](
+ this,
+ name = "overlapPolicy",
+ doc = "(Optional) describes the type of overlap policy to apply to the ner output.",
+ isValid = {
+ case Left(value) => value == "matchLongest" || value == "allowOverlap"
+ case Right(_) => true
+ }
+ )
+
+ def getOverlapPolicy: String = getScalarParam(overlapPolicy)
+
+ def setOverlapPolicy(value: String): this.type = setScalarParam(overlapPolicy, value)
+
+ def getOverlapPolicyCol: String = getVectorParam(overlapPolicy)
+
+ def setOverlapPolicyCol(value: String): this.type = setVectorParam(overlapPolicy, value)
+
+ val excludeNormalizedValues = new ServiceParam[Boolean](
+ this,
+ name = "inferenceOptions",
+ doc = "(Optional) request parameter that allows the user to provide settings for"
+ + " running the inference. If set to true, the service will exclude normalized"
+ )
+
+ def getInferenceOptions: Boolean = getScalarParam(excludeNormalizedValues)
+
+ def setInferenceOptions(value: Boolean): this.type = setScalarParam(excludeNormalizedValues, value)
+
+ def getInferenceOptionsCol: String = getVectorParam(excludeNormalizedValues)
+
+ def setInferenceOptionsCol(value: String): this.type = setVectorParam(excludeNormalizedValues, value)
+
+ def createEntityRecognitionRequest(row: Row,
+ analysisInput: MultiLanguageAnalysisInput,
+ modelVersion: String,
+ stringIndexType: String,
+ loggingOptOut: Boolean): String = {
+ val serviceOverlapPolicy: Option[EntityOverlapPolicy] = getValueOpt(row, overlapPolicy) match {
+ case Some(policy) => Some(new EntityOverlapPolicy(policy))
+ case None => None
+ }
+
+ val inferenceOptions: Option[EntityInferenceOptions] = getValueOpt(row, excludeNormalizedValues) match {
+ case Some(value) => Some(new EntityInferenceOptions(value))
+ case None => None
+ }
+ val taskParameter = EntityRecognitionLROTask(
+ parameters = EntityRecognitionTaskParameters(
+ exclusionList = getValueOpt(row, exclusionList),
+ inclusionList = getValueOpt(row, inclusionList),
+ loggingOptOut = loggingOptOut,
+ modelVersion = modelVersion,
+ overlapPolicy = serviceOverlapPolicy,
+ stringIndexType = stringIndexType,
+ inferenceOptions = inferenceOptions
+ ),
+ taskName = None,
+ kind = AnalysisTaskKind.EntityRecognition.toString
+ )
+ EntityRecognitionJobsInput(displayName = None,
+ analysisInput = analysisInput,
+ tasks = Seq(taskParameter)).toJson.compactPrint
+ }
+}
+
+trait HasCustomLanguageModelParam extends HasServiceParams {
+ val projectName = new ServiceParam[String](
+ this,
+ name = "projectName",
+ doc = "This field indicates the project name for the model. This is a required field"
+ )
+
+ def getProjectName: String = getScalarParam(projectName)
+
+ def setProjectName(value: String): this.type = setScalarParam(projectName, value)
+
+ def getProjectNameCol: String = getVectorParam(projectName)
+
+ def setProjectNameCol(value: String): this.type = setVectorParam(projectName, value)
+
+ val deploymentName = new ServiceParam[String](
+ this,
+ name = "deploymentName",
+ doc = "This field indicates the deployment name for the model. This is a required field."
+ )
+
+ def getDeploymentName: String = getScalarParam(deploymentName)
+
+ def setDeploymentName(value: String): this.type = setScalarParam(deploymentName, value)
+
+ def getDeploymentNameCol: String = getVectorParam(deploymentName)
+
+ def setDeploymentNameCol(value: String): this.type = setVectorParam(deploymentName, value)
+}
+
+trait HandleCustomEntityRecognition extends HasServiceParams
+ with HasCustomLanguageModelParam {
+
+ def createCustomEntityRecognitionRequest(row: Row,
+ analysisInput: MultiLanguageAnalysisInput,
+ // This paremeter is not used and only exists for compatibility
+ modelVersion: String,
+ stringIndexType: String,
+ loggingOptOut: Boolean): String = {
+ val taskParameter = CustomEntityRecognitionLROTask(
+ parameters = CustomEntitiesTaskParameters(
+ loggingOptOut = loggingOptOut,
+ projectName = getValue(row, projectName),
+ deploymentName = getValue(row, deploymentName),
+ stringIndexType = stringIndexType
+ ),
+ taskName = None,
+ kind = AnalysisTaskKind.CustomEntityRecognition.toString
+ )
+ CustomEntitiesJobsInput(displayName = None,
+ analysisInput = analysisInput,
+ tasks = Seq(taskParameter)).toJson.compactPrint
+ }
+}
+
+trait HandleCustomLabelClassification extends HasServiceParams
+ with HasCustomLanguageModelParam {
+ def getKind: String
+
+ def createCustomMultiLabelRequest(row: Row,
+ analysisInput: MultiLanguageAnalysisInput,
+ // This paremeter is not used and only exists for compatibility
+ modelVersion: String,
+ // This paremeter is not used and only exists for compatibility
+ stringIndexType: String,
+ loggingOptOut: Boolean): String = {
+ val taskParameter = CustomLabelLROTask(
+ parameters = CustomLabelTaskParameters(
+ loggingOptOut = loggingOptOut,
+ projectName = getValue(row, projectName),
+ deploymentName = getValue(row, deploymentName)
+ ),
+ taskName = None,
+ kind = getKind
+ )
+ CustomLabelJobsInput(displayName = None,
+ analysisInput = analysisInput,
+ tasks = Seq(taskParameter)).toJson.compactPrint
+ }
+}
\ No newline at end of file
diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/language/AnalyzeTextLongRunningOperations.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/language/AnalyzeTextLongRunningOperations.scala
new file mode 100644
index 0000000000..3a3942822a
--- /dev/null
+++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/language/AnalyzeTextLongRunningOperations.scala
@@ -0,0 +1,230 @@
+package com.microsoft.azure.synapse.ml.services.language
+
+import com.microsoft.azure.synapse.ml.logging.{ FeatureNames, SynapseMLLogging }
+import com.microsoft.azure.synapse.ml.services.{
+ CognitiveServicesBaseNoHandler, HasAPIVersion, HasCognitiveServiceInput, HasInternalJsonOutputParser, HasSetLocation }
+import com.microsoft.azure.synapse.ml.services.text.{ TADocument, TextAnalyticsAutoBatch }
+import com.microsoft.azure.synapse.ml.services.vision.BasicAsyncReply
+import com.microsoft.azure.synapse.ml.stages.{ FixedMiniBatchTransformer, FlattenBatch, HasBatchSize, UDFTransformer }
+import org.apache.http.entity.{ AbstractHttpEntity, StringEntity }
+import org.apache.spark.injections.UDFUtils
+import org.apache.spark.ml.{ ComplexParamsReadable, NamespaceInjections, PipelineModel }
+import org.apache.spark.ml.util.Identifiable
+import org.apache.spark.sql.types.{ ArrayType, DataType, StructType }
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.expressions.UserDefinedFunction
+
+import java.net.URI
+
+object AnalyzeTextLongRunningOperations extends ComplexParamsReadable[AnalyzeTextLongRunningOperations]
+ with Serializable
+
+/**
+ *
+ * This transformer is used to analyze text using the Azure AI Language service. It uses AI service asynchronously.
+ * For more details please visit
+ * [[https://learn.microsoft.com/en-us/azure/ai-services/language-service/concepts/use-asynchronously]]
+ * For each row, it submits a job to the service and polls the service until the job is complete. Delay between
+ * polling requests is controlled by the [[pollingDelay]] parameter, which is set to 1000 milliseconds by default.
+ * Number of polling attempts is controlled by the [[maxPollingRetries]] parameter, which is set to 1000 by default.
+ *
+ *
+ * This transformer will use the field specified as TextCol to submit the text to the service. The response from the
+ * service will be stored in the field specified as OutputCol. The response will be a struct with the
+ * following fields:
+ *
+ * - statistics: A struct containing statistics about the job.
+ * - documents: An array of structs containing the results for each document.
+ * - errors: An array of structs containing the errors for each document.
+ * - modelVersion: The version of the model used to analyze the text.
+ *
+ *
+ *
+ * This transformer support only single task per row. The task to be performed is specified by the [[kind]] parameter.
+ * The supported tasks are:
+ *
+ * - ExtractiveSummarization
+ * - AbstractiveSummarization
+ * - Healthcare
+ * - SentimentAnalysis
+ * - KeyPhraseExtraction
+ * - PiiEntityRecognition
+ * - EntityLinking
+ * - EntityRecognition
+ * - CustomEntityRecognition
+ *
+ * Each task has its own set of parameters that can be set to control the behavior of the service and response
+ * schema.
+ *
+ */
+class AnalyzeTextLongRunningOperations(override val uid: String) extends CognitiveServicesBaseNoHandler(uid)
+ with HasAPIVersion
+ with BasicAsyncReply
+ with HasCognitiveServiceInput
+ with HasInternalJsonOutputParser
+ with HasSetLocation
+ with TextAnalyticsAutoBatch
+ with SynapseMLLogging
+ with HasAnalyzeTextServiceBaseParams
+ with HasBatchSize
+ with HandleExtractiveSummarization
+ with HandleAbstractiveSummarization
+ with HandleHealthcareTextAnalystics
+ with HandleSentimentAnalysis
+ with HandleKeyPhraseExtraction
+ with HandlePiiEntityRecognition
+ with HandleEntityLinking
+ with HandleEntityRecognition
+ with HandleCustomEntityRecognition {
+ logClass(FeatureNames.AiServices.Language)
+ def this() = this(Identifiable.randomUID("AnalyzeTextLongRunningOperations"))
+
+ override private[ml] def internalServiceType: String = "textanalytics"
+
+ override def urlPath: String = "/language/analyze-text/jobs"
+
+ override protected def validKinds: Set[String] = responseDataTypeSchemaMap.keySet.map(_.toString)
+
+ setDefault(
+ apiVersion -> Left("2023-04-01"),
+ showStats -> Left(false),
+ batchSize -> 10,
+ pollingDelay -> 1000
+ )
+
+ def setKind(value: AnalysisTaskKind.AnalysisTaskKind): this.type = set(kind, value.toString)
+
+ override protected def shouldSkip(row: Row): Boolean = emptyParamData(row, text) || super.shouldSkip(row)
+
+ /**
+ * Modifies the polling URI to include the showStats parameter if enabled.
+ */
+ override protected def modifyPollingURI(originalURI: URI): URI = {
+ if (getShowStats) {
+ new URI(s"${ originalURI.toString }&showStats=true")
+ } else {
+ originalURI
+ }
+ }
+
+ private val responseDataTypeSchemaMap: Map[AnalysisTaskKind.AnalysisTaskKind, StructType] = Map(
+ AnalysisTaskKind.ExtractiveSummarization -> ExtractiveSummarizationJobState.schema,
+ AnalysisTaskKind.AbstractiveSummarization -> AbstractiveSummarizationJobState.schema,
+ AnalysisTaskKind.Healthcare -> HealthcareJobState.schema,
+ AnalysisTaskKind.SentimentAnalysis -> SentimentAnalysisJobState.schema,
+ AnalysisTaskKind.KeyPhraseExtraction -> KeyPhraseExtractionJobState.schema,
+ AnalysisTaskKind.PiiEntityRecognition -> PiiEntityRecognitionJobState.schema,
+ AnalysisTaskKind.EntityLinking -> EntityLinkingJobState.schema,
+ AnalysisTaskKind.EntityRecognition -> EntityRecognitionJobState.schema,
+ AnalysisTaskKind.CustomEntityRecognition -> EntityRecognitionJobState.schema,
+ )
+
+ override protected def responseDataType: DataType = {
+ val taskKind = AnalysisTaskKind.getKindFromString(getKind)
+ responseDataTypeSchemaMap(taskKind)
+ }
+
+ private val requestCreatorMap: Map[AnalysisTaskKind.AnalysisTaskKind,
+ (Row, MultiLanguageAnalysisInput, String, String, Boolean) => String] = Map(
+ AnalysisTaskKind.ExtractiveSummarization -> createExtractiveSummarizationRequest,
+ AnalysisTaskKind.AbstractiveSummarization -> createAbstractiveSummarizationRequest,
+ AnalysisTaskKind.Healthcare -> createHealthcareTextAnalyticsRequest,
+ AnalysisTaskKind.SentimentAnalysis -> createSentimentAnalysisRequest,
+ AnalysisTaskKind.KeyPhraseExtraction -> createKeyPhraseExtractionRequest,
+ AnalysisTaskKind.PiiEntityRecognition -> createPiiEntityRecognitionRequest,
+ AnalysisTaskKind.EntityLinking -> createEntityLinkingRequest,
+ AnalysisTaskKind.EntityRecognition -> createEntityRecognitionRequest,
+ AnalysisTaskKind.CustomEntityRecognition -> createCustomEntityRecognitionRequest
+ )
+
+ // This method is made package private for testing purposes
+ override protected[language] def prepareEntity: Row => Option[AbstractHttpEntity] = row => {
+ val analysisInput = createMultiLanguageAnalysisInput(row)
+ val taskKind = AnalysisTaskKind.getKindFromString(getKind)
+ val requestString = requestCreatorMap(taskKind)(row,
+ analysisInput,
+ getValue(row, modelVersion),
+ getValue(row, stringIndexType),
+ getValue(row, loggingOptOut))
+ Some(new StringEntity(requestString, "UTF-8"))
+ }
+
+ protected def postprocessResponse(responseOpt: Row): Option[Seq[Row]] = {
+ Option(responseOpt).map { response =>
+ val tasks = response.getAs[Row]("tasks")
+ val items = tasks.getAs[Seq[Row]]("items")
+ items.flatMap(item => {
+ val results = item.getAs[Row]("results")
+ val stats = results.getAs[Row]("statistics")
+ val docs = results.getAs[Seq[Row]]("documents").map(
+ doc => (doc.getAs[String]("id"), doc)).toMap
+ val errors = results.getAs[Seq[Row]]("errors").map(
+ error => (error.getAs[String]("id"), error)).toMap
+ val modelVersion = results.getAs[String]("modelVersion")
+ (0 until (docs.size + errors.size)).map { i =>
+ Row.fromSeq(Seq(
+ stats,
+ docs.get(i.toString),
+ errors.get(i.toString),
+ modelVersion
+ ))
+ }
+ })
+ }
+ }
+
+ protected def postprocessResponseUdf: UserDefinedFunction = {
+ val responseType = responseDataType.asInstanceOf[StructType]
+ val results = responseType("tasks").dataType.asInstanceOf[StructType]("items")
+ .dataType.asInstanceOf[ArrayType].elementType.asInstanceOf[StructType]("results")
+ .dataType.asInstanceOf[StructType]
+ val outputType = ArrayType(
+ new StructType()
+ .add("statistics", results("statistics").dataType)
+ .add("documents", results("documents").dataType.asInstanceOf[ArrayType].elementType)
+ .add("errors", results("errors").dataType.asInstanceOf[ArrayType].elementType)
+ .add("modelVersion", results("modelVersion").dataType)
+ )
+ UDFUtils.oldUdf(postprocessResponse _, outputType)
+ }
+
+ override protected def getInternalTransformer(schema: StructType): PipelineModel = {
+ val batcher = if (shouldAutoBatch(schema)) {
+ Some(new FixedMiniBatchTransformer().setBatchSize(getBatchSize))
+ } else {
+ None
+ }
+ val newSchema = batcher.map(_.transformSchema(schema)).getOrElse(schema)
+
+ val pipe = super.getInternalTransformer(newSchema)
+
+ val postprocess = new UDFTransformer()
+ .setInputCol(getOutputCol)
+ .setOutputCol(getOutputCol)
+ .setUDF(postprocessResponseUdf)
+
+ val flatten = if (shouldAutoBatch(schema)) {
+ Some(new FlattenBatch())
+ } else {
+ None
+ }
+
+ NamespaceInjections.pipelineModel(
+ Array(batcher, Some(pipe), Some(postprocess), flatten).flatten
+ )
+ }
+
+ private def createMultiLanguageAnalysisInput(row: Row): MultiLanguageAnalysisInput = {
+ val validText = getValue(row, text)
+ val langs = getValueOpt(row, language).getOrElse(Seq.fill(validText.length)(""))
+ val validLanguages = (if (langs.length == 1) {
+ Seq.fill(validText.length)(langs.head)
+ } else {
+ langs
+ }).map(lang => Option(lang).getOrElse(""))
+ assert(validLanguages.length == validText.length)
+ MultiLanguageAnalysisInput(validText.zipWithIndex.map { case (t, i) =>
+ TADocument(Some(validLanguages(i)), i.toString, Option(t).getOrElse(""))
+ })
+ }
+}
diff --git a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/language/AnalyzeTextLROSuite.scala b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/language/AnalyzeTextLROSuite.scala
new file mode 100644
index 0000000000..e23d62fe73
--- /dev/null
+++ b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/language/AnalyzeTextLROSuite.scala
@@ -0,0 +1,630 @@
+package com.microsoft.azure.synapse.ml.services.language
+
+import com.microsoft.azure.synapse.ml.core.test.fuzzing.{ TestObject, TransformerFuzzing }
+import com.microsoft.azure.synapse.ml.services.text.{ SentimentAssessment, TextEndpoint }
+import org.apache.commons.io.IOUtils
+import org.apache.http.entity.AbstractHttpEntity
+import org.apache.spark.ml.util.MLReadable
+import org.apache.spark.sql.{ DataFrame, Row }
+import org.apache.spark.sql.functions.{ col, flatten, map }
+import org.scalactic.{ Equality, TolerantNumerics }
+
+class ExtractiveSummarizationSuite extends TransformerFuzzing[AnalyzeTextLongRunningOperations] with TextEndpoint {
+
+ import spark.implicits._
+
+ implicit val doubleEquality: Equality[Double] = TolerantNumerics.tolerantDoubleEquality(1e-3)
+
+
+ private val df = Seq(
+ Seq(
+ """At Microsoft, we have been on a quest to advance AI beyond existing techniques, by taking a more holistic,
+ |human-centric approach to learning and understanding. As Chief Technology Officer of Azure AI services,
+ |I have been working with a team of amazing scientists and engineers to turn this quest into a reality.
+ |In my role, I enjoy a unique perspective in viewing the relationship among three attributes of human
+ |cognition: monolingual text (X), audio or visual sensory signals, (Y) and multilingual (Z). At the
+ |intersection of all three, there’s magic—what we call XYZ-code as illustrated in Figure 1—a joint
+ |representation to create more powerful AI that can speak, hear, see, and understand humans better. We
+ |believe XYZ-code enables us to fulfill our long-term vision: cross-domain transfer learning, spanning
+ |modalities and languages. The goal is to have pretrained models that can jointly learn representations to
+ |support a broad range of downstream AI tasks, much in the way humans do today. Over the past five years, we
+ |have achieved human performance on benchmarks in conversational speech recognition, machine translation,
+ |conversational question answering, machine reading comprehension, and image captioning. These five
+ |breakthroughs provided us with strong signals toward our more ambitious aspiration to produce a leap in AI
+ |capabilities, achieving multi-sensory and multilingual learning that is closer in line with how humans learn
+ | and understand. I believe the joint XYZ-code is a foundational component of this aspiration, if grounded
+ | with external knowledge sources in the downstream AI tasks""".stripMargin,
+ "",
+ """Lorem ipsum dolor sit amet, consectetur adipiscing elit. Nam ultricies interdum eros, vehicula dignissim
+ |odio dignissim id. Nam sagittis lacinia enim at fringilla. Nunc imperdiet porta ex. Vestibulum quis nisl
+ |feugiat, dapibus nulla nec, dictum lorem. Vivamus ut urna a ante cursus egestas. In vulputate facilisis
+ |nunc, vitae aliquam neque faucibus a. Fusce et venenatis nisi. Duis eleifend condimentum neque. Mauris eu
+ |pulvinar sapien. Nam at nibh sem. Integer sapien ex, viverra vel interdum non, volutpat sed tellus. Aenean
+ | nec maximus nibh. Maecenas sagittis turpis vel nibh condimentum vulputate. Pellentesque viverra
+ | ullamcorper urna vitae rutrum. Nunc fermentum sem vitae commodo efficitur.""".stripMargin
+ )
+ ).toDF("text")
+
+
+ test("Basic usage") {
+ val model: AnalyzeTextLongRunningOperations = new AnalyzeTextLongRunningOperations()
+ .setSubscriptionKey(textKey)
+ .setLocation(textApiLocation)
+ .setTextCol("text")
+ .setLanguage("en")
+ .setKind(AnalysisTaskKind.ExtractiveSummarization)
+ .setOutputCol("response")
+ .setErrorCol("error")
+ val responses = model.transform(df)
+ .withColumn("documents", col("response.documents"))
+ .withColumn("modelVersion", col("response.modelVersion"))
+ .withColumn("errors", col("response.errors"))
+ .withColumn("statistics", col("response.statistics"))
+ .collect()
+ assert(responses.length == 1)
+ val response = responses.head
+ val documents = response.getAs[Seq[Row]]("documents")
+ val errors = response.getAs[Seq[Row]]("errors")
+ assert(documents.length == errors.length)
+ assert(documents.length == 3)
+ val sentences = documents.head.getAs[Seq[Row]]("sentences")
+ assert(sentences.nonEmpty)
+ sentences.foreach { sentence =>
+ assert(sentence.getAs[String]("text").nonEmpty)
+ assert(sentence.getAs[Double]("rankScore") > 0.0)
+ assert(sentence.getAs[Int]("offset") >= 0)
+ assert(sentence.getAs[Int]("length") > 0)
+ }
+ }
+
+
+ test("show-stats and sentence-count") {
+ val sentenceCount = 10
+ val model: AnalyzeTextLongRunningOperations = new AnalyzeTextLongRunningOperations()
+ .setSubscriptionKey(textKey)
+ .setLocation(textApiLocation)
+ .setTextCol("text")
+ .setLanguage("en")
+ .setKind(AnalysisTaskKind.ExtractiveSummarization)
+ .setOutputCol("response")
+ .setErrorCol("error")
+ .setShowStats(true)
+ .setSentenceCount(sentenceCount)
+ val responses = model.transform(df)
+ .withColumn("documents", col("response.documents"))
+ .withColumn("modelVersion", col("response.modelVersion"))
+ .withColumn("errors", col("response.errors"))
+ .withColumn("statistics", col("response.statistics"))
+ .collect()
+ assert(responses.length == 1)
+ val response = responses.head
+ val stats = response.getAs[Seq[Row]]("statistics")
+ assert(stats.length == 3)
+ stats.foreach { stat =>
+ assert(stat.getAs[Int]("documentsCount") == 3)
+ assert(stat.getAs[Int]("validDocumentsCount") == 2)
+ assert(stat.getAs[Int]("erroneousDocumentsCount") == 1)
+ assert(stat.getAs[Int]("transactionsCount") == 3)
+ }
+
+ val documents = response.getAs[Seq[Row]]("documents")
+ for (document <- documents) {
+ if (document != null) {
+ val sentences = document.getAs[Seq[Row]]("sentences")
+ assert(sentences.length == sentenceCount)
+ sentences.foreach { sentence =>
+ assert(sentence.getAs[String]("text").nonEmpty)
+ assert(sentence.getAs[Double]("rankScore") > 0.0)
+ assert(sentence.getAs[Int]("offset") >= 0)
+ assert(sentence.getAs[Int]("length") > 0)
+ }
+ }
+ }
+ }
+
+ override def testObjects(): Seq[TestObject[AnalyzeTextLongRunningOperations]] =
+ Seq(new TestObject[AnalyzeTextLongRunningOperations](new AnalyzeTextLongRunningOperations()
+ .setSubscriptionKey(textKey)
+ .setLocation(textApiLocation)
+ .setTextCol("text")
+ .setLanguage("en")
+ .setKind("ExtractiveSummarization")
+ .setOutputCol("response"),
+ df))
+
+ override def reader: MLReadable[_] = AnalyzeTextLongRunningOperations
+}
+
+
+class AbstractiveSummarizationSuite extends TransformerFuzzing[AnalyzeTextLongRunningOperations] with TextEndpoint {
+
+ import spark.implicits._
+
+ implicit val doubleEquality: Equality[Double] = TolerantNumerics.tolerantDoubleEquality(1e-3)
+
+
+ private val df = Seq(
+ Seq(
+ """At Microsoft, we have been on a quest to advance AI beyond existing techniques, by taking a more holistic,
+ |human-centric approach to learning and understanding. As Chief Technology Officer of Azure AI services,
+ |I have been working with a team of amazing scientists and engineers to turn this quest into a reality.
+ |In my role, I enjoy a unique perspective in viewing the relationship among three attributes of human
+ |cognition: monolingual text (X), audio or visual sensory signals, (Y) and multilingual (Z). At the
+ |intersection of all three, there’s magic—what we call XYZ-code as illustrated in Figure 1—a joint
+ |representation to create more powerful AI that can speak, hear, see, and understand humans better. We
+ |believe XYZ-code enables us to fulfill our long-term vision: cross-domain transfer learning, spanning
+ |modalities and languages. The goal is to have pretrained models that can jointly learn representations to
+ |support a broad range of downstream AI tasks, much in the way humans do today. Over the past five years, we
+ |have achieved human performance on benchmarks in conversational speech recognition, machine translation,
+ |conversational question answering, machine reading comprehension, and image captioning. These five
+ |breakthroughs provided us with strong signals toward our more ambitious aspiration to produce a leap in AI
+ |capabilities, achieving multi-sensory and multilingual learning that is closer in line with how humans learn
+ | and understand. I believe the joint XYZ-code is a foundational component of this aspiration, if grounded
+ | with external knowledge sources in the downstream AI tasks""".stripMargin
+ )
+ ).toDF("text")
+
+
+ test("Basic usage") {
+ val model: AnalyzeTextLongRunningOperations = new AnalyzeTextLongRunningOperations()
+ .setSubscriptionKey(textKey)
+ .setLocation(textApiLocation)
+ .setTextCol("text")
+ .setLanguage("en")
+ .setKind("AbstractiveSummarization")
+ .setOutputCol("response")
+ .setErrorCol("error")
+ val responses = model.transform(df)
+ .withColumn("documents", col("response.documents"))
+ .withColumn("modelVersion", col("response.modelVersion"))
+ .withColumn("errors", col("response.errors"))
+ .withColumn("statistics", col("response.statistics"))
+ .collect()
+ assert(responses.length == 1)
+ val response = responses.head
+ val documents = response.getAs[Seq[Row]]("documents")
+ val errors = response.getAs[Seq[Row]]("errors")
+ assert(documents.length == errors.length)
+ assert(documents.length == 1)
+ val summaries = documents.head.getAs[Seq[Row]]("summaries")
+ assert(summaries.nonEmpty)
+ }
+
+
+ test("show-stats and summary-length") {
+ val model: AnalyzeTextLongRunningOperations = new AnalyzeTextLongRunningOperations()
+ .setSubscriptionKey(textKey)
+ .setLocation(textApiLocation)
+ .setTextCol("text")
+ .setLanguage("en")
+ .setKind(AnalysisTaskKind.AbstractiveSummarization)
+ .setOutputCol("response")
+ .setErrorCol("error")
+ .setShowStats(true)
+ .setSummaryLength(SummaryLength.Short)
+ val responses = model.transform(df)
+ .withColumn("documents", col("response.documents"))
+ .withColumn("modelVersion", col("response.modelVersion"))
+ .withColumn("errors", col("response.errors"))
+ .withColumn("statistics", col("response.statistics"))
+ .collect()
+ assert(responses.length == 1)
+ val response = responses.head
+ val stat = response.getAs[Seq[Row]]("statistics").head
+ assert(stat.getAs[Int]("documentsCount") == 1)
+ assert(stat.getAs[Int]("validDocumentsCount") == 1)
+ assert(stat.getAs[Int]("erroneousDocumentsCount") == 0)
+
+
+ val document = response.getAs[Seq[Row]]("documents").head
+ val summaries = document.getAs[Seq[Row]]("summaries")
+ assert(summaries.length == 1)
+ val summary = summaries.head.getAs[String]("text")
+ assert(summary.length <= 750)
+ }
+
+ override def testObjects(): Seq[TestObject[AnalyzeTextLongRunningOperations]] =
+ Seq(new TestObject[AnalyzeTextLongRunningOperations](new AnalyzeTextLongRunningOperations()
+ .setSubscriptionKey(textKey)
+ .setLocation(textApiLocation)
+ .setTextCol("text")
+ .setLanguage("en")
+ .setKind("AbstractiveSummarization")
+ .setOutputCol("response"),
+ df))
+
+ override def reader: MLReadable[_] = AnalyzeTextLongRunningOperations
+}
+
+class HealthcareSuite extends TransformerFuzzing[AnalyzeTextLongRunningOperations] with TextEndpoint {
+
+
+ import spark.implicits._
+
+ implicit val doubleEquality: Equality[Double] = TolerantNumerics.tolerantDoubleEquality(1e-3)
+
+ private val df = Seq(
+ "The doctor prescried 200mg Ibuprofen."
+ ).toDF("text")
+
+ test("Basic usage") {
+ val model: AnalyzeTextLongRunningOperations = new AnalyzeTextLongRunningOperations()
+ .setSubscriptionKey(textKey)
+ .setLocation(textApiLocation)
+ .setTextCol("text")
+ .setLanguage("en")
+ .setKind("Healthcare")
+ .setOutputCol("response")
+ .setShowStats(true)
+ .setErrorCol("error")
+ val responses = model.transform(df)
+ .withColumn("documents", col("response.documents"))
+ .withColumn("modelVersion", col("response.modelVersion"))
+ .withColumn("errors", col("response.errors"))
+ .withColumn("statistics", col("response.statistics"))
+ .collect()
+ assert(responses.length == 1)
+ }
+
+ override def testObjects(): Seq[TestObject[AnalyzeTextLongRunningOperations]] =
+ Seq(new TestObject[AnalyzeTextLongRunningOperations](new AnalyzeTextLongRunningOperations()
+ .setSubscriptionKey(textKey)
+ .setLocation(textApiLocation)
+ .setTextCol("text")
+ .setLanguage("en")
+ .setKind("Healthcare")
+ .setOutputCol("response"),
+ df))
+
+ override def reader: MLReadable[_] = AnalyzeTextLongRunningOperations
+}
+
+class SentimentAnalysisLROSuite extends TransformerFuzzing[AnalyzeTextLongRunningOperations] with TextEndpoint {
+
+ import spark.implicits._
+
+ implicit val doubleEquality: Equality[Double] = TolerantNumerics.tolerantDoubleEquality(1e-3)
+
+ def df: DataFrame = Seq(
+ "Great atmosphere. Close to plenty of restaurants, hotels, and transit! Staff are friendly and helpful.",
+ "What a sad story!"
+ ).toDF("text")
+
+ def model: AnalyzeTextLongRunningOperations = new AnalyzeTextLongRunningOperations()
+ .setSubscriptionKey(textKey)
+ .setLocation(textApiLocation)
+ .setTextCol("text")
+ .setKind(AnalysisTaskKind.SentimentAnalysis)
+ .setOutputCol("response")
+ .setErrorCol("error")
+
+ test("Basic usage") {
+ val result = model.transform(df)
+ .withColumn("documents", col("response.documents"))
+ .withColumn("sentiment", col("documents.sentiment"))
+ .collect()
+ assert(result.head.getAs[String]("sentiment") == "positive")
+ assert(result(1).getAs[String]("sentiment") == "negative")
+ }
+
+ test("api-version 2022-10-01-preview") {
+ val result = model.setApiVersion("2022-10-01-preview").transform(df)
+ .withColumn("documents", col("response.documents"))
+ .withColumn("sentiment", col("documents.sentiment"))
+ .collect()
+ assert(result.head.getAs[String]("sentiment") == "positive")
+ assert(result(1).getAs[String]("sentiment") == "negative")
+ }
+
+ test("Show stats") {
+ val result = model.setShowStats(true).transform(df)
+ .withColumn("documents", col("response.documents"))
+ .withColumn("sentiment", col("documents.sentiment"))
+ .withColumn("validDocumentsCount", col("response.statistics.validDocumentsCount"))
+ .collect()
+ assert(result.head.getAs[String]("sentiment") == "positive")
+ assert(result(1).getAs[String]("sentiment") == "negative")
+ assert(result.head.getAs[Int]("validDocumentsCount") == 1)
+ }
+
+ test("Opinion Mining") {
+ val result = model.setOpinionMining(true).transform(df)
+ .withColumn("documents", col("response.documents"))
+ .withColumn("sentiment", col("documents.sentiment"))
+ .withColumn("assessments", flatten(col("documents.sentences.assessments")))
+ .collect()
+ assert(result.head.getAs[String]("sentiment") == "positive")
+ assert(result(1).getAs[String]("sentiment") == "negative")
+ val fromRow = SentimentAssessment.makeFromRowConverter
+ assert(result.head.getAs[Seq[Row]]("assessments").map(fromRow).head.sentiment == "positive")
+ }
+
+ override def testObjects(): Seq[TestObject[AnalyzeTextLongRunningOperations]] =
+ Seq(new TestObject[AnalyzeTextLongRunningOperations](model, df))
+
+ override def reader: MLReadable[_] = AnalyzeText
+}
+
+
+class KeyPhraseLROSuite extends TransformerFuzzing[AnalyzeTextLongRunningOperations] with TextEndpoint {
+
+ import spark.implicits._
+
+ implicit val doubleEquality: Equality[Double] = TolerantNumerics.tolerantDoubleEquality(1e-3)
+
+ def df: DataFrame = Seq(
+ ("en", "Microsoft was founded by Bill Gates and Paul Allen."),
+ ("en", "Text Analytics is one of the Azure Cognitive Services."),
+ ("en", "My cat might need to see a veterinarian.")
+ ).toDF("language", "text")
+
+ def model: AnalyzeTextLongRunningOperations = new AnalyzeTextLongRunningOperations()
+ .setSubscriptionKey(textKey)
+ .setLocation(textApiLocation)
+ .setLanguageCol("language")
+ .setTextCol("text")
+ .setKind("KeyPhraseExtraction")
+ .setOutputCol("response")
+ .setErrorCol("error")
+
+ test("Basic usage") {
+ val result = model.transform(df)
+ .withColumn("documents", col("response.documents"))
+ .withColumn("keyPhrases", col("documents.keyPhrases"))
+ val keyPhrases = result.collect()(1).getAs[Seq[String]]("keyPhrases")
+ assert(keyPhrases.contains("Azure Cognitive Services"))
+ assert(keyPhrases.contains("Text Analytics"))
+ }
+
+ test("api-version 2022-10-01-preview") {
+ val result = model.setApiVersion("2022-10-01-preview").transform(df)
+ .withColumn("documents", col("response.documents"))
+ .withColumn("keyPhrases", col("documents.keyPhrases"))
+ val keyPhrases = result.collect()(1).getAs[Seq[String]]("keyPhrases")
+ assert(keyPhrases.contains("Azure Cognitive Services"))
+ assert(keyPhrases.contains("Text Analytics"))
+ }
+
+ test("Show stats") {
+ val result = model.setShowStats(true).transform(df)
+ .withColumn("documents", col("response.documents"))
+ .withColumn("keyPhrases", col("documents.keyPhrases"))
+ .withColumn("validDocumentsCount", col("response.statistics.validDocumentsCount"))
+ val keyPhrases = result.collect()(1).getAs[Seq[String]]("keyPhrases")
+ assert(keyPhrases.contains("Azure Cognitive Services"))
+ assert(keyPhrases.contains("Text Analytics"))
+ assert(result.head.getAs[Int]("validDocumentsCount") == 1)
+ }
+
+ override def testObjects(): Seq[TestObject[AnalyzeTextLongRunningOperations]] =
+ Seq(new TestObject[AnalyzeTextLongRunningOperations](model, df))
+
+ override def reader: MLReadable[_] = AnalyzeTextLongRunningOperations
+}
+
+
+class AnalyzeTextPIILORSuite extends TransformerFuzzing[AnalyzeTextLongRunningOperations] with TextEndpoint {
+
+ import spark.implicits._
+
+ implicit val doubleEquality: Equality[Double] = TolerantNumerics.tolerantDoubleEquality(1e-3)
+
+ def df: DataFrame = Seq(
+ "My SSN is 859-98-0987",
+ "Your ABA number - 111000025 - is the first 9 digits in the lower left hand corner of your personal check.",
+ "Is 998.214.865-68 your Brazilian CPF number?"
+ ).toDF("text")
+
+ def model: AnalyzeTextLongRunningOperations = new AnalyzeTextLongRunningOperations()
+ .setSubscriptionKey(textKey)
+ .setLocation(textApiLocation)
+ .setTextCol("text")
+ .setKind("PiiEntityRecognition")
+ .setOutputCol("response")
+ .setErrorCol("error")
+
+ test("Basic usage") {
+ val result = model.transform(df)
+ .withColumn("documents", col("response.documents"))
+ .withColumn("redactedText", col("documents.redactedText"))
+ .withColumn("entities", col("documents.entities.text"))
+ .collect()
+ val entities = result.head.getAs[Seq[String]]("entities")
+ assert(entities.contains("859-98-0987"))
+ val redactedText = result(1).getAs[String]("redactedText")
+ assert(!redactedText.contains("111000025"))
+ }
+
+ test("api-version 2022-10-01-preview") {
+ val result = model.setApiVersion("2022-10-01-preview").transform(df)
+ .withColumn("documents", col("response.documents"))
+ .withColumn("redactedText", col("documents.redactedText"))
+ .withColumn("entities", col("documents.entities.text"))
+ .collect()
+ val entities = result.head.getAs[Seq[String]]("entities")
+ assert(entities.contains("859-98-0987"))
+ val redactedText = result(1).getAs[String]("redactedText")
+ assert(!redactedText.contains("111000025"))
+ }
+
+ test("Show stats") {
+ val result = model.setShowStats(true).transform(df)
+ .withColumn("documents", col("response.documents"))
+ .withColumn("redactedText", col("documents.redactedText"))
+ .withColumn("entities", col("documents.entities.text"))
+ .withColumn("validDocumentsCount", col("response.statistics.validDocumentsCount"))
+ .collect()
+ val entities = result.head.getAs[Seq[String]]("entities")
+ assert(entities.contains("859-98-0987"))
+ val redactedText = result(1).getAs[String]("redactedText")
+ assert(!redactedText.contains("111000025"))
+ assert(result.head.getAs[Int]("validDocumentsCount") == 1)
+ }
+
+ override def testObjects(): Seq[TestObject[AnalyzeTextLongRunningOperations]] =
+ Seq(new TestObject[AnalyzeTextLongRunningOperations](model, df))
+
+ override def reader: MLReadable[_] = AnalyzeTextLongRunningOperations
+}
+
+
+class EntityLinkingLROSuite extends TransformerFuzzing[AnalyzeTextLongRunningOperations] with TextEndpoint {
+
+ import spark.implicits._
+
+ implicit val doubleEquality: Equality[Double] = TolerantNumerics.tolerantDoubleEquality(1e-3)
+
+ def df: DataFrame = Seq(
+ ("en", "Microsoft was founded by Bill Gates and Paul Allen."),
+ ("en", "Pike place market is my favorite Seattle attraction.")
+ ).toDF("language", "text")
+
+ def model: AnalyzeTextLongRunningOperations = new AnalyzeTextLongRunningOperations()
+ .setSubscriptionKey(textKey)
+ .setLocation(textApiLocation)
+ .setLanguageCol("language")
+ .setTextCol("text")
+ .setKind("EntityLinking")
+ .setOutputCol("response")
+ .setErrorCol("error")
+
+ test("Basic usage") {
+ val result = model.transform(df)
+ .withColumn("documents", col("response.documents"))
+ .withColumn("entityNames", map(col("documents.id"), col("documents.entities.name")))
+ val entities = result.head.getAs[Map[String, Seq[String]]]("entityNames")("0")
+ assert(entities.contains("Microsoft"))
+ assert(entities.contains("Bill Gates"))
+ }
+
+ test("api-version 2022-10-01-preview") {
+ val result = model.setApiVersion("2022-10-01-preview").transform(df)
+ .withColumn("documents", col("response.documents"))
+ .withColumn("entityNames", map(col("documents.id"), col("documents.entities.name")))
+ val entities = result.head.getAs[Map[String, Seq[String]]]("entityNames")("0")
+ assert(entities.contains("Microsoft"))
+ assert(entities.contains("Bill Gates"))
+ }
+
+ test("Show stats") {
+ val result = model.setShowStats(true).transform(df)
+ .withColumn("documents", col("response.documents"))
+ .withColumn("entityNames", map(col("documents.id"), col("documents.entities.name")))
+ .withColumn("validDocumentsCount", col("response.statistics.validDocumentsCount"))
+ val entities = result.head.getAs[Map[String, Seq[String]]]("entityNames")("0")
+ assert(entities.contains("Microsoft"))
+ assert(entities.contains("Bill Gates"))
+ assert(result.head.getAs[Int]("validDocumentsCount") == 1)
+ }
+
+ override def testObjects(): Seq[TestObject[AnalyzeTextLongRunningOperations]] =
+ Seq(new TestObject[AnalyzeTextLongRunningOperations](model, df))
+
+ override def reader: MLReadable[_] = AnalyzeTextLongRunningOperations
+}
+
+
+class EntityRecognitionLROSuite extends TransformerFuzzing[AnalyzeTextLongRunningOperations] with TextEndpoint {
+
+ import spark.implicits._
+
+ implicit val doubleEquality: Equality[Double] = TolerantNumerics.tolerantDoubleEquality(1e-3)
+
+ def df: DataFrame = Seq(
+ ("en", "Microsoft was founded by Bill Gates and Paul Allen."),
+ ("en", "Pike place market is my favorite Seattle attraction.")
+ ).toDF("language", "text")
+
+ def model: AnalyzeTextLongRunningOperations = new AnalyzeTextLongRunningOperations()
+ .setSubscriptionKey(textKey)
+ .setLocation(textApiLocation)
+ .setLanguageCol("language")
+ .setTextCol("text")
+ .setKind("EntityRecognition")
+ .setOutputCol("response")
+ .setErrorCol("error")
+
+ test("Basic usage") {
+ val result = model.transform(df)
+ .withColumn("documents", col("response.documents"))
+ .withColumn("entityNames", map(col("documents.id"), col("documents.entities.text")))
+ val entities = result.head.getAs[Map[String, Seq[String]]]("entityNames")("0")
+ assert(entities.contains("Microsoft"))
+ assert(entities.contains("Bill Gates"))
+ }
+
+ test("api-version 2022-10-01-preview") {
+ val result = model.setApiVersion("2022-10-01-preview").transform(df)
+ .withColumn("documents", col("response.documents"))
+ .withColumn("entityNames", map(col("documents.id"), col("documents.entities.text")))
+ val entities = result.head.getAs[Map[String, Seq[String]]]("entityNames")("0")
+ assert(entities.contains("Microsoft"))
+ assert(entities.contains("Bill Gates"))
+ }
+
+ test("Show stats") {
+ val result = model.setShowStats(true).transform(df)
+ .withColumn("documents", col("response.documents"))
+ .withColumn("entityNames", map(col("documents.id"), col("documents.entities.text")))
+ .withColumn("validDocumentsCount", col("response.statistics.validDocumentsCount"))
+ val entities = result.head.getAs[Map[String, Seq[String]]]("entityNames")("0")
+ assert(entities.contains("Microsoft"))
+ assert(entities.contains("Bill Gates"))
+ assert(result.head.getAs[Int]("validDocumentsCount") == 1)
+ }
+
+ override def testObjects(): Seq[TestObject[AnalyzeTextLongRunningOperations]] =
+ Seq(new TestObject[AnalyzeTextLongRunningOperations](model, df))
+
+ override def reader: MLReadable[_] = AnalyzeText
+}
+
+class CustomEntityRecognitionSuite extends TransformerFuzzing[AnalyzeTextLongRunningOperations] with TextEndpoint {
+
+ import spark.implicits._
+
+ implicit val doubleEquality: Equality[Double] = TolerantNumerics.tolerantDoubleEquality(1e-3)
+
+ def df: DataFrame = Seq(
+ Seq("Microsoft was founded by Bill Gates and Paul Allen."),
+ ).toDF("text")
+
+ def model: AnalyzeTextLongRunningOperations = new AnalyzeTextLongRunningOperations()
+ .setSubscriptionKey(textKey)
+ .setLocation(textApiLocation)
+ .setLanguage("en")
+ .setTextCol("text")
+ .setKind(AnalysisTaskKind.CustomEntityRecognition)
+ .setOutputCol("response")
+ .setErrorCol("error")
+ .setDeploymentName("test-deployment")
+ .setProjectName("test-project")
+
+ test("Basic request parsing") {
+ import spray.json._
+ import ATLROJSONFormat.CustomEntitiesJobsInputF
+
+ val stringEntityOption: Option[AbstractHttpEntity] = model.prepareEntity(df.head)
+ assert(stringEntityOption.isDefined)
+ val stringEntity = stringEntityOption.get
+ val stringJson = IOUtils.toString(stringEntity.getContent, "UTF-8")
+ val inputObj = stringJson.parseJson.convertTo[CustomEntitiesJobsInput]
+ assert(inputObj != null)
+ assert(inputObj.tasks.length == 1)
+ val task = inputObj.tasks.head
+ assert(task.kind == "CustomEntityRecognition")
+ assert(task.parameters.projectName == "test-project")
+ assert(task.parameters.deploymentName == "test-deployment")
+ assert(task.parameters.stringIndexType == "TextElements_v8")
+ val doc = inputObj.analysisInput.documents.head
+ assert(doc.language.getOrElse("") == "en")
+ assert(doc.id.nonEmpty)
+ assert(doc.text == "Microsoft was founded by Bill Gates and Paul Allen.")
+ }
+
+
+ override def testObjects(): Seq[TestObject[AnalyzeTextLongRunningOperations]] =
+ Seq(new TestObject[AnalyzeTextLongRunningOperations](model, df))
+
+ override def reader: MLReadable[_] = AnalyzeText
+}
From 6287e75c2f42e683a5c042d7215d76bf21cb004d Mon Sep 17 00:00:00 2001
From: Farrukh Masud
Date: Thu, 30 Jan 2025 11:04:26 -0800
Subject: [PATCH 2/8] Adding unit test and fixing failing style test
---
.../language/AnalyzeTextJobSchema.scala | 4 ++
...raits.scala => AnalyzeTextLROTraits.scala} | 17 +++---
.../AnalyzeTextLongRunningOperations.scala | 6 +-
.../language/AnalyzeTextLROSuite.scala | 61 +++++++++----------
.../microsoft/azure/synapse/ml/Secrets.scala | 1 +
5 files changed, 49 insertions(+), 40 deletions(-)
rename cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/language/{AnalyzeTextJobServiceTraits.scala => AnalyzeTextLROTraits.scala} (98%)
diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/language/AnalyzeTextJobSchema.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/language/AnalyzeTextJobSchema.scala
index a84d2f3dcc..d29bc2efba 100644
--- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/language/AnalyzeTextJobSchema.scala
+++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/language/AnalyzeTextJobSchema.scala
@@ -1,8 +1,12 @@
+// Copyright (C) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License. See LICENSE in project root for information.
+
package com.microsoft.azure.synapse.ml.services.language
import com.microsoft.azure.synapse.ml.core.schema.SparkBindings
import spray.json. RootJsonFormat
+// scalastyle:off number.of.types
case class DocumentWarning(code: String,
message: String,
targetRef: Option[String])
diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/language/AnalyzeTextJobServiceTraits.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/language/AnalyzeTextLROTraits.scala
similarity index 98%
rename from cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/language/AnalyzeTextJobServiceTraits.scala
rename to cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/language/AnalyzeTextLROTraits.scala
index ce06e964ce..967a2c6fd2 100644
--- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/language/AnalyzeTextJobServiceTraits.scala
+++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/language/AnalyzeTextLROTraits.scala
@@ -1,3 +1,6 @@
+// Copyright (C) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License. See LICENSE in project root for information.
+
package com.microsoft.azure.synapse.ml.services.language
import com.microsoft.azure.synapse.ml.param.ServiceParam
@@ -323,7 +326,7 @@ trait HandlePiiEntityRecognition extends HasServiceParams {
def getPiiCategoriesCol: String = getVectorParam(piiCategories)
setDefault(
- domain -> Left("none"),
+ domain -> Left("none")
)
def createPiiEntityRecognitionRequest(row: Row,
@@ -396,8 +399,7 @@ trait HandleEntityRecognition extends HasServiceParams {
isValid = {
case Left(value) => value == "matchLongest" || value == "allowOverlap"
case Right(_) => true
- }
- )
+ })
def getOverlapPolicy: String = getScalarParam(overlapPolicy)
@@ -448,7 +450,7 @@ trait HandleEntityRecognition extends HasServiceParams {
),
taskName = None,
kind = AnalysisTaskKind.EntityRecognition.toString
- )
+ )
EntityRecognitionJobsInput(displayName = None,
analysisInput = analysisInput,
tasks = Seq(taskParameter)).toJson.compactPrint
@@ -502,8 +504,7 @@ trait HandleCustomEntityRecognition extends HasServiceParams
stringIndexType = stringIndexType
),
taskName = None,
- kind = AnalysisTaskKind.CustomEntityRecognition.toString
- )
+ kind = AnalysisTaskKind.CustomEntityRecognition.toString)
CustomEntitiesJobsInput(displayName = None,
analysisInput = analysisInput,
tasks = Seq(taskParameter)).toJson.compactPrint
@@ -511,7 +512,7 @@ trait HandleCustomEntityRecognition extends HasServiceParams
}
trait HandleCustomLabelClassification extends HasServiceParams
- with HasCustomLanguageModelParam {
+ with HasCustomLanguageModelParam {
def getKind: String
def createCustomMultiLabelRequest(row: Row,
@@ -534,4 +535,4 @@ trait HandleCustomLabelClassification extends HasServiceParams
analysisInput = analysisInput,
tasks = Seq(taskParameter)).toJson.compactPrint
}
-}
\ No newline at end of file
+}
diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/language/AnalyzeTextLongRunningOperations.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/language/AnalyzeTextLongRunningOperations.scala
index 3a3942822a..2e53d0b79f 100644
--- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/language/AnalyzeTextLongRunningOperations.scala
+++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/language/AnalyzeTextLongRunningOperations.scala
@@ -1,3 +1,6 @@
+// Copyright (C) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License. See LICENSE in project root for information.
+
package com.microsoft.azure.synapse.ml.services.language
import com.microsoft.azure.synapse.ml.logging.{ FeatureNames, SynapseMLLogging }
@@ -77,6 +80,7 @@ class AnalyzeTextLongRunningOperations(override val uid: String) extends Cogniti
with HandleEntityRecognition
with HandleCustomEntityRecognition {
logClass(FeatureNames.AiServices.Language)
+
def this() = this(Identifiable.randomUID("AnalyzeTextLongRunningOperations"))
override private[ml] def internalServiceType: String = "textanalytics"
@@ -116,7 +120,7 @@ class AnalyzeTextLongRunningOperations(override val uid: String) extends Cogniti
AnalysisTaskKind.PiiEntityRecognition -> PiiEntityRecognitionJobState.schema,
AnalysisTaskKind.EntityLinking -> EntityLinkingJobState.schema,
AnalysisTaskKind.EntityRecognition -> EntityRecognitionJobState.schema,
- AnalysisTaskKind.CustomEntityRecognition -> EntityRecognitionJobState.schema,
+ AnalysisTaskKind.CustomEntityRecognition -> EntityRecognitionJobState.schema
)
override protected def responseDataType: DataType = {
diff --git a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/language/AnalyzeTextLROSuite.scala b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/language/AnalyzeTextLROSuite.scala
index e23d62fe73..041e98e607 100644
--- a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/language/AnalyzeTextLROSuite.scala
+++ b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/language/AnalyzeTextLROSuite.scala
@@ -2,13 +2,17 @@ package com.microsoft.azure.synapse.ml.services.language
import com.microsoft.azure.synapse.ml.core.test.fuzzing.{ TestObject, TransformerFuzzing }
import com.microsoft.azure.synapse.ml.services.text.{ SentimentAssessment, TextEndpoint }
-import org.apache.commons.io.IOUtils
-import org.apache.http.entity.AbstractHttpEntity
+import com.microsoft.azure.synapse.ml.Secrets
import org.apache.spark.ml.util.MLReadable
import org.apache.spark.sql.{ DataFrame, Row }
import org.apache.spark.sql.functions.{ col, flatten, map }
import org.scalactic.{ Equality, TolerantNumerics }
+trait LanguageServiceEndpoint {
+ lazy val langServiceKey: String = sys.env.getOrElse("LANGUAGE_API_KEY", Secrets.LanguageApiKey)
+ lazy val langServiceLocation: String = sys.env.getOrElse("LANGUAGE_API_LOCATION", "eastus")
+}
+
class ExtractiveSummarizationSuite extends TransformerFuzzing[AnalyzeTextLongRunningOperations] with TextEndpoint {
import spark.implicits._
@@ -579,47 +583,42 @@ class EntityRecognitionLROSuite extends TransformerFuzzing[AnalyzeTextLongRunnin
override def reader: MLReadable[_] = AnalyzeText
}
-class CustomEntityRecognitionSuite extends TransformerFuzzing[AnalyzeTextLongRunningOperations] with TextEndpoint {
+class CustomEntityRecognitionSuite extends TransformerFuzzing[AnalyzeTextLongRunningOperations]
+ with LanguageServiceEndpoint {
import spark.implicits._
implicit val doubleEquality: Equality[Double] = TolerantNumerics.tolerantDoubleEquality(1e-3)
- def df: DataFrame = Seq(
- Seq("Microsoft was founded by Bill Gates and Paul Allen."),
- ).toDF("text")
+ def df: DataFrame =
+ Seq("Maria Sullivan with a mailing address of 334 Shinn Avenue, City of Wampum, State of Pennsylvania")
+ .toDF("text")
def model: AnalyzeTextLongRunningOperations = new AnalyzeTextLongRunningOperations()
- .setSubscriptionKey(textKey)
- .setLocation(textApiLocation)
+ .setSubscriptionKey(langServiceKey)
+ .setLocation(langServiceLocation)
.setLanguage("en")
.setTextCol("text")
.setKind(AnalysisTaskKind.CustomEntityRecognition)
.setOutputCol("response")
.setErrorCol("error")
- .setDeploymentName("test-deployment")
- .setProjectName("test-project")
-
- test("Basic request parsing") {
- import spray.json._
- import ATLROJSONFormat.CustomEntitiesJobsInputF
-
- val stringEntityOption: Option[AbstractHttpEntity] = model.prepareEntity(df.head)
- assert(stringEntityOption.isDefined)
- val stringEntity = stringEntityOption.get
- val stringJson = IOUtils.toString(stringEntity.getContent, "UTF-8")
- val inputObj = stringJson.parseJson.convertTo[CustomEntitiesJobsInput]
- assert(inputObj != null)
- assert(inputObj.tasks.length == 1)
- val task = inputObj.tasks.head
- assert(task.kind == "CustomEntityRecognition")
- assert(task.parameters.projectName == "test-project")
- assert(task.parameters.deploymentName == "test-deployment")
- assert(task.parameters.stringIndexType == "TextElements_v8")
- val doc = inputObj.analysisInput.documents.head
- assert(doc.language.getOrElse("") == "en")
- assert(doc.id.nonEmpty)
- assert(doc.text == "Microsoft was founded by Bill Gates and Paul Allen.")
+ .setDeploymentName("custom-ner-unitest-deployment")
+ .setProjectName("for-unit-test")
+
+ test("Basic usage") {
+ val result = model.transform(df)
+ .withColumn("documents", col("response.documents"))
+ .withColumn("entities", col("documents.entities"))
+ .collect()
+ val entities = result.head.getAs[Seq[Row]]("entities")
+ assert(entities.length == 4)
+ val resultMap: Map[String, String] = entities.map { entity =>
+ entity.getAs[String]("text") -> entity.getAs[String]("category")
+ }.toMap
+ assert(resultMap("Maria Sullivan") == "BorrowerName")
+ assert(resultMap("334 Shinn Avenue") == "BorrowerAddress")
+ assert(resultMap("Wampum") == "BorrowerCity")
+ assert(resultMap("Pennsylvania") == "BorrowerState")
}
diff --git a/core/src/test/scala/com/microsoft/azure/synapse/ml/Secrets.scala b/core/src/test/scala/com/microsoft/azure/synapse/ml/Secrets.scala
index 17eed8a668..ae45753e8b 100644
--- a/core/src/test/scala/com/microsoft/azure/synapse/ml/Secrets.scala
+++ b/core/src/test/scala/com/microsoft/azure/synapse/ml/Secrets.scala
@@ -74,4 +74,5 @@ object Secrets {
lazy val Platform: String = getSecret("synapse-platform")
lazy val AadResource: String = getSecret("synapse-internal-aad-resource")
+ lazy val LanguageApiKey: String = getSecret("language-api-key")
}
From 2227cd639564b92e0c5bc97fb51c94d54bff0189 Mon Sep 17 00:00:00 2001
From: Farrukh Masud
Date: Thu, 30 Jan 2025 12:10:08 -0800
Subject: [PATCH 3/8] Adding unit test and fixing style for the test.
---
.../AnalyzeTextLongRunningOperations.scala | 6 ++++--
.../language/AnalyzeTextLROSuite.scala | 21 +++++++++++++++----
.../microsoft/azure/synapse/ml/Secrets.scala | 2 +-
3 files changed, 22 insertions(+), 7 deletions(-)
diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/language/AnalyzeTextLongRunningOperations.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/language/AnalyzeTextLongRunningOperations.scala
index 2e53d0b79f..19102b3a8e 100644
--- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/language/AnalyzeTextLongRunningOperations.scala
+++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/language/AnalyzeTextLongRunningOperations.scala
@@ -111,7 +111,8 @@ class AnalyzeTextLongRunningOperations(override val uid: String) extends Cogniti
}
}
- private val responseDataTypeSchemaMap: Map[AnalysisTaskKind.AnalysisTaskKind, StructType] = Map(
+ // This method is made package private for testing purposes
+ private[language] val responseDataTypeSchemaMap: Map[AnalysisTaskKind.AnalysisTaskKind, StructType] = Map(
AnalysisTaskKind.ExtractiveSummarization -> ExtractiveSummarizationJobState.schema,
AnalysisTaskKind.AbstractiveSummarization -> AbstractiveSummarizationJobState.schema,
AnalysisTaskKind.Healthcare -> HealthcareJobState.schema,
@@ -128,7 +129,8 @@ class AnalyzeTextLongRunningOperations(override val uid: String) extends Cogniti
responseDataTypeSchemaMap(taskKind)
}
- private val requestCreatorMap: Map[AnalysisTaskKind.AnalysisTaskKind,
+ // This method is made package private for testing purposes
+ private[language] val requestCreatorMap: Map[AnalysisTaskKind.AnalysisTaskKind,
(Row, MultiLanguageAnalysisInput, String, String, Boolean) => String] = Map(
AnalysisTaskKind.ExtractiveSummarization -> createExtractiveSummarizationRequest,
AnalysisTaskKind.AbstractiveSummarization -> createAbstractiveSummarizationRequest,
diff --git a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/language/AnalyzeTextLROSuite.scala b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/language/AnalyzeTextLROSuite.scala
index 041e98e607..d40e759df9 100644
--- a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/language/AnalyzeTextLROSuite.scala
+++ b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/language/AnalyzeTextLROSuite.scala
@@ -1,3 +1,6 @@
+// Copyright (C) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License. See LICENSE in project root for information.
+
package com.microsoft.azure.synapse.ml.services.language
import com.microsoft.azure.synapse.ml.core.test.fuzzing.{ TestObject, TransformerFuzzing }
@@ -7,10 +10,20 @@ import org.apache.spark.ml.util.MLReadable
import org.apache.spark.sql.{ DataFrame, Row }
import org.apache.spark.sql.functions.{ col, flatten, map }
import org.scalactic.{ Equality, TolerantNumerics }
+import org.scalatest.funsuite.AnyFunSuiteLike
trait LanguageServiceEndpoint {
- lazy val langServiceKey: String = sys.env.getOrElse("LANGUAGE_API_KEY", Secrets.LanguageApiKey)
- lazy val langServiceLocation: String = sys.env.getOrElse("LANGUAGE_API_LOCATION", "eastus")
+ lazy val customNERKey: String = sys.env.getOrElse("CUSTOM_NER_KEY", Secrets.CustomNERLanguageApiKey)
+ lazy val customNERLocation: String = sys.env.getOrElse("LANGUAGE_API_LOCATION", "eastus")
+}
+
+class AnalyzeTextLROSuite extends AnyFunSuiteLike {
+ test("Validate that response map and creator handle same kinds") {
+ val transformer = new AnalyzeTextLongRunningOperations()
+ val responseKinds = transformer.responseDataTypeSchemaMap.keySet
+ val creatorKinds = transformer.requestCreatorMap.keySet
+ assert(responseKinds == creatorKinds)
+ }
}
class ExtractiveSummarizationSuite extends TransformerFuzzing[AnalyzeTextLongRunningOperations] with TextEndpoint {
@@ -595,8 +608,8 @@ class CustomEntityRecognitionSuite extends TransformerFuzzing[AnalyzeTextLongRun
.toDF("text")
def model: AnalyzeTextLongRunningOperations = new AnalyzeTextLongRunningOperations()
- .setSubscriptionKey(langServiceKey)
- .setLocation(langServiceLocation)
+ .setSubscriptionKey(customNERKey)
+ .setLocation(customNERLocation)
.setLanguage("en")
.setTextCol("text")
.setKind(AnalysisTaskKind.CustomEntityRecognition)
diff --git a/core/src/test/scala/com/microsoft/azure/synapse/ml/Secrets.scala b/core/src/test/scala/com/microsoft/azure/synapse/ml/Secrets.scala
index ae45753e8b..344058015b 100644
--- a/core/src/test/scala/com/microsoft/azure/synapse/ml/Secrets.scala
+++ b/core/src/test/scala/com/microsoft/azure/synapse/ml/Secrets.scala
@@ -74,5 +74,5 @@ object Secrets {
lazy val Platform: String = getSecret("synapse-platform")
lazy val AadResource: String = getSecret("synapse-internal-aad-resource")
- lazy val LanguageApiKey: String = getSecret("language-api-key")
+ lazy val CustomNERLanguageApiKey: String = getSecret("custom-ner-key")
}
From 214a5789c3e817b29dff8ee9b0dd7e318ee9927a Mon Sep 17 00:00:00 2001
From: Farrukh Masud
Date: Thu, 30 Jan 2025 18:15:00 -0800
Subject: [PATCH 4/8] Adding support for Custom MultiLabel Classification and
Single Label classification. Unit tests are added to validate that requests
and response are correct. Also added tiemout for AbstractiveSummary requests.
---
.../language/AnalyzeTextJobSchema.scala | 69 +-----------
.../language/AnalyzeTextLROTraits.scala | 100 ++++++++++++++++++
.../AnalyzeTextLongRunningOperations.scala | 16 ++-
.../language/AnalyzeTextLROSuite.scala | 65 +++++++++++-
.../microsoft/azure/synapse/ml/Secrets.scala | 2 +-
5 files changed, 175 insertions(+), 77 deletions(-)
diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/language/AnalyzeTextJobSchema.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/language/AnalyzeTextJobSchema.scala
index d29bc2efba..6c27ab2c6c 100644
--- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/language/AnalyzeTextJobSchema.scala
+++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/language/AnalyzeTextJobSchema.scala
@@ -41,38 +41,28 @@ case class ExtractedSummarySentence(text: String,
offset: Int,
length: Int)
-object ExtractedSummarySentence extends SparkBindings[ExtractedSummarySentence]
-
case class ExtractedSummaryDocumentResult(id: String,
warnings: Seq[DocumentWarning],
statistics: Option[RequestStatistics],
sentences: Seq[ExtractedSummarySentence])
-object ExtractedSummaryDocumentResult extends SparkBindings[ExtractedSummaryDocumentResult]
-
case class ExtractiveSummarizationResult(errors: Seq[ATError],
statistics: Option[RequestStatistics],
modelVersion: String,
documents: Seq[ExtractedSummaryDocumentResult])
-object ExtractiveSummarizationResult extends SparkBindings[ExtractiveSummarizationResult]
-
case class ExtractiveSummarizationLROResult(results: ExtractiveSummarizationResult,
lastUpdateDateTime: String,
status: String,
taskName: Option[String],
kind: String)
-object ExtractiveSummarizationLROResult extends SparkBindings[ExtractiveSummarizationLROResult]
-
case class ExtractiveSummarizationTaskResult(completed: Int,
failed: Int,
inProgress: Int,
total: Int,
items: Option[Seq[ExtractiveSummarizationLROResult]])
-object ExtractiveSummarizationTaskResult extends SparkBindings[ExtractiveSummarizationTaskResult]
-
case class ExtractiveSummarizationJobState(displayName: Option[String],
createdDateTime: String,
expirationDateTime: Option[String],
@@ -112,8 +102,6 @@ case class AbstractiveSummarizationJobsInput(displayName: Option[String],
case class AbstractiveSummary(text: String,
contexts: Option[Seq[SummaryContext]])
-object AbstractiveSummary extends SparkBindings[AbstractiveSummary]
-
case class AbstractiveSummaryDocumentResult(id: String,
warnings: Seq[DocumentWarning],
statistics: Option[RequestStatistics],
@@ -126,25 +114,18 @@ case class AbstractiveSummarizationResult(errors: Seq[ATError],
modelVersion: String,
documents: Seq[AbstractiveSummaryDocumentResult])
-object AbstractiveSummarizationResult extends SparkBindings[AbstractiveSummarizationResult]
-
case class AbstractiveSummarizationLROResult(results: AbstractiveSummarizationResult,
lastUpdateDateTime: String,
status: String,
taskName: Option[String],
kind: String)
-object AbstractiveSummarizationLROResult extends SparkBindings[AbstractiveSummarizationLROResult]
-
-
case class AbstractiveSummarizationTaskResult(completed: Int,
failed: Int,
inProgress: Int,
total: Int,
items: Option[Seq[AbstractiveSummarizationLROResult]])
-object AbstractiveSummarizationTaskResult extends SparkBindings[AbstractiveSummarizationTaskResult]
-
case class AbstractiveSummarizationJobState(displayName: Option[String],
createdDateTime: String,
expirationDateTime: Option[String],
@@ -179,8 +160,6 @@ case class HealthcareAssertion(conditionality: Option[String],
association: Option[String],
temporality: Option[String])
-object HealthcareAssertion extends SparkBindings[HealthcareAssertion]
-
case class HealthcareEntitiesDocumentResult(id: String,
warnings: Seq[DocumentWarning],
statistics: Option[RequestStatistics],
@@ -188,8 +167,6 @@ case class HealthcareEntitiesDocumentResult(id: String,
relations: Seq[HealthcareRelation],
fhirBundle: Option[String])
-object HealthcareEntitiesDocumentResult extends SparkBindings[HealthcareEntitiesDocumentResult]
-
case class HealthcareEntity(text: String,
category: String,
subcategory: Option[String],
@@ -200,48 +177,33 @@ case class HealthcareEntity(text: String,
name: Option[String],
links: Option[Seq[HealthcareEntityLink]])
-object HealthcareEntity extends SparkBindings[HealthcareEntity]
-
case class HealthcareEntityLink(dataSource: String,
id: String)
-object HealthcareEntityLink extends SparkBindings[HealthcareEntityLink]
-
case class HealthcareLROResult(results: HealthcareResult,
lastUpdateDateTime: String,
status: String,
taskName: Option[String],
kind: String)
-object HealthcareLROResult extends SparkBindings[HealthcareLROResult]
-
-
case class HealthcareRelation(relationType: String,
entities: Seq[HealthcareRelationEntity],
confidenceScore: Option[Double])
-object HealthcareRelation extends SparkBindings[HealthcareRelation]
-
case class HealthcareRelationEntity(ref: String,
role: String)
-object HealthcareRelationEntity extends SparkBindings[HealthcareRelationEntity]
-
case class HealthcareResult(errors: Seq[DocumentError],
statistics: Option[RequestStatistics],
modelVersion: String,
documents: Seq[HealthcareEntitiesDocumentResult])
-object HealthcareResult extends SparkBindings[HealthcareResult]
-
case class HealthcareTaskResult(completed: Int,
failed: Int,
inProgress: Int,
total: Int,
items: Option[Seq[HealthcareLROResult]])
-object HealthcareTaskResult extends SparkBindings[HealthcareTaskResult]
-
case class HealthcareJobState(displayName: Option[String],
createdDateTime: String,
expirationDateTime: Option[String],
@@ -274,16 +236,12 @@ case class SentimentAnalysisLROResult(results: SentimentResult,
taskName: Option[String],
kind: String)
-object SentimentAnalysisLROResult extends SparkBindings[SentimentAnalysisLROResult]
-
case class SentimentAnalysisTaskResult(completed: Int,
failed: Int,
inProgress: Int,
total: Int,
items: Option[Seq[SentimentAnalysisLROResult]])
-object SentimentAnalysisTaskResult extends SparkBindings[SentimentAnalysisTaskResult]
-
case class SentimentAnalysisJobState(displayName: Option[String],
createdDateTime: String,
expirationDateTime: Option[String],
@@ -315,16 +273,12 @@ case class KeyPhraseExtractionLROResult(results: KeyPhraseExtractionResult,
taskName: Option[String],
kind: String)
-object KeyPhraseExtractionLROResult extends SparkBindings[KeyPhraseExtractionLROResult]
-
case class KeyPhraseExtractionTaskResult(completed: Int,
failed: Int,
inProgress: Int,
total: Int,
items: Option[Seq[KeyPhraseExtractionLROResult]])
-object KeyPhraseExtractionTaskResult extends SparkBindings[KeyPhraseExtractionTaskResult]
-
case class KeyPhraseExtractionJobState(displayName: Option[String],
createdDateTime: String,
expirationDateTime: Option[String],
@@ -355,23 +309,18 @@ case class PiiEntityRecognitionJobsInput(displayName: Option[String],
analysisInput: MultiLanguageAnalysisInput,
tasks: Seq[PiiEntityRecognitionLROTask])
-
case class PiiEntityRecognitionLROResult(results: PIIResult,
lastUpdateDateTime: String,
status: String,
taskName: Option[String],
kind: String)
-object PiiEntityRecognitionLROResult extends SparkBindings[PiiEntityRecognitionLROResult]
-
case class PiiEntityRecognitionTaskResult(completed: Int,
failed: Int,
inProgress: Int,
total: Int,
items: Option[Seq[PiiEntityRecognitionLROResult]])
-object PiiEntityRecognitionTaskResult extends SparkBindings[PiiEntityRecognitionTaskResult]
-
case class PiiEntityRecognitionJobState(displayName: Option[String],
createdDateTime: String,
expirationDateTime: Option[String],
@@ -405,16 +354,12 @@ case class EntityLinkingLROResult(results: EntityLinkingResult,
taskName: Option[String],
kind: String)
-object EntityLinkingLROResult extends SparkBindings[EntityLinkingLROResult]
-
case class EntityLinkingTaskResult(completed: Int,
failed: Int,
inProgress: Int,
total: Int,
items: Option[Seq[EntityLinkingLROResult]])
-object EntityLinkingTaskResult extends SparkBindings[EntityLinkingTaskResult]
-
case class EntityLinkingJobState(displayName: Option[String],
createdDateTime: String,
expirationDateTime: Option[String],
@@ -459,16 +404,12 @@ case class EntityRecognitionLROResult(results: EntityRecognitionResult,
taskName: Option[String],
kind: String)
-object EntityRecognitionLROResult extends SparkBindings[EntityRecognitionLROResult]
-
case class EntityRecognitionTaskResult(completed: Int,
failed: Int,
inProgress: Int,
total: Int,
items: Option[Seq[EntityRecognitionLROResult]])
-object EntityRecognitionTaskResult extends SparkBindings[EntityRecognitionTaskResult]
-
case class EntityRecognitionJobState(displayName: Option[String],
createdDateTime: String,
expirationDateTime: Option[String],
@@ -518,9 +459,9 @@ case class CustomLabelJobsInput(displayName: Option[String],
case class ClassificationDocumentResult(id: String,
warnings: Seq[DocumentWarning],
statistics: Option[RequestStatistics],
- classes: Seq[ClassificationResult])
+ classifications: Seq[ClassificationResult])
-object ClassificationDocumentResult extends SparkBindings[ClassificationDocumentResult]
+//object ClassificationDocumentResult extends SparkBindings[ClassificationDocumentResult]
case class ClassificationResult(category: String,
confidenceScore: Double)
@@ -532,24 +473,18 @@ case class CustomLabelResult(errors: Seq[DocumentError],
modelVersion: String,
documents: Seq[ClassificationDocumentResult])
-object CustomLabelResult extends SparkBindings[CustomLabelResult]
-
case class CustomLabelLROResult(results: CustomLabelResult,
lastUpdateDateTime: String,
status: String,
taskName: Option[String],
kind: String)
-object CustomLabelLROResult extends SparkBindings[CustomLabelLROResult]
-
case class CustomLabelTaskResult(completed: Int,
failed: Int,
inProgress: Int,
total: Int,
items: Option[Seq[CustomLabelLROResult]])
-object CustomLabelTaskResult extends SparkBindings[CustomLabelTaskResult]
-
case class CustomLabelJobState(displayName: Option[String],
createdDateTime: String,
expirationDateTime: Option[String],
diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/language/AnalyzeTextLROTraits.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/language/AnalyzeTextLROTraits.scala
index 967a2c6fd2..1c7e6f5ee9 100644
--- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/language/AnalyzeTextLROTraits.scala
+++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/language/AnalyzeTextLROTraits.scala
@@ -3,16 +3,23 @@
package com.microsoft.azure.synapse.ml.services.language
+import com.microsoft.azure.synapse.ml.io.http.{ EntityData, HTTPResponseData }
+import com.microsoft.azure.synapse.ml.logging.SynapseMLLogging
import com.microsoft.azure.synapse.ml.param.ServiceParam
import com.microsoft.azure.synapse.ml.services.HasServiceParams
import com.microsoft.azure.synapse.ml.services.language.ATLROJSONFormat._
import com.microsoft.azure.synapse.ml.services.language.PiiDomain.PiiDomain
import com.microsoft.azure.synapse.ml.services.language.SummaryLength.SummaryLength
+import com.microsoft.azure.synapse.ml.services.vision.BasicAsyncReply
+import org.apache.commons.io.IOUtils
+import org.apache.http.impl.client.CloseableHttpClient
import org.apache.spark.ml.param.ParamValidators
import org.apache.spark.sql.Row
import spray.json.DefaultJsonProtocol._
import spray.json.enrichAny
+import java.net.URI
+
object AnalysisTaskKind extends Enumeration {
type AnalysisTaskKind = Value
val SentimentAnalysis,
@@ -511,8 +518,101 @@ trait HandleCustomEntityRecognition extends HasServiceParams
}
}
+/**
+ * Trait `ModifiableAsyncReply` extends `BasicAsyncReply` and provides a mechanism to modify the HTTP response
+ * received from an asynchronous service call. This trait is designed to be mixed into classes that require
+ * custom handling of the response data.
+ *
+ * The primary purpose of this trait is to allow modification of the response before it is processed further.
+ * This is particularly useful in scenarios where the response needs to be transformed or certain fields need
+ * to be renamed to comply with specific requirements or constraints.
+ *
+ * In this implementation, the `queryForResult` method is overridden and marked as `final` to prevent further
+ * overriding. This ensures that the response modification logic is consistently applied across all subclasses.
+ *
+ * @note This trait is designed to be used with the `SynapseMLLogging` trait for consistent logging.
+ */
+trait ModifiableAsyncReply extends BasicAsyncReply {
+ self: SynapseMLLogging =>
+
+ protected def modifyResponse(response: Option[HTTPResponseData]): Option[HTTPResponseData] = response
+
+ /**
+ * Queries for the result of an asynchronous service call and applies the response modification logic.
+ */
+ override final protected def queryForResult(key: Option[String],
+ client: CloseableHttpClient,
+ location: URI): Option[HTTPResponseData] = {
+ val originalResponse = super.queryForResult(key, client, location)
+ logDebug(s"Original response: ${ originalResponse }")
+ modifyResponse(originalResponse)
+ }
+}
+
+
+/**
+ * Trait `HandleCustomLabelClassification` extends `HasServiceParams` and `HasCustomLanguageModelParam` to handle
+ * custom label classification tasks. This trait provides the necessary methods to create requests for custom
+ * multi-label classification and to modify the response to comply with specific requirements.
+ *
+ * The primary purpose of this trait is to address the limitation in Spark where fields named "class" cannot be
+ * directly bound. To work around this limitation, the response is modified to rename the "class" field to
+ * "classifications".
+ *
+ * This trait is designed to be mixed into classes that require custom label classification functionality and
+ * response modification logic.
+ *
+ * @note This trait is designed to be used with the `ModifiableAsyncReply` and `SynapseMLLogging` traits for
+ * consistent response handling and logging.
+ */
trait HandleCustomLabelClassification extends HasServiceParams
with HasCustomLanguageModelParam {
+ self: ModifiableAsyncReply
+ with SynapseMLLogging =>
+
+ private def isCustomLabelClassification: Boolean = {
+ val kind = getKind
+ kind == AnalysisTaskKind.CustomSingleLabelClassification.toString ||
+ kind == AnalysisTaskKind.CustomMultiLabelClassification.toString
+ }
+
+ /**
+ * Modifies the entity in the HTTP response to rename the "class" field to "classifications".
+ *
+ * @param response The original HTTP response.
+ * @return The modified HTTP response with the "class" field renamed to "classifications".
+ */
+ private def modifyEntity(response: HTTPResponseData): HTTPResponseData = {
+ val modifiedEntity = response.entity.flatMap { entity =>
+ val strEntity = IOUtils.toString(entity.content, "UTF-8")
+ val modifiedEntity = strEntity.replace("\"class\":", "\"classifications\":")
+ logDebug(s"Original entity: $strEntity\t Modified entity: $modifiedEntity")
+ Some(new EntityData(
+ content = modifiedEntity.getBytes,
+ contentEncoding = entity.contentEncoding,
+ contentLength = Some(strEntity.length),
+ contentType = entity.contentType,
+ isChunked = entity.isChunked,
+ isRepeatable = entity.isRepeatable,
+ isStreaming = entity.isStreaming
+ ))
+ }
+ new HTTPResponseData(response.headers, modifiedEntity, response.statusLine, response.locale)
+ }
+
+ /**
+ * Modifies the HTTP response if the task kind is custom label classification.
+ */
+ override def modifyResponse(response: Option[HTTPResponseData]): Option[HTTPResponseData] = {
+ if (!isCustomLabelClassification) {
+ logDebug(s"Kind is not CustomSingleLabelClassification or CustomMultiLabelClassification. Kind: $getKind")
+ response
+ } else {
+ response.map(modifyEntity)
+ }
+ }
+
+
def getKind: String
def createCustomMultiLabelRequest(row: Row,
diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/language/AnalyzeTextLongRunningOperations.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/language/AnalyzeTextLongRunningOperations.scala
index 19102b3a8e..1d1e412ee4 100644
--- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/language/AnalyzeTextLongRunningOperations.scala
+++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/language/AnalyzeTextLongRunningOperations.scala
@@ -4,8 +4,8 @@
package com.microsoft.azure.synapse.ml.services.language
import com.microsoft.azure.synapse.ml.logging.{ FeatureNames, SynapseMLLogging }
-import com.microsoft.azure.synapse.ml.services.{
- CognitiveServicesBaseNoHandler, HasAPIVersion, HasCognitiveServiceInput, HasInternalJsonOutputParser, HasSetLocation }
+import com.microsoft.azure.synapse.ml.services.{ CognitiveServicesBaseNoHandler, HasAPIVersion,
+ HasCognitiveServiceInput, HasInternalJsonOutputParser, HasSetLocation }
import com.microsoft.azure.synapse.ml.services.text.{ TADocument, TextAnalyticsAutoBatch }
import com.microsoft.azure.synapse.ml.services.vision.BasicAsyncReply
import com.microsoft.azure.synapse.ml.stages.{ FixedMiniBatchTransformer, FlattenBatch, HasBatchSize, UDFTransformer }
@@ -78,7 +78,9 @@ class AnalyzeTextLongRunningOperations(override val uid: String) extends Cogniti
with HandlePiiEntityRecognition
with HandleEntityLinking
with HandleEntityRecognition
- with HandleCustomEntityRecognition {
+ with HandleCustomEntityRecognition
+ with ModifiableAsyncReply
+ with HandleCustomLabelClassification {
logClass(FeatureNames.AiServices.Language)
def this() = this(Identifiable.randomUID("AnalyzeTextLongRunningOperations"))
@@ -121,7 +123,9 @@ class AnalyzeTextLongRunningOperations(override val uid: String) extends Cogniti
AnalysisTaskKind.PiiEntityRecognition -> PiiEntityRecognitionJobState.schema,
AnalysisTaskKind.EntityLinking -> EntityLinkingJobState.schema,
AnalysisTaskKind.EntityRecognition -> EntityRecognitionJobState.schema,
- AnalysisTaskKind.CustomEntityRecognition -> EntityRecognitionJobState.schema
+ AnalysisTaskKind.CustomEntityRecognition -> EntityRecognitionJobState.schema,
+ AnalysisTaskKind.CustomSingleLabelClassification -> CustomLabelJobState.schema,
+ AnalysisTaskKind.CustomMultiLabelClassification -> CustomLabelJobState.schema
)
override protected def responseDataType: DataType = {
@@ -140,7 +144,9 @@ class AnalyzeTextLongRunningOperations(override val uid: String) extends Cogniti
AnalysisTaskKind.PiiEntityRecognition -> createPiiEntityRecognitionRequest,
AnalysisTaskKind.EntityLinking -> createEntityLinkingRequest,
AnalysisTaskKind.EntityRecognition -> createEntityRecognitionRequest,
- AnalysisTaskKind.CustomEntityRecognition -> createCustomEntityRecognitionRequest
+ AnalysisTaskKind.CustomEntityRecognition -> createCustomEntityRecognitionRequest,
+ AnalysisTaskKind.CustomSingleLabelClassification -> createCustomMultiLabelRequest,
+ AnalysisTaskKind.CustomMultiLabelClassification -> createCustomMultiLabelRequest
)
// This method is made package private for testing purposes
diff --git a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/language/AnalyzeTextLROSuite.scala b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/language/AnalyzeTextLROSuite.scala
index d40e759df9..9e72a9fcb2 100644
--- a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/language/AnalyzeTextLROSuite.scala
+++ b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/language/AnalyzeTextLROSuite.scala
@@ -13,8 +13,8 @@ import org.scalactic.{ Equality, TolerantNumerics }
import org.scalatest.funsuite.AnyFunSuiteLike
trait LanguageServiceEndpoint {
- lazy val customNERKey: String = sys.env.getOrElse("CUSTOM_NER_KEY", Secrets.CustomNERLanguageApiKey)
- lazy val customNERLocation: String = sys.env.getOrElse("LANGUAGE_API_LOCATION", "eastus")
+ lazy val languageApiKey: String = sys.env.getOrElse("CUSTOM_LANGUAGE_KEY", Secrets.LanguageApiKey)
+ lazy val languageApiLocation: String = sys.env.getOrElse("LANGUAGE_API_LOCATION", "eastus")
}
class AnalyzeTextLROSuite extends AnyFunSuiteLike {
@@ -191,6 +191,8 @@ class AbstractiveSummarizationSuite extends TransformerFuzzing[AnalyzeTextLongRu
.setKind("AbstractiveSummarization")
.setOutputCol("response")
.setErrorCol("error")
+ .setPollingDelay(5 * 1000)
+ .setMaxPollingRetries(30)
val responses = model.transform(df)
.withColumn("documents", col("response.documents"))
.withColumn("modelVersion", col("response.modelVersion"))
@@ -219,6 +221,8 @@ class AbstractiveSummarizationSuite extends TransformerFuzzing[AnalyzeTextLongRu
.setErrorCol("error")
.setShowStats(true)
.setSummaryLength(SummaryLength.Short)
+ .setPollingDelay(5 * 1000)
+ .setMaxPollingRetries(30)
val responses = model.transform(df)
.withColumn("documents", col("response.documents"))
.withColumn("modelVersion", col("response.modelVersion"))
@@ -247,6 +251,8 @@ class AbstractiveSummarizationSuite extends TransformerFuzzing[AnalyzeTextLongRu
.setTextCol("text")
.setLanguage("en")
.setKind("AbstractiveSummarization")
+ .setPollingDelay(5 * 1000)
+ .setMaxPollingRetries(30)
.setOutputCol("response"),
df))
@@ -608,8 +614,8 @@ class CustomEntityRecognitionSuite extends TransformerFuzzing[AnalyzeTextLongRun
.toDF("text")
def model: AnalyzeTextLongRunningOperations = new AnalyzeTextLongRunningOperations()
- .setSubscriptionKey(customNERKey)
- .setLocation(customNERLocation)
+ .setSubscriptionKey(languageApiKey)
+ .setLocation(languageApiLocation)
.setLanguage("en")
.setTextCol("text")
.setKind(AnalysisTaskKind.CustomEntityRecognition)
@@ -640,3 +646,54 @@ class CustomEntityRecognitionSuite extends TransformerFuzzing[AnalyzeTextLongRun
override def reader: MLReadable[_] = AnalyzeText
}
+
+
+class MultiLableClassificationSuite extends TransformerFuzzing[AnalyzeTextLongRunningOperations]
+ with LanguageServiceEndpoint {
+
+ import spark.implicits._
+
+ implicit val doubleEquality: Equality[Double] = TolerantNumerics.tolerantDoubleEquality(1e-3)
+
+ def df: DataFrame = {
+ // description of movie Finding Nemo
+ Seq("In the depths of the ocean, a father's worst nightmare comes to life. A grieving and determined father, " +
+ "must overcome his fears and navigate, the treacherous waters to find his missing son. The journey is " +
+ "fraught with relentless predators, dark secrets, and the haunting realization that the ocean is a vast, " +
+ "unforgiving abyss. Will a Father's unwavering resolve be enough to reunite him with his son, or will " +
+ "the shadows of the deep consume them both? Dive into the darkness and discover the lengths a parent will " +
+ "go to for their child.")
+ .toDF("text")
+ }
+
+ def model: AnalyzeTextLongRunningOperations = new AnalyzeTextLongRunningOperations()
+ .setSubscriptionKey(languageApiKey)
+ .setLocation(languageApiLocation)
+ .setLanguage("en")
+ .setTextCol("text")
+ .setKind(AnalysisTaskKind.CustomMultiLabelClassification)
+ .setOutputCol("response")
+ .setErrorCol("error")
+ .setDeploymentName("multi-class-movie-dep")
+ .setProjectName("for-unit-test-muti-class")
+
+ test("Basic usage") {
+ val result = model.transform(df)
+ .withColumn("documents", col("response.documents"))
+ .withColumn("classifications", col("documents.classifications"))
+ .collect()
+ val classifications = result.head.getAs[Seq[Row]]("classifications")
+ assert(classifications.nonEmpty)
+ assert(classifications.head.getAs[String]("category").nonEmpty)
+ assert(classifications.head.getAs[Double]("confidenceScore") > 0.0)
+ }
+
+
+ override def testObjects(): Seq[TestObject[AnalyzeTextLongRunningOperations]] =
+ Seq(new TestObject[AnalyzeTextLongRunningOperations](model, df))
+
+ override def reader: MLReadable[_] = AnalyzeText
+}
+
+
+
diff --git a/core/src/test/scala/com/microsoft/azure/synapse/ml/Secrets.scala b/core/src/test/scala/com/microsoft/azure/synapse/ml/Secrets.scala
index 344058015b..ae45753e8b 100644
--- a/core/src/test/scala/com/microsoft/azure/synapse/ml/Secrets.scala
+++ b/core/src/test/scala/com/microsoft/azure/synapse/ml/Secrets.scala
@@ -74,5 +74,5 @@ object Secrets {
lazy val Platform: String = getSecret("synapse-platform")
lazy val AadResource: String = getSecret("synapse-internal-aad-resource")
- lazy val CustomNERLanguageApiKey: String = getSecret("custom-ner-key")
+ lazy val LanguageApiKey: String = getSecret("language-api-key")
}
From 5d54d9c220906ec5fb7c7684197ae3a43883f6d7 Mon Sep 17 00:00:00 2001
From: Farrukh Masud
Date: Thu, 30 Jan 2025 18:42:14 -0800
Subject: [PATCH 5/8] fixing minor problems and documentation
---
.../services/language/AnalyzeTextLongRunningOperations.scala | 5 +++--
1 file changed, 3 insertions(+), 2 deletions(-)
diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/language/AnalyzeTextLongRunningOperations.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/language/AnalyzeTextLongRunningOperations.scala
index 1d1e412ee4..cc5d0f1ed4 100644
--- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/language/AnalyzeTextLongRunningOperations.scala
+++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/language/AnalyzeTextLongRunningOperations.scala
@@ -55,6 +55,8 @@ object AnalyzeTextLongRunningOperations extends ComplexParamsReadable[AnalyzeTex
* EntityLinking
* EntityRecognition
* CustomEntityRecognition
+ * CustomSingleLabelClassification
+ * CustomMultiLabelClassification
*
* Each task has its own set of parameters that can be set to control the behavior of the service and response
* schema.
@@ -62,7 +64,7 @@ object AnalyzeTextLongRunningOperations extends ComplexParamsReadable[AnalyzeTex
*/
class AnalyzeTextLongRunningOperations(override val uid: String) extends CognitiveServicesBaseNoHandler(uid)
with HasAPIVersion
- with BasicAsyncReply
+ with ModifiableAsyncReply
with HasCognitiveServiceInput
with HasInternalJsonOutputParser
with HasSetLocation
@@ -79,7 +81,6 @@ class AnalyzeTextLongRunningOperations(override val uid: String) extends Cogniti
with HandleEntityLinking
with HandleEntityRecognition
with HandleCustomEntityRecognition
- with ModifiableAsyncReply
with HandleCustomLabelClassification {
logClass(FeatureNames.AiServices.Language)
From 02b27df93acb5689dd8ad8bd0230df64af56bd75 Mon Sep 17 00:00:00 2001
From: Farrukh Masud
Date: Fri, 31 Jan 2025 12:08:51 -0800
Subject: [PATCH 6/8] Fixing parameter name to reflect the name of field
---
.../ml/services/language/AnalyzeTextLROTraits.scala | 10 +++++-----
1 file changed, 5 insertions(+), 5 deletions(-)
diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/language/AnalyzeTextLROTraits.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/language/AnalyzeTextLROTraits.scala
index 1c7e6f5ee9..c0930e8a6f 100644
--- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/language/AnalyzeTextLROTraits.scala
+++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/language/AnalyzeTextLROTraits.scala
@@ -418,18 +418,18 @@ trait HandleEntityRecognition extends HasServiceParams {
val excludeNormalizedValues = new ServiceParam[Boolean](
this,
- name = "inferenceOptions",
+ name = "excludeNormalizedValues",
doc = "(Optional) request parameter that allows the user to provide settings for"
+ " running the inference. If set to true, the service will exclude normalized"
)
- def getInferenceOptions: Boolean = getScalarParam(excludeNormalizedValues)
+ def getExcludeNormalizedValues: Boolean = getScalarParam(excludeNormalizedValues)
- def setInferenceOptions(value: Boolean): this.type = setScalarParam(excludeNormalizedValues, value)
+ def setExcludeNormalizedValues(value: Boolean): this.type = setScalarParam(excludeNormalizedValues, value)
- def getInferenceOptionsCol: String = getVectorParam(excludeNormalizedValues)
+ def getExcludeNormalizedValuesCol: String = getVectorParam(excludeNormalizedValues)
- def setInferenceOptionsCol(value: String): this.type = setVectorParam(excludeNormalizedValues, value)
+ def setexcludeNormalizedValuesCol(value: String): this.type = setVectorParam(excludeNormalizedValues, value)
def createEntityRecognitionRequest(row: Row,
analysisInput: MultiLanguageAnalysisInput,
From db504a0306d380af83edeb0d9c3126417434c894 Mon Sep 17 00:00:00 2001
From: Farrukh Masud
Date: Mon, 3 Feb 2025 13:59:42 -0800
Subject: [PATCH 7/8] Fixing failing fuzzing tests.
---
.../synapse/ml/services/language/AnalyzeTextLROSuite.scala | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
diff --git a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/language/AnalyzeTextLROSuite.scala b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/language/AnalyzeTextLROSuite.scala
index 9e72a9fcb2..03a6afa948 100644
--- a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/language/AnalyzeTextLROSuite.scala
+++ b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/language/AnalyzeTextLROSuite.scala
@@ -253,8 +253,9 @@ class AbstractiveSummarizationSuite extends TransformerFuzzing[AnalyzeTextLongRu
.setKind("AbstractiveSummarization")
.setPollingDelay(5 * 1000)
.setMaxPollingRetries(30)
+ .setSummaryLength(SummaryLength.Short)
.setOutputCol("response"),
- df))
+ Seq("Microsoft Azure AI Data Fabric").toDF("text")))
override def reader: MLReadable[_] = AnalyzeTextLongRunningOperations
}
From ae388292d2d4ba6ed483f563437e0cf2eee26e74 Mon Sep 17 00:00:00 2001
From: Farrukh Masud
Date: Wed, 5 Feb 2025 12:38:33 -0800
Subject: [PATCH 8/8] making traits and methods package private
---
.../language/AnalyzeTextLROTraits.scala | 50 +++++++++----------
1 file changed, 25 insertions(+), 25 deletions(-)
diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/language/AnalyzeTextLROTraits.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/language/AnalyzeTextLROTraits.scala
index c0930e8a6f..7f33107e7b 100644
--- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/language/AnalyzeTextLROTraits.scala
+++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/language/AnalyzeTextLROTraits.scala
@@ -41,7 +41,7 @@ object AnalysisTaskKind extends Enumeration {
}
}
-trait HasSummarizationBaseParameter extends HasServiceParams {
+private[language] trait HasSummarizationBaseParameter extends HasServiceParams {
val sentenceCount = new ServiceParam[Int](
this,
name = "sentenceCount",
@@ -72,7 +72,7 @@ trait HasSummarizationBaseParameter extends HasServiceParams {
* the documentation.
* [[https://learn.microsoft.com/en-us/azure/ai-services/language-service/summarization/overview]]
*/
-trait HandleExtractiveSummarization extends HasServiceParams
+private[language] trait HandleExtractiveSummarization extends HasServiceParams
with HasSummarizationBaseParameter {
val sortBy = new ServiceParam[String](
this,
@@ -91,7 +91,7 @@ trait HandleExtractiveSummarization extends HasServiceParams
def setSortByCol(value: String): this.type = setVectorParam(sortBy, value)
- def createExtractiveSummarizationRequest(row: Row,
+ private[language] def createExtractiveSummarizationRequest(row: Row,
analysisInput: MultiLanguageAnalysisInput,
modelVersion: String,
stringIndexType: String,
@@ -125,7 +125,7 @@ trait HandleExtractiveSummarization extends HasServiceParams
* the documentation.
* [[https://learn.microsoft.com/en-us/azure/ai-services/language-service/summarization/overview]]
*/
-trait HandleAbstractiveSummarization extends HasServiceParams with HasSummarizationBaseParameter {
+private[language] trait HandleAbstractiveSummarization extends HasServiceParams with HasSummarizationBaseParameter {
val summaryLength = new ServiceParam[String](
this,
name = "summaryLength",
@@ -148,7 +148,7 @@ trait HandleAbstractiveSummarization extends HasServiceParams with HasSummarizat
def setSummaryLengthCol(value: String): this.type = setVectorParam(summaryLength, value)
- def createAbstractiveSummarizationRequest(row: Row,
+ private[language] def createAbstractiveSummarizationRequest(row: Row,
analysisInput: MultiLanguageAnalysisInput,
modelVersion: String,
stringIndexType: String,
@@ -176,8 +176,8 @@ trait HandleAbstractiveSummarization extends HasServiceParams with HasSummarizat
* the parameters, please refer to the documentation.
* [[https://learn.microsoft.com/en-us/azure/ai-services/language-service/text-analytics-for-health/overview]]
*/
-trait HandleHealthcareTextAnalystics extends HasServiceParams {
- def createHealthcareTextAnalyticsRequest(row: Row,
+private[language] trait HandleHealthcareTextAnalystics extends HasServiceParams {
+ private[language] def createHealthcareTextAnalyticsRequest(row: Row,
analysisInput: MultiLanguageAnalysisInput,
modelVersion: String,
stringIndexType: String,
@@ -204,7 +204,7 @@ trait HandleHealthcareTextAnalystics extends HasServiceParams {
* to the documentation.
* [[https://learn.microsoft.com/en-us/azure/ai-services/language-service/sentiment-opinion-mining/overview]]
*/
-trait HandleSentimentAnalysis extends HasServiceParams {
+private[language] trait HandleSentimentAnalysis extends HasServiceParams {
val opinionMining = new ServiceParam[Boolean](
this,
name = "opinionMining",
@@ -223,7 +223,7 @@ trait HandleSentimentAnalysis extends HasServiceParams {
opinionMining -> Left(false)
)
- def createSentimentAnalysisRequest(row: Row,
+ private[language] def createSentimentAnalysisRequest(row: Row,
analysisInput: MultiLanguageAnalysisInput,
modelVersion: String,
stringIndexType: String,
@@ -251,8 +251,8 @@ trait HandleSentimentAnalysis extends HasServiceParams {
* please refer to the documentation.
* [[https://learn.microsoft.com/en-us/azure/ai-services/language-service/key-phrase-extraction/overview]]
*/
-trait HandleKeyPhraseExtraction extends HasServiceParams {
- def createKeyPhraseExtractionRequest(row: Row,
+private[language] trait HandleKeyPhraseExtraction extends HasServiceParams {
+ private[language] def createKeyPhraseExtractionRequest(row: Row,
analysisInput: MultiLanguageAnalysisInput,
modelVersion: String,
// This parameter is not used and only exists for compatibility
@@ -272,8 +272,8 @@ trait HandleKeyPhraseExtraction extends HasServiceParams {
}
}
-trait HandleEntityLinking extends HasServiceParams {
- def createEntityLinkingRequest(row: Row,
+private[language] trait HandleEntityLinking extends HasServiceParams {
+ private[language] def createEntityLinkingRequest(row: Row,
analysisInput: MultiLanguageAnalysisInput,
modelVersion: String,
stringIndexType: String,
@@ -300,7 +300,7 @@ trait HandleEntityLinking extends HasServiceParams {
* the documentation.
* [[https://learn.microsoft.com/en-us/azure/ai-services/language-service/personally-identifiable-information/overview]]
*/
-trait HandlePiiEntityRecognition extends HasServiceParams {
+private[language] trait HandlePiiEntityRecognition extends HasServiceParams {
val domain = new ServiceParam[String](
this,
name = "domain",
@@ -336,7 +336,7 @@ trait HandlePiiEntityRecognition extends HasServiceParams {
domain -> Left("none")
)
- def createPiiEntityRecognitionRequest(row: Row,
+ private[language] def createPiiEntityRecognitionRequest(row: Row,
analysisInput: MultiLanguageAnalysisInput,
modelVersion: String,
stringIndexType: String,
@@ -365,7 +365,7 @@ trait HandlePiiEntityRecognition extends HasServiceParams {
* about the parameters, please refer to the documentation.
* [[https://learn.microsoft.com/en-us/azure/ai-services/language-service/named-entity-recognition/overview]]
*/
-trait HandleEntityRecognition extends HasServiceParams {
+private[language] trait HandleEntityRecognition extends HasServiceParams {
val inclusionList = new ServiceParam[Seq[String]](
this,
name = "inclusionList",
@@ -431,18 +431,18 @@ trait HandleEntityRecognition extends HasServiceParams {
def setexcludeNormalizedValuesCol(value: String): this.type = setVectorParam(excludeNormalizedValues, value)
- def createEntityRecognitionRequest(row: Row,
+ private[language] def createEntityRecognitionRequest(row: Row,
analysisInput: MultiLanguageAnalysisInput,
modelVersion: String,
stringIndexType: String,
loggingOptOut: Boolean): String = {
val serviceOverlapPolicy: Option[EntityOverlapPolicy] = getValueOpt(row, overlapPolicy) match {
- case Some(policy) => Some(new EntityOverlapPolicy(policy))
+ case Some(policy) => Some(EntityOverlapPolicy(policy))
case None => None
}
val inferenceOptions: Option[EntityInferenceOptions] = getValueOpt(row, excludeNormalizedValues) match {
- case Some(value) => Some(new EntityInferenceOptions(value))
+ case Some(value) => Some(EntityInferenceOptions(value))
case None => None
}
val taskParameter = EntityRecognitionLROTask(
@@ -464,7 +464,7 @@ trait HandleEntityRecognition extends HasServiceParams {
}
}
-trait HasCustomLanguageModelParam extends HasServiceParams {
+private[language] trait HasCustomLanguageModelParam extends HasServiceParams {
val projectName = new ServiceParam[String](
this,
name = "projectName",
@@ -494,10 +494,10 @@ trait HasCustomLanguageModelParam extends HasServiceParams {
def setDeploymentNameCol(value: String): this.type = setVectorParam(deploymentName, value)
}
-trait HandleCustomEntityRecognition extends HasServiceParams
+private[language] trait HandleCustomEntityRecognition extends HasServiceParams
with HasCustomLanguageModelParam {
- def createCustomEntityRecognitionRequest(row: Row,
+ private[language] def createCustomEntityRecognitionRequest(row: Row,
analysisInput: MultiLanguageAnalysisInput,
// This paremeter is not used and only exists for compatibility
modelVersion: String,
@@ -544,7 +544,7 @@ trait ModifiableAsyncReply extends BasicAsyncReply {
client: CloseableHttpClient,
location: URI): Option[HTTPResponseData] = {
val originalResponse = super.queryForResult(key, client, location)
- logDebug(s"Original response: ${ originalResponse }")
+ logDebug(s"Original response: $originalResponse")
modifyResponse(originalResponse)
}
}
@@ -565,7 +565,7 @@ trait ModifiableAsyncReply extends BasicAsyncReply {
* @note This trait is designed to be used with the `ModifiableAsyncReply` and `SynapseMLLogging` traits for
* consistent response handling and logging.
*/
-trait HandleCustomLabelClassification extends HasServiceParams
+private[language] trait HandleCustomLabelClassification extends HasServiceParams
with HasCustomLanguageModelParam {
self: ModifiableAsyncReply
with SynapseMLLogging =>
@@ -615,7 +615,7 @@ trait HandleCustomLabelClassification extends HasServiceParams
def getKind: String
- def createCustomMultiLabelRequest(row: Row,
+ private[language] def createCustomMultiLabelRequest(row: Row,
analysisInput: MultiLanguageAnalysisInput,
// This paremeter is not used and only exists for compatibility
modelVersion: String,