Skip to content

Commit

Permalink
Reduced int precision for texture coordinates in q_linear op, to redu…
Browse files Browse the repository at this point in the history
…ce shader register pressure.

Differential Revision: D64191093

Pull Request resolved: pytorch#6354
  • Loading branch information
trivedivivek authored Oct 22, 2024
1 parent 5a6a571 commit 5d7bd71
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions backends/vulkan/runtime/graph/ops/glsl/q_8w_linear.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -89,13 +89,15 @@ void main() {

#else // USING_TEXTURE

#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require

VEC4_T q_8w_linear(const ivec3 out_pos, const int K) {
ivec3 mat1_pos = ivec3(0, out_pos.yz);
ivec3 qmat2_pos = ivec3(0, out_pos.x * 4, 0);
u16vec3 mat1_pos = u16vec3(0, out_pos.yz);
u16vec3 qmat2_pos = u16vec3(0, out_pos.x * 4, 0);

VEC4_T outtex = VEC4_T(0);

const ivec3 scales_pos = ivec3(out_pos.x, 0, 0);
const u16vec3 scales_pos = u16vec3(out_pos.x, 0, 0);
const VEC4_T scales = load_texel(t_scales, scales_pos);

for (int i = 0; i < K; i += 4) {
Expand All @@ -104,11 +106,11 @@ VEC4_T q_8w_linear(const ivec3 out_pos, const int K) {
const VEC4_T sums = VEC4_T(
dot(mat1_tex, load_texel(t_qmat2, qmat2_pos) * scales.x),
dot(mat1_tex,
load_texel(t_qmat2, qmat2_pos + ivec3(0, 1, 0)) * scales.y),
load_texel(t_qmat2, qmat2_pos + u16vec3(0, 1, 0)) * scales.y),
dot(mat1_tex,
load_texel(t_qmat2, qmat2_pos + ivec3(0, 2, 0)) * scales.z),
load_texel(t_qmat2, qmat2_pos + u16vec3(0, 2, 0)) * scales.z),
dot(mat1_tex,
load_texel(t_qmat2, qmat2_pos + ivec3(0, 3, 0)) * scales.w));
load_texel(t_qmat2, qmat2_pos + u16vec3(0, 3, 0)) * scales.w));

outtex += sums;

Expand Down

0 comments on commit 5d7bd71

Please sign in to comment.