Skip to content

Commit

Permalink
remove unneeded requests, keep target state
Browse files Browse the repository at this point in the history
  • Loading branch information
keyn4 committed May 24, 2024
1 parent fd57fd4 commit 3316167
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 103 deletions.
8 changes: 6 additions & 2 deletions target_salesforce_v3/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ def check_salesforce_limits(self, response):
percent_used_from_total = (remaining / allotted) * 100

if percent_used_from_total > quota_percent_total:
self._target.hit_rate_limit = True
total_message = (
"Salesforce has reported {}/{} ({:3.2f}%) total REST quota "
"used across all Salesforce Applications. Terminating "
Expand Down Expand Up @@ -233,8 +234,11 @@ def sf_fields(self, object_type=None):
sobject = self.request_api("GET", f"sobjects/{object_type}/describe/")
return [f for f in sobject.json()["fields"]]

def sf_fields_description(self, object_type=None):
fld = self.sf_fields(object_type=object_type)
def sf_fields_description(self, object_type=None, object_fields=None):
if not object_fields:
fld = self.sf_fields(object_type=object_type)
fld = object_fields

fields = {}
fields["createable"] = [
f["name"] for f in fld if f["createable"] and not f["custom"]
Expand Down
198 changes: 97 additions & 101 deletions target_salesforce_v3/sinks.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from dateutil.parser import parse
from datetime import datetime
from singer_sdk.exceptions import FatalAPIError, RetriableAPIError
from target_salesforce_v3.client import TargetSalesforceQuotaExceededException


class MissingObjectInSalesforceError(Exception):
Expand Down Expand Up @@ -710,136 +711,131 @@ class FallbackSink(SalesforceV3Sink):
def name(self):
return self.stream_name

def get_fields_for_object(self, object_type):
def get_fields_for_object(self, object_type, objects_list):
"""Check if Salesforce has an object type and fetches its fields."""
req = self.request_api("GET")
for object in req.json().get("sobjects", []):
for object in objects_list:
if object["name"] == object_type or object["label"] == object_type or object["labelPlural"] == object_type:
obj_req = self.request_api("GET", endpoint=f"sobjects/{object['name']}/describe").json()
return {f["name"]: f for f in obj_req.get("fields", [])}

raise MissingObjectInSalesforceError(f"Object type {object_type} not found in Salesforce.")

def preprocess_record(self, record, context):
# Check if object exists in Salesforce
object_type = None
req = self.request_api("GET", "sobjects")
objects_list = req.json().get("sobjects", [])
for object in objects_list:
is_name = object["name"] == self.stream_name
is_label = object["label"] == self.stream_name
is_label_plural = object["labelPlural"] == self.stream_name
if is_name or is_label or is_label_plural:
self.logger.info(f"Processing record for type {self.stream_name}. Using fallback sink.")
object_type = object["name"]
break

if not object_type:
self.logger.info(f"Record doesn't exist on Salesforce {self.stream_name} was not found on Salesforce.")
return {}
# not process records if target hit API rate limits
if not self._target.hit_rate_limit:
try:
# Check if object exists in Salesforce
object_type = None

# get list of objects
if not self._target.sobjects:
req = self.request_api("GET", "sobjects")
self._target.sobjects = req.json().get("sobjects", [])
objects_list = self._target.sobjects

# find sobject for record
for object in objects_list:
is_name = object["name"] == self.stream_name
is_label = object["label"] == self.stream_name
is_label_plural = object["labelPlural"] == self.stream_name
if is_name or is_label or is_label_plural:
self.logger.info(f"Processing record for type {self.stream_name}. Using fallback sink.")
object_type = object["name"]
break

if not object_type:
self.logger.info(f"Record doesn't exist on Salesforce {self.stream_name} was not found on Salesforce.")
return {}

# get record fields
try:
self._target.current_sink_name = self.name
if not self._target.current_fields or self.name != self._target.current_sink_name:
self._target.current_fields = self.get_fields_for_object(object_type, objects_list)
fields = self._target.current_fields
except MissingObjectInSalesforceError:
self.logger.info("Skipping record, because it was not found on Salesforce.")
return {}
record["object_type"] = object_type

# Try to find object instance using email
email_fields = ["Email", "npe01__AlternateEmail__c", "npe01__HomeEmail__c", "npe01__Preferred_Email__c", "npe01__WorkEmail__c"]
email_values = [record.get(email_field) for email_field in email_fields if record.get(email_field)]
for email_to_check in email_values:
# Escape special characters on email
for char in ["+", "-"]:
if char in email_to_check:
email_to_check = email_to_check.replace(char, f"\{char}")

query = "".join(["FIND {", email_to_check, "} ", f" IN ALL FIELDS RETURNING {object_type}(id)"])
req = self.request_api("GET", "search/", params={"q": query})

if req.json().get("searchRecords"):
record["Id"] = req.json()["searchRecords"][0]["Id"]
break

return record
except TargetSalesforceQuotaExceededException as e:
return {"error": str(e)}
else:
return {"error": "Unprocessed record due to requests exceeded API rate limits"}

try:
fields = self.get_fields_for_object(object_type)
except MissingObjectInSalesforceError:
self.logger.info("Skipping record, because it was not found on Salesforce.")
return {}
record["object_type"] = object_type

# Try to find object instance using email
email_fields = ["Email", "npe01__AlternateEmail__c", "npe01__HomeEmail__c", "npe01__Preferred_Email__c", "npe01__WorkEmail__c"]
email_values = [record.get(email_field) for email_field in email_fields if record.get(email_field)]
for email_to_check in email_values:
# Escape special characters on email
for char in ["+", "-"]:
if char in email_to_check:
email_to_check = email_to_check.replace(char, f"\{char}")

query = "".join(["FIND {", email_to_check, "} ", f" IN ALL FIELDS RETURNING {object_type}(id)"])
req = self.request_api("GET", "search/", params={"q": query})

if req.json().get("searchRecords"):
record["Id"] = req.json()["searchRecords"][0]["Id"]
break

return record


def upsert_record(self, record, context):
if record == {} or record is None:
return None, False, {}

state_updates = dict()

object_type = record.pop("object_type", None)
self.logger.info(f"Processing record for type {self.stream_name}. Using fallback sink.")

if record == {}:
self.logger.info(f"Processing record for type {self.stream_name} failed. Check logs.")
return

fields_desc = self.sf_fields_description(object_type=object_type)

possible_update_fields = []

for field in fields_desc["external_ids"]:
if field in record:
possible_update_fields.append(field)
# Not process records if target hit API rate limits
if record.get("error") and self._target.hit_rate_limit:
return None, False, record
if record:
state_updates = dict()

if record.get("Id"):
fields = ["Id"]
else:
list_fields = [field_list for field_list in fields_desc.values()]
fields = []
for list_field in list_fields:
for item in list_field:
fields.append(item)
# build the right endpoint
object_type = record.pop("object_type", None)
endpoint = f"sobjects/{object_type}"
self.logger.info(f"Processing record for type {self.stream_name}. Using fallback sink.")

endpoint = f"sobjects/{object_type}"
# check if all payload fields exist in salesforce
fields = self._target.current_fields
fields_desc = self.sf_fields_description(object_type, fields.values())

for field in record.keys():
if field not in fields:
self.logger.info(f"Field {field} doesn't exist on Salesforce.")
possible_update_fields = [field for field in fields_desc["external_ids"] if field in record]

for field in record.keys():
if field not in fields.keys():
self.logger.info(f"Field {field} doesn't exist on Salesforce.")

missing_fields = list(set(fields) - set(record.keys()))
missing_fields = list(set(fields) - set(record.keys()))

if len(missing_fields) > 0.5 * len(fields):
self.logger.info(f"This record may require more fields to be mapped. Missing fields: {missing_fields}")
if len(missing_fields) > 0.5 * len(fields):
self.logger.info(f"This record may require more fields to be mapped. Missing fields: {missing_fields}")

if record.get("Id") or record.get("id"):
object_id = record.pop("Id") or record.pop("id")
url = "/".join([endpoint, object_id])
try:
if record.get("Id") or record.get("id"):
object_id = record.pop("Id") or record.pop("id")
url = "/".join([endpoint, object_id])
response = self.request_api("PATCH", endpoint=url, request_data=record)
if response.status_code == 204:
self.logger.info(f"{object_type} updated with id: {object_id}")
return object_id, True, state_updates

id = response.json().get("id")
self.logger.info(f"{object_type} updated with id: {id}")
return id, True, state_updates
except Exception as e:
self.logger.exception(f"Error encountered while updating {object_type}")

if len(possible_update_fields) > 0:
for id_field in possible_update_fields:
try:
url = "/".join([endpoint, id_field, record.get(id_field)])
response = self.request_api("PATCH", endpoint=url, request_data={k: record[k] for k in set(list(record.keys())) - set([id_field])})
else:
id = response.json().get("id")
self.logger.info(f"{object_type} updated with id: {id}")
return id, True, state_updates
except Exception as e:
self.logger.exception(f"Could not PATCH to {url}: {e}")

try:
if len(possible_update_fields) > 0:
for id_field in possible_update_fields:
try:
url = "/".join([endpoint, id_field, record.get(id_field)])
response = self.request_api("PATCH", endpoint=url, request_data={k: record[k] for k in set(list(record.keys())) - set([id_field])})
id = response.json().get("id")
self.logger.info(f"{object_type} updated with id: {id}")
return id, True, state_updates
except Exception as e:
self.logger.exception(f"Could not PATCH to {url}: {e}")

if len(possible_update_fields) > 0:
self.logger.info("Failed to find updatable entity, trying to create it.")

response = self.request_api("POST", endpoint=endpoint, request_data=record)
id = response.json().get("id")
self.logger.info(f"{object_type} created with id: {id}")
return id, True, state_updates
except Exception as e:
self.logger.exception(f"Error encountered while creating {object_type}")
raise e

return None, False, {}
6 changes: 6 additions & 0 deletions target_salesforce_v3/target.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from singer_sdk import typing as th
from target_hotglue.target import TargetHotglue
import copy

from target_salesforce_v3.sinks import (
FallbackSink,
Expand Down Expand Up @@ -34,6 +35,11 @@ class TargetSalesforceV3(TargetHotglue):
name = "target-salesforce-v3"
MAX_PARALLELISM = 10
SINK_TYPES = SINK_TYPES
sobjects = {}
current_sink_name = None
current_fields = {}
hit_rate_limit = False

def get_sink_class(self, stream_name: str):
"""Get sink for a stream."""
for sink_class in SINK_TYPES:
Expand Down

0 comments on commit 3316167

Please sign in to comment.