diff --git a/plsc_flat.py b/plsc_flat.py index b9d7fc1..e139520 100755 --- a/plsc_flat.py +++ b/plsc_flat.py @@ -10,6 +10,7 @@ import util from sldap import SLdap +from util import escape_dn_chars # import ipdb # ipdb.set_trace() @@ -39,17 +40,17 @@ def create(src, dst): for co_id in cos: logging.debug(f"- co: {co_id}") - src_dn = src.rfind(f"o={co_id},dc=ordered,dc={service}", '(ObjectClass=organization)') - src_co = src_dn.get(f"o={co_id},dc=ordered,dc={service},{src.basedn}", {}) + src_dn = src.rfind(f"o={co_id},dc=ordered,dc={escape_dn_chars(service)}", '(ObjectClass=organization)') + src_co = src_dn.get(f"o={co_id},dc=ordered,dc={escape_dn_chars(service)},{src.basedn}", {}) src_mail = src_co.get('mail', []) logging.debug(f"src_mail: {src_mail}") - co_dn = f"dc=flat,dc={service},{dst.basedn}" + co_dn = f"dc=flat,dc={escape_dn_chars(service)},{dst.basedn}" # Create flat dn if it doesn't exist - flat_dns = dst.rfind(f"dc={service}", "(&(objectClass=dcObject)(dc=flat))") + flat_dns = dst.rfind(f"dc={escape_dn_chars(service)}", "(&(objectClass=dcObject)(dc=flat))") if len(flat_dns) == 0: - flat_dn = f"dc=flat,dc={service},{dst.basedn}" + flat_dn = f"dc=flat,dc={escape_dn_chars(service)},{dst.basedn}" flat_entry = {'objectClass': ['dcObject', 'organizationalUnit'], 'dc': ['flat'], 'ou': ['flat']} dst.add(flat_dn, flat_entry) for ou in ['Groups', 'People']: @@ -58,7 +59,7 @@ def create(src, dst): dst.add(ou_dn, ou_entry) logging.debug(" - People") - src_dns = src.rfind(f"ou=People,o={co_id},dc=ordered,dc={service}", '(ObjectClass=person)') + src_dns = src.rfind(f"ou=People,o={co_id},dc=ordered,dc={escape_dn_chars(service)}", '(ObjectClass=person)') for src_dn, src_entry in src_dns.items(): logging.debug(" - srcdn: {}".format(src_dn)) @@ -73,7 +74,10 @@ def create(src, dst): dst_entries[dst_dn] = src_entry logging.debug(" - Groups") - grp_dns = src.rfind(f"ou=Groups,o={co_id},dc=ordered,dc={service}", '(objectClass=groupOfMembers)') + grp_dns = src.rfind( + f"ou=Groups,o={co_id},dc=ordered,dc={escape_dn_chars(service)}", + '(objectClass=groupOfMembers)' + ) for grp_dn, grp_entry in grp_dns.items(): logging.debug(" - group_dn: {}".format(grp_dn)) @@ -126,21 +130,21 @@ def cleanup(src, dst): logging.debug("service: {}".format(service)) logging.debug(" - People") - dst_dns = dst.rfind(f"ou=People,dc=flat,dc={service}", "(objectClass=person)") + dst_dns = dst.rfind(f"ou=People,dc=flat,dc={escape_dn_chars(service)}", "(objectClass=person)") for dst_dn, dst_entry in dst_dns.items(): #logging.debug(" - dstdn: {}".format(dst_dn)) #logging.debug(" entry: {}".format(dst_entry)) if dst_entry.get('uid', None): src_uid = dst_entry['uid'][0] - src_dns = src.rfind(f"dc=ordered,dc={service}", f"(uid={src_uid})") + src_dns = src.rfind(f"dc=ordered,dc={escape_dn_chars(service)}", f"(uid={src_uid})") if len(src_dns) == 0: logging.debug(" - dstdn: {}".format(dst_dn)) logging.debug(" srcdn not found, deleting {}".format(dst_dn)) dst.delete(dst_dn) logging.debug(" - Groups") - dst_dns = dst.rfind(f"ou=Groups,dc=flat,dc={service}", "(objectClass=groupOfMembers)") + dst_dns = dst.rfind(f"ou=Groups,dc=flat,dc={escape_dn_chars(service)}", "(objectClass=groupOfMembers)") for dst_dn, dst_entry in dst_dns.items(): #logging.debug(" - dstdn: {}".format(dst_dn)) #logging.debug(" entry: {}".format(dst_entry)) @@ -154,7 +158,7 @@ def cleanup(src, dst): # If not, remove this object. logging.debug(f"CHECKING CO : {org}.{co}...") src_dns = src.rfind( - f"dc=ordered,dc={service}", + f"dc=ordered,dc={escape_dn_chars(service)}", f"(&(objectClass=organization)(o={org}.{co}))") if len(src_dns) == 0: logging.debug(" - dstdn: {}".format(dst_dn)) @@ -163,7 +167,7 @@ def cleanup(src, dst): else: # Verify that group still valid group within referenced CO... src_dns = src.rfind( - f"o={org}.{co},dc=ordered,dc={service}", + f"o={org}.{co},dc=ordered,dc={escape_dn_chars(service)}", f"(&(objectClass=groupOfMembers)(cn={src_cn}))") #if len(src_dns): # for src_dn, src_entry in src_dns.items(): diff --git a/plsc_ordered.py b/plsc_ordered.py index 5511fd3..b1471c3 100755 --- a/plsc_ordered.py +++ b/plsc_ordered.py @@ -14,6 +14,7 @@ from sbs import SBS from typing import Tuple, List, Dict, Union, Optional +from util import escape_dn_chars logging.basicConfig(level=os.environ.get("LOGLEVEL", "INFO")) @@ -106,7 +107,7 @@ def create(src, dst): logging.debug("service: {}".format(service)) # check if service exists and create it if necessary - service_dn = f"dc={service},{dst.basedn}" + service_dn = f"dc={escape_dn_chars(service)},{dst.basedn}" admin_dn = 'cn=admin,' + service_dn # find existing services @@ -180,9 +181,9 @@ def create(src, dst): dst.modify(admin_dn, list(current_admin.values())[0], new_admin) # check if dc=ordered subtree exists and create it if necessary - ordered_dns = dst.rfind(f"dc={service}", "(&(objectClass=dcObject)(dc=ordered))") + ordered_dns = dst.rfind(f"dc={escape_dn_chars(service)}", "(&(objectClass=dcObject)(dc=ordered))") if len(ordered_dns) == 0: - ordered_dn = f"dc=ordered,dc={service},{dst.basedn}" + ordered_dn = f"dc=ordered,dc={escape_dn_chars(service)},{dst.basedn}" ordered_entry = {'objectClass': ['dcObject', 'organizationalUnit'], 'dc': ['ordered'], 'ou': ['ordered']} dst.store(ordered_dn, ordered_entry) @@ -202,7 +203,7 @@ def create(src, dst): scope=co['organisation']['short_name']) # Create CO if necessary - co_dn = f"o={co_identifier},dc=ordered,dc={service},{dst.basedn}" + co_dn = f"o={co_identifier},dc=ordered,dc={escape_dn_chars(service)},{dst.basedn}" co_entry = { 'objectClass': ['top', 'organization', 'extensibleObject'], 'o': [co_identifier], @@ -225,7 +226,10 @@ def create(src, dst): scope=co['organisation']['short_name']) co_entry['mail'] = list(set(admin.get('email') for admin in co.get('admins'))) - co_dns = dst.rfind(f"dc=ordered,dc={service}", f"(&(objectClass=organization)(o={co_identifier}))") + co_dns = dst.rfind( + f"dc=ordered,dc={escape_dn_chars(service)}", + f"(&(objectClass=organization)(o={co_identifier}))" + ) if len(co_dns) == 0: dst.add(co_dn, co_entry) for ou in ['Groups', 'People']: @@ -261,8 +265,11 @@ def create(src, dst): logging.debug(" - grp: {}/{}".format(group['id'], grp_urn)) vc[service][co_identifier].setdefault('groups', []).append(grp_name) - grp_dn = f"cn={grp_name},ou=Groups,o={co_identifier},dc=ordered,dc={service},{dst.basedn}" - grp_dns = dst.rfind(f"ou=Groups,o={co_identifier},dc=ordered,dc={service}", + grp_dn = f"cn={grp_name},ou=Groups,"\ + f"o={co_identifier},dc=ordered,"\ + f"dc={escape_dn_chars(service)},{dst.basedn}" + + grp_dns = dst.rfind(f"ou=Groups,o={co_identifier},dc=ordered,dc={escape_dn_chars(service)}", f"(&(objectClass=groupOfMembers)(cn={grp_name}))") if len(grp_dns) == 1: old_dn, old_entry = list(grp_dns.items())[0] @@ -274,7 +281,8 @@ def create(src, dst): raise Exception(f"Found multiple groups for dn={grp_dn}") #if not gidNumber: - # gidNumber = dst.get_sequence(f"cn=gidNumberSequence,ou=Sequence,dc={service},{dst.basedn}") + # gidNumber = dst.get_sequence(f"cn=gidNumberSequence,"\ + # f"ou=Sequence,dc={escape_dn_chars(service)},{dst.basedn}") # Here's the magic: Build the new group entry grp_entry = { @@ -306,7 +314,9 @@ def create(src, dst): # convert the BS data to an LDAP record dst_rdn, dst_entry = sbs2ldap_record(src_uid, src_user) - dst_dn = f"{dst_rdn},ou=People,o={co_identifier},dc=ordered,dc={service},{dst.basedn}" + dst_dn = f"{dst_rdn},ou=People,"\ + f"o={co_identifier},dc=ordered,"\ + f"dc={escape_dn_chars(service)},{dst.basedn}" # Pivotal #181218689 accepted_aups = src_user.get("accepted_aups", []) @@ -353,10 +363,15 @@ def create(src, dst): #vc[service][co_identifier].setdefault('groups', []).append(grp_urn) vc[service][co_identifier].setdefault('groups', []).append(grp_name) - grp_dn = f"cn={grp_name},ou=Groups,o={co_identifier},dc=ordered,dc={service},{dst.basedn}" - grp_dns = dst.rfind(f"ou=Groups,o={co_identifier},dc=ordered,dc={service}", + grp_dn = f"cn={grp_name},ou=Groups,"\ + f"o={co_identifier},dc=ordered,"\ + f"dc={escape_dn_chars(service)},{dst.basedn}" + + grp_dns = dst.rfind(f"ou=Groups,o={co_identifier},dc=ordered,dc={escape_dn_chars(service)}", f"(&(objectClass=groupOfMembers)(cn={grp_name}))") + logging.info(grp_dns) + # ipdb.set_trace() if len(grp_dns) == 1: old_dn, old_entry = list(grp_dns.items())[0] @@ -384,7 +399,7 @@ def create(src, dst): # TODO: Why are we always updating? Shouldn't this be conditional on an actual change happening? ldif = dst.store(grp_dn, grp_entry) - logging.debug(" - store: {}".format(ldif)) + logging.info(" - store: {}".format(ldif)) if details['enabled']: logging.debug(" - Group all") @@ -395,7 +410,9 @@ def create(src, dst): logging.debug(" - grp: {}".format(grp_name)) vc[service][co_identifier].setdefault('groups', []).append(grp_name) - grp_dn = f"cn={grp_name},ou=Groups,o={co_identifier},dc=ordered,dc={service},{dst.basedn}" + grp_dn = f"cn={grp_name},ou=Groups,"\ + f"o={co_identifier},dc=ordered,"\ + f"dc={escape_dn_chars(service)},{dst.basedn}" members = [] for src_id, src_detail in users.items(): @@ -408,12 +425,18 @@ def create(src, dst): logging.debug(f"User {dst_rdn} is not participating @ALL group because of expiration !") continue - dst_dn = f"{dst_rdn},ou=People,o={co_identifier},dc=ordered,dc={service},{dst.basedn}" + dst_dn = f"{dst_rdn},ou=People,"\ + f"o={co_identifier},"\ + f"dc=ordered,dc={escape_dn_chars(service)},"\ + f"{dst.basedn}" + members.append(dst_dn) vc[service][co_identifier]['roles'].setdefault(grp_id, []).append(dst_dn) #if not gidNumber: - # gidNumber = dst.get_sequence(f"cn=gidNumberSequence,ou=Sequence,dc={service},{dst.basedn}") + # gidNumber = dst.get_sequence( + # f"cn=gidNumberSequence,ou=Sequence,dc={escape_dn_chars(service)},{dst.basedn}" + # ) # Here's the magic: Build the new group entry grp_entry = { @@ -454,10 +477,10 @@ def cleanup(dst): logging.debug(f"service: {service}") if vc.get(service, None) is None: logging.debug(f"- {service} not found in our services, cleaning up") - dst.rdelete(f"{service_dn}") + dst.rdelete(f"{escape_dn_chars(service_dn)}") continue - organizations = dst.rfind(f"dc=ordered,dc={service}", + organizations = dst.rfind(f"dc=ordered,dc={escape_dn_chars(service)}", '(&(objectClass=organization)(objectClass=extensibleObject))') for o_dn, o_entry in organizations.items(): if o_entry.get('o'): @@ -472,11 +495,14 @@ def cleanup(dst): dst.rdelete(o_dn) continue - logging.debug(" - People") + logging.info(" - People") src_members = vc.get(dc, {}).get(co, {}).get('members', []) - dst_dns = dst.rfind("ou=people,o={},dc=ordered,dc={}".format(co, service), '(objectClass=person)') + dst_dns = dst.rfind( + "ou=people,o={},dc=ordered,dc={}".format(co, escape_dn_chars(service)), + '(objectClass=person)' + ) for dst_dn, dst_entry in dst_dns.items(): - logging.debug(" - dest_dn: {}".format(dst_dn)) + logging.info(" - dest_dn: {}".format(dst_dn)) if dst_entry.get('eduPersonUniqueId', None): dst_uid = dst_entry['eduPersonUniqueId'][0] if dst_uid not in src_members: @@ -487,8 +513,8 @@ def cleanup(dst): if dst_dn not in registered_users: dst.delete(dst_dn) - logging.debug(" - Groups") - dst_dns = dst.rfind("ou=Groups,o={},dc=ordered,dc={}".format(co, service), + logging.info(" - Groups") + dst_dns = dst.rfind("ou=Groups,o={},dc=ordered,dc={}".format(co, escape_dn_chars(service)), '(objectClass=groupOfMembers)') for dst_dn, dst_entry in dst_dns.items(): grp_name = dst_entry['cn'][0] @@ -498,7 +524,7 @@ def cleanup(dst): continue #grp_urn = dst_entry['labeledURI'][0] - logging.debug(" - dest_dn: {}".format(dst_dn)) + logging.info(" - dest_dn: {}".format(dst_dn)) # TODO: rework this to use the short_name uri-like cn attribute instead of the sbs id src_id = dst_entry.get('uniqueIdentifier') if src_id is None: @@ -515,7 +541,7 @@ def cleanup(dst): for dst_member in dst_members: dst_rdn = util.dn2rdns(dst_member)["uid"][0] #dst_rdn = util.dn2rdns(dst_member)['cn'][0] - logging.debug(" - dst_member: {}".format(dst_rdn)) + logging.info(" - dst_member: {}".format(dst_rdn)) if dst_member not in src_members: logging.debug(" dst_member not found, deleting {}".format(dst_rdn)) dst_members.remove(dst_member) diff --git a/tests/test_all.py b/tests/test_all.py index 7ec50d4..9610717 100644 --- a/tests/test_all.py +++ b/tests/test_all.py @@ -1,6 +1,7 @@ import logging from tests.base_test import BaseTest +from util import escape_dn_chars logger = logging.getLogger(__name__) @@ -165,7 +166,9 @@ def group_name_flat(g): if s['ldap_enabled']: check_ldap( - f"o={org_sname}.{c['short_name']},dc=ordered,dc={s['entity_id']},{self.dst_conf['basedn']}", + f"o={org_sname}.{c['short_name']}," + f"dc=ordered,dc={escape_dn_chars(s['entity_id'])}," + f"{self.dst_conf['basedn']}", detail['collaboration_memberships'], detail['groups'], group_name_ordered, @@ -176,31 +179,35 @@ def group_name_flat(g): ) check_ldap( - f"dc=flat,dc={s['entity_id']},{self.dst_conf['basedn']}", + f"dc=flat,dc={escape_dn_chars(s['entity_id'])}," + f"{self.dst_conf['basedn']}", detail['collaboration_memberships'], detail['groups'], group_name_flat ) - elif object_count(f"dc={s['entity_id']},{self.dst_conf['basedn']}") > 0: + elif object_count(f"dc={escape_dn_chars(s['entity_id'])},{self.dst_conf['basedn']}") > 0: # in case the service 'exists' in LDAP but is not enabled, make sure # people and group are 'empty' check_ldap( - f"o={org_sname}.{c['short_name']},dc=ordered,dc={s['entity_id']},{self.dst_conf['basedn']}", + f"o={org_sname}.{c['short_name']}," + f"dc=ordered,dc={escape_dn_chars(s['entity_id'])}," + f"{self.dst_conf['basedn']}", [], [], group_name_ordered ) check_ldap( - f"dc=flat,dc={s['entity_id']},{self.dst_conf['basedn']}", + f"dc=flat,dc={escape_dn_chars(s['entity_id'])}," + f"{self.dst_conf['basedn']}", [], [], group_name_flat ) - if object_count(f"dc={s['entity_id']},{self.dst_conf['basedn']}") > 0: + if object_count(f"dc={escape_dn_chars(s['entity_id'])},{self.dst_conf['basedn']}") > 0: logger.info(f"*** Checking Admin account: {s['entity_id']}") self.assertTrue('ldap_password' in s) - admin_object = check_object(f"cn=admin,dc={s['entity_id']}," + admin_object = check_object(f"cn=admin,dc={escape_dn_chars(s['entity_id'])}," f"{self.dst_conf['basedn']}", expected_count=1) ldap_password = s['ldap_password'] if ldap_password: diff --git a/util.py b/util.py index 1724e62..5b44615 100644 --- a/util.py +++ b/util.py @@ -1,6 +1,10 @@ import json import ldap +import logging + +logger = logging.getLogger() + def make_secret(password): import passlib.hash @@ -8,6 +12,26 @@ def make_secret(password): return '{SSHA}' + crypted.decode('ascii') +def escape_dn_chars(s): + """ + Escape dn characters to prevent injection according to RFC 4514. + Refer: https://ldapwiki.com/wiki/Wiki.jsp?page=DN%20Escape%20Values + """ + + s = s.replace('\\', r'\5C') + s = s.replace(r',', r'\2C') + s = s.replace(r'#', r'\23') + s = s.replace(r'+', r'\2B') + s = s.replace(r'<', r'\3C') + s = s.replace(r'>', r'\3E') + s = s.replace(r';', r'\3B') + s = s.replace(r'"', r'\22') + s = s.replace(r'=', r'\3D') + s = s.replace('\x00', r'\00') + + return s + + def dn2rdns(dn): rdns = {} r = ldap.dn.str2dn(dn)