Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve simplify_algebra to find more horizontal fusion opportunities #3432

Open
kahmed10 opened this issue Sep 10, 2024 · 5 comments · May be fixed by #3478
Open

Improve simplify_algebra to find more horizontal fusion opportunities #3432

kahmed10 opened this issue Sep 10, 2024 · 5 comments · May be fixed by #3478
Assignees

Comments

@kahmed10
Copy link
Collaborator

In SD clip, there is an opportunity to fuse all the add kernels:

@15 = gpu::code_object[code_object=7632,symbol_name=mlir_dot_add,global=133632,local=256,](@13,@12,@5,@14) -> half_type, {24, 77, 2304}, {177408, 2304, 1}: 0.0934304ms, 2%
@16 = hip::hip_copy_literal[id=main:@literal:78] -> half_type, {768}, {1}: 0.00109522ms, 1%
@17 = hip::hip_copy_literal[id=main:@literal:59] -> half_type, {768}, {1}: 0.00108192ms, 1%
@18 = slice[axes={2},starts={768},ends={1536}](@15) -> half_type, {24, 77, 768}, {177408, 2304, 1}: 0.00165542ms, 1%
@19 = multibroadcast[out_lens={24, 77, 768},out_dyn_dims={}](@17) -> half_type, {24, 77, 768}, {0, 0, 1}: 0.00094074ms, 1%
@20 = load[offset=18184320,end=21022848](@1) -> half_type, {24, 77, 768}, {59136, 768, 1}: 0.00076536ms, 1%
**@21 = gpu::code_object[code_object=5128,symbol_name=add_kernel,global=354816,local=1024,](@19,@18,@20) -> half_type, {24, 77, 768}, {59136, 768, 1}: 0.0211362ms, 1%**
@22 = load[offset=11354112,end=14192640](@1) -> half_type, {24, 77, 768}, {59136, 768, 1}: 0.00099472ms, 1%
@23 = multibroadcast[out_lens={24, 77, 768},out_dyn_dims={}](@16) -> half_type, {24, 77, 768}, {0, 0, 1}: 0.00182424ms, 1%
@24 = slice[axes={2},starts={0},ends={768}](@15) -> half_type, {24, 77, 768}, {177408, 2304, 1}: 0.00103286ms, 1%
**@25 = gpu::code_object[code_object=5136,symbol_name=mul_add_kernel,global=354816,local=1024,](@24,@23,@22) -> half_type, {24, 77, 768}, {59136, 768, 1}: 0.0413997ms, 1%**
@26 = load[offset=14769216,end=18184320](@1) -> half_type, {24, 12, 77, 77}, {71148, 5929, 77, 1}: 0.00105ms, 1%
@27 = gpu::code_object[code_object=6736,symbol_name=mlir_reshape_transpose_reshape_transpose_dot,global=73728,local=256,](@25,@21,@26) -> half_type, {24, 12, 77, 77}, {71148, 5929, 77, 1}: 0.0248955ms, 1%
...
@32 = load[offset=14769216,end=17607744](@1) -> half_type, {24, 77, 768}, {59136, 768, 1}
@33 = multibroadcast[out_lens={24, 77, 768},out_dyn_dims={}](@31) -> half_type, {24, 77, 768}, {0, 0, 1}
@34 = slice[axes={2},starts={1536},ends={2304}](@14) -> half_type, {24, 77, 768}, {177408, 2304, 1}
**@35 = gpu::code_object[code_object=5128,symbol_name=add_kernel,global=354816,local=1024,](@33,@34,@32) -> half_type, {24, 77, 768}, {59136, 768, 1}**

Here the mul_add kernel is actually a scalar multiply + add:

module: "main:pointwise10"
main:pointwise10:x1 = @param:x1 -> half_type, {1}, {0}
main:pointwise10:x0 = @param:x0 -> half_type, {1}, {0}
main:pointwise10:@2 = @literal{0.125} -> half_type, {1}, {0}
main:pointwise10:@3 = mul(main:pointwise10:@2,main:pointwise10:x0) -> half_type, {1}, {0}
main:pointwise10:@4 = add(main:pointwise10:@3,main:pointwise10:x1) -> half_type, {1}, {0}
main:pointwise10:@5 = @return(main:pointwise10:@4)

One possible solution would be to improve simplify_algebra to add two loops. The first is to check for horizontal fusions, and the second is to rewrite expressions.

The scalar multiply may be standalone after this, so find_unary_shape_transforms would need to be tweaked to support this as well.

And we may need to add an exception to find_mul_add to skip the rewrite if the input is scalar and feeds into a gemm or convolution.

@aarushjain29 aarushjain29 linked a pull request Sep 18, 2024 that will close this issue
@aarushjain29 aarushjain29 linked a pull request Sep 25, 2024 that will close this issue
@CharlieL7
Copy link
Collaborator

Note for clarity: the fusion opportunity is because the add and mul_add kernels slice contiguously from the same input of @15 and the other inputs are literals. Both inputs then are used in the gemm instruction @27

@pfultz2
Copy link
Collaborator

pfultz2 commented Nov 6, 2024

Here is standalone case that demonstrates the issue:

p = migraphx.program()
m = p.get_main_module()
x_0 = m.add_literal(migraphx.generate_argument(migraphx.shape(type="float_type", lens=[64, 512]), 0))
x_1 = m.add_literal(migraphx.generate_argument(migraphx.shape(type="float_type", lens=[64, 512]), 1))
x_2 = m.add_literal(migraphx.generate_argument(migraphx.shape(type="float_type", lens=[64, 512]), 2))
x_3 = m.add_literal(migraphx.generate_argument(migraphx.shape(type="float_type", lens=[64, 512]), 3))
x_4 = m.add_literal(migraphx.generate_argument(migraphx.shape(type="float_type", lens=[64, 512]), 4))
x_5 = m.add_literal(migraphx.generate_argument(migraphx.shape(type="float_type", lens=[64, 512]), 5))
x_6 = m.add_literal(migraphx.generate_argument(migraphx.shape(type="float_type", lens=[64, 512]), 6))
x_7 = m.add_literal(migraphx.generate_argument(migraphx.shape(type="float_type", lens=[64, 512]), 7))
x_8 = m.add_literal(migraphx.generate_argument(migraphx.shape(type="float_type", lens=[64, 512]), 8))
x_9 = m.add_literal(migraphx.generate_argument(migraphx.shape(type="float_type", lens=[64, 512]), 9))
x_10 = m.add_literal(migraphx.generate_argument(migraphx.shape(type="float_type", lens=[64, 512]), 10))
x_11 = m.add_literal(migraphx.generate_argument(migraphx.shape(type="float_type", lens=[64, 512]), 11))
x_12 = m.add_literal(migraphx.create_argument(migraphx.shape(type="float_type", lens=[1]), [0.125]))
p_x = m.add_parameter("x", migraphx.shape(type="float_type", lens=[64, 64]))
x_14 = m.add_instruction(migraphx.op("multibroadcast", out_lens=[64,512]), [x_12]) # migraphx.shape(type="float_type", lens=[64, 512], strides=[0, 0])
x_15 = m.add_instruction(migraphx.op("dot"), [p_x, x_11]) # migraphx.shape(type="float_type", lens=[64, 512])
x_16 = m.add_instruction(migraphx.op("add"), [x_15, x_10]) # migraphx.shape(type="float_type", lens=[64, 512])
x_17 = m.add_instruction(migraphx.op("add"), [x_16, x_9]) # migraphx.shape(type="float_type", lens=[64, 512])
x_18 = m.add_instruction(migraphx.op("add"), [x_17, x_8]) # migraphx.shape(type="float_type", lens=[64, 512])
x_19 = m.add_instruction(migraphx.op("dot"), [p_x, x_7]) # migraphx.shape(type="float_type", lens=[64, 512])
x_20 = m.add_instruction(migraphx.op("add"), [x_19, x_6]) # migraphx.shape(type="float_type", lens=[64, 512])
x_21 = m.add_instruction(migraphx.op("add"), [x_20, x_5]) # migraphx.shape(type="float_type", lens=[64, 512])
x_22 = m.add_instruction(migraphx.op("add"), [x_21, x_4]) # migraphx.shape(type="float_type", lens=[64, 512])
x_23 = m.add_instruction(migraphx.op("dot"), [p_x, x_3]) # migraphx.shape(type="float_type", lens=[64, 512])
x_24 = m.add_instruction(migraphx.op("add"), [x_23, x_2]) # migraphx.shape(type="float_type", lens=[64, 512])
x_25 = m.add_instruction(migraphx.op("add"), [x_24, x_1]) # migraphx.shape(type="float_type", lens=[64, 512])
x_26 = m.add_instruction(migraphx.op("add"), [x_25, x_0]) # migraphx.shape(type="float_type", lens=[64, 512])
x_27 = m.add_instruction(migraphx.op("mul"), [x_26, x_14]) # migraphx.shape(type="float_type", lens=[64, 512])
m.add_return([x_18, x_22, x_27])

Which produces this:

@0 = check_context::migraphx::gpu::context -> float_type, {}, {}
@1 = hip::hip_allocate_memory[shape=int8_type, {393216}, {1},id=main:scratch] -> int8_type, {393216}, {1}
@2 = hip::hip_copy_literal[id=main:@literal:2] -> float_type, {64, 512}, {512, 1}
@3 = hip::hip_copy_literal[id=main:@literal:4] -> float_type, {64, 512}, {512, 1}
@4 = hip::hip_copy_literal[id=main:@literal:3] -> float_type, {64, 512}, {512, 1}
@5 = hip::hip_copy_literal[id=main:@literal:1] -> float_type, {64, 1536}, {1536, 1}
@6 = hip::hip_copy_literal[id=main:@literal:0] -> float_type, {64, 1536}, {1536, 1}
@7 = load[offset=0,end=393216](@1) -> float_type, {64, 1536}, {1536, 1}
x = @param:x -> float_type, {64, 64}, {64, 1}
@9 = gpu::code_object[code_object=6400,symbol_name=mlir_dot_add,global=6144,local=256,](@5,x,@6,@7) -> float_type, {64, 1536}, {1536, 1}
main:#output_2 = @param:main:#output_2 -> float_type, {64, 512}, {512, 1}
@11 = slice[axes={1},starts={1024},ends={1536}](@9) -> float_type, {64, 512}, {1536, 1}
@12 = gpu::code_object[code_object=9168,symbol_name=mul_add_kernel,global=16384,local=1024,](@11,@2,main:#output_2) -> float_type, {64, 512}, {512, 1}
main:#output_1 = @param:main:#output_1 -> float_type, {64, 512}, {512, 1}
@14 = slice[axes={1},starts={512},ends={1024}](@9) -> float_type, {64, 512}, {1536, 1}
@15 = gpu::code_object[code_object=9160,symbol_name=add_kernel,global=16384,local=1024,](@14,@4,main:#output_1) -> float_type, {64, 512}, {512, 1}
main:#output_0 = @param:main:#output_0 -> float_type, {64, 512}, {512, 1}
@17 = slice[axes={1},starts={0},ends={512}](@9) -> float_type, {64, 512}, {1536, 1}
@18 = gpu::code_object[code_object=9160,symbol_name=add_kernel,global=16384,local=1024,](@17,@3,main:#output_0) -> float_type, {64, 512}, {512, 1}
@19 = @return(@18,@15,@12)

@aarushjain29
Copy link
Contributor

Before the horizontal fusion

@298 = multibroadcast[out_lens={1, 768, 768},out_dyn_dims={}](@151) -> float_type, {1, 768, 768}, {0, 768, 1}
@299 = dot(@297,@298) -> float_type, {1, 77, 768}, {59136, 768, 1}
@300 = multibroadcast[out_lens={1, 77, 768},out_dyn_dims={}](@271) -> float_type, {1, 77, 768}, {0, 0, 1}
@301 = add(@300,@299) -> float_type, {1, 77, 768}, {59136, 768, 1}
@302 = multibroadcast[out_lens={1, 77, 768},out_dyn_dims={}](@74) -> float_type, {1, 77, 768}, {0, 0, 0}
@303 = mul(@301,@302) -> float_type, {1, 77, 768}, {59136, 768, 1}
@304 = multibroadcast[out_lens={1, 768, 768},out_dyn_dims={}](@150) -> float_type, {1, 768, 768}, {0, 768, 1}
@305 = dot(@297,@304) -> float_type, {1, 77, 768}, {59136, 768, 1}
@306 = multibroadcast[out_lens={1, 77, 768},out_dyn_dims={}](@273) -> float_type, {1, 77, 768}, {0, 0, 1}
@307 = add(@306,@305) -> float_type, {1, 77, 768}, {59136, 768, 1}
@308 = multibroadcast[out_lens={1, 768, 768},out_dyn_dims={}](@149) -> float_type, {1, 768, 768}, {0, 768, 1}
@309 = dot(@297,@308) -> float_type, {1, 77, 768}, {59136, 768, 1}
@310 = multibroadcast[out_lens={1, 77, 768},out_dyn_dims={}](@272) -> float_type, {1, 77, 768}, {0, 0, 1}
@311 = add(@310,@309) -> float_type, {1, 77, 768}, {59136, 768, 1}
@312 = reshape[dims={1, 77, 12, 64}](@303) -> float_type, {1, 77, 12, 64}, {59136, 768, 64, 1}`

There's an add who's output is being used by 3 dot operation. The first dot operation is followed by add and then mul operation. The second and third dot operation is followed by add operation.

@lakhinderwalia
Copy link
Contributor

Here is standalone case that demonstrates the issue:
Which produces this:

@0 = check_context::migraphx::gpu::context -> float_type, {}, {}
@1 = hip::hip_allocate_memory[shape=int8_type, {393216}, {1},id=main:scratch] -> int8_type, {393216}, {1}
@2 = hip::hip_copy_literal[id=main:@literal:2] -> float_type, {64, 512}, {512, 1}
@3 = hip::hip_copy_literal[id=main:@literal:4] -> float_type, {64, 512}, {512, 1}
@4 = hip::hip_copy_literal[id=main:@literal:3] -> float_type, {64, 512}, {512, 1}
@5 = hip::hip_copy_literal[id=main:@literal:1] -> float_type, {64, 1536}, {1536, 1}
@6 = hip::hip_copy_literal[id=main:@literal:0] -> float_type, {64, 1536}, {1536, 1}
@7 = load[offset=0,end=393216](@1) -> float_type, {64, 1536}, {1536, 1}
x = @param:x -> float_type, {64, 64}, {64, 1}
@9 = gpu::code_object[code_object=6400,symbol_name=mlir_dot_add,global=6144,local=256,](@5,x,@6,@7) -> float_type, {64, 1536}, {1536, 1}
main:#output_2 = @param:main:#output_2 -> float_type, {64, 512}, {512, 1}
@11 = slice[axes={1},starts={1024},ends={1536}](@9) -> float_type, {64, 512}, {1536, 1}
@12 = gpu::code_object[code_object=9168,symbol_name=mul_add_kernel,global=16384,local=1024,](@11,@2,main:#output_2) -> float_type, {64, 512}, {512, 1}
main:#output_1 = @param:main:#output_1 -> float_type, {64, 512}, {512, 1}
@14 = slice[axes={1},starts={512},ends={1024}](@9) -> float_type, {64, 512}, {1536, 1}
@15 = gpu::code_object[code_object=9160,symbol_name=add_kernel,global=16384,local=1024,](@14,@4,main:#output_1) -> float_type, {64, 512}, {512, 1}
main:#output_0 = @param:main:#output_0 -> float_type, {64, 512}, {512, 1}
@17 = slice[axes={1},starts={0},ends={512}](@9) -> float_type, {64, 512}, {1536, 1}
@18 = gpu::code_object[code_object=9160,symbol_name=add_kernel,global=16384,local=1024,](@17,@3,main:#output_0) -> float_type, {64, 512}, {512, 1}
@19 = @return(@18,@15,@12)

@pfultz2, trying to understand the proposed optimization. The above IR is presented in a hacky sort of picture below:

  1. Should all 3 bottom add operations be now fused into the add in mlir_dot_add? Or they should just combine into a new add_kernel placed just below mlir_dot_add?
  2. The mul in mul_add_kernel will become a standalone mul_kernel, as a result of fusion of add?
  3. The now separated mul operation at the bottom-left will need a slice of the output of the newly-fused-add operation. I assume slice will be required for all three final outputs.

Thanks.

pfultz_slice_example_ir

@aarushjain29
Copy link
Contributor

Before -

@16 = hip::hip_copy_literal[id=main:@literal:78] -> half_type, {768}, {1}: 0.00109522ms, 1%
@17 = hip::hip_copy_literal[id=main:@literal:59] -> half_type, {768}, {1}: 0.00108192ms, 1%
@18 = slice[axes={2},starts={768},ends={1536}](@15) -> half_type, {24, 77, 768}, {177408, 2304, 1}: 0.00165542ms, 1%
@19 = multibroadcast[out_lens={24, 77, 768},out_dyn_dims={}](@17) -> half_type, {24, 77, 768}, {0, 0, 1}: 0.00094074ms, 1%
@20 = load[offset=18184320,end=21022848](@1) -> half_type, {24, 77, 768}, {59136, 768, 1}: 0.00076536ms, 1%
**@21 = gpu::code_object[code_object=5128,symbol_name=add_kernel,global=354816,local=1024,](@19,@18,@20) -> half_type, {24, 77, 768}, {59136, 768, 1}: 0.0211362ms, 1%**
@22 = load[offset=11354112,end=14192640](@1) -> half_type, {24, 77, 768}, {59136, 768, 1}: 0.00099472ms, 1%
@23 = multibroadcast[out_lens={24, 77, 768},out_dyn_dims={}](@16) -> half_type, {24, 77, 768}, {0, 0, 1}: 0.00182424ms, 1%
@24 = slice[axes={2},starts={0},ends={768}](@15) -> half_type, {24, 77, 768}, {177408, 2304, 1}: 0.00103286ms, 1%
**@25 = gpu::code_object[code_object=5136,symbol_name=mul_add_kernel,global=354816,local=1024,](@24,@23,@22) -> half_type, {24, 77, 768}, {59136, 768, 1}: 0.0413997ms, 1%**
@26 = load[offset=14769216,end=18184320](@1) -> half_type, {24, 12, 77, 77}, {71148, 5929, 77, 1}: 0.00105ms, 1%
@27 = gpu::code_object[code_object=6736,symbol_name=mlir_reshape_transpose_reshape_transpose_dot,global=73728,local=256,](@25,@21,@26) -> half_type, {24, 12, 77, 77}, {71148, 5929, 77, 1}: 0.0248955ms, 1%
...
@32 = load[offset=14769216,end=17607744](@1) -> half_type, {24, 77, 768}, {59136, 768, 1}
@33 = multibroadcast[out_lens={24, 77, 768},out_dyn_dims={}](@31) -> half_type, {24, 77, 768}, {0, 0, 1}
@34 = slice[axes={2},starts={1536},ends={2304}](@14) -> half_type, {24, 77, 768}, {177408, 2304, 1}
**@35 = gpu::code_object[code_object=5128,symbol_name=add_kernel,global=354816,local=1024,](@33,@34,@32) -> half_type, {24, 77, 768}, {59136, 768, 1}**

Output -

@135 = gpu::code_object[code_object=6016,symbol_name=mlir_dot_add,global=23040,local=64,](@64,@133,@56,@134) -> half_type, {1, 77, 2304}, {177408, 2304, 1}
@136 = gpu::precompile_op[op=gpu::mlir_op[op=dot],additional_args=1,ignore_modules=0,output_shape=nullopt](@64,@133,@56,@134), [mlir_main:pointwise12] -> half_type, {1, 77, 2304}, {177408, 2304, 1}
@137 = hip::allocate[shape=half_type, {12, 77, 77}, {5929, 77, 1}] -> half_type, {12, 77, 77}, {5929, 77, 1}
@138 = gpu::code_object[code_object=5248,symbol_name=mlir_slice_mul_reshape_transpose_squeeze_slice_reshape_transpose_squeeze_dot,global=9600,local=32,](@135,@137) -> half_type, {12, 77, 77}, {5929, 77, 1}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants