From b10340d22d9971798f0d99e52cad6d20d52a4e51 Mon Sep 17 00:00:00 2001 From: CrowleyRajapakse Date: Wed, 19 Feb 2025 15:41:48 +0530 Subject: [PATCH] fixing issues related to model based round robin --- .../envoyconf/routes_with_clusters.go | 18 +++++++-------- .../oasparser/model/adapter_internal_api.go | 8 +++---- .../operator/synchronizer/rest_api.go | 4 ++-- gateway/enforcer/internal/extproc/ext_proc.go | 22 +++++++++++++++--- .../ballerina/APIClient.bal | 23 +++++++++++++------ .../ballerina/Dependencies.toml | 3 ++- 6 files changed, 52 insertions(+), 26 deletions(-) diff --git a/adapter/internal/oasparser/envoyconf/routes_with_clusters.go b/adapter/internal/oasparser/envoyconf/routes_with_clusters.go index b0bf8b6d29..69a7c990c6 100644 --- a/adapter/internal/oasparser/envoyconf/routes_with_clusters.go +++ b/adapter/internal/oasparser/envoyconf/routes_with_clusters.go @@ -213,7 +213,7 @@ func CreateRoutesWithClusters(adapterInternalAPI *model.AdapterInternalAPI, inte existingClusterName := getExistingClusterName(*endpoint, processedEndpoints) if existingClusterName == "" { - clusterName = getClusterName(endpoint.EndpointPrefix, organizationID, vHost, adapterInternalAPI.GetTitle(), apiVersion, endpoint.Endpoints[0].Host) + clusterName = getClusterName(endpoint.EndpointPrefix, organizationID, vHost, adapterInternalAPI.GetTitle(), apiVersion, endpoint.Endpoints[0].Host+endpoint.Endpoints[0].Basepath) cluster, address, err := processEndpoints(clusterName, endpoint, timeout, basePath, nil) if err != nil { logger.LoggerOasparser.ErrorC(logging.PrintError(logging.Error2239, logging.MAJOR, "Error while adding resource level endpoints for %s:%v-%v. %v", apiTitle, apiVersion, resourcePath, err.Error())) @@ -269,7 +269,7 @@ func CreateRoutesWithClusters(adapterInternalAPI *model.AdapterInternalAPI, inte //existingClusterName := getExistingClusterName(*endpoint, processedEndpoints) - clusterName = getClusterName(endpoint.EndpointPrefix, organizationID, vHost, adapterInternalAPI.GetTitle(), apiVersion, ep.Host) + clusterName = getClusterName(endpoint.EndpointPrefix, organizationID, vHost, adapterInternalAPI.GetTitle(), apiVersion, ep.Host+ep.Basepath) cluster, address, err := processEndpoints(clusterName, endpoint, timeout, basePath, &ep) if err != nil { logger.LoggerOasparser.ErrorC(logging.PrintError(logging.Error2239, logging.MAJOR, "Error while adding resource level endpoints for %s:%v-%v. %v", apiTitle, apiVersion, resourcePath, err.Error())) @@ -288,7 +288,7 @@ func CreateRoutesWithClusters(adapterInternalAPI *model.AdapterInternalAPI, inte existingMirrorClusterName := getExistingClusterName(*mirrorEndpointCluster, processedEndpoints) var mirrorClusterName string if existingMirrorClusterName == "" { - mirrorClusterName = getClusterName(mirrorEndpointCluster.EndpointPrefix, organizationID, vHost, adapterInternalAPI.GetTitle(), apiVersion, mirrorEndpoint.Host) + mirrorClusterName = getClusterName(mirrorEndpointCluster.EndpointPrefix, organizationID, vHost, adapterInternalAPI.GetTitle(), apiVersion, mirrorEndpoint.Host+mirrorEndpoint.Basepath) mirrorCluster, mirrorAddress, err := processEndpoints(mirrorClusterName, mirrorEndpointCluster, timeout, mirrorBasepath, &mirrorEndpoint) if err != nil { logger.LoggerOasparser.ErrorC(logging.PrintError(logging.Error2239, logging.MAJOR, "Error while adding resource level mirror filter endpoints for %s:%v-%v. %v", apiTitle, apiVersion, resourcePath, err.Error())) @@ -1993,7 +1993,7 @@ func createInterceptorAPIClusters(adapterInternalAPI *model.AdapterInternalAPI, if apiRequestInterceptor.Enable { logger.LoggerOasparser.Debugf("API level request interceptors found for %v : %v", apiTitle, apiVersion) apiRequestInterceptor.ClusterName = getClusterName(requestInterceptClustersNamePrefix, organizationID, vHost, - apiTitle, apiVersion, apiRequestInterceptor.EndpointCluster.Endpoints[0].Host) + apiTitle, apiVersion, apiRequestInterceptor.EndpointCluster.Endpoints[0].Host+apiRequestInterceptor.EndpointCluster.Endpoints[0].Basepath) cluster, addresses, err := CreateLuaCluster(interceptorCerts, apiRequestInterceptor) if err != nil { apiRequestInterceptor = model.InterceptEndpoint{} @@ -2008,7 +2008,7 @@ func createInterceptorAPIClusters(adapterInternalAPI *model.AdapterInternalAPI, if apiResponseInterceptor.Enable { logger.LoggerOasparser.Debugln("API level response interceptors found for " + apiTitle) apiResponseInterceptor.ClusterName = getClusterName(responseInterceptClustersNamePrefix, organizationID, vHost, - apiTitle, apiVersion, apiResponseInterceptor.EndpointCluster.Endpoints[0].Host) + apiTitle, apiVersion, apiResponseInterceptor.EndpointCluster.Endpoints[0].Host+apiResponseInterceptor.EndpointCluster.Endpoints[0].Basepath) cluster, addresses, err := CreateLuaCluster(interceptorCerts, apiResponseInterceptor) if err != nil { apiResponseInterceptor = model.InterceptEndpoint{} @@ -2036,7 +2036,7 @@ func createInterceptorResourceClusters(adapterInternalAPI *model.AdapterInternal if reqInterceptorVal.Enable { logger.LoggerOasparser.Debugf("Resource level request interceptors found for %v:%v-%v", apiTitle, apiVersion, resource.GetPath()) reqInterceptorVal.ClusterName = getClusterName(requestInterceptClustersNamePrefix, organizationID, vHost, - apiTitle, apiVersion, reqInterceptorVal.EndpointCluster.Endpoints[0].Host) + apiTitle, apiVersion, reqInterceptorVal.EndpointCluster.Endpoints[0].Host+reqInterceptorVal.EndpointCluster.Endpoints[0].Host) cluster, addresses, err := CreateLuaCluster(interceptorCerts, reqInterceptorVal) if err != nil { logger.LoggerOasparser.ErrorC(logging.PrintError(logging.Error2244, logging.MAJOR, "Error while adding resource level request intercept external cluster for %s. %v", apiTitle, err.Error())) @@ -2054,7 +2054,7 @@ func createInterceptorResourceClusters(adapterInternalAPI *model.AdapterInternal logger.LoggerOasparser.Debugf("Operation level request interceptors found for %v:%v-%v-%v", apiTitle, apiVersion, resource.GetPath(), opI.ClusterName) opID := opI.ClusterName - opI.ClusterName = getClusterName(requestInterceptClustersNamePrefix, organizationID, vHost, apiTitle, apiVersion, opI.EndpointCluster.Endpoints[0].Host) + opI.ClusterName = getClusterName(requestInterceptClustersNamePrefix, organizationID, vHost, apiTitle, apiVersion, opI.EndpointCluster.Endpoints[0].Host+opI.EndpointCluster.Endpoints[0].Host) operationalReqInterceptors[method] = opI // since cluster name is updated cluster, addresses, err := CreateLuaCluster(interceptorCerts, opI) if err != nil { @@ -2073,7 +2073,7 @@ func createInterceptorResourceClusters(adapterInternalAPI *model.AdapterInternal if respInterceptorVal.Enable { logger.LoggerOasparser.Debugf("Resource level response interceptors found for %v:%v-%v"+apiTitle, apiVersion, resource.GetPath()) respInterceptorVal.ClusterName = getClusterName(responseInterceptClustersNamePrefix, organizationID, - vHost, apiTitle, apiVersion, respInterceptorVal.EndpointCluster.Endpoints[0].Host) + vHost, apiTitle, apiVersion, respInterceptorVal.EndpointCluster.Endpoints[0].Host+respInterceptorVal.EndpointCluster.Endpoints[0].Basepath) cluster, addresses, err := CreateLuaCluster(interceptorCerts, respInterceptorVal) if err != nil { logger.LoggerOasparser.ErrorC(logging.PrintError(logging.Error2246, logging.MAJOR, "Error while adding resource level response intercept external cluster for %s. %v", apiTitle, err.Error())) @@ -2092,7 +2092,7 @@ func createInterceptorResourceClusters(adapterInternalAPI *model.AdapterInternal logger.LoggerOasparser.Debugf("Operational level response interceptors found for %v:%v-%v-%v", apiTitle, apiVersion, resource.GetPath(), opI.ClusterName) opID := opI.ClusterName - opI.ClusterName = getClusterName(responseInterceptClustersNamePrefix, organizationID, vHost, apiTitle, apiVersion, opI.EndpointCluster.Endpoints[0].Host) + opI.ClusterName = getClusterName(responseInterceptClustersNamePrefix, organizationID, vHost, apiTitle, apiVersion, opI.EndpointCluster.Endpoints[0].Host+opI.EndpointCluster.Endpoints[0].Basepath) operationalRespInterceptorVal[method] = opI // since cluster name is updated cluster, addresses, err := CreateLuaCluster(interceptorCerts, opI) if err != nil { diff --git a/adapter/internal/oasparser/model/adapter_internal_api.go b/adapter/internal/oasparser/model/adapter_internal_api.go index 0c623f5690..307c0d49b5 100644 --- a/adapter/internal/oasparser/model/adapter_internal_api.go +++ b/adapter/internal/oasparser/model/adapter_internal_api.go @@ -1117,7 +1117,7 @@ func extractModelBasedRoundRobinFromPolicy(apiPolicy *dpv1alpha4.APIPolicy, back } endpoints := GetEndpoints(backendNamespacedName, backendMapping) - clusternName := getClusterName("", adapterInternalAPI.GetOrganizationID(), vHost, adapterInternalAPI.GetTitle(), adapterInternalAPI.GetVersion(), endpoints[0].Host) + clusternName := getClusterName("", adapterInternalAPI.GetOrganizationID(), vHost, adapterInternalAPI.GetTitle(), adapterInternalAPI.GetVersion(), endpoints[0].Host+endpoints[0].Basepath) resolvedModelWeight := InternalModelWeight{ Model: model.Model, @@ -1148,7 +1148,7 @@ func extractModelBasedRoundRobinFromPolicy(apiPolicy *dpv1alpha4.APIPolicy, back } endpoints := GetEndpoints(backendNamespacedName, backendMapping) - clusternName := getClusterName("", adapterInternalAPI.GetOrganizationID(), vHost, adapterInternalAPI.GetTitle(), adapterInternalAPI.GetVersion(), endpoints[0].Host) + clusternName := getClusterName("", adapterInternalAPI.GetOrganizationID(), vHost, adapterInternalAPI.GetTitle(), adapterInternalAPI.GetVersion(), endpoints[0].Host+endpoints[0].Basepath) resolvedModelWeight := InternalModelWeight{ Model: model.Model, @@ -1188,7 +1188,7 @@ func extractModelBasedRoundRobinFromPolicy(apiPolicy *dpv1alpha4.APIPolicy, back } endpoints := GetEndpoints(backendNamespacedName, backendMapping) - clusternName := getClusterName("", adapterInternalAPI.GetOrganizationID(), vHost, adapterInternalAPI.GetTitle(), adapterInternalAPI.GetVersion(), endpoints[0].Host) + clusternName := getClusterName("", adapterInternalAPI.GetOrganizationID(), vHost, adapterInternalAPI.GetTitle(), adapterInternalAPI.GetVersion(), endpoints[0].Host+endpoints[0].Basepath) resolvedModelWeight := InternalModelWeight{ Model: model.Model, @@ -1219,7 +1219,7 @@ func extractModelBasedRoundRobinFromPolicy(apiPolicy *dpv1alpha4.APIPolicy, back } endpoints := GetEndpoints(backendNamespacedName, backendMapping) - clusternName := getClusterName("", adapterInternalAPI.GetOrganizationID(), vHost, adapterInternalAPI.GetTitle(), adapterInternalAPI.GetVersion(), endpoints[0].Host) + clusternName := getClusterName("", adapterInternalAPI.GetOrganizationID(), vHost, adapterInternalAPI.GetTitle(), adapterInternalAPI.GetVersion(), endpoints[0].Host+endpoints[0].Basepath) resolvedModelWeight := InternalModelWeight{ Model: model.Model, diff --git a/adapter/internal/operator/synchronizer/rest_api.go b/adapter/internal/operator/synchronizer/rest_api.go index af6e6a5dbb..1e07442827 100644 --- a/adapter/internal/operator/synchronizer/rest_api.go +++ b/adapter/internal/operator/synchronizer/rest_api.go @@ -145,7 +145,7 @@ func generateAdapterInternalAPI(apiState APIState, httpRouteState *HTTPRouteStat for _, hostName := range httpRouteState.HTTPRouteCombined.Spec.Hostnames { vhost = string(hostName) } - clusternName := getClusterName(endpointCluster.EndpointPrefix, adapterInternalAPI.GetOrganizationID(), vhost, adapterInternalAPI.GetTitle(), adapterInternalAPI.GetVersion(), endpoints[0].Host) + clusternName := getClusterName(endpointCluster.EndpointPrefix, adapterInternalAPI.GetOrganizationID(), vhost, adapterInternalAPI.GetTitle(), adapterInternalAPI.GetVersion(), endpoints[0].Host+endpoints[0].Basepath) productionModels = append(productionModels, model.InternalModelWeight{ Model: aiModel.Model, EndpointClusterName: clusternName, @@ -161,7 +161,7 @@ func generateAdapterInternalAPI(apiState APIState, httpRouteState *HTTPRouteStat for _, hostName := range httpRouteState.HTTPRouteCombined.Spec.Hostnames { vhost = string(hostName) } - clusternName := getClusterName(endpointCluster.EndpointPrefix, adapterInternalAPI.GetOrganizationID(), vhost, adapterInternalAPI.GetTitle(), adapterInternalAPI.GetVersion(), endpoints[0].Host) + clusternName := getClusterName(endpointCluster.EndpointPrefix, adapterInternalAPI.GetOrganizationID(), vhost, adapterInternalAPI.GetTitle(), adapterInternalAPI.GetVersion(), endpoints[0].Host+endpoints[0].Basepath) sandboxModels = append(sandboxModels, model.InternalModelWeight{ Model: aiModel.Model, EndpointClusterName: clusternName, diff --git a/gateway/enforcer/internal/extproc/ext_proc.go b/gateway/enforcer/internal/extproc/ext_proc.go index ef0145059a..bac556e81c 100644 --- a/gateway/enforcer/internal/extproc/ext_proc.go +++ b/gateway/enforcer/internal/extproc/ext_proc.go @@ -735,9 +735,10 @@ func (s *ExternalProcessingServer) Process(srv envoy_service_proc_v3.ExternalPro s.log.Sugar().Debug(fmt.Sprintf("Header Values: %v", headerValues)) remainingTokenCount := 100 remainingRequestCount := 100 + remainingCount := 100 status := 200 for _, headerValue := range headerValues { - if headerValue.Key == "x-ratelimit-remaining-tokens" || headerValue.Key == "x-ratelimit-remaining" { + if headerValue.Key == "x-ratelimit-remaining-tokens" { value, err := util.ConvertStringToInt(string(headerValue.RawValue)) if err != nil { s.log.Error(err, "Unable to retrieve remaining token count by header") @@ -757,8 +758,15 @@ func (s *ExternalProcessingServer) Process(srv envoy_service_proc_v3.ExternalPro s.log.Error(err, "Unable to retrieve status code by header") } } + if headerValue.Key == "x-ratelimit-remaining" { + value, err := util.ConvertStringToInt(string(headerValue.RawValue)) + if err != nil { + s.log.Error(err, "Unable to retrieve remaining count by header") + } + remainingCount = value + } } - if remainingTokenCount <= 0 || remainingRequestCount <= 0 || status == 429 { // Suspend model if token/request count reaches 0 or status code is 429 + if remainingCount <= 0 || remainingTokenCount <= 0 || remainingRequestCount <= 0 || status == 429 { // Suspend model if token/request count reaches 0 or status code is 429 s.log.Sugar().Debug("Token/request are exhausted. Suspending the model") matchedResource.RouteMetadataAttributes.SuspendAIModel = "true" matchedAPI.ResourceMap[metadata.MatchedResourceIdentifier] = matchedResource @@ -775,6 +783,7 @@ func (s *ExternalProcessingServer) Process(srv envoy_service_proc_v3.ExternalPro s.log.Sugar().Debug(fmt.Sprintf("Header Values: %v", headerValues)) remainingTokenCount := 100 remainingRequestCount := 100 + remainingCount := 100 status := 200 for _, headerValue := range headerValues { if headerValue.Key == "x-ratelimit-remaining-tokens" { @@ -797,8 +806,15 @@ func (s *ExternalProcessingServer) Process(srv envoy_service_proc_v3.ExternalPro s.log.Error(err, "Unable to retrieve status code by header") } } + if headerValue.Key == "x-ratelimit-remaining" { + value, err := util.ConvertStringToInt(string(headerValue.RawValue)) + if err != nil { + s.log.Error(err, "Unable to retrieve remaining count by header") + } + remainingCount = value + } } - if remainingTokenCount <= 0 || remainingRequestCount <= 0 || status == 429 { // Suspend model if token/request count reaches 0 or status code is 429 + if remainingCount <= 0 || remainingTokenCount <= 0 || remainingRequestCount <= 0 || status == 429 { // Suspend model if token/request count reaches 0 or status code is 429 s.log.Sugar().Debug("Token/request are exhausted. Suspending the model") matchedResource.RouteMetadataAttributes.SuspendAIModel = "true" matchedAPI.ResourceMap[metadata.MatchedResourceIdentifier] = matchedResource diff --git a/runtime/config-deployer-service/ballerina/APIClient.bal b/runtime/config-deployer-service/ballerina/APIClient.bal index ef005703ce..13dfd2232b 100644 --- a/runtime/config-deployer-service/ballerina/APIClient.bal +++ b/runtime/config-deployer-service/ballerina/APIClient.bal @@ -1068,11 +1068,12 @@ public class APIClient { log:printError("Mirror filter cannot be appended as a response policy."); } string host = self.getHost(url); + string path = self.getPath(url); int|error port = self.getPort(url); if port is int { model:Backend backendService = { metadata: { - name: self.getBackendServiceUid(apkConf, apiOperation, "", host, organization), + name: self.getBackendServiceUid(apkConf, apiOperation, "", host, path, organization), labels: self.getLabels(apkConf, organization) }, spec: { @@ -1209,7 +1210,13 @@ public class APIClient { return generatedPath; } - isolated function getPath(string url) returns string { + isolated function getPath(string|K8sService endpoint) returns string { + string url; + if endpoint is string { + url = endpoint; + } else { + url = self.constructURlFromK8sService(endpoint); + } string host = ""; if url.startsWith("https://") { host = url.substring(8, url.length()); @@ -1464,7 +1471,7 @@ public class APIClient { EndpointSecurity? endpointSecurity = endpointConfig?.endpointSecurity; model:Backend backendService = { metadata: { - name: self.getBackendServiceUid(apkConf, apiOperation, endpointType, self.getHost(endpointConfig.endpoint), organization), + name: self.getBackendServiceUid(apkConf, apiOperation, endpointType, self.getHost(endpointConfig.endpoint), self.getPath(endpointConfig.endpoint), organization), labels: self.getLabels(apkConf, organization) }, spec: { @@ -1650,10 +1657,11 @@ public class APIClient { string url = model.endpoint; string host = self.getHost(url); int|error port = self.getPort(url); + string path = self.getPath(url); if port is int { model:Backend backendService = { metadata: { - name: self.getBackendServiceUid(apkConf, operations, PRODUCTION_TYPE, host, organization), + name: self.getBackendServiceUid(apkConf, operations, PRODUCTION_TYPE, host, path, organization), labels: self.getLabels(apkConf, organization) }, spec: { @@ -1686,10 +1694,11 @@ public class APIClient { string url = model.endpoint; string host = self.getHost(url); int|error port = self.getPort(url); + string path = self.getPath(url); if port is int { model:Backend backendService = { metadata: { - name: self.getBackendServiceUid(apkConf, operations, SANDBOX_TYPE, host, organization), + name: self.getBackendServiceUid(apkConf, operations, SANDBOX_TYPE, host, path, organization), labels: self.getLabels(apkConf, organization) }, spec: { @@ -2005,12 +2014,12 @@ public class APIClient { } } - public isolated function getBackendServiceUid(APKConf apkConf, APKOperations? apiOperation, string endpointType, string endpointHost, commons:Organization organization) returns string { + public isolated function getBackendServiceUid(APKConf apkConf, APKOperations? apiOperation, string endpointType, string endpointHost, string endpointPath, commons:Organization organization) returns string { string concatanatedString = uuid:createType1AsString(); if (apiOperation is APKOperations && apiOperation.endpointConfigurations is EndpointConfigurations) { return "backend-" + concatanatedString + "-resource"; } else { - concatanatedString = string:'join("-", organization.name, apkConf.name, 'apkConf.'version, endpointType, endpointHost); + concatanatedString = string:'join("-", organization.name, apkConf.name, 'apkConf.'version, endpointType, endpointHost, endpointPath); byte[] hashedValue = crypto:hashSha1(concatanatedString.toBytes()); concatanatedString = hashedValue.toBase16(); return "backend-" + concatanatedString + "-api"; diff --git a/runtime/config-deployer-service/ballerina/Dependencies.toml b/runtime/config-deployer-service/ballerina/Dependencies.toml index d4c8d76b58..ed75f94714 100644 --- a/runtime/config-deployer-service/ballerina/Dependencies.toml +++ b/runtime/config-deployer-service/ballerina/Dependencies.toml @@ -70,7 +70,7 @@ modules = [ [[package]] org = "ballerina" name = "http" -version = "2.12.4" +version = "2.12.7" dependencies = [ {org = "ballerina", name = "auth"}, {org = "ballerina", name = "cache"}, @@ -466,3 +466,4 @@ modules = [ {org = "wso2", packageName = "config_deployer_service", moduleName = "config_deployer_service.org.wso2.apk.config.model"}, {org = "wso2", packageName = "config_deployer_service", moduleName = "config_deployer_service.partitionClient"} ] +