Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Stdlib] Add stol function #3178

Open
wants to merge 10 commits into
base: nightly
Choose a base branch
from
181 changes: 181 additions & 0 deletions stdlib/src/collections/string.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,187 @@ fn ascii(value: String) -> String:
# ===----------------------------------------------------------------------=== #


@always_inline
fn _stol(str_slice: StringSlice, base: Int = 10) raises -> (Int, String):
"""Implementation of `stol` for StringSlice inputs.

Please see its docstring for details.
"""
if (base != 0) and (base < 2 or base > 36):
raise Error("Base must be >= 2 and <= 36, or 0.")

if not str_slice:
raise Error("Cannot convert empty string to integer.")

var result: Int = 0
var real_base: Int
var start: Int = 0
var is_negative: Bool = False
var has_prefix: Bool = False
var str_len = str_slice.byte_length()
var buff = str_slice.unsafe_ptr()

start, is_negative = _trim_and_handle_sign(str_slice, str_len)

if start == str_len or not _is_valid_digit(int(buff[start]), base):
return 0, String(str_slice)

if base == 0:
var real_base_new_start = _identify_base(str_slice, start)
real_base = real_base_new_start[0]

# If identify_base returns error but starts with 0, treat as base 10
if real_base == -1 and buff[start] == ord("0"):
real_base = 10
# Keep original start position for base 10
else:
# For valid prefixes, use the new start position
start = real_base_new_start[1]

has_prefix = real_base != 10
else:
start, has_prefix = _handle_base_prefix(start, str_slice, str_len, base)
real_base = base

var ord_num_max: Int
var ord_letter_max = (-1, -1)
alias ord_0 = ord("0")
var ord_letter_min = (ord("a"), ord("A"))
alias ord_underscore = ord("_")

if real_base <= 10:
ord_num_max = ord(str(real_base - 1))
else:
ord_num_max = ord("9")
ord_letter_max = (
ord("a") + (real_base - 11),
ord("A") + (real_base - 11),
)

var was_last_digit_underscore = not (real_base in (2, 8, 16) and has_prefix)
for pos in range(start, str_len):
var ord_current = int(buff[pos])
if ord_current == ord_underscore and was_last_digit_underscore:
break # Break out as apposed to raising exception
if ord_current == ord_underscore:
was_last_digit_underscore = True
continue

was_last_digit_underscore = False

var digit_value: Int
if ord_0 <= ord_current <= ord_num_max:
digit_value = ord_current - ord_0
elif ord_letter_min[0] <= ord_current <= ord_letter_max[0]:
digit_value = ord_current - ord_letter_min[0] + 10
elif ord_letter_min[1] <= ord_current <= ord_letter_max[1]:
digit_value = ord_current - ord_letter_min[1] + 10
else:
break

if digit_value >= real_base:
break

var new_result = result * real_base + digit_value
if new_result <= result and result > 0:
raise Error(
_str_to_base_error(real_base, str_slice)
+ " String expresses an integer too large to store in Int."
)
result = new_result
start = pos + 1

if is_negative:
result = -result

return result, String(
StringSlice(unsafe_from_utf8=str_slice.as_bytes()[start:])
)


fn stol(str: String, base: Int = 10) raises -> (Int, String):
"""Convert a string to a integer and return the remaining unparsed string.

Similar to `atol`, but `stol` parses only a portion of the string and returns
both the parsed integer and the remaining unparsed part. For example, `stol("32abc")` returns `(32, "abc")`.

If base is 0, the string is parsed as an [Integer literal][1], with the following considerations:
- '0b' or '0B' prefix indicates binary (base 2)
- '0o' or '0O' prefix indicates octal (base 8)
- '0x' or '0X' prefix indicates hexadecimal (base 16)
- Without a prefix, it's treated as decimal (base 10)
Notes:
This follows [Python's integer literals](\
https://docs.python.org/3/reference/lexical_analysis.html#integers)

Raises:
If the base is invalid or if the string is empty.

Args:
str: A string to be parsed as an integer in the given base.
base: Base used for conversion, value must be between 2 and 36, or 0.

Returns:
A tuple containing:
- An integer value representing the parsed part of the string.
- The remaining unparsed part of the string.

Examples:
>>> stol("19abc")
(19, "abc")
>>> stol("0xFF hello", 16)
(255, " hello")
>>> stol("0x123ghi", 0)
(291, "ghi")
>>> stol("0b1010 binary", 0)
(10, " binary")
>>> stol("0o123 octal", 0)
(83, " octal")

See Also:
`atol`: A similar function that parses the entire string and returns an integer.

[1]: https://docs.python.org/3/reference/lexical_analysis.html#integers.

"""
var result: Int
var remaining: String
result, remaining = _stol(str.as_string_slice(), base)

return result, remaining


@always_inline
fn _is_valid_digit(char: UInt8, base: Int) -> Bool:
"""Checks if a character is a valid digit for the given base.

Args:
char: The character to check, as a UInt8.
base: The numeric base (0-36, where 0 is special case).

Returns:
True if the character is a valid digit for the given base, False otherwise.
"""
if base == 0:
# For base 0, we need to allow 0-9 and a-f/A-F for potential hex numbers
if char >= ord("0") and char <= ord("9"):
return True
var upper_char = char & ~32 # Convert to uppercase
return upper_char >= ord("A") and upper_char <= ord("F")

if char == ord("_"):
return True

if char >= ord("0") and char <= ord("9"):
return (char - ord("0")) < base
if base <= 10:
return False
var upper_char = char & ~32 # Convert to uppercase
if upper_char >= ord("A") and upper_char <= ord("Z"):
return (upper_char - ord("A") + 10) < base
return False


fn _atol(str_slice: StringSlice, base: Int = 10) raises -> Int:
"""Implementation of `atol` for StringSlice inputs.

Expand Down
1 change: 1 addition & 0 deletions stdlib/src/prelude/__init__.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ from collections.string import (
ascii,
atof,
atol,
stol,
chr,
isdigit,
islower,
Expand Down
162 changes: 162 additions & 0 deletions stdlib/test/collections/test_string.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,166 @@ def test_string_indexing():
assert_equal("H", str[-50::50])


def test_stol():
var result: Int
var remaining: String

# base 10
result, remaining = stol(String("375 ABC"))
assert_equal(375, result)
assert_equal(" ABC", remaining)
result, remaining = stol(String(" 005"))
assert_equal(5, result)
assert_equal("", remaining)
result, remaining = stol(String(" 013 "))
assert_equal(13, result)
assert_equal(" ", remaining)
result, remaining = stol(String("-89"))
assert_equal(-89, result)
assert_equal("", remaining)
result, remaining = stol(String(" -52"))
assert_equal(-52, result)
assert_equal("", remaining)

# other bases
result, remaining = stol(" FF", 16)
assert_equal(255, result)
assert_equal("", remaining)
result, remaining = stol(" 0xff ", 16)
assert_equal(255, result)
assert_equal(" ", remaining)
result, remaining = stol("10010eighteen18", 2)
assert_equal(18, result)
assert_equal("eighteen18", remaining)
result, remaining = stol("0b10010", 2)
assert_equal(18, result)
result, remaining = stol("0b_10010", 2)
assert_equal(18, result)
result, remaining = stol("0b_0010010", 2)
assert_equal(18, result)
result, remaining = stol("0b0000_0_010010", 2)
assert_equal(18, result)
assert_equal("", remaining)
result, remaining = stol("0o12", 8)
assert_equal(10, result)
result, remaining = stol("0o_12", 8)
assert_equal(10, result)
result, remaining = stol("0o_012", 8)
assert_equal(10, result)
result, remaining = stol("0o0000_0_0012", 8)
assert_equal(10, result)
assert_equal("", remaining)
result, remaining = stol("Z", 36)
assert_equal(35, result)
assert_equal("", remaining)

# test with trailing characters
result, remaining = stol("123abc")
assert_equal(123, result)
assert_equal("abc", remaining)
result, remaining = stol("-45def")
assert_equal(-45, result)
assert_equal("def", remaining)
result, remaining = stol("0xffghi", 0)
assert_equal(255, result)
result, remaining = stol("0x_ffghi", 0)
assert_equal(255, result)
result, remaining = stol("0x_0ffghi", 0)
assert_equal(255, result)
result, remaining = stol("0x0000_0_00ffghi", 0)
assert_equal(255, result)
assert_equal("ghi", remaining)

result, remaining = stol(" ")
assert_equal(0, result)
assert_equal(" ", remaining)

result, remaining = stol("123.456", 10)
assert_equal(123, result)
assert_equal(".456", remaining)
result, remaining = stol("--123", 10)
assert_equal(0, result)
assert_equal("--123", remaining)

result, remaining = stol("12a34", 10)
assert_equal(12, result)
assert_equal("a34", remaining)
result, remaining = stol("1G5", 16)
assert_equal(1, result)
assert_equal("G5", remaining)

result, remaining = stol("-1A", 16)
assert_equal(-26, result)
assert_equal("", remaining)
result, remaining = stol("-110", 2)
assert_equal(-6, result)
assert_equal("", remaining)

result, remaining = stol("Mojo!")
assert_equal(0, result)
assert_equal("Mojo!", remaining)

# Negative Cases
with assert_raises(contains="Cannot convert empty string to integer."):
_ = stol("")

with assert_raises(contains="Base must be >= 2 and <= 36, or 0."):
_ = stol("Bad Base", 42)

with assert_raises(
contains="String expresses an integer too large to store in Int."
):
_ = stol(String("9223372036854775832"), 10)


def test_stol_base_0():
var result: Int
var remaining: String

result, remaining = stol("155_155", 0)
assert_equal(155155, result)
assert_equal("", remaining)
result, remaining = stol("1_2_3_4_5", 0)
assert_equal(12345, result)
assert_equal("", remaining)
result, remaining = stol("1_2_3_4_5_", 0)
assert_equal(12345, result)
assert_equal("_", remaining)
result, remaining = stol("0b1_0_1_0", 0)
assert_equal(10, result)
assert_equal("", remaining)
result, remaining = stol("0o1_2_3", 0)
assert_equal(83, result)
assert_equal("", remaining)
result, remaining = stol("0x1_A_B", 0)
assert_equal(427, result)
assert_equal("", remaining)
result, remaining = stol("123_", 0)
assert_equal(123, result)
assert_equal("_", remaining)
result, remaining = stol("_123", 0)
assert_equal(0, result)
assert_equal("_123", remaining)
result, remaining = stol("123__456", 0)
assert_equal(123, result)
assert_equal("__456", remaining)
result, remaining = stol("0x1_23", 0)
assert_equal(291, result)
assert_equal("", remaining)
result, remaining = stol("0_123", 0)
assert_equal(123, result)
assert_equal("", remaining)
result, remaining = stol("0z123", 0)
assert_equal(0, result)
assert_equal("z123", remaining)
result, remaining = stol("Mojo!", 0)
assert_equal(0, result)
assert_equal("Mojo!", remaining)
result, remaining = stol("0o123 octal", 0)
assert_equal(83, result)
assert_equal(" octal", remaining)


def test_atol():
# base 10
assert_equal(375, atol(String("375")))
Expand Down Expand Up @@ -1606,6 +1766,8 @@ def main():
test_ord()
test_chr()
test_string_indexing()
test_stol()
test_stol_base_0()
test_atol()
test_atol_base_0()
test_atof()
Expand Down
3 changes: 2 additions & 1 deletion stdlib/test/python/my_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ def __init__(self, bar):

class AbstractPerson(ABC):
@abstractmethod
def method(self): ...
def method(self):
...


def my_function(name):
Expand Down