Skip to content

Commit

Permalink
Fix xvsrani/xvsrlni/xvssrarn/xvsrlrn
Browse files Browse the repository at this point in the history
  • Loading branch information
jiegec committed Dec 13, 2023
1 parent 754852d commit 8ad20f3
Show file tree
Hide file tree
Showing 21 changed files with 339 additions and 71 deletions.
130 changes: 99 additions & 31 deletions code/gen_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,41 +259,109 @@
with open(
f"{prefix}s{name}rn_{width}_{double_width_signed}.h", "w"
) as f:
print(f"for (int i = 0;i < {vlen // w};i++) {{", file=f)
print(f"if (i < {vlen // 2 // w}) {{", file=f)
print(f"{shift_sign}{double_w} temp;", file=f)
print(f"if ((b.{double_m}[i] & {double_w-1}) == 0) {{", file=f)
print(
f" temp = ({shift_sign}{double_w})a.{double_m}[i];",
file=f,
)
print(f"}} else {{", file=f)
print(
f" temp = (({shift_sign}{double_w})a.{double_m}[i] >> (b.{double_m}[i] & {double_w-1})) + ((({shift_sign}{double_w})a.{double_m}[i] >> ((b.{double_m}[i] & {double_w-1}) - 1)) & 1);",
file=f,
)
print(f"}}", file=f)
print(
f" dst.{m}[i] = clamp<{shift_sign}{double_w}>(temp, {min}, {max});",
file=f,
)
print(f"}} else {{", file=f)
print(
f" dst.{m}[i] = 0;",
file=f,
)
print(f"}}", file=f)
print(f"}}", file=f)
if prefix == "v":
print(f"for (int i = 0;i < {vlen // w};i++) {{", file=f)
print(f"if (i < {vlen // 2 // w}) {{", file=f)
print(f"{shift_sign}{double_w} temp;", file=f)
print(f"if ((b.{double_m}[i] & {double_w-1}) == 0) {{", file=f)
print(
f" temp = ({shift_sign}{double_w})a.{double_m}[i];",
file=f,
)
print(f"}} else {{", file=f)
print(
f" temp = (({shift_sign}{double_w})a.{double_m}[i] >> (b.{double_m}[i] & {double_w-1})) + ((({shift_sign}{double_w})a.{double_m}[i] >> ((b.{double_m}[i] & {double_w-1}) - 1)) & 1);",
file=f,
)
print(f"}}", file=f)
print(
f" dst.{m}[i] = clamp<{shift_sign}{double_w}>(temp, {min}, {max});",
file=f,
)
print(f"}} else {{", file=f)
print(
f" dst.{m}[i] = 0;",
file=f,
)
print(f"}}", file=f)
print(f"}}", file=f)
else:
print(f"for (int i = 0;i < {vlen // 2 // w};i++) {{", file=f)
print(f"if (i < {vlen // 4 // w}) {{", file=f)
print(f"{shift_sign}{double_w} temp;", file=f)
print(f"if ((b.{double_m}[i] & {double_w-1}) == 0) {{", file=f)
print(
f" temp = ({shift_sign}{double_w})a.{double_m}[i];",
file=f,
)
print(f"}} else {{", file=f)
print(
f" temp = (({shift_sign}{double_w})a.{double_m}[i] >> (b.{double_m}[i] & {double_w-1})) + ((({shift_sign}{double_w})a.{double_m}[i] >> ((b.{double_m}[i] & {double_w-1}) - 1)) & 1);",
file=f,
)
print(f"}}", file=f)
print(
f" dst.{m}[i] = clamp<{shift_sign}{double_w}>(temp, {min}, {max});",
file=f,
)
print(f"}} else {{", file=f)
print(
f" dst.{m}[i] = 0;",
file=f,
)
print(f"}}", file=f)
print(f"}}", file=f)

print(f"for (int i = {vlen // 2 // w};i < {vlen // w};i++) {{", file=f)
print(f"if (i < {3 * vlen // 4 // w}) {{", file=f)
print(f"{shift_sign}{double_w} temp;", file=f)
print(f"if ((b.{double_m}[i - {vlen // 4 // w}] & {double_w-1}) == 0) {{", file=f)
print(
f" temp = ({shift_sign}{double_w})a.{double_m}[i - {vlen // 4 // w}];",
file=f,
)
print(f"}} else {{", file=f)
print(
f" temp = (({shift_sign}{double_w})a.{double_m}[i - {vlen // 4 // w}] >> (b.{double_m}[i - {vlen // 4 // w}] & {double_w-1})) + ((({shift_sign}{double_w})a.{double_m}[i - {vlen // 4 // w}] >> ((b.{double_m}[i - {vlen // 4 // w}] & {double_w-1}) - 1)) & 1);",
file=f,
)
print(f"}}", file=f)
print(
f" dst.{m}[i] = clamp<{shift_sign}{double_w}>(temp, {min}, {max});",
file=f,
)
print(f"}} else {{", file=f)
print(
f" dst.{m}[i] = 0;",
file=f,
)
print(f"}}", file=f)
print(f"}}", file=f)

if sign == "s":
for name, sign in [("srl", "u"), ("sra", "s")]:
with open(f"{prefix}{name}ni_{width}_{double_width}.h", "w") as f:
print(f"for (int i = 0;i < {vlen // w};i++) {{", file=f)
print(
f" dst.{m}[i] = (i < {vlen // 2 // w}) ? ({sign}{w})(({sign}{double_w})b.{double_m}[i] >> imm) : ({sign}{w})(({sign}{double_w})a.{double_m}[i - {vlen // 2 // w}] >> imm);",
file=f,
)
print(f"}}", file=f)
if prefix == "v":
print(f"for (int i = 0;i < {vlen // w};i++) {{", file=f)
print(
f" dst.{m}[i] = (i < {vlen // 2 // w}) ? ({sign}{w})(({sign}{double_w})b.{double_m}[i] >> imm) : ({sign}{w})(({sign}{double_w})a.{double_m}[i - {vlen // 2 // w}] >> imm);",
file=f,
)
print(f"}}", file=f)
else:
print(f"for (int i = 0;i < {vlen // 2 // w};i++) {{", file=f)
print(
f" dst.{m}[i] = (i < {vlen // 4 // w}) ? ({sign}{w})(({sign}{double_w})b.{double_m}[i] >> imm) : ({sign}{w})(({sign}{double_w})a.{double_m}[i - {vlen // 4 // w}] >> imm);",
file=f,
)
print(f"}}", file=f)

print(f"for (int i = {vlen // 2 // w};i < {vlen // w};i++) {{", file=f)
print(
f" dst.{m}[i] = (i < {3 * vlen // 4 // w}) ? ({sign}{w})(({sign}{double_w})b.{double_m}[i - {vlen // 4 // w}] >> imm) : ({sign}{w})(({sign}{double_w})a.{double_m}[i - {vlen // 2 // w}] >> imm);",
file=f,
)
print(f"}}", file=f)
with open(f"{prefix}{name}rni_{width}_{double_width}.h", "w") as f:
if prefix == "v":
print(f"for (int i = 0;i < {vlen // w};i++) {{", file=f)
Expand Down
8 changes: 6 additions & 2 deletions code/xvsrani_b_h.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
for (int i = 0; i < 32; i++) {
for (int i = 0; i < 16; i++) {
dst.byte[i] =
(i < 16) ? (s8)((s16)b.half[i] >> imm) : (s8)((s16)a.half[i - 16] >> imm);
(i < 8) ? (s8)((s16)b.half[i] >> imm) : (s8)((s16)a.half[i - 8] >> imm);
}
for (int i = 16; i < 32; i++) {
dst.byte[i] = (i < 24) ? (s8)((s16)b.half[i - 8] >> imm)
: (s8)((s16)a.half[i - 16] >> imm);
}
8 changes: 6 additions & 2 deletions code/xvsrani_d_q.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
for (int i = 0; i < 4; i++) {
dst.dword[i] = (i < 2) ? (s64)((s128)b.qword[i] >> imm)
for (int i = 0; i < 2; i++) {
dst.dword[i] = (i < 1) ? (s64)((s128)b.qword[i] >> imm)
: (s64)((s128)a.qword[i - 1] >> imm);
}
for (int i = 2; i < 4; i++) {
dst.dword[i] = (i < 3) ? (s64)((s128)b.qword[i - 1] >> imm)
: (s64)((s128)a.qword[i - 2] >> imm);
}
8 changes: 6 additions & 2 deletions code/xvsrani_h_w.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
for (int i = 0; i < 16; i++) {
for (int i = 0; i < 8; i++) {
dst.half[i] =
(i < 8) ? (s16)((s32)b.word[i] >> imm) : (s16)((s32)a.word[i - 8] >> imm);
(i < 4) ? (s16)((s32)b.word[i] >> imm) : (s16)((s32)a.word[i - 4] >> imm);
}
for (int i = 8; i < 16; i++) {
dst.half[i] = (i < 12) ? (s16)((s32)b.word[i - 4] >> imm)
: (s16)((s32)a.word[i - 8] >> imm);
}
8 changes: 6 additions & 2 deletions code/xvsrani_w_d.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
for (int i = 0; i < 8; i++) {
dst.word[i] = (i < 4) ? (s32)((s64)b.dword[i] >> imm)
for (int i = 0; i < 4; i++) {
dst.word[i] = (i < 2) ? (s32)((s64)b.dword[i] >> imm)
: (s32)((s64)a.dword[i - 2] >> imm);
}
for (int i = 4; i < 8; i++) {
dst.word[i] = (i < 6) ? (s32)((s64)b.dword[i - 2] >> imm)
: (s32)((s64)a.dword[i - 4] >> imm);
}
8 changes: 6 additions & 2 deletions code/xvsrlni_b_h.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
for (int i = 0; i < 32; i++) {
for (int i = 0; i < 16; i++) {
dst.byte[i] =
(i < 16) ? (u8)((u16)b.half[i] >> imm) : (u8)((u16)a.half[i - 16] >> imm);
(i < 8) ? (u8)((u16)b.half[i] >> imm) : (u8)((u16)a.half[i - 8] >> imm);
}
for (int i = 16; i < 32; i++) {
dst.byte[i] = (i < 24) ? (u8)((u16)b.half[i - 8] >> imm)
: (u8)((u16)a.half[i - 16] >> imm);
}
8 changes: 6 additions & 2 deletions code/xvsrlni_d_q.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
for (int i = 0; i < 4; i++) {
dst.dword[i] = (i < 2) ? (u64)((u128)b.qword[i] >> imm)
for (int i = 0; i < 2; i++) {
dst.dword[i] = (i < 1) ? (u64)((u128)b.qword[i] >> imm)
: (u64)((u128)a.qword[i - 1] >> imm);
}
for (int i = 2; i < 4; i++) {
dst.dword[i] = (i < 3) ? (u64)((u128)b.qword[i - 1] >> imm)
: (u64)((u128)a.qword[i - 2] >> imm);
}
8 changes: 6 additions & 2 deletions code/xvsrlni_h_w.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
for (int i = 0; i < 16; i++) {
for (int i = 0; i < 8; i++) {
dst.half[i] =
(i < 8) ? (u16)((u32)b.word[i] >> imm) : (u16)((u32)a.word[i - 8] >> imm);
(i < 4) ? (u16)((u32)b.word[i] >> imm) : (u16)((u32)a.word[i - 4] >> imm);
}
for (int i = 8; i < 16; i++) {
dst.half[i] = (i < 12) ? (u16)((u32)b.word[i - 4] >> imm)
: (u16)((u32)a.word[i - 8] >> imm);
}
8 changes: 6 additions & 2 deletions code/xvsrlni_w_d.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
for (int i = 0; i < 8; i++) {
dst.word[i] = (i < 4) ? (u32)((u64)b.dword[i] >> imm)
for (int i = 0; i < 4; i++) {
dst.word[i] = (i < 2) ? (u32)((u64)b.dword[i] >> imm)
: (u32)((u64)a.dword[i - 2] >> imm);
}
for (int i = 4; i < 8; i++) {
dst.word[i] = (i < 6) ? (u32)((u64)b.dword[i - 2] >> imm)
: (u32)((u64)a.dword[i - 4] >> imm);
}
18 changes: 16 additions & 2 deletions code/xvssrarn_b_h.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
for (int i = 0; i < 32; i++) {
if (i < 16) {
for (int i = 0; i < 16; i++) {
if (i < 8) {
s16 temp;
if ((b.half[i] & 15) == 0) {
temp = (s16)a.half[i];
Expand All @@ -12,3 +12,17 @@ for (int i = 0; i < 32; i++) {
dst.byte[i] = 0;
}
}
for (int i = 16; i < 32; i++) {
if (i < 24) {
s16 temp;
if ((b.half[i - 8] & 15) == 0) {
temp = (s16)a.half[i - 8];
} else {
temp = ((s16)a.half[i - 8] >> (b.half[i - 8] & 15)) +
(((s16)a.half[i - 8] >> ((b.half[i - 8] & 15) - 1)) & 1);
}
dst.byte[i] = clamp<s16>(temp, -128, 127);
} else {
dst.byte[i] = 0;
}
}
18 changes: 16 additions & 2 deletions code/xvssrarn_bu_h.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
for (int i = 0; i < 32; i++) {
if (i < 16) {
for (int i = 0; i < 16; i++) {
if (i < 8) {
s16 temp;
if ((b.half[i] & 15) == 0) {
temp = (s16)a.half[i];
Expand All @@ -12,3 +12,17 @@ for (int i = 0; i < 32; i++) {
dst.byte[i] = 0;
}
}
for (int i = 16; i < 32; i++) {
if (i < 24) {
s16 temp;
if ((b.half[i - 8] & 15) == 0) {
temp = (s16)a.half[i - 8];
} else {
temp = ((s16)a.half[i - 8] >> (b.half[i - 8] & 15)) +
(((s16)a.half[i - 8] >> ((b.half[i - 8] & 15) - 1)) & 1);
}
dst.byte[i] = clamp<s16>(temp, 0, 255);
} else {
dst.byte[i] = 0;
}
}
18 changes: 16 additions & 2 deletions code/xvssrarn_h_w.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
for (int i = 0; i < 16; i++) {
if (i < 8) {
for (int i = 0; i < 8; i++) {
if (i < 4) {
s32 temp;
if ((b.word[i] & 31) == 0) {
temp = (s32)a.word[i];
Expand All @@ -12,3 +12,17 @@ for (int i = 0; i < 16; i++) {
dst.half[i] = 0;
}
}
for (int i = 8; i < 16; i++) {
if (i < 12) {
s32 temp;
if ((b.word[i - 4] & 31) == 0) {
temp = (s32)a.word[i - 4];
} else {
temp = ((s32)a.word[i - 4] >> (b.word[i - 4] & 31)) +
(((s32)a.word[i - 4] >> ((b.word[i - 4] & 31) - 1)) & 1);
}
dst.half[i] = clamp<s32>(temp, -32768, 32767);
} else {
dst.half[i] = 0;
}
}
18 changes: 16 additions & 2 deletions code/xvssrarn_hu_w.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
for (int i = 0; i < 16; i++) {
if (i < 8) {
for (int i = 0; i < 8; i++) {
if (i < 4) {
s32 temp;
if ((b.word[i] & 31) == 0) {
temp = (s32)a.word[i];
Expand All @@ -12,3 +12,17 @@ for (int i = 0; i < 16; i++) {
dst.half[i] = 0;
}
}
for (int i = 8; i < 16; i++) {
if (i < 12) {
s32 temp;
if ((b.word[i - 4] & 31) == 0) {
temp = (s32)a.word[i - 4];
} else {
temp = ((s32)a.word[i - 4] >> (b.word[i - 4] & 31)) +
(((s32)a.word[i - 4] >> ((b.word[i - 4] & 31) - 1)) & 1);
}
dst.half[i] = clamp<s32>(temp, 0, 65535);
} else {
dst.half[i] = 0;
}
}
18 changes: 16 additions & 2 deletions code/xvssrarn_w_d.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
for (int i = 0; i < 8; i++) {
if (i < 4) {
for (int i = 0; i < 4; i++) {
if (i < 2) {
s64 temp;
if ((b.dword[i] & 63) == 0) {
temp = (s64)a.dword[i];
Expand All @@ -12,3 +12,17 @@ for (int i = 0; i < 8; i++) {
dst.word[i] = 0;
}
}
for (int i = 4; i < 8; i++) {
if (i < 6) {
s64 temp;
if ((b.dword[i - 2] & 63) == 0) {
temp = (s64)a.dword[i - 2];
} else {
temp = ((s64)a.dword[i - 2] >> (b.dword[i - 2] & 63)) +
(((s64)a.dword[i - 2] >> ((b.dword[i - 2] & 63) - 1)) & 1);
}
dst.word[i] = clamp<s64>(temp, -2147483648, 2147483647);
} else {
dst.word[i] = 0;
}
}
18 changes: 16 additions & 2 deletions code/xvssrarn_wu_d.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
for (int i = 0; i < 8; i++) {
if (i < 4) {
for (int i = 0; i < 4; i++) {
if (i < 2) {
s64 temp;
if ((b.dword[i] & 63) == 0) {
temp = (s64)a.dword[i];
Expand All @@ -12,3 +12,17 @@ for (int i = 0; i < 8; i++) {
dst.word[i] = 0;
}
}
for (int i = 4; i < 8; i++) {
if (i < 6) {
s64 temp;
if ((b.dword[i - 2] & 63) == 0) {
temp = (s64)a.dword[i - 2];
} else {
temp = ((s64)a.dword[i - 2] >> (b.dword[i - 2] & 63)) +
(((s64)a.dword[i - 2] >> ((b.dword[i - 2] & 63) - 1)) & 1);
}
dst.word[i] = clamp<s64>(temp, 0, 4294967295);
} else {
dst.word[i] = 0;
}
}
Loading

0 comments on commit 8ad20f3

Please sign in to comment.