From c605cf0ca3473fff2f6d04d59992ad0b0b3dee8c Mon Sep 17 00:00:00 2001 From: William G Hatch Date: Wed, 16 Oct 2024 15:24:02 -0600 Subject: [PATCH 1/8] [stdlib] fix formatter lifetime bug in string.join The lifetime of the formatter object was not extended to the lifetime of the `@parameter` closure, so in certain configurations it was crashing. Closes https://github.com/modularml/mojo/issues/2751 MODULAR_ORIG_COMMIT_REV_ID: 045a508b51458e01a55f0a36d6a2cc54ca9a0e92 --- stdlib/src/collections/string.mojo | 1 + stdlib/test/collections/test_string.mojo | 4 +--- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/stdlib/src/collections/string.mojo b/stdlib/src/collections/string.mojo index d365365bfc..f69b1f1e5b 100644 --- a/stdlib/src/collections/string.mojo +++ b/stdlib/src/collections/string.mojo @@ -1391,6 +1391,7 @@ struct String( elems.each[add_elt]() _ = is_first + _ = formatter^ return result fn join[T: StringableCollectionElement](self, elems: List[T, *_]) -> String: diff --git a/stdlib/test/collections/test_string.mojo b/stdlib/test/collections/test_string.mojo index bbf3768c7f..ed92b61a49 100644 --- a/stdlib/test/collections/test_string.mojo +++ b/stdlib/test/collections/test_string.mojo @@ -10,10 +10,8 @@ # See the License for the specific language governing permissions and # limitations under the License. # ===----------------------------------------------------------------------=== # -# RUN: %bare-mojo %s +# RUN: %mojo %s -# TODO: Replace %bare-mojo with %mojo -# when https://github.com/modularml/mojo/issues/2751 is fixed. from collections.string import ( _calc_initial_buffer_size_int32, _calc_initial_buffer_size_int64, From c409f01dba72f393a6a5fa9f59e8923144a01ae5 Mon Sep 17 00:00:00 2001 From: Joe Loser Date: Wed, 16 Oct 2024 16:05:43 -0600 Subject: [PATCH 2/8] [stdlib] Fix missing `benchmark.keep` calls in `bench_math.mojo` There are a few missing `benchmark.keep` calls in the benchmark functions. Add them so the compiler won't optimize away this function. Other files will be audited in a follow-up. MODULAR_ORIG_COMMIT_REV_ID: ec352ca3f588a0f15d709465d5222c4999192cc7 --- stdlib/benchmarks/math/bench_math.mojo | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/stdlib/benchmarks/math/bench_math.mojo b/stdlib/benchmarks/math/bench_math.mojo index 877c2798a2..9a62d9aba2 100644 --- a/stdlib/benchmarks/math/bench_math.mojo +++ b/stdlib/benchmarks/math/bench_math.mojo @@ -52,7 +52,8 @@ fn bench_math[ @parameter fn call_fn() raises: for input in inputs: - _ = math_f1p(input[]) + var result = math_f1p(input[]) + keep(result) b.iter[call_fn]() @@ -70,7 +71,8 @@ fn bench_math3[ @parameter fn call_fn() raises: for input in inputs: - _ = math_f3p(input[], input[], input[]) + var result = math_f3p(input[], input[], input[]) + keep(result) b.iter[call_fn]() From 2c83f89bbd0d6267e4cf3cb6a597f153beb0d0ce Mon Sep 17 00:00:00 2001 From: Helehex Date: Wed, 16 Oct 2024 16:15:05 -0600 Subject: [PATCH 3/8] [External] [stdlib] Fix `/_math.mojo` examples. (#49264) [External] [stdlib] Fix `/_math.mojo` examples. FYI: `Truncable` is not in `math`, so that one still fails, not sure what's going on there. Co-authored-by: Helehex Closes modularml/mojo#3681 MODULAR_ORIG_COMMIT_REV_ID: d1fb6af747e4599fe0eefa709dad1f9f4840e5b1 --- stdlib/src/builtin/_math.mojo | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/stdlib/src/builtin/_math.mojo b/stdlib/src/builtin/_math.mojo index 67b818e9f4..22a309e7ed 100644 --- a/stdlib/src/builtin/_math.mojo +++ b/stdlib/src/builtin/_math.mojo @@ -41,7 +41,7 @@ trait Ceilable: var im: Float64 fn __ceil__(self) -> Self: - return Self(ceil(re), ceil(im)) + return Self(ceil(self.re), ceil(self.im)) ``` """ @@ -78,7 +78,7 @@ trait Floorable: var im: Float64 fn __floor__(self) -> Self: - return Self(floor(re), floor(im)) + return Self(floor(self.re), floor(self.im)) ``` """ From fccb1716473cd845880e380701a83c57dfc7694f Mon Sep 17 00:00:00 2001 From: Joe Loser Date: Wed, 16 Oct 2024 16:55:54 -0600 Subject: [PATCH 4/8] [stdlib] Remove `Reference` type alias MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit All of the internal code has moved over to using `Pointer` — the new name, instead of `Reference`. Drop the compatibility type alias now. MODULAR_ORIG_COMMIT_REV_ID: cda036aa01a5ec512dbeae1debfaba3a6f181298 --- stdlib/src/memory/__init__.mojo | 2 +- stdlib/src/memory/pointer.mojo | 3 --- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/stdlib/src/memory/__init__.mojo b/stdlib/src/memory/__init__.mojo index 9d25f91476..cc226348fd 100644 --- a/stdlib/src/memory/__init__.mojo +++ b/stdlib/src/memory/__init__.mojo @@ -15,6 +15,6 @@ from .arc import Arc from .box import Box from .memory import memcmp, memcpy, memset, memset_zero, stack_allocation -from .pointer import AddressSpace, Pointer, Reference +from .pointer import AddressSpace, Pointer from .unsafe import bitcast from .unsafe_pointer import UnsafePointer diff --git a/stdlib/src/memory/pointer.mojo b/stdlib/src/memory/pointer.mojo index ddbfa54ffb..167feb5f5a 100644 --- a/stdlib/src/memory/pointer.mojo +++ b/stdlib/src/memory/pointer.mojo @@ -21,9 +21,6 @@ from memory import Pointer from builtin._documentation import doc_private -# TODO: This is kept for compatibility, remove this in the future. -alias Reference = Pointer[*_] - # ===----------------------------------------------------------------------===# # AddressSpace # ===----------------------------------------------------------------------===# From f95b14da37a0bc6d5ddff4d5733566f3c7b4dbb2 Mon Sep 17 00:00:00 2001 From: modularbot <116839051+modularbot@users.noreply.github.com> Date: Wed, 16 Oct 2024 17:44:12 -0600 Subject: [PATCH 5/8] [External] [stdlib] Fix `String.format()` to use byte indexing, refactor for performance and readability, and prepare for full format spec support (#49225) [External] [stdlib] Fix `String.format()` to use byte indexing, refactor for performance and readability, and prepare for full format spec support Fix `String.format()` to use byte indexing, refactor for performance and readability, and prepare for full format spec support. Closes https://github.com/modularml/mojo/issues/3296. ORIGINAL_AUTHOR=martinvuyk <110240700+martinvuyk@users.noreply.github.com> PUBLIC_PR_LINK=modularml/mojo#3539 Co-authored-by: martinvuyk <110240700+martinvuyk@users.noreply.github.com> Closes modularml/mojo#3539 MODULAR_ORIG_COMMIT_REV_ID: 2d7ebae6f48b7d29866c089cffd084cb4405ac45 --- stdlib/src/builtin/string_literal.mojo | 48 +- stdlib/src/collections/string.mojo | 340 +---------- stdlib/src/utils/string_slice.mojo | 687 ++++++++++++++++++++++- stdlib/test/collections/test_string.mojo | 54 +- stdlib/test/hashlib/test_ahash.mojo | 8 +- 5 files changed, 768 insertions(+), 369 deletions(-) diff --git a/stdlib/src/builtin/string_literal.mojo b/stdlib/src/builtin/string_literal.mojo index 975963b749..21bdb178b1 100644 --- a/stdlib/src/builtin/string_literal.mojo +++ b/stdlib/src/builtin/string_literal.mojo @@ -24,7 +24,11 @@ from utils import StringRef, Span, StringSlice, StaticString from utils import Formattable, Formatter from utils._visualizers import lldb_formatter_wrapping_type -from collections.string import _atol, _StringSliceIter +from utils.string_slice import ( + _StringSliceIter, + _FormatCurlyEntry, + _CurlyEntryFormattable, +) # ===----------------------------------------------------------------------===# # StringLiteral @@ -206,27 +210,25 @@ struct StringLiteral( """ return len(self) != 0 + @always_inline fn __int__(self) raises -> Int: """Parses the given string as a base-10 integer and returns that value. - - For example, `int("19")` returns `19`. If the given string cannot be parsed - as an integer value, an error is raised. For example, `int("hi")` raises an - error. + If the string cannot be parsed as an int, an error is raised. Returns: An integer value that represents the string, or otherwise raises. """ - return _atol(self) + return int(self.as_string_slice()) + @always_inline fn __float__(self) raises -> Float64: - """Parses the string as a float point number and returns that value. - - If the string cannot be parsed as a float, an error is raised. + """Parses the string as a float point number and returns that value. If + the string cannot be parsed as a float, an error is raised. Returns: A float value that represents the string, or otherwise raises. """ - return atof(self) + return float(self.as_string_slice()) @no_inline fn __str__(self) -> String: @@ -414,6 +416,32 @@ struct StringLiteral( len=self.byte_length(), ) + @always_inline + fn format[*Ts: _CurlyEntryFormattable](self, *args: *Ts) raises -> String: + """Format a template with `*args`. + + Args: + args: The substitution values. + + Parameters: + Ts: The types of substitution values that implement `Representable` + and `Stringable` (to be changed and made more flexible). + + Returns: + The template with the given values substituted. + + Examples: + + ```mojo + # Manual indexing: + print("{0} {1} {0}".format("Mojo", 1.125)) # Mojo 1.125 Mojo + # Automatic indexing: + print("{} {}".format(True, "hello world")) # True hello world + ``` + . + """ + return _FormatCurlyEntry.format(self, args) + fn format_to(self, inout writer: Formatter): """ Formats this string literal to the provided formatter. diff --git a/stdlib/src/collections/string.mojo b/stdlib/src/collections/string.mojo index f69b1f1e5b..33e68615af 100644 --- a/stdlib/src/collections/string.mojo +++ b/stdlib/src/collections/string.mojo @@ -42,6 +42,8 @@ from utils.string_slice import ( _StringSliceIter, _unicode_codepoint_utf8_byte_length, _shift_unicode_to_utf8, + _FormatCurlyEntry, + _CurlyEntryFormattable, ) # ===----------------------------------------------------------------------=== # @@ -748,7 +750,7 @@ struct String( impl: The buffer. """ debug_assert( - impl[-1] == 0, + len(impl) > 0 and impl[-1] == 0, "expected last element of String buffer to be null terminator", ) # We make a backup because steal_data() will clear size and capacity. @@ -2095,22 +2097,20 @@ struct String( return self[: -suffix.byte_length()] return self + @always_inline fn __int__(self) raises -> Int: """Parses the given string as a base-10 integer and returns that value. - - For example, `int("19")` returns `19`. If the given string cannot be - parsed as an integer value, an error is raised. For example, `int("hi")` - raises an error. + If the string cannot be parsed as an int, an error is raised. Returns: An integer value that represents the string, or otherwise raises. """ return atol(self) + @always_inline fn __float__(self) raises -> Float64: - """Parses the string as a float point number and returns that value. - - If the string cannot be parsed as a float, an error is raised. + """Parses the string as a float point number and returns that value. If + the string cannot be parsed as a float, an error is raised. Returns: A float value that represents the string, or otherwise raises. @@ -2140,84 +2140,31 @@ struct String( ) return String(buf^) - fn format[*Ts: StringRepresentable](self, *args: *Ts) raises -> String: - """Format a template with *args. - - Example of manual indexing: - - ```mojo - print( - String("{0} {1} {0}").format( - "Mojo", 1.125 - ) - ) #Mojo 1.125 Mojo - ``` - - Example of automatic indexing: - - ```mojo - var x = String("{} {}").format( - True, "hello world" - ) - print(x) #True hello world - ``` + @always_inline + fn format[*Ts: _CurlyEntryFormattable](self, *args: *Ts) raises -> String: + """Format a template with `*args`. Args: args: The substitution values. Parameters: - Ts: The types of the substitution values. - Are required to implement `Stringable`. + Ts: The types of substitution values that implement `Representable` + and `Stringable` (to be changed and made more flexible). Returns: The template with the given values substituted. - """ - alias num_pos_args = len(VariadicList(Ts)) - var entries = _FormatCurlyEntry.create_entries(self, num_pos_args) - - var res: String = "" - var pos_in_self = 0 - - var current_automatic_arg_index = 0 - for e in entries: - debug_assert( - pos_in_self < self.byte_length(), - "pos_in_self >= self.byte_length()", - ) - res += self[pos_in_self : e[].first_curly] - - if e[].is_escaped_brace(): - res += "}" if e[].field[Bool] else "{" - - if e[].is_manual_indexing(): - - @parameter - for i in range(num_pos_args): - if i == e[].field[Int]: - if e[].conversion_flag == "r": - res += repr(args[i]) - else: - res += str(args[i]) - - if e[].is_automatic_indexing(): - - @parameter - for i in range(num_pos_args): - if i == current_automatic_arg_index: - if e[].conversion_flag == "r": - res += repr(args[i]) - else: - res += str(args[i]) - - current_automatic_arg_index += 1 - - pos_in_self = e[].last_curly + 1 - - if pos_in_self < self.byte_length(): - res += self[pos_in_self : self.byte_length()] + Examples: - return res^ + ```mojo + # Manual indexing: + print(String("{0} {1} {0}").format("Mojo", 1.125)) # Mojo 1.125 Mojo + # Automatic indexing: + print(String("{} {}").format(True, "hello world")) # True hello world + ``` + . + """ + return _FormatCurlyEntry.format(self, args) fn isdigit(self) -> Bool: """A string is a digit string if all characters in the string are digits @@ -2461,244 +2408,3 @@ fn _calc_format_buffer_size[type: DType]() -> Int: return 64 + 1 else: return 128 + 1 # Add 1 for the terminator - - -# ===----------------------------------------------------------------------===# -# Format method structures -# ===----------------------------------------------------------------------===# - - -trait StringRepresentable(Stringable, Representable): - """The `StringRepresentable` trait denotes a trait composition of the - `Stringable` and `Representable` traits. - - This trait is used by the `format()` method to support both `{!s}` (or `{}`) - and `{!r}` format specifiers. It allows the method to handle types that - can be formatted using both their string representation and their - more detailed representation. - - Types implementing this trait must provide both `__str__()` and `__repr__()` - methods as defined in `Stringable` and `Representable` traits respectively. - """ - - pass - - -@value -struct _FormatCurlyEntry(CollectionElement, CollectionElementNew): - """ - Internally used by the `format()` method. - - Specifically to structure fields. - - Does not contain any substitution values. - - """ - - var first_curly: Int - """The index of an opening brace around a substitution field.""" - - var last_curly: Int - """The index of an closing brace around a substitution field.""" - - var conversion_flag: String - """Store the format specifier (e.g., 'r' for repr).""" - - alias _FieldVariantType = Variant[ - String, # kwargs indexing (`{field_name}`) - Int, # args manual indexing (`{3}`) - NoneType, # args automatic indexing (`{}`) - Bool, # for escaped curlies ('{{') - ] - var field: Self._FieldVariantType - """Store the substitution field.""" - - fn __init__(inout self, *, other: Self): - self.first_curly = other.first_curly - self.last_curly = other.last_curly - self.conversion_flag = other.conversion_flag - self.field = Self._FieldVariantType(other=other.field) - - fn is_escaped_brace(ref [_]self) -> Bool: - return self.field.isa[Bool]() - - fn is_kwargs_field(ref [_]self) -> Bool: - return self.field.isa[String]() - - fn is_automatic_indexing(ref [_]self) -> Bool: - return self.field.isa[NoneType]() - - fn is_manual_indexing(ref [_]self) -> Bool: - return self.field.isa[Int]() - - @staticmethod - fn create_entries( - format_src: String, len_pos_args: Int - ) raises -> List[Self]: - """Used internally by the `format()` method. - - Args: - format_src: The "format" part provided by the user. - len_pos_args: The len of *args - - Returns: - A `List` of structured field entries. - - Purpose of the `Variant` `Self.field`: - - - `Int` for manual indexing - (value field contains `0`) - - - `NoneType` for automatic indexing - (value field contains `None`) - - - `String` for **kwargs indexing - (value field contains `foo`) - - - `Bool` for escaped curlies - (value field contains False for `{` or True for '}') - """ - var manual_indexing_count = 0 - var automatic_indexing_count = 0 - var raised_manual_index = Optional[Int](None) - var raised_automatic_index = Optional[Int](None) - var raised_kwarg_field = Optional[String](None) - alias supported_conversion_flags = ( - String("s"), # __str__ - String("r"), # __repr__ - ) - - var entries = List[Self]() - var start = Optional[Int](None) - var skip_next = False - for i in range(format_src.byte_length()): - if skip_next: - skip_next = False - continue - if format_src[i] == "{": - if start: - # already one there. - if i - start.value() == 1: - # python escapes double curlies - var current_entry = Self( - first_curly=start.value(), - last_curly=i, - field=False, - conversion_flag="", - ) - entries.append(current_entry^) - start = None - continue - raise ( - "there is a single curly { left unclosed or unescaped" - ) - else: - start = i - continue - if format_src[i] == "}": - if start: - var start_value = start.value() - var current_entry = Self( - first_curly=start_value, - last_curly=i, - field=NoneType(), - conversion_flag="", - ) - - if i - start_value != 1: - var field = format_src[start_value + 1 : i] - var exclamation_index = field.find("!") - - # TODO: Future implementation of format specifiers - # When implementing format specifiers, modify this section to handle: - # replacement_field ::= "{" [field_name] ["!" conversion] [":" format_spec] "}" - # this will involve: - # 1. finding a colon ':' after the conversion flag (if present) - # 2. extracting the format_spec if a colon is found - # 3. adjusting the field and conversion_flag parsing accordingly - - if exclamation_index != -1: - if exclamation_index + 1 < len(field): - var conversion_flag: String = field[ - exclamation_index + 1 : - ] - if ( - conversion_flag - not in supported_conversion_flags - ): - raise 'Conversion flag "' + conversion_flag + '" not recognised.' - current_entry.conversion_flag = conversion_flag - else: - raise "Empty conversion flag." - - field = field[:exclamation_index] - - if ( - field == "" - ): # an empty field, so it's automatic indexing - if automatic_indexing_count >= len_pos_args: - raised_automatic_index = ( - automatic_indexing_count - ) - break - automatic_indexing_count += 1 - else: - try: - # field is a number for manual indexing: - var number = int(field) - current_entry.field = number - if number >= len_pos_args or number < 0: - raised_manual_index = number - break - manual_indexing_count += 1 - except e: - debug_assert( - "not convertible to integer" in str(e), - "Not the expected error from atol", - ) - # field is an keyword for **kwargs: - current_entry.field = field - raised_kwarg_field = field - break - - else: - # automatic indexing - # current_entry.field is already None - if automatic_indexing_count >= len_pos_args: - raised_automatic_index = automatic_indexing_count - break - automatic_indexing_count += 1 - entries.append(current_entry^) - start = None - else: - # python escapes double curlies - if (i + 1) < format_src.byte_length(): - if format_src[i + 1] == "}": - var curren_entry = Self( - first_curly=i, - last_curly=i + 1, - field=True, - conversion_flag="", - ) - entries.append(curren_entry^) - skip_next = True - continue - # if it is not an escaped one, it is an error - raise ( - "there is a single curly } left unclosed or unescaped" - ) - - if raised_automatic_index: - raise "Automatic indexing require more args in *args" - if raised_kwarg_field: - raise "Index " + raised_kwarg_field.value() + " not in kwargs" - if manual_indexing_count and automatic_indexing_count: - raise "Cannot both use manual and automatic indexing" - if raised_manual_index: - raise ( - "Index " + str(raised_manual_index.value()) + " not in *args" - ) - if start: - raise "there is a single curly { left unclosed or unescaped" - - return entries^ diff --git a/stdlib/src/utils/string_slice.mojo b/stdlib/src/utils/string_slice.mojo index a860c5982e..9855f99de7 100644 --- a/stdlib/src/utils/string_slice.mojo +++ b/stdlib/src/utils/string_slice.mojo @@ -22,8 +22,8 @@ from utils import StringSlice from bit import count_leading_zeros from utils import Span -from collections.string import _isspace -from collections import List +from collections.string import _isspace, _atol, _atof +from collections import List, Optional from memory import memcmp, UnsafePointer from sys import simdwidthof, bitwidthof @@ -492,6 +492,26 @@ struct StringSlice[ unsafe_pointer=self.unsafe_ptr(), length=self.byte_length() ) + @always_inline + fn __int__(self) raises -> Int: + """Parses the given string as a base-10 integer and returns that value. + If the string cannot be parsed as an int, an error is raised. + + Returns: + An integer value that represents the string, or otherwise raises. + """ + return _atol(self._strref_dangerous()) + + @always_inline + fn __float__(self) raises -> Float64: + """Parses the string as a float point number and returns that value. If + the string cannot be parsed as a float, an error is raised. + + Returns: + A float value that represents the string, or otherwise raises. + """ + return _atof(self._strref_dangerous()) + # ===------------------------------------------------------------------===# # Methods # ===------------------------------------------------------------------===# @@ -585,6 +605,32 @@ struct StringSlice[ # and use something smarter. return StringSlice(unsafe_from_utf8=self._slice[abs_start:]) + @always_inline + fn format[*Ts: _CurlyEntryFormattable](self, *args: *Ts) raises -> String: + """Format a template with `*args`. + + Args: + args: The substitution values. + + Parameters: + Ts: The types of substitution values that implement `Representable` + and `Stringable` (to be changed and made more flexible). + + Returns: + The template with the given values substituted. + + Examples: + + ```mojo + # Manual indexing: + print("{0} {1} {0}".format("Mojo", 1.125)) # Mojo 1.125 Mojo + # Automatic indexing: + print("{} {}".format(True, "hello world")) # True hello world + ``` + . + """ + return _FormatCurlyEntry.format(self, args) + fn find(self, substr: StringSlice, start: Int = 0) -> Int: """Finds the offset of the first occurrence of `substr` starting at `start`. If not found, returns -1. @@ -712,3 +758,640 @@ struct StringSlice[ current_offset += eol_location + eol_length return output^ + + +# ===----------------------------------------------------------------------===# +# Utils +# ===----------------------------------------------------------------------===# + + +trait Stringlike: + """Trait intended to be used only with `String`, `StringLiteral` and + `StringSlice`.""" + + fn byte_length(self) -> Int: + """Get the string length in bytes. + + Returns: + The length of this string in bytes. + + Notes: + This does not include the trailing null terminator in the count. + """ + ... + + fn unsafe_ptr(self) -> UnsafePointer[UInt8]: + """Get raw pointer to the underlying data. + + Returns: + The raw pointer to the data. + """ + ... + + +# ===----------------------------------------------------------------------===# +# Format method structures +# ===----------------------------------------------------------------------===# + + +trait _CurlyEntryFormattable(Stringable, Representable): + """This trait is used by the `format()` method to support format specifiers. + Currently, it is a composition of both `Stringable` and `Representable` + traits i.e. a type to be formatted must implement both. In the future this + will be less constrained. + """ + + pass + + +@value +struct _FormatCurlyEntry(CollectionElement, CollectionElementNew): + """The struct that handles string-like formatting by curly braces entries. + This is internal for the types: `String`, `StringLiteral` and `StringSlice`. + """ + + var first_curly: Int + """The index of an opening brace around a substitution field.""" + var last_curly: Int + """The index of a closing brace around a substitution field.""" + # TODO: ord("a") conversion flag not supported yet + var conversion_flag: UInt8 + """The type of conversion for the entry: {ord("s"), ord("r")}.""" + var format_spec: Optional[_FormatSpec] + """The format specifier.""" + # TODO: ord("a") conversion flag not supported yet + alias supported_conversion_flags = SIMD[DType.uint8, 2](ord("s"), ord("r")) + """Currently supported conversion flags: `__str__` and `__repr__`.""" + alias _FieldVariantType = Variant[String, Int, NoneType, Bool] + """Purpose of the `Variant` `Self.field`: + + - `Int` for manual indexing: (value field contains `0`). + - `NoneType` for automatic indexing: (value field contains `None`). + - `String` for **kwargs indexing: (value field contains `foo`). + - `Bool` for escaped curlies: (value field contains False for `{` or True + for `}`). + """ + var field: Self._FieldVariantType + """Store the substitution field. See `Self._FieldVariantType` docstrings for + more details.""" + alias _args_t = VariadicPack[element_trait=_CurlyEntryFormattable, *_] + """Args types that are formattable by curly entry.""" + + fn __init__(inout self, *, other: Self): + self.first_curly = other.first_curly + self.last_curly = other.last_curly + self.conversion_flag = other.conversion_flag + self.field = Self._FieldVariantType(other=other.field) + self.format_spec = other.format_spec + + fn __init__( + inout self, + first_curly: Int, + last_curly: Int, + field: Self._FieldVariantType, + conversion_flag: UInt8 = 0, + format_spec: Optional[_FormatSpec] = None, + ): + self.first_curly = first_curly + self.last_curly = last_curly + self.field = field + self.conversion_flag = conversion_flag + self.format_spec = format_spec + + @always_inline + fn is_escaped_brace(ref [_]self) -> Bool: + return self.field.isa[Bool]() + + @always_inline + fn is_kwargs_field(ref [_]self) -> Bool: + return self.field.isa[String]() + + @always_inline + fn is_automatic_indexing(ref [_]self) -> Bool: + return self.field.isa[NoneType]() + + @always_inline + fn is_manual_indexing(ref [_]self) -> Bool: + return self.field.isa[Int]() + + @staticmethod + fn format[T: Stringlike](fmt_src: T, args: Self._args_t) raises -> String: + alias len_pos_args = __type_of(args).__len__() + entries, size_estimation = Self._create_entries(fmt_src, len_pos_args) + var fmt_len = fmt_src.byte_length() + var buf = String._buffer_type(capacity=fmt_len + size_estimation) + buf.size = 1 + buf.unsafe_set(0, 0) + var res = String(buf^) + var offset = 0 + var ptr = fmt_src.unsafe_ptr() + alias S = StringSlice[StaticConstantOrigin] + + @always_inline("nodebug") + fn _build_slice(p: UnsafePointer[UInt8], start: Int, end: Int) -> S: + return S(unsafe_from_utf8_ptr=p + start, len=end - start) + + var auto_arg_index = 0 + for e in entries: + debug_assert(offset < fmt_len, "offset >= fmt_src.byte_length()") + res += _build_slice(ptr, offset, e[].first_curly) + e[]._format_entry[len_pos_args](res, args, auto_arg_index) + offset = e[].last_curly + 1 + + res += _build_slice(ptr, offset, fmt_len) + return res^ + + @staticmethod + fn _create_entries[ + T: Stringlike + ](fmt_src: T, len_pos_args: Int) raises -> (List[Self], Int): + """Returns a list of entries and its total estimated entry byte width. + """ + var manual_indexing_count = 0 + var automatic_indexing_count = 0 + var raised_manual_index = Optional[Int](None) + var raised_automatic_index = Optional[Int](None) + var raised_kwarg_field = Optional[String](None) + alias `}` = UInt8(ord("}")) + alias `{` = UInt8(ord("{")) + alias l_err = "there is a single curly { left unclosed or unescaped" + alias r_err = "there is a single curly } left unclosed or unescaped" + + var entries = List[Self]() + var start = Optional[Int](None) + var skip_next = False + var fmt_ptr = fmt_src.unsafe_ptr() + var fmt_len = fmt_src.byte_length() + var total_estimated_entry_byte_width = 0 + + for i in range(fmt_len): + if skip_next: + skip_next = False + continue + if fmt_ptr[i] == `{`: + if not start: + start = i + continue + if i - start.value() != 1: + raise Error(l_err) + # python escapes double curlies + entries.append(Self(start.value(), i, field=False)) + start = None + continue + elif fmt_ptr[i] == `}`: + if not start and (i + 1) < fmt_len: + # python escapes double curlies + if fmt_ptr[i + 1] == `}`: + entries.append(Self(i, i + 1, field=True)) + total_estimated_entry_byte_width += 2 + skip_next = True + continue + elif not start: # if it is not an escaped one, it is an error + raise Error(r_err) + + var start_value = start.value() + var current_entry = Self(start_value, i, field=NoneType()) + + if i - start_value != 1: + if current_entry._handle_field_and_break( + fmt_src, + len_pos_args, + i, + start_value, + automatic_indexing_count, + raised_automatic_index, + manual_indexing_count, + raised_manual_index, + raised_kwarg_field, + total_estimated_entry_byte_width, + ): + break + else: # automatic indexing + if automatic_indexing_count >= len_pos_args: + raised_automatic_index = automatic_indexing_count + break + automatic_indexing_count += 1 + total_estimated_entry_byte_width += 8 # guessing + entries.append(current_entry^) + start = None + + if raised_automatic_index: + raise Error("Automatic indexing require more args in *args") + elif raised_kwarg_field: + var val = raised_kwarg_field.value() + raise Error("Index " + val + " not in kwargs") + elif manual_indexing_count and automatic_indexing_count: + raise Error("Cannot both use manual and automatic indexing") + elif raised_manual_index: + var val = str(raised_manual_index.value()) + raise Error("Index " + val + " not in *args") + elif start: + raise Error(l_err) + return entries^, total_estimated_entry_byte_width + + fn _handle_field_and_break[ + T: Stringlike + ]( + inout self, + fmt_src: T, + len_pos_args: Int, + i: Int, + start_value: Int, + inout automatic_indexing_count: Int, + inout raised_automatic_index: Optional[Int], + inout manual_indexing_count: Int, + inout raised_manual_index: Optional[Int], + inout raised_kwarg_field: Optional[String], + inout total_estimated_entry_byte_width: Int, + ) raises -> Bool: + alias S = StringSlice[StaticConstantOrigin] + + @always_inline("nodebug") + fn _build_slice(p: UnsafePointer[UInt8], start: Int, end: Int) -> S: + return S(unsafe_from_utf8_ptr=p + start, len=end - start) + + var field = _build_slice(fmt_src.unsafe_ptr(), start_value + 1, i) + var field_ptr = field.unsafe_ptr() + var field_len = i - (start_value + 1) + var exclamation_index = -1 + var idx = 0 + while idx < field_len: + if field_ptr[idx] == ord("!"): + exclamation_index = idx + break + idx += 1 + var new_idx = exclamation_index + 1 + if exclamation_index != -1: + if new_idx == field_len: + raise Error("Empty conversion flag.") + var conversion_flag = field_ptr[new_idx] + if field_len - new_idx > 1 or ( + conversion_flag not in Self.supported_conversion_flags + ): + var f = String(_build_slice(field_ptr, new_idx, field_len)) + _ = field^ + raise Error('Conversion flag "' + f + '" not recognised.') + self.conversion_flag = conversion_flag + field = _build_slice(field_ptr, 0, exclamation_index) + else: + new_idx += 1 + + var extra = int(new_idx < field_len) + var fmt_field = _build_slice(field_ptr, new_idx + extra, field_len) + self.format_spec = _FormatSpec.parse(fmt_field) + var w = int(self.format_spec.value().width) if self.format_spec else 0 + # fully guessing the byte width here to be at least 8 bytes per entry + # minus the length of the whole format specification + total_estimated_entry_byte_width += 8 * int(w > 0) + w - (field_len + 2) + + if field.byte_length() == 0: + # an empty field, so it's automatic indexing + if automatic_indexing_count >= len_pos_args: + raised_automatic_index = automatic_indexing_count + return True + automatic_indexing_count += 1 + else: + try: + # field is a number for manual indexing: + var number = int(field) + self.field = number + if number >= len_pos_args or number < 0: + raised_manual_index = number + return True + manual_indexing_count += 1 + except e: + alias unexp = "Not the expected error from atol" + debug_assert("not convertible to integer" in str(e), unexp) + # field is a keyword for **kwargs: + var f = str(field) + self.field = f + raised_kwarg_field = f + return True + return False + + fn _format_entry[ + len_pos_args: Int + ](self, inout res: String, args: Self._args_t, inout auto_idx: Int) raises: + # TODO(#3403 and/or #3252): this function should be able to use + # Formatter syntax when the type implements it, since it will give great + # performance benefits. This also needs to be able to check if the given + # args[i] conforms to the trait needed by the conversion_flag to avoid + # needing to constraint that every type needs to conform to every trait. + alias `r` = UInt8(ord("r")) + alias `s` = UInt8(ord("s")) + # alias `a` = UInt8(ord("a")) # TODO + + @parameter + fn _format(idx: Int) raises: + @parameter + for i in range(len_pos_args): + if i == idx: + var type_impls_repr = True # TODO + var type_impls_str = True # TODO + var type_impls_formatter_repr = True # TODO + var type_impls_formatter_str = True # TODO + var flag = self.conversion_flag + var empty = flag == 0 and not self.format_spec + + var data: String + if empty and type_impls_formatter_str: + data = str(args[i]) # TODO: use formatter and return + elif empty and type_impls_str: + data = str(args[i]) + elif flag == `s` and type_impls_formatter_str: + if empty: + # TODO: use formatter and return + pass + data = str(args[i]) + elif flag == `s` and type_impls_str: + data = str(args[i]) + elif flag == `r` and type_impls_formatter_repr: + if empty: + # TODO: use formatter and return + pass + data = repr(args[i]) + elif flag == `r` and type_impls_repr: + data = repr(args[i]) + elif self.format_spec: + self.format_spec.value().stringify(res, args[i]) + return + else: + alias argnum = "Argument number: " + alias does_not = " does not implement the trait " + alias needed = "needed for conversion_flag: " + var flg = String(List[UInt8](flag, 0)) + raise Error(argnum + str(i) + does_not + needed + flg) + + if self.format_spec: + self.format_spec.value().format_string(res, data) + else: + res += data + + if self.is_escaped_brace(): + res += "}" if self.field[Bool] else "{" + elif self.is_manual_indexing(): + _format(self.field[Int]) + elif self.is_automatic_indexing(): + _format(auto_idx) + auto_idx += 1 + + +@value +@register_passable("trivial") +struct _FormatSpec: + """Store every field of the format specifier in a byte (e.g., ord("+") for + sign). It is stored in a byte because every [format specifier](\ + https://docs.python.org/3/library/string.html#formatspec) is an ASCII + character. + """ + + var fill: UInt8 + """If a valid align value is specified, it can be preceded by a fill + character that can be any character and defaults to a space if omitted. + """ + var align: UInt8 + """The meaning of the various alignment options is as follows: + + | Option | Meaning| + |:-------|:-------| + |'<' | Forces the field to be left-aligned within the available space + (this is the default for most objects).| + |'>' | Forces the field to be right-aligned within the available space + (this is the default for numbers).| + |'=' | Forces the padding to be placed after the sign (if any) but before + the digits. This is used for printing fields in the form `+000000120`. This + alignment option is only valid for numeric types. It becomes the default for + numbers when `0` immediately precedes the field width.| + |'^' | Forces the field to be centered within the available space.| + """ + var sign: UInt8 + """The sign option is only valid for number types, and can be one of the + following: + + | Option | Meaning| + |:-------|:-------| + |'+' | indicates that a sign should be used for both positive as well as + negative numbers.| + |'-' | indicates that a sign should be used only for negative numbers (this + is the default behavior).| + |space | indicates that a leading space should be used on positive numbers, + and a minus sign on negative numbers.| + """ + var coerce_z: Bool + """The 'z' option coerces negative zero floating-point values to positive + zero after rounding to the format precision. This option is only valid for + floating-point presentation types. + """ + var alternate_form: Bool + """The alternate form is defined differently for different types. This + option is only valid for types that implement the trait `# TODO: define + trait`. For integers, when binary, octal, or hexadecimal output is used, + this option adds the respective prefix '0b', '0o', '0x', or '0X' to the + output value. For float and complex the alternate form causes the result of + the conversion to always contain a decimal-point character, even if no + digits follow it. + """ + var width: UInt8 + """A decimal integer defining the minimum total field width, including any + prefixes, separators, and other formatting characters. If not specified, + then the field width will be determined by the content. When no explicit + alignment is given, preceding the width field by a zero ('0') character + enables sign-aware zero-padding for numeric types. This is equivalent to a + fill character of '0' with an alignment type of '='. + """ + var grouping_option: UInt8 + """The ',' option signals the use of a comma for a thousands separator. For + a locale aware separator, use the 'n' integer presentation type instead. The + '_' option signals the use of an underscore for a thousands separator for + floating-point presentation types and for integer presentation type 'd'. For + integer presentation types 'b', 'o', 'x', and 'X', underscores will be + inserted every 4 digits. For other presentation types, specifying this + option is an error. + """ + var precision: UInt8 + """The precision is a decimal integer indicating how many digits should be + displayed after the decimal point for presentation types 'f' and 'F', or + before and after the decimal point for presentation types 'g' or 'G'. For + string presentation types the field indicates the maximum field size - in + other words, how many characters will be used from the field content. The + precision is not allowed for integer presentation types. + """ + var type: UInt8 + """Determines how the data should be presented. + + The available integer presentation types are: + + | Option | Meaning| + |:-------|:-------| + |'b' |Binary format. Outputs the number in base 2.| + |'c' |Character. Converts the integer to the corresponding unicode character + before printing.| + |'d' |Decimal Integer. Outputs the number in base 10.| + |'o' |Octal format. Outputs the number in base 8.| + |'x' |Hex format. Outputs the number in base 16, using lower-case letters + for the digits above 9.| + |'X' |Hex format. Outputs the number in base 16, using upper-case letters + for the digits above 9. In case '#' is specified, the prefix '0x' will be + upper-cased to '0X' as well.| + |'n' |Number. This is the same as 'd', except that it uses the current + locale setting to insert the appropriate number separator characters.| + |None | The same as 'd'.| + + In addition to the above presentation types, integers can be formatted with + the floating-point presentation types listed below (except 'n' and None). + When doing so, float() is used to convert the integer to a floating-point + number before formatting. + + The available presentation types for float and Decimal values are: + + | Option | Meaning| + |:-------|:-------| + |'e' |Scientific notation. For a given precision p, formats the number in + scientific notation with the letter `e` separating the coefficient from the + exponent. The coefficient has one digit before and p digits after the + decimal point, for a total of p + 1 significant digits. With no precision + given, uses a precision of 6 digits after the decimal point for float, and + shows all coefficient digits for Decimal. If no digits follow the decimal + point, the decimal point is also removed unless the # option is used.| + |'E' |Scientific notation. Same as 'e' except it uses an upper case `E` as + the separator character.| + |'f' |Fixed-point notation. For a given precision p, formats the number as a + decimal number with exactly p digits following the decimal point. With no + precision given, uses a precision of 6 digits after the decimal point for + float, and uses a precision large enough to show all coefficient digits for + Decimal. If no digits follow the decimal point, the decimal point is also + removed unless the # option is used.| + |'F' |Fixed-point notation. Same as 'f', but converts nan to NAN and inf to + INF.| + |'g' |General format. For a given precision p >= 1, this rounds the number + to p significant digits and then formats the result in either fixed-point + format or in scientific notation, depending on its magnitude. A precision of + 0 is treated as equivalent to a precision of 1. + The precise rules are as follows: suppose that the result formatted with + presentation type 'e' and precision p-1 would have exponent exp. Then, if + m <= exp < p, where m is -4 for floats and -6 for Decimals, the number is + formatted with presentation type 'f' and precision p-1-exp. Otherwise, the + number is formatted with presentation type 'e' and precision p-1. In both + cases insignificant trailing zeros are removed from the significand, and the + decimal point is also removed if there are no remaining digits following it, + unless the '#' option is used. + With no precision given, uses a precision of 6 significant digits for float. + For Decimal, the coefficient of the result is formed from the coefficient + digits of the value; scientific notation is used for values smaller than + 1e-6 in absolute value and values where the place value of the least + significant digit is larger than 1, and fixed-point notation is used + otherwise. + Positive and negative infinity, positive and negative zero, and nans, are + formatted as inf, -inf, 0, -0 and nan respectively, regardless of the + precision.| + |'G' |General format. Same as 'g' except switches to 'E' if the number gets + too large. The representations of infinity and NaN are uppercased, too.| + |'n' |Number. This is the same as 'g', except that it uses the current + locale setting to insert the appropriate number separator characters.| + |'%' |Percentage. Multiplies the number by 100 and displays in fixed ('f') + format, followed by a percent sign.| + |None |For float this is like the 'g' type, except that when fixed-point + notation is used to format the result, it always includes at least one digit + past the decimal point, and switches to the scientific notation when + exp >= p - 1. When the precision is not specified, the latter will be as + large as needed to represent the given value faithfully. + For Decimal, this is the same as either 'g' or 'G' depending on the value of + context.capitals for the current decimal context. + The overall effect is to match the output of str() as altered by the other + format modifiers.| + """ + + fn __init__( + inout self, + fill: UInt8 = ord(" "), + align: UInt8 = 0, + sign: UInt8 = ord("-"), + coerce_z: Bool = False, + alternate_form: Bool = False, + width: UInt8 = 0, + grouping_option: UInt8 = 0, + precision: UInt8 = 0, + type: UInt8 = 0, + ): + """Construct a FormatSpec instance. + + Args: + fill: Defaults to space. + align: Defaults to 0 which is adjusted to the default for the arg + type. + sign: Defaults to `-`. + coerce_z: Defaults to False. + alternate_form: Defaults to False. + width: Defaults to 0 which is adjusted to the default for the arg + type. + grouping_option: Defaults to 0 which is adjusted to the default for + the arg type. + precision: Defaults to 0 which is adjusted to the default for the + arg type. + type: Defaults to 0 which is adjusted to the default for the arg + type. + """ + self.fill = fill + self.align = align + self.sign = sign + self.coerce_z = coerce_z + self.alternate_form = alternate_form + self.width = width + self.grouping_option = grouping_option + self.precision = precision + self.type = type + + @staticmethod + fn parse(fmt_str: StringSlice) -> Optional[Self]: + """Parses the format spec string. + + Args: + fmt_str: The StringSlice with the format spec. + + Returns: + An instance of FormatSpec. + """ + var f_len = fmt_str.byte_length() + var f_ptr = fmt_str.unsafe_ptr() + var colon_idx = -1 + var idx = 0 + while idx < f_len: + if f_ptr[idx] == ord(":"): + exclamation_index = idx + break + idx += 1 + + if colon_idx == -1: + return None + + # TODO: Future implementation of format specifiers + return None + + fn stringify[ + T: _CurlyEntryFormattable + ](self, inout res: String, item: T) raises: + """Stringify a type according to its format specification. + + Args: + res: The resulting String. + item: The item to stringify. + """ + var type_implements_float = True # TODO + var type_implements_float_raising = True # TODO + var type_implements_int = True # TODO + var type_implements_int_raising = True # TODO + + # TODO: transform to int/float depending on format spec and stringify + # with hex/bin/oct etc. + res += str(item) + + fn format_string(self, inout res: String, item: String) raises: + """Transform a String according to its format specification. + + Args: + res: The resulting String. + item: The item to format. + """ + + # TODO: align, fill, etc. + res += item diff --git a/stdlib/test/collections/test_string.mojo b/stdlib/test/collections/test_string.mojo index ed92b61a49..de0a461913 100644 --- a/stdlib/test/collections/test_string.mojo +++ b/stdlib/test/collections/test_string.mojo @@ -1396,50 +1396,36 @@ def test_format_args(): with assert_raises(contains="Index first not in kwargs"): _ = String("A {first} B {second}").format(1, 2) - assert_equal( - String(" {} , {} {} !").format( - "Hello", - "Beautiful", - "World", - ), - " Hello , Beautiful World !", - ) + var s = String(" {} , {} {} !").format("Hello", "Beautiful", "World") + assert_equal(s, " Hello , Beautiful World !") - with assert_raises( - contains="there is a single curly { left unclosed or unescaped" - ): + fn curly(c: StringLiteral) -> StringLiteral: + return "there is a single curly " + c + " left unclosed or unescaped" + + with assert_raises(contains=curly("{")): _ = String("{ {}").format(1) - with assert_raises( - contains="there is a single curly { left unclosed or unescaped" - ): + with assert_raises(contains=curly("{")): _ = String("{ {0}").format(1) - with assert_raises( - contains="there is a single curly { left unclosed or unescaped" - ): + with assert_raises(contains=curly("{")): _ = String("{}{").format(1) - with assert_raises( - contains="there is a single curly } left unclosed or unescaped" - ): + with assert_raises(contains=curly("}")): _ = String("{}}").format(1) - with assert_raises( - contains="there is a single curly { left unclosed or unescaped" - ): + with assert_raises(contains=curly("{")): _ = String("{} {").format(1) - with assert_raises( - contains="there is a single curly { left unclosed or unescaped" - ): + with assert_raises(contains=curly("{")): _ = String("{").format(1) - with assert_raises( - contains="there is a single curly } left unclosed or unescaped" - ): + with assert_raises(contains=curly("}")): _ = String("}").format(1) + with assert_raises(contains=""): + _ = String("{}").format() + assert_equal(String("}}").format(), "}") assert_equal(String("{{").format(), "{") @@ -1469,13 +1455,8 @@ def test_format_args(): output = String(vinput).format() assert_equal(len(output), 0) - assert_equal( - String("{0} {1} ❤️‍🔥 {1} {0}").format( - "🔥", - "Mojo", - ), - "🔥 Mojo ❤️‍🔥 Mojo 🔥", - ) + var res = "🔥 Mojo ❤️‍🔥 Mojo 🔥" + assert_equal(String("{0} {1} ❤️‍🔥 {1} {0}").format("🔥", "Mojo"), res) assert_equal(String("{0} {1}").format(True, 1.125), "True 1.125") @@ -1501,6 +1482,7 @@ def test_format_conversion_flags(): assert_equal(String("{} {!r}").format(a, a), "Mojo 'Mojo'") assert_equal(String("{!s} {!r}").format(a, a), "Mojo 'Mojo'") assert_equal(String("{0!s} {0!r}").format(a), "Mojo 'Mojo'") + assert_equal(String("{0!s} {0!r}").format(a, "Mojo2"), "Mojo 'Mojo'") var b = 21.1 assert_true( diff --git a/stdlib/test/hashlib/test_ahash.mojo b/stdlib/test/hashlib/test_ahash.mojo index 1ec984086b..b411b857e5 100644 --- a/stdlib/test/hashlib/test_ahash.mojo +++ b/stdlib/test/hashlib/test_ahash.mojo @@ -18,7 +18,7 @@ from hashlib._ahash import AHasher from hashlib.hash import hash as old_hash from hashlib._hasher import _hash_with_hasher as hash from testing import assert_equal, assert_not_equal, assert_true -from memory import memset_zero, UnsafePointer +from memory import memset_zero, stack_allocation from time import now from utils import Span @@ -576,7 +576,7 @@ you, утра, боль, хорошие, пришёл, открой, брось, fn gen_word_pairs[words: String = words_en]() -> List[String]: var result = List[String]() try: - var list = words.split(",") + var list = words.split(", ") for w in list: var w1 = w[].strip() for w in list: @@ -636,7 +636,7 @@ def test_hash_byte_array(): def test_avalanche(): # test that values which differ just in one bit, # produce significatly different hash values - var data = UnsafePointer[UInt8].alloc(256) + var data = stack_allocation[256, UInt8]() memset_zero(data, 256) var hashes0 = List[UInt64]() var hashes1 = List[UInt64]() @@ -666,7 +666,7 @@ def test_avalanche(): def test_trailing_zeros(): # checks that a value with different amount of trailing zeros, # results in significantly different hash values - var data = UnsafePointer[UInt8].alloc(8) + var data = stack_allocation[8, UInt8]() memset_zero(data, 8) data[0] = 23 var hashes0 = List[UInt64]() From 583a0ce599120c83cc0d263e4f827a039f3186e4 Mon Sep 17 00:00:00 2001 From: Joshua James Venter Date: Wed, 16 Oct 2024 17:50:50 -0600 Subject: [PATCH 6/8] [External] [stdlib] Improve readability of `test_python_object` (#49254) [External] [stdlib] Improve readability of `test_python_object` Instead of raw strings with some Python inline for types and methods, pull them out into a separate Python module: `custom_indexable.py` and use that within the test. This allows for better maintenance and readability by having the code in a separate file. Co-authored-by: Joshua James Venter Closes modularml/mojo#3678 MODULAR_ORIG_COMMIT_REV_ID: c9a63e78e4d2f38401b6d344accd356dca89f9e6 --- stdlib/test/python/custom_indexable.py | 48 ++++++++++++++++ stdlib/test/python/test_python_object.mojo | 65 ++++++---------------- 2 files changed, 65 insertions(+), 48 deletions(-) create mode 100644 stdlib/test/python/custom_indexable.py diff --git a/stdlib/test/python/custom_indexable.py b/stdlib/test/python/custom_indexable.py new file mode 100644 index 0000000000..1164f7f5cb --- /dev/null +++ b/stdlib/test/python/custom_indexable.py @@ -0,0 +1,48 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2024, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + + +class Simple: + def __init__(self): + pass + + +class WithGetItem: + def __getitem__(self, key): + if isinstance(key, tuple): + return "Keys: {0}".format(", ".join(map(str, key))) + else: + return "Key: {0}".format(key) + + +class WithGetItemException: + def __getitem__(self, key): + raise ValueError("Custom error") + + +class With2DGetItem: + def __init__(self): + self.data = [[1, 2, 3], [4, 5, 6], [7, 8, 9]] + + def __getitem__(self, key): + if isinstance(key, tuple) and all(isinstance(k, slice) for k in key): + return [row[key[1]] for row in self.data[key[0]]] + elif isinstance(key, tuple): + return self.data[key[0]][key[1]] + else: + return self.data[key] + + +class Sliceable: + def __getitem__(self, key): + return key diff --git a/stdlib/test/python/test_python_object.mojo b/stdlib/test/python/test_python_object.mojo index 264eb958ee..b6e5fc38c1 100644 --- a/stdlib/test/python/test_python_object.mojo +++ b/stdlib/test/python/test_python_object.mojo @@ -410,6 +410,8 @@ fn test_none() raises: fn test_getitem_raises() raises: + custom_indexable = Python.import_module("custom_indexable") + var a = PythonObject(2) with assert_raises(contains="'int' object is not subscriptable"): _ = a[0] @@ -434,43 +436,23 @@ fn test_getitem_raises() raises: with assert_raises(contains="'NoneType' object is not subscriptable"): _ = d[0, 0] - with_get = Python.evaluate( - """type('WithGetItem', (), { - '__getitem__': lambda self, key: - 'Keys: {0}'.format(", ".join(map(str, key))) if isinstance(key, tuple) - else 'Key: {0}'.format(key) - })()""" - ) + with_get = custom_indexable.WithGetItem() assert_equal("Key: 0", str(with_get[0])) assert_equal("Keys: 0, 0", str(with_get[0, 0])) assert_equal("Keys: 0, 0, 0", str(with_get[0, 0, 0])) - var without_get = Python.evaluate( - "type('WithOutGetItem', (), {'__str__': lambda self: \"SomeString\"})()" - ) - with assert_raises(contains="'WithOutGetItem' object is not subscriptable"): + var without_get = custom_indexable.Simple() + with assert_raises(contains="'Simple' object is not subscriptable"): _ = without_get[0] - with assert_raises(contains="'WithOutGetItem' object is not subscriptable"): + with assert_raises(contains="'Simple' object is not subscriptable"): _ = without_get[0, 0] - var with_get_exception = Python.evaluate( - "type('WithGetItemException', (), {'__getitem__': lambda self, key: (_" - ' for _ in ()).throw(ValueError("Custom error")),})()' - ) - + var with_get_exception = custom_indexable.WithGetItemException() with assert_raises(contains="Custom error"): _ = with_get_exception[1] - with_2d = Python.evaluate( - """type('With2D', (), { - '__init__': lambda self: setattr(self, 'data', [[1, 2, 3], [4, 5, 6]]), - '__getitem__': lambda self, key: ( - self.data[key[0]][key[1]] if isinstance(key, tuple) - else self.data[key] - ) - })()""" - ) + with_2d = custom_indexable.With2DGetItem() assert_equal("[1, 2, 3]", str(with_2d[0])) assert_equal(2, with_2d[0, 1]) assert_equal(6, with_2d[1, 2]) @@ -479,13 +461,14 @@ fn test_getitem_raises() raises: _ = with_2d[0, 4] with assert_raises(contains="list index out of range"): - _ = with_2d[2, 0] + _ = with_2d[3, 0] with assert_raises(contains="list index out of range"): - _ = with_2d[2] + _ = with_2d[3] def test_setitem_raises(): + custom_indexable = Python.import_module("custom_indexable") t = Python.evaluate("(1,2,3)") with assert_raises( contains="'tuple' object does not support item assignment" @@ -502,15 +485,11 @@ def test_setitem_raises(): ): s[3] = "xy" - custom = Python.evaluate( - """type('Custom', (), { - '__init__': lambda self: None, - })()""" - ) + with_out = custom_indexable.Simple() with assert_raises( - contains="'Custom' object does not support item assignment" + contains="'Simple' object does not support item assignment" ): - custom[0] = 0 + with_out[0] = 0 d = Python.evaluate("{}") with assert_raises(contains="unhashable type: 'list'"): @@ -518,6 +497,7 @@ def test_setitem_raises(): fn test_py_slice() raises: + custom_indexable = Python.import_module("custom_indexable") var a = PythonObject([1, 2, 3, 4, 5]) assert_equal("[2, 3]", str(a[1:3])) assert_equal("[1, 2, 3, 4, 5]", str(a[:])) @@ -564,25 +544,14 @@ fn test_py_slice() raises: # with assert_raises(contains="slice(1, 3, None)"): # _ = d[1:3] - var custom = Python.evaluate( - "type('CustomSliceable', (), {'__getitem__': lambda self, key: key})()" - ) + var custom = custom_indexable.Sliceable() assert_equal("slice(1, 3, None)", str(custom[1:3])) var i = PythonObject(1) with assert_raises(contains="'int' object is not subscriptable"): _ = i[0:1] - with_2d = Python.evaluate( - """type('With2D', (), { - '__init__': lambda self: setattr(self, 'data', [[1, 2, 3], [4, 5, 6], [7, 8, 9]]), - '__getitem__': lambda self, key: ( - [row[key[1]] for row in self.data[key[0]]] if isinstance(key, tuple) and all(isinstance(k, slice) for k in key) - else (self.data[key[0]][key[1]] if isinstance(key, tuple) - else self.data[key]) - ) - })()""" - ) + with_2d = custom_indexable.With2DGetItem() assert_equal("[1, 2]", str(with_2d[0, PythonObject(Slice(0, 2))])) assert_equal("[1, 2]", str(with_2d[0][0:2])) From 174f3d838d3b6c522b0a264b0f9f5794739a4e28 Mon Sep 17 00:00:00 2001 From: Helehex Date: Wed, 16 Oct 2024 17:51:05 -0600 Subject: [PATCH 7/8] [External] [stdlib] Fix example in `/floatable.mojo` (#49279) [External] [stdlib] Fix example in `/floatable.mojo` Co-authored-by: Helehex Closes modularml/mojo#3679 MODULAR_ORIG_COMMIT_REV_ID: 5204d8180f7dfe4ede4afc7392e7c61d603aaa4d --- stdlib/src/builtin/floatable.mojo | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/stdlib/src/builtin/floatable.mojo b/stdlib/src/builtin/floatable.mojo index 340265d291..1510173716 100644 --- a/stdlib/src/builtin/floatable.mojo +++ b/stdlib/src/builtin/floatable.mojo @@ -69,8 +69,10 @@ trait FloatableRaising: For example: ```mojo + from utils import Variant + @value - struct MaybeFloat(FloatableRasing): + struct MaybeFloat(FloatableRaising): var value: Variant[Float64, NoneType] fn __float__(self) raises -> Float64: From f840ba2ba4fd82146b4a36a01069f78eca54937b Mon Sep 17 00:00:00 2001 From: modularbot Date: Thu, 17 Oct 2024 06:23:03 +0000 Subject: [PATCH 8/8] [stdlib] Bump compiler version to 2024.10.1705 --- stdlib/COMPATIBLE_COMPILER_VERSION | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stdlib/COMPATIBLE_COMPILER_VERSION b/stdlib/COMPATIBLE_COMPILER_VERSION index e5752dc072..3ce44febae 100644 --- a/stdlib/COMPATIBLE_COMPILER_VERSION +++ b/stdlib/COMPATIBLE_COMPILER_VERSION @@ -1 +1 @@ -2024.10.1619 +2024.10.1705