Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Harry Kodden committed Nov 9, 2023
1 parent b3ad04b commit 8209281
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 43 deletions.
28 changes: 16 additions & 12 deletions plsc_flat.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import util

from sldap import SLdap
from util import escape_dn_chars

# import ipdb
# ipdb.set_trace()
Expand Down Expand Up @@ -39,17 +40,17 @@ def create(src, dst):
for co_id in cos:
logging.debug(f"- co: {co_id}")

src_dn = src.rfind(f"o={co_id},dc=ordered,dc={service}", '(ObjectClass=organization)')
src_co = src_dn.get(f"o={co_id},dc=ordered,dc={service},{src.basedn}", {})
src_dn = src.rfind(f"o={co_id},dc=ordered,dc={escape_dn_chars(service)}", '(ObjectClass=organization)')
src_co = src_dn.get(f"o={co_id},dc=ordered,dc={escape_dn_chars(service)},{src.basedn}", {})
src_mail = src_co.get('mail', [])
logging.debug(f"src_mail: {src_mail}")

co_dn = f"dc=flat,dc={service},{dst.basedn}"
co_dn = f"dc=flat,dc={escape_dn_chars(service)},{dst.basedn}"

# Create flat dn if it doesn't exist
flat_dns = dst.rfind(f"dc={service}", "(&(objectClass=dcObject)(dc=flat))")
flat_dns = dst.rfind(f"dc={escape_dn_chars(service)}", "(&(objectClass=dcObject)(dc=flat))")
if len(flat_dns) == 0:
flat_dn = f"dc=flat,dc={service},{dst.basedn}"
flat_dn = f"dc=flat,dc={escape_dn_chars(service)},{dst.basedn}"
flat_entry = {'objectClass': ['dcObject', 'organizationalUnit'], 'dc': ['flat'], 'ou': ['flat']}
dst.add(flat_dn, flat_entry)
for ou in ['Groups', 'People']:
Expand All @@ -58,7 +59,7 @@ def create(src, dst):
dst.add(ou_dn, ou_entry)

logging.debug(" - People")
src_dns = src.rfind(f"ou=People,o={co_id},dc=ordered,dc={service}", '(ObjectClass=person)')
src_dns = src.rfind(f"ou=People,o={co_id},dc=ordered,dc={escape_dn_chars(service)}", '(ObjectClass=person)')

for src_dn, src_entry in src_dns.items():
logging.debug(" - srcdn: {}".format(src_dn))
Expand All @@ -73,7 +74,10 @@ def create(src, dst):
dst_entries[dst_dn] = src_entry

logging.debug(" - Groups")
grp_dns = src.rfind(f"ou=Groups,o={co_id},dc=ordered,dc={service}", '(objectClass=groupOfMembers)')
grp_dns = src.rfind(
f"ou=Groups,o={co_id},dc=ordered,dc={escape_dn_chars(service)}",
'(objectClass=groupOfMembers)'
)

for grp_dn, grp_entry in grp_dns.items():
logging.debug(" - group_dn: {}".format(grp_dn))
Expand Down Expand Up @@ -126,21 +130,21 @@ def cleanup(src, dst):
logging.debug("service: {}".format(service))

logging.debug(" - People")
dst_dns = dst.rfind(f"ou=People,dc=flat,dc={service}", "(objectClass=person)")
dst_dns = dst.rfind(f"ou=People,dc=flat,dc={escape_dn_chars(service)}", "(objectClass=person)")
for dst_dn, dst_entry in dst_dns.items():
#logging.debug(" - dstdn: {}".format(dst_dn))
#logging.debug(" entry: {}".format(dst_entry))

if dst_entry.get('uid', None):
src_uid = dst_entry['uid'][0]
src_dns = src.rfind(f"dc=ordered,dc={service}", f"(uid={src_uid})")
src_dns = src.rfind(f"dc=ordered,dc={escape_dn_chars(service)}", f"(uid={src_uid})")
if len(src_dns) == 0:
logging.debug(" - dstdn: {}".format(dst_dn))
logging.debug(" srcdn not found, deleting {}".format(dst_dn))
dst.delete(dst_dn)

logging.debug(" - Groups")
dst_dns = dst.rfind(f"ou=Groups,dc=flat,dc={service}", "(objectClass=groupOfMembers)")
dst_dns = dst.rfind(f"ou=Groups,dc=flat,dc={escape_dn_chars(service)}", "(objectClass=groupOfMembers)")
for dst_dn, dst_entry in dst_dns.items():
#logging.debug(" - dstdn: {}".format(dst_dn))
#logging.debug(" entry: {}".format(dst_entry))
Expand All @@ -154,7 +158,7 @@ def cleanup(src, dst):
# If not, remove this object.
logging.debug(f"CHECKING CO : {org}.{co}...")
src_dns = src.rfind(
f"dc=ordered,dc={service}",
f"dc=ordered,dc={escape_dn_chars(service)}",
f"(&(objectClass=organization)(o={org}.{co}))")
if len(src_dns) == 0:
logging.debug(" - dstdn: {}".format(dst_dn))
Expand All @@ -163,7 +167,7 @@ def cleanup(src, dst):
else:
# Verify that group still valid group within referenced CO...
src_dns = src.rfind(
f"o={org}.{co},dc=ordered,dc={service}",
f"o={org}.{co},dc=ordered,dc={escape_dn_chars(service)}",
f"(&(objectClass=groupOfMembers)(cn={src_cn}))")
#if len(src_dns):
# for src_dn, src_entry in src_dns.items():
Expand Down
74 changes: 50 additions & 24 deletions plsc_ordered.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from sbs import SBS

from typing import Tuple, List, Dict, Union, Optional
from util import escape_dn_chars

logging.basicConfig(level=os.environ.get("LOGLEVEL", "INFO"))

Expand Down Expand Up @@ -106,7 +107,7 @@ def create(src, dst):
logging.debug("service: {}".format(service))

# check if service exists and create it if necessary
service_dn = f"dc={service},{dst.basedn}"
service_dn = f"dc={escape_dn_chars(service)},{dst.basedn}"
admin_dn = 'cn=admin,' + service_dn

# find existing services
Expand Down Expand Up @@ -180,9 +181,9 @@ def create(src, dst):
dst.modify(admin_dn, list(current_admin.values())[0], new_admin)

# check if dc=ordered subtree exists and create it if necessary
ordered_dns = dst.rfind(f"dc={service}", "(&(objectClass=dcObject)(dc=ordered))")
ordered_dns = dst.rfind(f"dc={escape_dn_chars(service)}", "(&(objectClass=dcObject)(dc=ordered))")
if len(ordered_dns) == 0:
ordered_dn = f"dc=ordered,dc={service},{dst.basedn}"
ordered_dn = f"dc=ordered,dc={escape_dn_chars(service)},{dst.basedn}"
ordered_entry = {'objectClass': ['dcObject', 'organizationalUnit'], 'dc': ['ordered'], 'ou': ['ordered']}
dst.store(ordered_dn, ordered_entry)

Expand All @@ -202,7 +203,7 @@ def create(src, dst):
scope=co['organisation']['short_name'])

# Create CO if necessary
co_dn = f"o={co_identifier},dc=ordered,dc={service},{dst.basedn}"
co_dn = f"o={co_identifier},dc=ordered,dc={escape_dn_chars(service)},{dst.basedn}"
co_entry = {
'objectClass': ['top', 'organization', 'extensibleObject'],
'o': [co_identifier],
Expand All @@ -225,7 +226,10 @@ def create(src, dst):
scope=co['organisation']['short_name'])
co_entry['mail'] = list(set(admin.get('email') for admin in co.get('admins')))

co_dns = dst.rfind(f"dc=ordered,dc={service}", f"(&(objectClass=organization)(o={co_identifier}))")
co_dns = dst.rfind(
f"dc=ordered,dc={escape_dn_chars(service)}",
f"(&(objectClass=organization)(o={co_identifier}))"
)
if len(co_dns) == 0:
dst.add(co_dn, co_entry)
for ou in ['Groups', 'People']:
Expand Down Expand Up @@ -261,8 +265,11 @@ def create(src, dst):
logging.debug(" - grp: {}/{}".format(group['id'], grp_urn))
vc[service][co_identifier].setdefault('groups', []).append(grp_name)

grp_dn = f"cn={grp_name},ou=Groups,o={co_identifier},dc=ordered,dc={service},{dst.basedn}"
grp_dns = dst.rfind(f"ou=Groups,o={co_identifier},dc=ordered,dc={service}",
grp_dn = f"cn={grp_name},ou=Groups,"\
f"o={co_identifier},dc=ordered,"\
f"dc={escape_dn_chars(service)},{dst.basedn}"

grp_dns = dst.rfind(f"ou=Groups,o={co_identifier},dc=ordered,dc={escape_dn_chars(service)}",
f"(&(objectClass=groupOfMembers)(cn={grp_name}))")
if len(grp_dns) == 1:
old_dn, old_entry = list(grp_dns.items())[0]
Expand All @@ -274,7 +281,8 @@ def create(src, dst):
raise Exception(f"Found multiple groups for dn={grp_dn}")

#if not gidNumber:
# gidNumber = dst.get_sequence(f"cn=gidNumberSequence,ou=Sequence,dc={service},{dst.basedn}")
# gidNumber = dst.get_sequence(f"cn=gidNumberSequence,"\
# f"ou=Sequence,dc={escape_dn_chars(service)},{dst.basedn}")

# Here's the magic: Build the new group entry
grp_entry = {
Expand Down Expand Up @@ -306,7 +314,9 @@ def create(src, dst):

# convert the BS data to an LDAP record
dst_rdn, dst_entry = sbs2ldap_record(src_uid, src_user)
dst_dn = f"{dst_rdn},ou=People,o={co_identifier},dc=ordered,dc={service},{dst.basedn}"
dst_dn = f"{dst_rdn},ou=People,"\
f"o={co_identifier},dc=ordered,"\
f"dc={escape_dn_chars(service)},{dst.basedn}"

# Pivotal #181218689
accepted_aups = src_user.get("accepted_aups", [])
Expand Down Expand Up @@ -353,10 +363,15 @@ def create(src, dst):
#vc[service][co_identifier].setdefault('groups', []).append(grp_urn)
vc[service][co_identifier].setdefault('groups', []).append(grp_name)

grp_dn = f"cn={grp_name},ou=Groups,o={co_identifier},dc=ordered,dc={service},{dst.basedn}"
grp_dns = dst.rfind(f"ou=Groups,o={co_identifier},dc=ordered,dc={service}",
grp_dn = f"cn={grp_name},ou=Groups,"\
f"o={co_identifier},dc=ordered,"\
f"dc={escape_dn_chars(service)},{dst.basedn}"

grp_dns = dst.rfind(f"ou=Groups,o={co_identifier},dc=ordered,dc={escape_dn_chars(service)}",
f"(&(objectClass=groupOfMembers)(cn={grp_name}))")

logging.info(grp_dns)

# ipdb.set_trace()
if len(grp_dns) == 1:
old_dn, old_entry = list(grp_dns.items())[0]
Expand Down Expand Up @@ -384,7 +399,7 @@ def create(src, dst):

# TODO: Why are we always updating? Shouldn't this be conditional on an actual change happening?
ldif = dst.store(grp_dn, grp_entry)
logging.debug(" - store: {}".format(ldif))
logging.info(" - store: {}".format(ldif))

if details['enabled']:
logging.debug(" - Group all")
Expand All @@ -395,7 +410,9 @@ def create(src, dst):
logging.debug(" - grp: {}".format(grp_name))
vc[service][co_identifier].setdefault('groups', []).append(grp_name)

grp_dn = f"cn={grp_name},ou=Groups,o={co_identifier},dc=ordered,dc={service},{dst.basedn}"
grp_dn = f"cn={grp_name},ou=Groups,"\
f"o={co_identifier},dc=ordered,"\
f"dc={escape_dn_chars(service)},{dst.basedn}"

members = []
for src_id, src_detail in users.items():
Expand All @@ -408,12 +425,18 @@ def create(src, dst):
logging.debug(f"User {dst_rdn} is not participating @ALL group because of expiration !")
continue

dst_dn = f"{dst_rdn},ou=People,o={co_identifier},dc=ordered,dc={service},{dst.basedn}"
dst_dn = f"{dst_rdn},ou=People,"\
f"o={co_identifier},"\
f"dc=ordered,dc={escape_dn_chars(service)},"\
f"{dst.basedn}"

members.append(dst_dn)
vc[service][co_identifier]['roles'].setdefault(grp_id, []).append(dst_dn)

#if not gidNumber:
# gidNumber = dst.get_sequence(f"cn=gidNumberSequence,ou=Sequence,dc={service},{dst.basedn}")
# gidNumber = dst.get_sequence(
# f"cn=gidNumberSequence,ou=Sequence,dc={escape_dn_chars(service)},{dst.basedn}"
# )

# Here's the magic: Build the new group entry
grp_entry = {
Expand Down Expand Up @@ -454,10 +477,10 @@ def cleanup(dst):
logging.debug(f"service: {service}")
if vc.get(service, None) is None:
logging.debug(f"- {service} not found in our services, cleaning up")
dst.rdelete(f"{service_dn}")
dst.rdelete(f"{escape_dn_chars(service_dn)}")
continue

organizations = dst.rfind(f"dc=ordered,dc={service}",
organizations = dst.rfind(f"dc=ordered,dc={escape_dn_chars(service)}",
'(&(objectClass=organization)(objectClass=extensibleObject))')
for o_dn, o_entry in organizations.items():
if o_entry.get('o'):
Expand All @@ -472,11 +495,14 @@ def cleanup(dst):
dst.rdelete(o_dn)
continue

logging.debug(" - People")
logging.info(" - People")
src_members = vc.get(dc, {}).get(co, {}).get('members', [])
dst_dns = dst.rfind("ou=people,o={},dc=ordered,dc={}".format(co, service), '(objectClass=person)')
dst_dns = dst.rfind(
"ou=people,o={},dc=ordered,dc={}".format(co, escape_dn_chars(service)),
'(objectClass=person)'
)
for dst_dn, dst_entry in dst_dns.items():
logging.debug(" - dest_dn: {}".format(dst_dn))
logging.info(" - dest_dn: {}".format(dst_dn))
if dst_entry.get('eduPersonUniqueId', None):
dst_uid = dst_entry['eduPersonUniqueId'][0]
if dst_uid not in src_members:
Expand All @@ -487,8 +513,8 @@ def cleanup(dst):
if dst_dn not in registered_users:
dst.delete(dst_dn)

logging.debug(" - Groups")
dst_dns = dst.rfind("ou=Groups,o={},dc=ordered,dc={}".format(co, service),
logging.info(" - Groups")
dst_dns = dst.rfind("ou=Groups,o={},dc=ordered,dc={}".format(co, escape_dn_chars(service)),
'(objectClass=groupOfMembers)')
for dst_dn, dst_entry in dst_dns.items():
grp_name = dst_entry['cn'][0]
Expand All @@ -498,7 +524,7 @@ def cleanup(dst):
continue

#grp_urn = dst_entry['labeledURI'][0]
logging.debug(" - dest_dn: {}".format(dst_dn))
logging.info(" - dest_dn: {}".format(dst_dn))
# TODO: rework this to use the short_name uri-like cn attribute instead of the sbs id
src_id = dst_entry.get('uniqueIdentifier')
if src_id is None:
Expand All @@ -515,7 +541,7 @@ def cleanup(dst):
for dst_member in dst_members:
dst_rdn = util.dn2rdns(dst_member)["uid"][0]
#dst_rdn = util.dn2rdns(dst_member)['cn'][0]
logging.debug(" - dst_member: {}".format(dst_rdn))
logging.info(" - dst_member: {}".format(dst_rdn))
if dst_member not in src_members:
logging.debug(" dst_member not found, deleting {}".format(dst_rdn))
dst_members.remove(dst_member)
Expand Down
21 changes: 14 additions & 7 deletions tests/test_all.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging

from tests.base_test import BaseTest
from util import escape_dn_chars

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -165,7 +166,9 @@ def group_name_flat(g):

if s['ldap_enabled']:
check_ldap(
f"o={org_sname}.{c['short_name']},dc=ordered,dc={s['entity_id']},{self.dst_conf['basedn']}",
f"o={org_sname}.{c['short_name']},"
f"dc=ordered,dc={escape_dn_chars(s['entity_id'])},"
f"{self.dst_conf['basedn']}",
detail['collaboration_memberships'],
detail['groups'],
group_name_ordered,
Expand All @@ -176,31 +179,35 @@ def group_name_flat(g):
)

check_ldap(
f"dc=flat,dc={s['entity_id']},{self.dst_conf['basedn']}",
f"dc=flat,dc={escape_dn_chars(s['entity_id'])},"
f"{self.dst_conf['basedn']}",
detail['collaboration_memberships'],
detail['groups'],
group_name_flat
)
elif object_count(f"dc={s['entity_id']},{self.dst_conf['basedn']}") > 0:
elif object_count(f"dc={escape_dn_chars(s['entity_id'])},{self.dst_conf['basedn']}") > 0:
# in case the service 'exists' in LDAP but is not enabled, make sure
# people and group are 'empty'
check_ldap(
f"o={org_sname}.{c['short_name']},dc=ordered,dc={s['entity_id']},{self.dst_conf['basedn']}",
f"o={org_sname}.{c['short_name']},"
f"dc=ordered,dc={escape_dn_chars(s['entity_id'])},"
f"{self.dst_conf['basedn']}",
[],
[],
group_name_ordered
)
check_ldap(
f"dc=flat,dc={s['entity_id']},{self.dst_conf['basedn']}",
f"dc=flat,dc={escape_dn_chars(s['entity_id'])},"
f"{self.dst_conf['basedn']}",
[],
[],
group_name_flat
)

if object_count(f"dc={s['entity_id']},{self.dst_conf['basedn']}") > 0:
if object_count(f"dc={escape_dn_chars(s['entity_id'])},{self.dst_conf['basedn']}") > 0:
logger.info(f"*** Checking Admin account: {s['entity_id']}")
self.assertTrue('ldap_password' in s)
admin_object = check_object(f"cn=admin,dc={s['entity_id']},"
admin_object = check_object(f"cn=admin,dc={escape_dn_chars(s['entity_id'])},"
f"{self.dst_conf['basedn']}", expected_count=1)
ldap_password = s['ldap_password']
if ldap_password:
Expand Down
24 changes: 24 additions & 0 deletions util.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,37 @@
import json
import ldap

import logging

logger = logging.getLogger()


def make_secret(password):
import passlib.hash
crypted = passlib.hash.sha512_crypt.hash(password)
return '{SSHA}' + crypted.decode('ascii')


def escape_dn_chars(s):
"""
Escape dn characters to prevent injection according to RFC 4514.
Refer: https://ldapwiki.com/wiki/Wiki.jsp?page=DN%20Escape%20Values
"""

s = s.replace('\\', r'\5C')
s = s.replace(r',', r'\2C')
s = s.replace(r'#', r'\23')
s = s.replace(r'+', r'\2B')
s = s.replace(r'<', r'\3C')
s = s.replace(r'>', r'\3E')
s = s.replace(r';', r'\3B')
s = s.replace(r'"', r'\22')
s = s.replace(r'=', r'\3D')
s = s.replace('\x00', r'\00')

return s


def dn2rdns(dn):
rdns = {}
r = ldap.dn.str2dn(dn)
Expand Down

0 comments on commit 8209281

Please sign in to comment.