diff --git a/third_party/triton/temporary/numpy_type_promotion.patch b/third_party/triton/temporary/numpy_type_promotion.patch new file mode 100644 index 0000000000000..e41638db8fcaf --- /dev/null +++ b/third_party/triton/temporary/numpy_type_promotion.patch @@ -0,0 +1,12 @@ +--- a/python/test/unit/language/test_core.py ++++ b/python/test/unit/language/test_core.py +@@ -363,8 +363,7 @@ def _test_binary(dtype_x, dtype_y, expr, + # We remove any explicit casting + pattern = r'\.astype\(np\.\w+\)' + scalar_expr = expr if numpy_expr is None else re.sub(pattern, '', numpy_expr) +- with promotion_numpy_2_0(): +- z_ref = eval(scalar_expr) ++ z_ref = eval(scalar_expr) + else: + z_ref = eval(expr if numpy_expr is None else numpy_expr) + diff --git a/third_party/triton/temporary/series.bzl b/third_party/triton/temporary/series.bzl index 4fa55269e3323..0348fe0cbb87f 100644 --- a/third_party/triton/temporary/series.bzl +++ b/third_party/triton/temporary/series.bzl @@ -14,5 +14,6 @@ those to this list. """ temporary_patch_list = [ + "//third_party/triton:temporary/numpy_type_promotion.patch", # Add new patches just above this line ]