Skip to content

Commit c73a2c0

Browse files
committed
fix: ggml-hexagon: matmul fp16xfp32 support non-contigious src0
1 parent 2c3f20d commit c73a2c0

File tree

1 file changed

+28
-20
lines changed

1 file changed

+28
-20
lines changed

ggml/src/ggml-hexagon/htp/matmul-ops.c

Lines changed: 28 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -917,7 +917,7 @@ static void vec_dot_f16_f32(const int n, float * restrict s, const void * restri
917917

918918
// for some reason we need volatile here so that the compiler doesn't try anything funky
919919
volatile HVX_Vector rsum = Q6_V_vsplat_R(0);
920-
920+
float r_sum_scalar = 0.0f;
921921
uint32_t i = 0;
922922

923923
for (i = 0; i < nv0; i++) {
@@ -936,23 +936,32 @@ static void vec_dot_f16_f32(const int n, float * restrict s, const void * restri
936936
}
937937

938938
if (nv1) {
939-
HVX_VectorPair yp = vy[i];
939+
// HVX_VectorPair yp = vy[i];
940940

941-
HVX_Vector x = vx[i];
942-
HVX_VectorPair xp = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(x), Q6_Vh_vsplat_R(0x3C00)); // mul by 1.0
941+
// HVX_Vector x = vx[i];
942+
// HVX_VectorPair xp = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(x), Q6_Vh_vsplat_R(0x3C00)); // mul by 1.0
943943

944-
if (nv1 >= 32) {
945-
HVX_Vector hi = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(Q6_V_hi_W(xp)), Q6_V_hi_W(yp));
946-
rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, hi);
947-
nv1 -= 32;
948-
}
944+
// if (nv1 >= 32) {
945+
// volatile HVX_Vector hi = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(Q6_V_hi_W(xp)), Q6_V_hi_W(yp));
946+
// rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, hi);
947+
// nv1 -= 32;
948+
// }
949+
950+
// rsum = hvx_vec_qf32_reduce_sum(rsum);
949951

952+
// if (nv1) {
953+
// volatile HVX_Vector lo = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(Q6_V_lo_W(xp)), Q6_V_lo_W(yp));
954+
// HVX_Vector sum = hvx_vec_qf32_reduce_sum_n(lo, nv1);
955+
// rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, sum);
956+
// }
957+
958+
//process the remainder using scalar loop
950959
rsum = hvx_vec_qf32_reduce_sum(rsum);
960+
const __fp16 * restrict sx = (const __fp16 * restrict) x;
961+
const float * restrict sy = (const float * restrict) y;
951962

952-
if (nv1) {
953-
HVX_Vector lo = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(Q6_V_lo_W(xp)), Q6_V_lo_W(yp));
954-
HVX_Vector sum = hvx_vec_qf32_reduce_sum_n(lo, nv1);
955-
rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, sum);
963+
for (uint32_t i = nv0 * 64; i < n; i++) {
964+
r_sum_scalar += (float) sx[i] * sy[i];
956965
}
957966

958967
// hvx_vec_dump_fp16("X", x);
@@ -963,7 +972,7 @@ static void vec_dot_f16_f32(const int n, float * restrict s, const void * restri
963972
rsum = hvx_vec_qf32_reduce_sum(rsum);
964973
}
965974

966-
*s = hvx_vec_get_fp32(Q6_Vsf_equals_Vqf32(rsum));
975+
*s = hvx_vec_get_fp32(Q6_Vsf_equals_Vqf32(rsum)) + r_sum_scalar;
967976

968977
# ifdef HTP_DEBUG
969978
{
@@ -1500,9 +1509,6 @@ static void matmul_f16_f32(struct htp_tensor * restrict src0,
15001509
uint64_t t1, t2;
15011510
t1 = HAP_perf_get_qtimer_count();
15021511

1503-
const size_t src0_row_size = sizeof(__fp16) * ne00;
1504-
const size_t src1_row_size = sizeof(float) * ne10;
1505-
15061512
assert(ne12 % ne02 == 0);
15071513
assert(ne13 % ne03 == 0);
15081514

@@ -1561,14 +1567,16 @@ static void matmul_f16_f32(struct htp_tensor * restrict src0,
15611567
const uint32_t i2 = i12;
15621568
const uint32_t i3 = i13;
15631569

1564-
const uint8_t * restrict src0_row = (const uint8_t *) src0->data + (0 + i02 * nb02 + i03 * nb03);
1570+
const uint8_t * restrict src0_base = (const uint8_t *) src0->data + (0 + i02 * nb02 + i03 * nb03);
15651571
const uint8_t * restrict src1_col =
1566-
(const uint8_t *) src1->data + (i11 + i12 * ne11 + i13 * ne12 * ne11) * src1_row_size;
1572+
(const uint8_t *) src1->data + (i11 * nb11 + i12 * nb12 + i13 * nb13);
15671573
float * dst_col = (float *) ((uint8_t * restrict) dst->data + (i1 * nb1 + i2 * nb2 + i3 * nb3));
15681574

15691575
const uint32_t ir0_block_end = MIN(iir0 + blck_0, ir0_end);
15701576
for (uint32_t ir0 = iir0; ir0 < ir0_block_end; ir0++) {
1571-
vec_dot_f16_f32(ne00, &tmp[ir0 - iir0], src0_row + ir0 * src0_row_size, src1_col);
1577+
// Use nb01 stride for non-contiguous src0 support
1578+
const uint8_t * restrict src0_row = src0_base + ir0 * nb01;
1579+
vec_dot_f16_f32(ne00, &tmp[ir0 - iir0], src0_row, src1_col);
15721580
}
15731581

15741582
hvx_copy_fp32_ua((uint8_t *) &dst_col[iir0], (uint8_t *) tmp, MIN(iir0 + blck_0, ir0_end) - iir0);

0 commit comments

Comments
 (0)