Skip to content

Commit

Permalink
add return_string to get_attributes_for_oid for overload
Browse files Browse the repository at this point in the history
  • Loading branch information
prauscher committed Dec 5, 2024
1 parent c87c489 commit c4a2a5f
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 4 deletions.
16 changes: 14 additions & 2 deletions docs/x509/reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1574,10 +1574,16 @@ X.509 CSR (Certificate Signing Request) Builder Object
>>> x509.Name.from_rfc4514_string("[email protected]", {"E": NameOID.EMAIL_ADDRESS})
<Name([email protected])>

.. method:: get_attributes_for_oid(oid)
.. method:: get_attributes_for_oid(oid, *, return_string=False)

:param oid: An :class:`ObjectIdentifier` instance.

:param return_string: Set to True to point static type checkers that
resulting :class:`NameAttributes` will only contain strings.

:raise TypeError: If `return_string` is set to True but OID is
``NameOID.X500_UNIQUE_IDENTIFIER``.

:returns: A list of :class:`NameAttribute` instances that match the
OID provided. If nothing matches an empty list will be returned.

Expand Down Expand Up @@ -1702,10 +1708,16 @@ X.509 CSR (Certificate Signing Request) Builder Object
object is iterable to get every attribute, preserving the original order.
Passing duplicate attributes to the constructor raises ``ValueError``.

.. method:: get_attributes_for_oid(oid)
.. method:: get_attributes_for_oid(oid, *, return_string=False)

:param oid: An :class:`ObjectIdentifier` instance.

:param return_string: Set to True to point static type checkers that
resulting :class:`NameAttributes` will only contain strings.

:raise TypeError: If `return_string` is set to True but OID is
``NameOID.X500_UNIQUE_IDENTIFIER``.

:returns: A list of :class:`NameAttribute` instances that match the OID
provided. The list should contain zero or one values.

Expand Down
42 changes: 40 additions & 2 deletions src/cryptography/x509/name.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,10 +238,29 @@ def __init__(self, attributes: typing.Iterable[NameAttribute]):
if len(self._attribute_set) != len(attributes):
raise ValueError("duplicate attributes are not allowed")

@typing.overload
def get_attributes_for_oid(
self,
oid: ObjectIdentifier,
*,
return_string: typing.Literal[False] = False,
) -> list[NameAttribute[str | bytes]]: ...

@typing.overload
def get_attributes_for_oid(
self, oid: ObjectIdentifier, *, return_string: typing.Literal[True]
) -> list[NameAttribute[str]]: ...

def get_attributes_for_oid(
self,
oid: ObjectIdentifier,
) -> list[NameAttribute[str | bytes]]:
*,
return_string: bool = False,
) -> list[NameAttribute[str | bytes]] | list[NameAttribute[str]]:
if return_string is True and oid == NameOID.X500_UNIQUE_IDENTIFIER:
raise TypeError(
"oid must not be X500_UNIQUE_IDENTIFIER with return_string=True."
)
return [i for i in self if i.oid == oid]

def rfc4514_string(
Expand Down Expand Up @@ -332,10 +351,29 @@ def rfc4514_string(
for attr in reversed(self._attributes)
)

@typing.overload
def get_attributes_for_oid(
self,
oid: ObjectIdentifier,
*,
return_string: typing.Literal[False] = False,
) -> list[NameAttribute[str | bytes]]: ...

@typing.overload
def get_attributes_for_oid(
self, oid: ObjectIdentifier, *, return_string: typing.Literal[True]
) -> list[NameAttribute[str]]: ...

def get_attributes_for_oid(
self,
oid: ObjectIdentifier,
) -> list[NameAttribute[str | bytes]]:
*,
return_string: bool = False,
) -> list[NameAttribute[str | bytes]] | list[NameAttribute[str]]:
if return_string is True and oid == NameOID.X500_UNIQUE_IDENTIFIER:
raise TypeError(
"oid must not be X500_UNIQUE_IDENTIFIER with return_string=True."
)
return [i for i in self if i.oid == oid]

@property
Expand Down
8 changes: 8 additions & 0 deletions tests/x509/test_x509.py
Original file line number Diff line number Diff line change
Expand Up @@ -6135,8 +6135,16 @@ def test_get_attributes_for_oid(self):
attr = x509.NameAttribute(oid, "value1")
rdn = x509.RelativeDistinguishedName([attr])
assert rdn.get_attributes_for_oid(oid) == [attr]
assert rdn.get_attributes_for_oid(oid, return_string=True) == [attr]
assert rdn.get_attributes_for_oid(x509.ObjectIdentifier("1.2.3")) == []

def test_get_attributes_for_oid_string_x500_unique_identifier(self):
oid = NameOID.X500_UNIQUE_IDENTIFIER
attr = x509.NameAttribute(oid, b"value1", _ASN1Type.BitString)
rdn = x509.RelativeDistinguishedName([attr])
with pytest.raises(TypeError):
rdn.get_attributes_for_oid(oid, return_string=True)


class TestObjectIdentifier:
def test_eq(self):
Expand Down

0 comments on commit c4a2a5f

Please sign in to comment.