-
Notifications
You must be signed in to change notification settings - Fork 89
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improve simplify_algebra
to find more horizontal fusion opportunities
#3432
Comments
Note for clarity: the fusion opportunity is because the |
Here is standalone case that demonstrates the issue: p = migraphx.program()
m = p.get_main_module()
x_0 = m.add_literal(migraphx.generate_argument(migraphx.shape(type="float_type", lens=[64, 512]), 0))
x_1 = m.add_literal(migraphx.generate_argument(migraphx.shape(type="float_type", lens=[64, 512]), 1))
x_2 = m.add_literal(migraphx.generate_argument(migraphx.shape(type="float_type", lens=[64, 512]), 2))
x_3 = m.add_literal(migraphx.generate_argument(migraphx.shape(type="float_type", lens=[64, 512]), 3))
x_4 = m.add_literal(migraphx.generate_argument(migraphx.shape(type="float_type", lens=[64, 512]), 4))
x_5 = m.add_literal(migraphx.generate_argument(migraphx.shape(type="float_type", lens=[64, 512]), 5))
x_6 = m.add_literal(migraphx.generate_argument(migraphx.shape(type="float_type", lens=[64, 512]), 6))
x_7 = m.add_literal(migraphx.generate_argument(migraphx.shape(type="float_type", lens=[64, 512]), 7))
x_8 = m.add_literal(migraphx.generate_argument(migraphx.shape(type="float_type", lens=[64, 512]), 8))
x_9 = m.add_literal(migraphx.generate_argument(migraphx.shape(type="float_type", lens=[64, 512]), 9))
x_10 = m.add_literal(migraphx.generate_argument(migraphx.shape(type="float_type", lens=[64, 512]), 10))
x_11 = m.add_literal(migraphx.generate_argument(migraphx.shape(type="float_type", lens=[64, 512]), 11))
x_12 = m.add_literal(migraphx.create_argument(migraphx.shape(type="float_type", lens=[1]), [0.125]))
p_x = m.add_parameter("x", migraphx.shape(type="float_type", lens=[64, 64]))
x_14 = m.add_instruction(migraphx.op("multibroadcast", out_lens=[64,512]), [x_12]) # migraphx.shape(type="float_type", lens=[64, 512], strides=[0, 0])
x_15 = m.add_instruction(migraphx.op("dot"), [p_x, x_11]) # migraphx.shape(type="float_type", lens=[64, 512])
x_16 = m.add_instruction(migraphx.op("add"), [x_15, x_10]) # migraphx.shape(type="float_type", lens=[64, 512])
x_17 = m.add_instruction(migraphx.op("add"), [x_16, x_9]) # migraphx.shape(type="float_type", lens=[64, 512])
x_18 = m.add_instruction(migraphx.op("add"), [x_17, x_8]) # migraphx.shape(type="float_type", lens=[64, 512])
x_19 = m.add_instruction(migraphx.op("dot"), [p_x, x_7]) # migraphx.shape(type="float_type", lens=[64, 512])
x_20 = m.add_instruction(migraphx.op("add"), [x_19, x_6]) # migraphx.shape(type="float_type", lens=[64, 512])
x_21 = m.add_instruction(migraphx.op("add"), [x_20, x_5]) # migraphx.shape(type="float_type", lens=[64, 512])
x_22 = m.add_instruction(migraphx.op("add"), [x_21, x_4]) # migraphx.shape(type="float_type", lens=[64, 512])
x_23 = m.add_instruction(migraphx.op("dot"), [p_x, x_3]) # migraphx.shape(type="float_type", lens=[64, 512])
x_24 = m.add_instruction(migraphx.op("add"), [x_23, x_2]) # migraphx.shape(type="float_type", lens=[64, 512])
x_25 = m.add_instruction(migraphx.op("add"), [x_24, x_1]) # migraphx.shape(type="float_type", lens=[64, 512])
x_26 = m.add_instruction(migraphx.op("add"), [x_25, x_0]) # migraphx.shape(type="float_type", lens=[64, 512])
x_27 = m.add_instruction(migraphx.op("mul"), [x_26, x_14]) # migraphx.shape(type="float_type", lens=[64, 512])
m.add_return([x_18, x_22, x_27]) Which produces this:
|
Before the horizontal fusion
There's an |
@pfultz2, trying to understand the proposed optimization. The above IR is presented in a hacky sort of picture below:
Thanks. |
Before -
Output -
|
In SD clip, there is an opportunity to fuse all the add kernels:
Here the
mul_add
kernel is actually a scalar multiply + add:One possible solution would be to improve
simplify_algebra
to add two loops. The first is to check for horizontal fusions, and the second is to rewrite expressions.The scalar multiply may be standalone after this, so
find_unary_shape_transforms
would need to be tweaked to support this as well.And we may need to add an exception to
find_mul_add
to skip the rewrite if the input is scalar and feeds into a gemm or convolution.The text was updated successfully, but these errors were encountered: