diff --git a/stdlib/src/collections/string.mojo b/stdlib/src/collections/string.mojo index 1b38236b10..b896b534e8 100644 --- a/stdlib/src/collections/string.mojo +++ b/stdlib/src/collections/string.mojo @@ -233,6 +233,187 @@ fn ascii(value: String) -> String: # ===----------------------------------------------------------------------=== # +@always_inline +fn _stol(str_slice: StringSlice, base: Int = 10) raises -> (Int, String): + """Implementation of `stol` for StringSlice inputs. + + Please see its docstring for details. + """ + if (base != 0) and (base < 2 or base > 36): + raise Error("Base must be >= 2 and <= 36, or 0.") + + if not str_slice: + raise Error("Cannot convert empty string to integer.") + + var result: Int = 0 + var real_base: Int + var start: Int = 0 + var is_negative: Bool = False + var has_prefix: Bool = False + var str_len = str_slice.byte_length() + var buff = str_slice.unsafe_ptr() + + start, is_negative = _trim_and_handle_sign(str_slice, str_len) + + if start == str_len or not _is_valid_digit(int(buff[start]), base): + return 0, String(str_slice) + + if base == 0: + var real_base_new_start = _identify_base(str_slice, start) + real_base = real_base_new_start[0] + + # If identify_base returns error but starts with 0, treat as base 10 + if real_base == -1 and buff[start] == ord("0"): + real_base = 10 + # Keep original start position for base 10 + else: + # For valid prefixes, use the new start position + start = real_base_new_start[1] + + has_prefix = real_base != 10 + else: + start, has_prefix = _handle_base_prefix(start, str_slice, str_len, base) + real_base = base + + var ord_num_max: Int + var ord_letter_max = (-1, -1) + alias ord_0 = ord("0") + var ord_letter_min = (ord("a"), ord("A")) + alias ord_underscore = ord("_") + + if real_base <= 10: + ord_num_max = ord(str(real_base - 1)) + else: + ord_num_max = ord("9") + ord_letter_max = ( + ord("a") + (real_base - 11), + ord("A") + (real_base - 11), + ) + + var was_last_digit_underscore = not (real_base in (2, 8, 16) and has_prefix) + for pos in range(start, str_len): + var ord_current = int(buff[pos]) + if ord_current == ord_underscore and was_last_digit_underscore: + break # Break out as apposed to raising exception + if ord_current == ord_underscore: + was_last_digit_underscore = True + continue + + was_last_digit_underscore = False + + var digit_value: Int + if ord_0 <= ord_current <= ord_num_max: + digit_value = ord_current - ord_0 + elif ord_letter_min[0] <= ord_current <= ord_letter_max[0]: + digit_value = ord_current - ord_letter_min[0] + 10 + elif ord_letter_min[1] <= ord_current <= ord_letter_max[1]: + digit_value = ord_current - ord_letter_min[1] + 10 + else: + break + + if digit_value >= real_base: + break + + var new_result = result * real_base + digit_value + if new_result <= result and result > 0: + raise Error( + _str_to_base_error(real_base, str_slice) + + " String expresses an integer too large to store in Int." + ) + result = new_result + start = pos + 1 + + if is_negative: + result = -result + + return result, String( + StringSlice(unsafe_from_utf8=str_slice.as_bytes()[start:]) + ) + + +fn stol(str: String, base: Int = 10) raises -> (Int, String): + """Convert a string to a integer and return the remaining unparsed string. + + Similar to `atol`, but `stol` parses only a portion of the string and returns + both the parsed integer and the remaining unparsed part. For example, `stol("32abc")` returns `(32, "abc")`. + + If base is 0, the string is parsed as an [Integer literal][1], with the following considerations: + - '0b' or '0B' prefix indicates binary (base 2) + - '0o' or '0O' prefix indicates octal (base 8) + - '0x' or '0X' prefix indicates hexadecimal (base 16) + - Without a prefix, it's treated as decimal (base 10) + Notes: + This follows [Python's integer literals](\ + https://docs.python.org/3/reference/lexical_analysis.html#integers) + + Raises: + If the base is invalid or if the string is empty. + + Args: + str: A string to be parsed as an integer in the given base. + base: Base used for conversion, value must be between 2 and 36, or 0. + + Returns: + A tuple containing: + - An integer value representing the parsed part of the string. + - The remaining unparsed part of the string. + + Examples: + >>> stol("19abc") + (19, "abc") + >>> stol("0xFF hello", 16) + (255, " hello") + >>> stol("0x123ghi", 0) + (291, "ghi") + >>> stol("0b1010 binary", 0) + (10, " binary") + >>> stol("0o123 octal", 0) + (83, " octal") + + See Also: + `atol`: A similar function that parses the entire string and returns an integer. + + [1]: https://docs.python.org/3/reference/lexical_analysis.html#integers. + + """ + var result: Int + var remaining: String + result, remaining = _stol(str.as_string_slice(), base) + + return result, remaining + + +@always_inline +fn _is_valid_digit(char: UInt8, base: Int) -> Bool: + """Checks if a character is a valid digit for the given base. + + Args: + char: The character to check, as a UInt8. + base: The numeric base (0-36, where 0 is special case). + + Returns: + True if the character is a valid digit for the given base, False otherwise. + """ + if base == 0: + # For base 0, we need to allow 0-9 and a-f/A-F for potential hex numbers + if char >= ord("0") and char <= ord("9"): + return True + var upper_char = char & ~32 # Convert to uppercase + return upper_char >= ord("A") and upper_char <= ord("F") + + if char == ord("_"): + return True + + if char >= ord("0") and char <= ord("9"): + return (char - ord("0")) < base + if base <= 10: + return False + var upper_char = char & ~32 # Convert to uppercase + if upper_char >= ord("A") and upper_char <= ord("Z"): + return (upper_char - ord("A") + 10) < base + return False + + fn _atol(str_slice: StringSlice, base: Int = 10) raises -> Int: """Implementation of `atol` for StringSlice inputs. diff --git a/stdlib/src/prelude/__init__.mojo b/stdlib/src/prelude/__init__.mojo index 5551b826ee..8100ce2857 100644 --- a/stdlib/src/prelude/__init__.mojo +++ b/stdlib/src/prelude/__init__.mojo @@ -20,6 +20,7 @@ from collections.string import ( ascii, atof, atol, + stol, chr, isdigit, islower, diff --git a/stdlib/test/collections/test_string.mojo b/stdlib/test/collections/test_string.mojo index c70ff3a50c..d9904eaaf2 100644 --- a/stdlib/test/collections/test_string.mojo +++ b/stdlib/test/collections/test_string.mojo @@ -361,6 +361,166 @@ def test_string_indexing(): assert_equal("H", str[-50::50]) +def test_stol(): + var result: Int + var remaining: String + + # base 10 + result, remaining = stol(String("375 ABC")) + assert_equal(375, result) + assert_equal(" ABC", remaining) + result, remaining = stol(String(" 005")) + assert_equal(5, result) + assert_equal("", remaining) + result, remaining = stol(String(" 013 ")) + assert_equal(13, result) + assert_equal(" ", remaining) + result, remaining = stol(String("-89")) + assert_equal(-89, result) + assert_equal("", remaining) + result, remaining = stol(String(" -52")) + assert_equal(-52, result) + assert_equal("", remaining) + + # other bases + result, remaining = stol(" FF", 16) + assert_equal(255, result) + assert_equal("", remaining) + result, remaining = stol(" 0xff ", 16) + assert_equal(255, result) + assert_equal(" ", remaining) + result, remaining = stol("10010eighteen18", 2) + assert_equal(18, result) + assert_equal("eighteen18", remaining) + result, remaining = stol("0b10010", 2) + assert_equal(18, result) + result, remaining = stol("0b_10010", 2) + assert_equal(18, result) + result, remaining = stol("0b_0010010", 2) + assert_equal(18, result) + result, remaining = stol("0b0000_0_010010", 2) + assert_equal(18, result) + assert_equal("", remaining) + result, remaining = stol("0o12", 8) + assert_equal(10, result) + result, remaining = stol("0o_12", 8) + assert_equal(10, result) + result, remaining = stol("0o_012", 8) + assert_equal(10, result) + result, remaining = stol("0o0000_0_0012", 8) + assert_equal(10, result) + assert_equal("", remaining) + result, remaining = stol("Z", 36) + assert_equal(35, result) + assert_equal("", remaining) + + # test with trailing characters + result, remaining = stol("123abc") + assert_equal(123, result) + assert_equal("abc", remaining) + result, remaining = stol("-45def") + assert_equal(-45, result) + assert_equal("def", remaining) + result, remaining = stol("0xffghi", 0) + assert_equal(255, result) + result, remaining = stol("0x_ffghi", 0) + assert_equal(255, result) + result, remaining = stol("0x_0ffghi", 0) + assert_equal(255, result) + result, remaining = stol("0x0000_0_00ffghi", 0) + assert_equal(255, result) + assert_equal("ghi", remaining) + + result, remaining = stol(" ") + assert_equal(0, result) + assert_equal(" ", remaining) + + result, remaining = stol("123.456", 10) + assert_equal(123, result) + assert_equal(".456", remaining) + result, remaining = stol("--123", 10) + assert_equal(0, result) + assert_equal("--123", remaining) + + result, remaining = stol("12a34", 10) + assert_equal(12, result) + assert_equal("a34", remaining) + result, remaining = stol("1G5", 16) + assert_equal(1, result) + assert_equal("G5", remaining) + + result, remaining = stol("-1A", 16) + assert_equal(-26, result) + assert_equal("", remaining) + result, remaining = stol("-110", 2) + assert_equal(-6, result) + assert_equal("", remaining) + + result, remaining = stol("Mojo!") + assert_equal(0, result) + assert_equal("Mojo!", remaining) + + # Negative Cases + with assert_raises(contains="Cannot convert empty string to integer."): + _ = stol("") + + with assert_raises(contains="Base must be >= 2 and <= 36, or 0."): + _ = stol("Bad Base", 42) + + with assert_raises( + contains="String expresses an integer too large to store in Int." + ): + _ = stol(String("9223372036854775832"), 10) + + +def test_stol_base_0(): + var result: Int + var remaining: String + + result, remaining = stol("155_155", 0) + assert_equal(155155, result) + assert_equal("", remaining) + result, remaining = stol("1_2_3_4_5", 0) + assert_equal(12345, result) + assert_equal("", remaining) + result, remaining = stol("1_2_3_4_5_", 0) + assert_equal(12345, result) + assert_equal("_", remaining) + result, remaining = stol("0b1_0_1_0", 0) + assert_equal(10, result) + assert_equal("", remaining) + result, remaining = stol("0o1_2_3", 0) + assert_equal(83, result) + assert_equal("", remaining) + result, remaining = stol("0x1_A_B", 0) + assert_equal(427, result) + assert_equal("", remaining) + result, remaining = stol("123_", 0) + assert_equal(123, result) + assert_equal("_", remaining) + result, remaining = stol("_123", 0) + assert_equal(0, result) + assert_equal("_123", remaining) + result, remaining = stol("123__456", 0) + assert_equal(123, result) + assert_equal("__456", remaining) + result, remaining = stol("0x1_23", 0) + assert_equal(291, result) + assert_equal("", remaining) + result, remaining = stol("0_123", 0) + assert_equal(123, result) + assert_equal("", remaining) + result, remaining = stol("0z123", 0) + assert_equal(0, result) + assert_equal("z123", remaining) + result, remaining = stol("Mojo!", 0) + assert_equal(0, result) + assert_equal("Mojo!", remaining) + result, remaining = stol("0o123 octal", 0) + assert_equal(83, result) + assert_equal(" octal", remaining) + + def test_atol(): # base 10 assert_equal(375, atol(String("375"))) @@ -1606,6 +1766,8 @@ def main(): test_ord() test_chr() test_string_indexing() + test_stol() + test_stol_base_0() test_atol() test_atol_base_0() test_atof() diff --git a/stdlib/test/python/my_module.py b/stdlib/test/python/my_module.py index 8147b0a382..c78c39556e 100644 --- a/stdlib/test/python/my_module.py +++ b/stdlib/test/python/my_module.py @@ -25,7 +25,8 @@ def __init__(self, bar): class AbstractPerson(ABC): @abstractmethod - def method(self): ... + def method(self): + ... def my_function(name):